214 lines
7.4 KiB
Python
214 lines
7.4 KiB
Python
import datetime
|
|
import os
|
|
import sys
|
|
import urllib
|
|
from unittest import skip
|
|
from django.contrib.staticfiles.testing import StaticLiveServerTestCase
|
|
from django.test import SimpleTestCase, TestCase, tag
|
|
from django.urls import reverse
|
|
from selenium import webdriver
|
|
from selenium.webdriver.support.ui import WebDriverWait
|
|
from selenium.webdriver.support import expected_conditions as EC
|
|
|
|
|
|
def skip_unless_tag_option():
|
|
if '--tag' in sys.argv:
|
|
return lambda func: func
|
|
else:
|
|
return skip('Skipped unless --tag option is used')
|
|
|
|
|
|
class Url(object):
|
|
def __init__(self, location, name=None, func=None, **kwargs):
|
|
self.location = location
|
|
self.name = name
|
|
self.func = func
|
|
self.redirect = kwargs.get('redirect', False)
|
|
self.status_code = kwargs.get('status_code', 200)
|
|
self.follow = kwargs.get('follow', False)
|
|
|
|
|
|
class UrlsTestCase(SimpleTestCase):
|
|
urls = ()
|
|
|
|
def test_locations(self):
|
|
for url in self.urls:
|
|
if url.location:
|
|
response = self.client.get(url.location, follow=url.follow)
|
|
|
|
if url.redirect:
|
|
self.assertRedirects(response, url.redirect)
|
|
else:
|
|
self.assertEqual(response.status_code, url.status_code,
|
|
'Getting \'{}\' is not OK'.format(url.location))
|
|
|
|
if url.func:
|
|
self.assertEqual(response.resolver_match.func.__name__,
|
|
url.func.__name__,
|
|
'Getting \'{}\' resolve to wrong function'.format(url.location))
|
|
|
|
def test_names(self):
|
|
for url in self.urls:
|
|
if url.name:
|
|
response = self.client.get(reverse(url.name), follow=url.follow)
|
|
|
|
if url.redirect:
|
|
self.assertRedirects(response, url.redirect)
|
|
else:
|
|
self.assertEqual(response.status_code, url.status_code,
|
|
'Getting url named \'{}\' is not OK'.format(url.name))
|
|
|
|
if url.func:
|
|
self.assertEqual(response.resolver_match.func.__name__,
|
|
url.func.__name__,
|
|
'Getting url named \'{}\' resolve to wrong function'.format(url.name))
|
|
|
|
|
|
class FormDataSet(object):
|
|
def __init__(self, data, expected_errors=None, form_kwargs=None):
|
|
self.data = data
|
|
self.expected_errors = expected_errors
|
|
self.form_kwargs = form_kwargs
|
|
|
|
|
|
class FormsTestCase(TestCase):
|
|
form_class = None
|
|
valid_data_sets = ()
|
|
invalid_data_sets = ()
|
|
|
|
def test_valid_data(self, form_class=None, data_sets=None, form_kwargs=None):
|
|
if form_class is None:
|
|
form_class = self.form_class
|
|
if form_class is None:
|
|
return True
|
|
|
|
if data_sets is None:
|
|
data_sets = self.valid_data_sets
|
|
|
|
for data_set in data_sets:
|
|
fk = {}
|
|
if form_kwargs is not None:
|
|
fk.update(form_kwargs)
|
|
if data_set.form_kwargs is not None:
|
|
fk.update(data_set.form_kwargs)
|
|
fk['data'] = data_set.data
|
|
form = form_class(**fk)
|
|
if not form.is_valid():
|
|
errors = []
|
|
for key in form.errors.as_data():
|
|
for ve in form.errors[key].as_data():
|
|
errors.append(u'%s (%s)' % (ve.code, ve.message))
|
|
self.fail(u'Invalid form data \'%s\': %s' % (data_set.data, errors))
|
|
|
|
def test_invalid_data(self, form_class=None, data_sets=None, form_kwargs=None):
|
|
if form_class is None:
|
|
form_class = self.form_class
|
|
if form_class is None:
|
|
return True
|
|
|
|
if data_sets is None:
|
|
data_sets = self.invalid_data_sets
|
|
|
|
for data_set in data_sets:
|
|
fk = {}
|
|
if form_kwargs is not None:
|
|
fk.update(form_kwargs)
|
|
if data_set.form_kwargs is not None:
|
|
fk.update(data_set.form_kwargs)
|
|
fk['data'] = data_set.data
|
|
|
|
form = form_class(**fk)
|
|
|
|
if form.is_valid():
|
|
self.fail('Valid form data: \'%s\'' % data_set.data)
|
|
if data_set.expected_errors:
|
|
error_dicts = form.errors.as_data()
|
|
for key, code in data_set.expected_errors:
|
|
error_codes = [ve.code for ve in error_dicts[key]]
|
|
self.assertIn(code, error_codes)
|
|
|
|
|
|
class SeleniumTestCase(StaticLiveServerTestCase):
|
|
def setUp(self):
|
|
super(SeleniumTestCase, self).setUp()
|
|
self.selenium = webdriver.Firefox()
|
|
if not hasattr(self, 'quit_selenium'):
|
|
self.quit_selenium = True
|
|
|
|
def tearDown(self):
|
|
if hasattr(self, 'quit_selenium') and self.quit_selenium:
|
|
self.selenium.quit()
|
|
super(SeleniumTestCase, self).tearDown()
|
|
|
|
def complete_url(self, location):
|
|
base_url = self.live_server_url
|
|
return '{}/{}'.format(base_url, location.lstrip('/'))
|
|
|
|
def wait_until_stale(self, driver, element, timeout=30):
|
|
return WebDriverWait(driver, timeout).until(EC.staleness_of(element))
|
|
|
|
def wait_on_presence(self, driver, locator, timeout=30):
|
|
return WebDriverWait(driver, timeout).until(EC.presence_of_element_located(locator))
|
|
|
|
|
|
class ScreenshotTestCase(SeleniumTestCase):
|
|
screenshot_prefix = ''
|
|
window_width = 1024
|
|
window_height = 768
|
|
locations = ()
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
super(ScreenshotTestCase, cls).setUpClass()
|
|
|
|
# screenshot_base_dir = os.path.join('/', 'tmp', 'test-screenshots')
|
|
screenshot_base_dir = 'test-screenshots'
|
|
cls.screenshot_path = screenshot_base_dir
|
|
cls.screenshot_sequences = {}
|
|
|
|
def setUp(self):
|
|
super(ScreenshotTestCase, self).setUp()
|
|
if self.window_width and self.window_height:
|
|
self.selenium.set_window_size(self.window_width, self.window_height)
|
|
|
|
def sanitize_filename(self, location):
|
|
location = location.lstrip('/')
|
|
if location == '':
|
|
return 'root'
|
|
r = urllib.quote(location, '')
|
|
return r
|
|
|
|
def save_screenshot(self, name=None, sequence=None):
|
|
if name is not None:
|
|
pass
|
|
elif sequence is not None:
|
|
if sequence in self.screenshot_sequences:
|
|
self.screenshot_sequences[sequence] += 1
|
|
else:
|
|
self.screenshot_sequences[sequence] = 1
|
|
n = self.screenshot_sequences[sequence]
|
|
name = '%s-%04d' % (sequence, n)
|
|
else:
|
|
location = self.selenium.current_url
|
|
if location.startswith(self.live_server_url):
|
|
location = location[len(self.live_server_url):]
|
|
name = location
|
|
|
|
now = datetime.datetime.now()
|
|
base_name = '{timestamp}-{prefix}{name}.png'.format(
|
|
prefix=self.screenshot_prefix,
|
|
name=self.sanitize_filename(name),
|
|
timestamp=now.strftime('%Y%m%d-%H%M%S')
|
|
)
|
|
path = os.path.join(self.screenshot_path, base_name)
|
|
if not os.path.isdir(self.screenshot_path):
|
|
os.makedirs(self.screenshot_path, 0700)
|
|
self.selenium.save_screenshot(path)
|
|
|
|
@skip_unless_tag_option()
|
|
@tag('screenshots')
|
|
def test_screenshots(self):
|
|
for location in self.locations:
|
|
self.selenium.get(self.complete_url(location))
|
|
self.save_screenshot()
|