import datetime import os from urllib.parse import quote from django.apps import apps from django.contrib.auth.models import AbstractUser from django.contrib.staticfiles.testing import StaticLiveServerTestCase from django.core import mail as django_mail from django.core.exceptions import ValidationError from django.test import SimpleTestCase, TestCase, tag from django.urls import reverse from selenium import webdriver from selenium.webdriver.support.ui import WebDriverWait from selenium.webdriver.support import expected_conditions as ExpectedConditions class AppSetting: # pylint: disable=too-few-public-methods def __init__(self, name, of=None): self.name = name self.of = of class AppsTestCase(SimpleTestCase): app_config = None settings = () def setUp(self): super().setUp() if self.app_config: self.configured_settings = self.app_config.settings else: self.configured_settings = None def test_settings(self): config = self.configured_settings for setting in self.settings: name = setting.name self.assertTrue(hasattr(config, name), 'Settings do not contain {}'.format(name)) value = getattr(config, name) of = setting.of if of is not None: self.assertIsInstance(value, of) class EmailTestMixin: email_sender = 'Automatic Software Test ' email_base_url = 'http://localhost' email_subject_prefix = '[Test]' def get_mail_for_user(self, user): recipient = '"{fullname}" <{email}>'.format(fullname=user.get_full_name(), email=user.email) mails = [] for mail in django_mail.outbox: if recipient in mail.recipients(): mails.append(mail) return mails def assertSender(self, mail): # pylint: disable=invalid-name self.assertEqual(mail.from_email, self.email_sender) def assertReplyTo(self, mail, addresses): # pylint: disable=invalid-name self.assertEqual(len(mail.reply_to), len(addresses)) for expected_address in addresses: if isinstance(expected_address, AbstractUser): expected_address = '"%s" <%s>' % (expected_address.get_full_name(), expected_address.email) self.assertIn(expected_address, mail.reply_to) def assertRecipients(self, mail, recipients): # pylint: disable=invalid-name self.assertEqual(len(mail.recipients()), len(recipients)) for expected_recipient in recipients: if isinstance(expected_recipient, AbstractUser): expected_recipient = '"%s" <%s>' % (expected_recipient.get_full_name(), expected_recipient.email) recipients = mail.recipients() self.assertIn(expected_recipient, recipients) def assertSubject(self, mail, subject): # pylint: disable=invalid-name expected_subject = '{} {}'.format(self.email_subject_prefix, subject) self.assertEqual(mail.subject, expected_subject) def assertBody(self, mail, body): # pylint: disable=invalid-name expected_lines = body.splitlines() lines = mail.body.splitlines() i = 0 for expected_line in expected_lines: try: line = lines[i] except IndexError: self.fail('line %d: no such line: %s' % (i, expected_line)) i += 1 try: self.assertEqual(line, expected_line) except AssertionError as e: self.fail('line %d: %s' % (i, e)) self.assertEqual(mail.body, body) def setUp(self): # pylint: disable=invalid-name app_config = apps.get_app_config('dav_base') app_config.settings.email_sender = self.email_sender app_config.settings.email_base_url = self.email_base_url app_config.settings.email_subject_prefix = self.email_subject_prefix class FormDataSet: # pylint: disable=too-few-public-methods def __init__(self, data, expected_errors=None, form_kwargs=None): self.data = data self.expected_errors = expected_errors self.form_kwargs = form_kwargs class FormsTestCase(TestCase): form_class = None valid_data_sets = () invalid_data_sets = () def test_valid_data(self, form_class=None, data_sets=None, form_kwargs=None): if form_class is None: form_class = self.form_class if form_class is None: return if data_sets is None: data_sets = self.valid_data_sets given_form_kwargs = form_kwargs for data_set in data_sets: form_kwargs = {} if given_form_kwargs is not None: form_kwargs.update(given_form_kwargs) if data_set.form_kwargs is not None: form_kwargs.update(data_set.form_kwargs) form_kwargs['data'] = data_set.data form = form_class(**form_kwargs) if not form.is_valid(): errors = [] for key in form.errors.as_data(): for e in form.errors[key].as_data(): errors.append('%s (%s)' % (e.code, e.message)) self.fail('Invalid form data \'%s\': %s' % (data_set.data, errors)) def test_invalid_data(self, form_class=None, data_sets=None, form_kwargs=None): if form_class is None: form_class = self.form_class if form_class is None: return if data_sets is None: data_sets = self.invalid_data_sets given_form_kwargs = form_kwargs for data_set in data_sets: form_kwargs = {} if given_form_kwargs is not None: form_kwargs.update(given_form_kwargs) if data_set.form_kwargs is not None: form_kwargs.update(data_set.form_kwargs) form_kwargs['data'] = data_set.data form = form_class(**form_kwargs) if form.is_valid(): self.fail('Valid form data: \'%s\'' % data_set.data) if data_set.expected_errors: error_dicts = form.errors.as_data() for key, code in data_set.expected_errors: error_codes = [ve.code for ve in error_dicts[key]] self.assertIn(code, error_codes) class Url: # pylint: disable=too-few-public-methods def __init__(self, location, name=None, func=None, **kwargs): self.location = location self.name = name self.func = func self.http_method = kwargs.get('http_method', "GET") self.post_data = kwargs.get('post_data', {}) self.redirect = kwargs.get('redirect', False) self.status_code = kwargs.get('status_code', 200) self.follow = kwargs.get('follow', False) class UrlsTestCase(TestCase): urls = () def test_locations(self): for url in self.urls: if url.location: if url.http_method == "GET": response = self.client.get(url.location, follow=url.follow) elif url.http_method == "POST": response = self.client.post(url.location, data=url.post_data, follow=url.follow) else: # pragma: no cover raise NotImplementedError("Method {} is not supported".format(url.http_method)) if url.redirect: self.assertRedirects(response, url.redirect) else: self.assertEqual(response.status_code, url.status_code, 'Getting \'{}\' is not OK'.format(url.location)) if url.func: self.assertEqual(response.resolver_match.func.__name__, url.func.__name__, 'Getting \'{}\' resolve to wrong function'.format(url.location)) def test_names(self): for url in self.urls: if url.name: location = reverse(url.name) if url.http_method == "GET": response = self.client.get(location, follow=url.follow) elif url.http_method == "POST": response = self.client.post(location, data=url.post_data, follow=url.follow) else: # pragma: no cover raise NotImplementedError("Method {} is not supported".format(url.http_method)) if url.redirect: self.assertRedirects(response, url.redirect) else: self.assertEqual(response.status_code, url.status_code, 'Getting url named \'{}\' is not OK'.format(url.name)) if url.func: self.assertEqual(response.resolver_match.func.__name__, url.func.__name__, 'Getting url named \'{}\' resolve to wrong function'.format(url.name)) class ValidatorTestMixin: def assertValid(self, validator, data): # pylint: disable=invalid-name for val in data: try: validator(val) except ValidationError as e: # pragma: no cover self.fail('%s: %s' % (val, e)) def assertInvalid(self, validator, data): # pylint: disable=invalid-name for val in data: try: validator(val) except ValidationError: pass else: # pragma: no cover self.fail('%s: no ValidationError was raised' % val) class SeleniumTestCase(StaticLiveServerTestCase): headless = True window_width = 1024 window_height = 768 def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._driver = None self._driver_options = webdriver.FirefoxOptions() self.quit_selenium = None @property def selenium(self): if self._driver is None: if self.headless: self._driver_options.add_argument('--headless') self._driver = webdriver.Firefox(options=self._driver_options) if self.quit_selenium is None: self.quit_selenium = True if self.window_width and self.window_height: self._driver.set_window_size(self.window_width, self.window_height) return self._driver def tearDown(self): if self.quit_selenium: self.selenium.quit() super().tearDown() def complete_url(self, location): base_url = self.live_server_url return '{}/{}'.format(base_url, location.lstrip('/')) def get(self, location): return self.selenium.get(self.complete_url(location)) def wait_on(self, driver, ec_name, ec_argument, timeout=30): expected_condition = getattr(ExpectedConditions, ec_name) return WebDriverWait(driver, timeout).until(expected_condition(ec_argument)) def wait_on_presence(self, driver, locator, timeout=30): ec_name = 'presence_of_element_located' return self.wait_on(driver, ec_name, locator, timeout) def wait_until_stale(self, driver, element, timeout=30): ec_name = 'staleness_of' return self.wait_on(driver, ec_name, element, timeout) class ScreenshotTestCase(SeleniumTestCase): screenshot_prefix = '' locations = () def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) screenshot_base_dir = os.path.join('tmp', 'test-screenshots') self.screenshot_path = screenshot_base_dir self.screenshot_sequences = {} def sanitize_filename(self, location): return quote(location).replace('/', '--') def save_screenshot(self, title=None, sequence=None, resize=True): if sequence is None: sequence = '' else: if sequence in self.screenshot_sequences: self.screenshot_sequences[sequence] += 1 else: self.screenshot_sequences[sequence] = 1 n = self.screenshot_sequences[sequence] sequence = '%s-%04d-' % (sequence, n) if title is None: location = self.selenium.current_url if location.startswith(self.live_server_url): location = location[len(self.live_server_url):] location = location.lstrip('/') if location == '': location = 'root' title = location base_name = '{timestamp}-{prefix}{sequence}{title}.png'.format( timestamp=datetime.datetime.now().strftime('%Y%m%d-%H%M%S.%f'), prefix=self.screenshot_prefix, sequence=sequence, title=title, ) path = os.path.join(self.screenshot_path, self.sanitize_filename(base_name)) if not os.path.isdir(self.screenshot_path): os.makedirs(self.screenshot_path, 0o700) restore_size = False if resize: window_size = self.selenium.get_window_size() deco_height = self.selenium.execute_script('return window.outerHeight - window.innerHeight') doc_height = self.selenium.execute_script('return document.body.scrollHeight') if (window_size['height'] - deco_height) < doc_height: self.selenium.set_window_size(window_size['width'], doc_height + deco_height) restore_size = True self.selenium.save_screenshot(path) if restore_size: self.selenium.set_window_size(window_size['width'], window_size['height']) @tag('screenshots', 'browser') def test_screenshots(self): for location in self.locations: self.get(location) self.save_screenshot()