diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..f7320f1 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,53 @@ +exclude: "^docs/|/migrations/" +default_stages: [ commit ] +fail_fast: true + +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + exclude: ".svg$|.min.*|.idea*" + - id: check-yaml + - id: check-added-large-files + + - repo: https://github.com/asottile/pyupgrade + rev: v3.15.0 + hooks: + - id: pyupgrade + args: [ --py39-plus ] + + - repo: https://github.com/psf/black + rev: 23.12.1 + hooks: + - id: black + + - repo: https://github.com/PyCQA/isort + rev: 5.13.2 + hooks: + - id: isort + args: [ "--profile", "black", "-l=88" ] + + - repo: https://github.com/PyCQA/flake8 + rev: 7.0.0 + hooks: + - id: flake8 + args: [ "--config=setup.cfg" ] + additional_dependencies: + - flake8-isort + - flake8-bugbear + - flake8-print + + - repo: https://github.com/adamchainz/django-upgrade + rev: 1.15.0 + hooks: + - id: django-upgrade + args: [ --target-version, "4.2" ] + + +# sets up .pre-commit-ci.yaml to ensure pre-commit dependencies stay up to date +ci: + autoupdate_schedule: weekly + skip: [ ] + submodules: false diff --git a/allauth_cas/__init__.py b/allauth_cas/__init__.py index 04e5073..4562d67 100644 --- a/allauth_cas/__init__.py +++ b/allauth_cas/__init__.py @@ -1,7 +1,3 @@ -# -*- coding: utf-8 -*- +__version__ = "1.0.1" -__version__ = '1.0.0' - -default_app_config = 'allauth_cas.apps.CASAccountConfig' - -CAS_PROVIDER_SESSION_KEY = 'allauth_cas__provider_id' +CAS_PROVIDER_SESSION_KEY = "allauth_cas__provider_id" diff --git a/allauth_cas/apps.py b/allauth_cas/apps.py index 2ee7f0c..55d0dbb 100644 --- a/allauth_cas/apps.py +++ b/allauth_cas/apps.py @@ -1,10 +1,9 @@ -# -*- coding: utf-8 -*- from django.apps import AppConfig -from django.utils.translation import ugettext_lazy as _ +from django.utils.translation import gettext_lazy as _ class CASAccountConfig(AppConfig): - name = 'allauth_cas' + name = "allauth_cas" verbose_name = _("CAS Accounts") def ready(self): diff --git a/allauth_cas/exceptions.py b/allauth_cas/exceptions.py index 2ed507e..d096400 100644 --- a/allauth_cas/exceptions.py +++ b/allauth_cas/exceptions.py @@ -1,6 +1,3 @@ -# -*- coding: utf-8 -*- - - class CASAuthenticationError(Exception): """ Base exception to signal CAS authentication failure. diff --git a/allauth_cas/providers.py b/allauth_cas/providers.py index a46b4f1..3acf831 100644 --- a/allauth_cas/providers.py +++ b/allauth_cas/providers.py @@ -1,26 +1,18 @@ -# -*- coding: utf-8 -*- -from six.moves.urllib.parse import parse_qsl +from urllib.parse import parse_qsl -import django +from allauth.socialaccount.providers.base import Provider from django.contrib import messages from django.template.loader import render_to_string +from django.urls import reverse from django.utils.http import urlencode from django.utils.safestring import mark_safe -from allauth.socialaccount.providers.base import Provider - -if django.VERSION >= (1, 10): - from django.urls import reverse -else: - from django.core.urlresolvers import reverse - class CASProvider(Provider): - def get_auth_params(self, request, action): settings = self.get_settings() - ret = dict(settings.get('AUTH_PARAMS', {})) - dynamic_auth_params = request.GET.get('auth_params') + ret = dict(settings.get("AUTH_PARAMS", {})) + dynamic_auth_params = request.GET.get("auth_params") if dynamic_auth_params: ret.update(dict(parse_qsl(dynamic_auth_params))) return ret @@ -68,11 +60,11 @@ class CASProvider(Provider): """ uid, extra = data return { - 'username': extra.get('username', uid), - 'email': extra.get('email'), - 'first_name': extra.get('first_name'), - 'last_name': extra.get('last_name'), - 'name': extra.get('name'), + "username": extra.get("username", uid), + "email": extra.get("email"), + "first_name": extra.get("first_name"), + "last_name": extra.get("last_name"), + "name": extra.get("name"), } def extract_email_addresses(self, data): @@ -99,7 +91,7 @@ class CASProvider(Provider): ] """ - return super(CASProvider, self).extract_email_addresses(data) + return super().extract_email_addresses(data) def extract_extra_data(self, data): """Extract the data to save to `SocialAccount.extra_data`. @@ -119,7 +111,10 @@ class CASProvider(Provider): ## def add_message_suggest_caslogout( - self, request, next_page=None, level=None, + self, + request, + next_page=None, + level=None, ): """Add a message with a link for the user to logout of the CAS server. @@ -144,14 +139,15 @@ class CASProvider(Provider): # DefaultAccountAdapter.add_message is unusable because it always # escape the message content. - template = 'socialaccount/messages/suggest_caslogout.html' + template = "socialaccount/messages/suggest_caslogout.html" context = { - 'provider': self, - 'logout_url': logout_url, + "provider": self, + "logout_url": logout_url, } messages.add_message( - request, level, + request, + level, mark_safe(render_to_string(template, context).strip()), fail_silently=True, ) @@ -168,10 +164,7 @@ class CASProvider(Provider): signal ``user_logged_out``. """ - return ( - self.get_settings() - .get('MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT', False) - ) + return self.get_settings().get("MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT", False) def message_suggest_caslogout_on_logout_level(self, request): """Level of the logout message issued on user logout. @@ -185,9 +178,8 @@ class CASProvider(Provider): signal ``user_logged_out``. """ - return ( - self.get_settings() - .get('MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT_LEVEL', messages.INFO) + return self.get_settings().get( + "MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT_LEVEL", messages.INFO ) ## @@ -195,19 +187,19 @@ class CASProvider(Provider): ## def get_login_url(self, request, **kwargs): - url = reverse(self.id + '_login') + url = reverse(self.id + "_login") if kwargs: - url += '?' + urlencode(kwargs) + url += "?" + urlencode(kwargs) return url def get_callback_url(self, request, **kwargs): - url = reverse(self.id + '_callback') + url = reverse(self.id + "_callback") if kwargs: - url += '?' + urlencode(kwargs) + url += "?" + urlencode(kwargs) return url def get_logout_url(self, request, **kwargs): - url = reverse(self.id + '_logout') + url = reverse(self.id + "_logout") if kwargs: - url += '?' + urlencode(kwargs) + url += "?" + urlencode(kwargs) return url diff --git a/allauth_cas/signals.py b/allauth_cas/signals.py index 927cbfd..36c9b24 100644 --- a/allauth_cas/signals.py +++ b/allauth_cas/signals.py @@ -1,10 +1,8 @@ -# -*- coding: utf-8 -*- -from django.contrib.auth.signals import user_logged_out -from django.dispatch import receiver - from allauth.account.adapter import get_adapter from allauth.account.utils import get_next_redirect_url from allauth.socialaccount import providers +from django.contrib.auth.signals import user_logged_out +from django.dispatch import receiver from . import CAS_PROVIDER_SESSION_KEY @@ -21,12 +19,12 @@ def cas_account_logout(sender, request, **kwargs): if not provider.message_suggest_caslogout_on_logout(request): return - next_page = ( - get_next_redirect_url(request) or - get_adapter(request).get_logout_redirect_url(request) - ) + next_page = get_next_redirect_url(request) or get_adapter( + request + ).get_logout_redirect_url(request) provider.add_message_suggest_caslogout( - request, next_page=next_page, + request, + next_page=next_page, level=provider.message_suggest_caslogout_on_logout_level(request), ) diff --git a/allauth_cas/test/testcases.py b/allauth_cas/test/testcases.py index d1e93e9..7ec8c07 100644 --- a/allauth_cas/test/testcases.py +++ b/allauth_cas/test/testcases.py @@ -1,29 +1,20 @@ -# -*- coding: utf-8 -*- try: from unittest.mock import patch except ImportError: - from mock import patch - -import django -from django.conf import settings -from django.test import TestCase + from unittest.mock import patch import cas +from django.conf import settings +from django.test import TestCase +from django.urls import reverse from allauth_cas import CAS_PROVIDER_SESSION_KEY -if django.VERSION >= (1, 10): - from django.urls import reverse -else: - from django.core.urlresolvers import reverse - class CASTestCase(TestCase): - def client_cas_login( - self, - client, provider_id='theid', - username=None, attributes={}): + self, client, provider_id=None, username=None, attributes=None + ): """ Authenticate client through provider_id. @@ -32,20 +23,22 @@ class CASTestCase(TestCase): username and attributes control the CAS server response when ticket is checked. """ - client.get(reverse('{id}_login'.format(id=provider_id))) + if attributes is None: + attributes = {} + if provider_id is None: + provider_id = "theid" + client.get(reverse(f"{provider_id}_login")) self.patch_cas_response( - valid_ticket='__all__', - username=username, attributes=attributes, + valid_ticket="__all__", + username=username, + attributes=attributes, ) - callback_url = reverse('{id}_callback'.format(id=provider_id)) - r = client.get(callback_url, {'ticket': 'fake-ticket'}) + callback_url = reverse(f"{provider_id}_callback") + r = client.get(callback_url, {"ticket": "fake-ticket"}) self.patch_cas_response_stop() return r - def patch_cas_response( - self, - valid_ticket, - username=None, attributes={}): + def patch_cas_response(self, valid_ticket, username=None, attributes=None): """ Patch the CASClient class used by views of CAS providers. @@ -68,35 +61,38 @@ class CASTestCase(TestCase): Note that valid_ticket sould be a string (which is the type of the ticket retrieved from GET parameter on request on the callback view). """ - if hasattr(self, '_patch_cas_client'): + if attributes is None: + attributes = {} + if hasattr(self, "_patch_cas_client"): self.patch_cas_response_stop() - class MockCASClient(object): + class MockCASClient: _username = username def __new__(self_client, *args, **kwargs): - version = kwargs.pop('version') - if version in (1, '1'): + version = kwargs.pop("version") + if version in (1, "1"): client_class = cas.CASClientV1 - elif version in (2, '2'): + elif version in (2, "2"): client_class = cas.CASClientV2 - elif version in (3, '3'): + elif version in (3, "3"): client_class = cas.CASClientV3 - elif version == 'CAS_2_SAML_1_0': + elif version == "CAS_2_SAML_1_0": client_class = cas.CASClientWithSAMLV1 else: - raise ValueError('Unsupported CAS_VERSION %r' % version) + raise ValueError("Unsupported CAS_VERSION %r" % version) client_class._username = self_client._username def verify_ticket(self, ticket): - if valid_ticket == '__all__' or ticket == valid_ticket: - username = self._username or 'username' + if valid_ticket == "__all__" or ticket == valid_ticket: + username = self._username or "username" return username, attributes, None return None, {}, None patcher = patch.object( - client_class, 'verify_ticket', + client_class, + "verify_ticket", new=verify_ticket, ) patcher.start() @@ -104,7 +100,7 @@ class CASTestCase(TestCase): return client_class(*args, **kwargs) self._patch_cas_client = patch( - 'allauth_cas.views.cas.CASClient', + "allauth_cas.views.cas.CASClient", MockCASClient, ) self._patch_cas_client.start() @@ -114,12 +110,11 @@ class CASTestCase(TestCase): del self._patch_cas_client def tearDown(self): - if hasattr(self, '_patch_cas_client'): + if hasattr(self, "_patch_cas_client"): self.patch_cas_response_stop() class CASViewTestCase(CASTestCase): - def assertLoginSuccess(self, response, redirect_to=None): """ Asserts response corresponds to a successful login. @@ -133,7 +128,8 @@ class CASViewTestCase(CASTestCase): redirect_to = settings.LOGIN_REDIRECT_URL self.assertRedirects( - response, redirect_to, + response, + redirect_to, fetch_redirect_response=False, ) self.assertIn( @@ -146,6 +142,6 @@ class CASViewTestCase(CASTestCase): Asserts response corresponds to a failed login. """ return self.assertInHTML( - '

Social Network Login Failure

', + "

Social Network Login Failure

", str(response.content), ) diff --git a/allauth_cas/urls.py b/allauth_cas/urls.py index 687dca6..292bce7 100644 --- a/allauth_cas/urls.py +++ b/allauth_cas/urls.py @@ -1,5 +1,4 @@ -# -*- coding: utf-8 -*- -from django.conf.urls import include, url +from django.urls import include, path, re_path from django.utils.module_loading import import_string @@ -7,45 +6,44 @@ def default_urlpatterns(provider): package = provider.get_package() try: - login_view = import_string(package + '.views.login') + login_view = import_string(package + ".views.login") except ImportError: raise ImportError( "The login view for the '{id}' provider is lacking from the " "'views' module of its app.\n" "You may want to add:\n" "from allauth_cas.views import CASLoginView\n\n" - "login = CASLoginView.adapter_view()" - .format(id=provider.id) + "login = CASLoginView.adapter_view()".format( + id=provider.id + ) ) try: - callback_view = import_string(package + '.views.callback') + callback_view = import_string(package + ".views.callback") except ImportError: raise ImportError( "The callback view for the '{id}' provider is lacking from the " "'views' module of its app.\n" "You may want to add:\n" "from allauth_cas.views import CASCallbackView\n\n" - "callback = CASCallbackView.adapter_view()" - .format(id=provider.id) + "callback = CASCallbackView.adapter_view()".format( + id=provider.id + ) ) try: - logout_view = import_string(package + '.views.logout') + logout_view = import_string(package + ".views.logout") except ImportError: logout_view = None urlpatterns = [ - url('^login/$', login_view, - name=provider.id + '_login'), - url('^login/callback/$', callback_view, - name=provider.id + '_callback'), + path("login/", login_view, name=provider.id + "_login"), + path("login/callback/", callback_view, name=provider.id + "_callback"), ] if logout_view is not None: urlpatterns += [ - url('^logout/$', logout_view, - name=provider.id + '_logout'), + path("logout/", logout_view, name=provider.id + "_logout"), ] - return [url('^' + provider.get_slug() + '/', include(urlpatterns))] + return [re_path("^" + provider.get_slug() + "/", include(urlpatterns))] diff --git a/allauth_cas/views.py b/allauth_cas/views.py index b27b3c7..d08e354 100644 --- a/allauth_cas/views.py +++ b/allauth_cas/views.py @@ -1,28 +1,26 @@ -# -*- coding: utf-8 -*- -from django.http import HttpResponseRedirect -from django.utils.functional import cached_property - +import cas from allauth.account.adapter import get_adapter from allauth.account.utils import get_next_redirect_url from allauth.socialaccount import providers from allauth.socialaccount.helpers import ( - complete_social_login, render_authentication_error, + complete_social_login, + render_authentication_error, ) from allauth.socialaccount.models import SocialLogin - -import cas +from django.http import HttpResponseRedirect +from django.utils.functional import cached_property from . import CAS_PROVIDER_SESSION_KEY from .exceptions import CASAuthenticationError -class AuthAction(object): - AUTHENTICATE = 'authenticate' - REAUTHENTICATE = 'reauthenticate' - DEAUTHENTICATE = 'deauthenticate' +class AuthAction: + AUTHENTICATE = "authenticate" + REAUTHENTICATE = "reauthenticate" + DEAUTHENTICATE = "deauthenticate" -class CASAdapter(object): +class CASAdapter: #: CAS server url. url = None #: CAS server version. @@ -92,19 +90,19 @@ class CASAdapter(object): """ redirect_to = get_next_redirect_url(request) - callback_kwargs = {'next': redirect_to} if redirect_to else {} - callback_url = ( - self.provider.get_callback_url(request, **callback_kwargs)) + callback_kwargs = {"next": redirect_to} if redirect_to else {} + callback_url = self.provider.get_callback_url(request, **callback_kwargs) service_url = request.build_absolute_uri(callback_url) return service_url -class CASView(object): +class CASView: """ Base class for CAS views. """ + @classmethod def adapter_view(cls, adapter): """Transform the view class into a view function. @@ -124,6 +122,7 @@ class CASView(object): """ + def view(request, *args, **kwargs): # Prepare the func-view. self = cls() @@ -169,19 +168,17 @@ class CASView(object): class CASLoginView(CASView): - def dispatch(self, request): """ Redirects to the CAS server login page. """ - action = request.GET.get('action', AuthAction.AUTHENTICATE) + action = request.GET.get("action", AuthAction.AUTHENTICATE) SocialLogin.stash_state(request) client = self.get_client(request, action=action) return HttpResponseRedirect(client.get_login_url()) class CASCallbackView(CASView): - def dispatch(self, request): """ The CAS server redirects the user to this view after a successful @@ -195,11 +192,9 @@ class CASCallbackView(CASView): # CAS server should let a ticket. try: - ticket = request.GET['ticket'] + ticket = request.GET["ticket"] except KeyError: - raise CASAuthenticationError( - "CAS server didn't respond with a ticket." - ) + raise CASAuthenticationError("CAS server didn't respond with a ticket.") # Check ticket validity. # Response format on: @@ -210,9 +205,7 @@ class CASCallbackView(CASView): uid, extra, _ = response if not uid: - raise CASAuthenticationError( - "CAS server doesn't validate the ticket." - ) + raise CASAuthenticationError("CAS server doesn't validate the ticket.") # Keep tracks of the last used CAS provider. request.session[CAS_PROVIDER_SESSION_KEY] = self.provider.id @@ -226,7 +219,6 @@ class CASCallbackView(CASView): class CASLogoutView(CASView): - def dispatch(self, request, next_page=None): """ Redirects to the CAS server logout page. @@ -248,7 +240,6 @@ class CASLogoutView(CASView): Returns the url to redirect after logout. """ request = self.request - return ( - get_next_redirect_url(request) or - get_adapter(request).get_logout_redirect_url(request) - ) + return get_next_redirect_url(request) or get_adapter( + request + ).get_logout_redirect_url(request) diff --git a/runtests.py b/runtests.py index 5cea84f..a4e0970 100644 --- a/runtests.py +++ b/runtests.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- import os import sys @@ -7,10 +6,10 @@ import django from django.conf import settings from django.test.utils import get_runner -if __name__ == '__main__': - os.environ['DJANGO_SETTINGS_MODULE'] = 'tests.settings' +if __name__ == "__main__": + os.environ["DJANGO_SETTINGS_MODULE"] = "tests.settings" django.setup() TestRunner = get_runner(settings) test_runner = TestRunner() - failures = test_runner.run_tests(sys.argv[1:] or ['tests']) + failures = test_runner.run_tests(sys.argv[1:] or ["tests"]) sys.exit(bool(failures)) diff --git a/setup.cfg b/setup.cfg index 7e7ca4d..d4b1006 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,18 +1,45 @@ [flake8] +max-line-length = 120 +exclude = .tox,.git,*/migrations/*,*/static/CACHE/*,docs,node_modules,venv # E731: lambda expression ignore = E731 +[pycodestyle] +max-line-length = 120 +exclude = .tox,.git,*/migrations/*,*/static/CACHE/*,docs,node_modules,venv + [isort] -combine_as_imports = True +line_length = 88 +known_first_party = geoip,config +multi_line_output = 3 default_section = THIRDPARTY -include_trailing_comma = True -known_allauth = allauth -known_future_library = future,six -known_django = django -known_first_party = allauth_cas -multi_line_output = 5 -not_skip = __init__.py -sections = FUTURE,STDLIB,DJANGO,ALLAUTH,THIRDPARTY,FIRSTPARTY,LOCALFOLDER +skip = venv/ +skip_glob = **/migrations/*.py +include_trailing_comma = true +force_grid_wrap = 0 +use_parentheses = true + +[mypy] +python_version = 3.9 +check_untyped_defs = True +ignore_missing_imports = True +warn_unused_ignores = True +warn_redundant_casts = True +warn_unused_configs = True +plugins = mypy_django_plugin.main + +[mypy.plugins.django-stubs] +django_settings_module = config.settings.test + +[mypy-*.migrations.*] +# Django migrations should not produce any errors: +ignore_errors = True + +[coverage:run] +include = geoip/* +omit = *migrations*, *tests* +plugins = + django_coverage_plugin [bdist_wheel] universal = 1 diff --git a/setup.py b/setup.py index b256fac..fb06ec0 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import os from setuptools import find_packages, setup @@ -7,49 +6,56 @@ from allauth_cas import __version__ BASE_DIR = os.path.dirname(__file__) -with open(os.path.join(BASE_DIR, 'README.rst')) as readme: +with open(os.path.join(BASE_DIR, "README.rst")) as readme: README = readme.read() setup( - name='django-allauth-cas', + name="django-allauth-cas", version=__version__, - description='CAS support for django-allauth.', - author='Aurélien Delobelle', - author_email='aurelien.delobelle@gmail.com', - keywords='django allauth cas authentication', + description="CAS support for django-allauth.", + author="Aurélien Delobelle", + author_email="aurelien.delobelle@gmail.com", + keywords="django allauth cas authentication", long_description=README, - url='https://github.com/aureplop/django-allauth-cas', + url="https://github.com/aureplop/django-allauth-cas", classifiers=[ - 'Development Status :: 4 - Beta', - 'Environment :: Web Environment', - 'Framework :: Django', - 'Framework :: Django :: 1.8', - 'Framework :: Django :: 1.9', - 'Framework :: Django :: 1.10', - 'Framework :: Django :: 1.11', - 'Framework :: Django :: 2.0', - 'Intended Audience :: Developers', - 'License :: OSI Approved :: MIT License', - 'Operating System :: OS Independent', - 'Programming Language :: Python', - 'Programming Language :: Python :: 2', - 'Programming Language :: Python :: 2.7', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.4', - 'Programming Language :: Python :: 3.5', - 'Programming Language :: Python :: 3.6', - 'Topic :: Internet :: WWW/HTTP', + "Development Status :: 4 - Beta", + "Environment :: Web Environment", + "Framework :: Django", + "Framework :: Django :: 1.8", + "Framework :: Django :: 1.9", + "Framework :: Django :: 1.10", + "Framework :: Django :: 1.11", + "Framework :: Django :: 2.0", + "Framework :: Django :: 3.0", + "Framework :: Django :: 4.0", + "Framework :: Django :: 5.0", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + "Programming Language :: Python", + "Programming Language :: Python :: 2", + "Programming Language :: Python :: 2.7", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.4", + "Programming Language :: Python :: 3.5", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.11", + "Topic :: Internet :: WWW/HTTP", ], - license='MIT', - packages=find_packages(exclude=['tests']), + license="MIT", + packages=find_packages(exclude=["tests"]), include_package_data=True, install_requires=[ - 'django-allauth', - 'python-cas', - 'six', + "django-allauth", + "python-cas", + "six", ], extras_require={ - 'docs': ['sphinx'], - 'tests': ['tox'], + "docs": ["sphinx"], + "tests": ["tox"], }, ) diff --git a/tests/example/provider.py b/tests/example/provider.py index c393b70..72d1c45 100644 --- a/tests/example/provider.py +++ b/tests/example/provider.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- from allauth.socialaccount.providers.base import ProviderAccount from allauth_cas.providers import CASProvider @@ -9,8 +8,8 @@ class ExampleCASAccount(ProviderAccount): class ExampleCASProvider(CASProvider): - id = 'theid' - name = 'The Provider' + id = "theid" + name = "The Provider" account_class = ExampleCASAccount diff --git a/tests/example/urls.py b/tests/example/urls.py index 335c441..795e835 100644 --- a/tests/example/urls.py +++ b/tests/example/urls.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- from allauth_cas.urls import default_urlpatterns from .provider import ExampleCASProvider diff --git a/tests/example/views.py b/tests/example/views.py index 6aa916a..1f04fb2 100644 --- a/tests/example/views.py +++ b/tests/example/views.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- from allauth_cas import views from .provider import ExampleCASProvider @@ -6,7 +5,7 @@ from .provider import ExampleCASProvider class ExampleCASAdapter(views.CASAdapter): provider_id = ExampleCASProvider.id - url = 'https://server.cas' + url = "https://server.cas" version = 2 diff --git a/tests/settings.py b/tests/settings.py index 5acb51d..505c325 100644 --- a/tests/settings.py +++ b/tests/settings.py @@ -1,63 +1,54 @@ -# -*- coding: utf-8 -*- -import django - -SECRET_KEY = 'iamabird' +SECRET_KEY = "iamabird" INSTALLED_APPS = [ - 'django.contrib.admin', - 'django.contrib.auth', - 'django.contrib.contenttypes', - 'django.contrib.messages', - 'django.contrib.sessions', - 'django.contrib.sites', - 'django.contrib.staticfiles', - - 'allauth', - 'allauth.account', - 'allauth.socialaccount', - - 'allauth_cas', - - 'tests.example', # Dummy CAS provider app + "django.contrib.admin", + "django.contrib.auth", + "django.contrib.contenttypes", + "django.contrib.messages", + "django.contrib.sessions", + "django.contrib.sites", + "django.contrib.staticfiles", + "allauth", + "allauth.account", + "allauth.socialaccount", + "allauth_cas", + "tests.example", # Dummy CAS provider app ] DATABASES = { - 'default': { - 'ENGINE': 'django.db.backends.sqlite3', + "default": { + "ENGINE": "django.db.backends.sqlite3", }, } AUTHENTICATION_BACKENDS = [ - 'allauth.account.auth_backends.AuthenticationBackend', + "allauth.account.auth_backends.AuthenticationBackend", ] _MIDDLEWARES = [ - 'django.middleware.common.CommonMiddleware', - 'django.contrib.sessions.middleware.SessionMiddleware', - 'django.middleware.csrf.CsrfViewMiddleware', - 'django.contrib.auth.middleware.AuthenticationMiddleware', - 'django.contrib.messages.middleware.MessageMiddleware', + "django.middleware.common.CommonMiddleware", + "django.contrib.sessions.middleware.SessionMiddleware", + "django.middleware.csrf.CsrfViewMiddleware", + "django.contrib.auth.middleware.AuthenticationMiddleware", + "django.contrib.messages.middleware.MessageMiddleware", ] -if django.VERSION >= (1, 10): - MIDDLEWARE = _MIDDLEWARES -else: - MIDDLEWARE_CLASSES = _MIDDLEWARES +MIDDLEWARE = _MIDDLEWARES TEMPLATES = [ { - 'BACKEND': 'django.template.backends.django.DjangoTemplates', - 'DIRS': [], - 'APP_DIRS': True, - 'OPTIONS': { - 'context_processors': [ - 'django.template.context_processors.debug', - 'django.template.context_processors.request', - 'django.contrib.auth.context_processors.auth', - 'django.contrib.messages.context_processors.messages', + "BACKEND": "django.template.backends.django.DjangoTemplates", + "DIRS": [], + "APP_DIRS": True, + "OPTIONS": { + "context_processors": [ + "django.template.context_processors.debug", + "django.template.context_processors.request", + "django.contrib.auth.context_processors.auth", + "django.contrib.messages.context_processors.messages", ], }, }, ] -ROOT_URLCONF = 'tests.urls' +ROOT_URLCONF = "tests.urls" diff --git a/tests/test_flows.py b/tests/test_flows.py index fb9fdbd..00d2b63 100644 --- a/tests/test_flows.py +++ b/tests/test_flows.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- from django.contrib import messages from django.contrib.auth import get_user_model from django.contrib.messages.api import get_messages @@ -13,7 +12,7 @@ User = get_user_model() class LogoutFlowTests(CASTestCase): expected_msg_str = ( "To logout of The Provider, please close your browser, or visit this " - "" + '' "link." ) @@ -28,42 +27,45 @@ class LogoutFlowTests(CASTestCase): ) self.assertTemplateNotUsed( response, - 'socialaccount/messages/suggest_caslogout.html', + "socialaccount/messages/suggest_caslogout.html", ) - @override_settings(SOCIALACCOUNT_PROVIDERS={ - 'theid': { - 'MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT': True, - 'MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT_LEVEL': messages.WARNING, - }, - }) + @override_settings( + SOCIALACCOUNT_PROVIDERS={ + "theid": { + "MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT": True, + "MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT_LEVEL": messages.WARNING, + }, + } + ) def test_message_on_logout(self): """ Message is sent to propose user to logout of CAS. """ - r = self.client.post('/accounts/logout/?next=/redir/') + r = self.client.post("/accounts/logout/?next=/redir/") r_messages = get_messages(r.wsgi_request) expected_msg = Message(messages.WARNING, self.expected_msg_str) self.assertIn(expected_msg, r_messages) - self.assertTemplateUsed( - r, 'socialaccount/messages/suggest_caslogout.html') + self.assertTemplateUsed(r, "socialaccount/messages/suggest_caslogout.html") def test_message_on_logout_disabled(self): - r = self.client.post('/accounts/logout/') + r = self.client.post("/accounts/logout/") self.assertCASLogoutNotInMessages(r) - @override_settings(SOCIALACCOUNT_PROVIDERS={ - 'theid': {'MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT': True}, - }) + @override_settings( + SOCIALACCOUNT_PROVIDERS={ + "theid": {"MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT": True}, + } + ) def test_other_logout(self): """ The CAS logout message doesn't appear with other login methods. """ - User.objects.create_user('user', '', 'user') + User.objects.create_user("user", "", "user") client = Client() - client.login(username='user', password='user') + client.login(username="user", password="user") - r = client.post('/accounts/logout/') + r = client.post("/accounts/logout/") self.assertCASLogoutNotInMessages(r) diff --git a/tests/test_providers.py b/tests/test_providers.py index ef68bb7..508f5b3 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -1,6 +1,6 @@ -# -*- coding: utf-8 -*- -from six.moves.urllib.parse import urlencode +from urllib.parse import urlencode +from allauth.socialaccount.providers import registry from django.contrib import messages from django.contrib.messages.api import get_messages from django.contrib.messages.middleware import MessageMiddleware @@ -8,21 +8,18 @@ from django.contrib.messages.storage.base import Message from django.contrib.sessions.middleware import SessionMiddleware from django.test import RequestFactory, TestCase, override_settings -from allauth.socialaccount.providers import registry - from allauth_cas.views import AuthAction from .example.provider import ExampleCASProvider class CASProviderTests(TestCase): - def setUp(self): self.request = self._get_request() self.provider = ExampleCASProvider(self.request) def _get_request(self): - request = RequestFactory().get('/test/') + request = RequestFactory().get("/test/") SessionMiddleware().process_request(request) MessageMiddleware().process_request(request) return request @@ -31,74 +28,81 @@ class CASProviderTests(TestCase): """ Example CAS provider is registered as social account provider. """ - self.assertIsInstance(registry.by_id('theid'), ExampleCASProvider) + self.assertIsInstance(registry.by_id("theid"), ExampleCASProvider) def test_get_login_url(self): url = self.provider.get_login_url(self.request) - self.assertEqual('/accounts/theid/login/', url) + self.assertEqual("/accounts/theid/login/", url) url_with_qs = self.provider.get_login_url( self.request, - next='/path?quéry=string&two=whoam%C3%AF', + next="/path?quéry=string&two=whoam%C3%AF", ) self.assertEqual( url_with_qs, - '/accounts/theid/login/?next=%2Fpath%3Fqu%C3%A9ry%3Dstring%26two%3' - 'Dwhoam%25C3%25AF' + "/accounts/theid/login/?next=%2Fpath%3Fqu%C3%A9ry%3Dstring%26two%3" + "Dwhoam%25C3%25AF", ) def test_get_callback_url(self): url = self.provider.get_callback_url(self.request) - self.assertEqual('/accounts/theid/login/callback/', url) + self.assertEqual("/accounts/theid/login/callback/", url) url_with_qs = self.provider.get_callback_url( self.request, - next='/path?quéry=string&two=whoam%C3%AF', + next="/path?quéry=string&two=whoam%C3%AF", ) self.assertEqual( url_with_qs, - '/accounts/theid/login/callback/?next=%2Fpath%3Fqu%C3%A9ry%3Dstrin' - 'g%26two%3Dwhoam%25C3%25AF' + "/accounts/theid/login/callback/?next=%2Fpath%3Fqu%C3%A9ry%3Dstrin" + "g%26two%3Dwhoam%25C3%25AF", ) def test_get_logout_url(self): url = self.provider.get_logout_url(self.request) - self.assertEqual('/accounts/theid/logout/', url) + self.assertEqual("/accounts/theid/logout/", url) url_with_qs = self.provider.get_logout_url( self.request, - next='/path?quéry=string&two=whoam%C3%AF', + next="/path?quéry=string&two=whoam%C3%AF", ) self.assertEqual( url_with_qs, - '/accounts/theid/logout/?next=%2Fpath%3Fqu%C3%A9ry%3Dstring%26two%' - '3Dwhoam%25C3%25AF' + "/accounts/theid/logout/?next=%2Fpath%3Fqu%C3%A9ry%3Dstring%26two%" + "3Dwhoam%25C3%25AF", ) - @override_settings(SOCIALACCOUNT_PROVIDERS={ - 'theid': { - 'AUTH_PARAMS': {'key': 'value'}, - }, - }) + @override_settings( + SOCIALACCOUNT_PROVIDERS={ + "theid": { + "AUTH_PARAMS": {"key": "value"}, + }, + } + ) def test_get_auth_params(self): action = AuthAction.AUTHENTICATE auth_params = self.provider.get_auth_params(self.request, action) - self.assertDictEqual(auth_params, { - 'key': 'value', - }) + self.assertDictEqual( + auth_params, + { + "key": "value", + }, + ) - @override_settings(SOCIALACCOUNT_PROVIDERS={ - 'theid': { - 'AUTH_PARAMS': {'key': 'value'}, - }, - }) + @override_settings( + SOCIALACCOUNT_PROVIDERS={ + "theid": { + "AUTH_PARAMS": {"key": "value"}, + }, + } + ) def test_get_auth_params_with_dynamic(self): factory = RequestFactory() request = factory.get( - '/test/?auth_params=next%3Dtwo%253Dwhoam%2525C3%2525AF%2526qu%2525' - 'C3%2525A9ry%253Dstring' + "/test/?auth_params=next%3Dtwo%253Dwhoam%2525C3%2525AF%2526qu%2525" + "C3%2525A9ry%253Dstring" ) request.session = {} @@ -106,15 +110,18 @@ class CASProviderTests(TestCase): auth_params = self.provider.get_auth_params(request, action) - self.assertDictEqual(auth_params, { - 'key': 'value', - 'next': 'two=whoam%C3%AF&qu%C3%A9ry=string', - }) + self.assertDictEqual( + auth_params, + { + "key": "value", + "next": "two=whoam%C3%AF&qu%C3%A9ry=string", + }, + ) def test_add_message_suggest_caslogout(self): expected_msg_base_str = ( "To logout of The Provider, please close your browser, or visit " - "this link." + 'this link.' ) # Defaults. @@ -124,7 +131,7 @@ class CASProviderTests(TestCase): expected_msg1 = Message( messages.INFO, - expected_msg_base_str.format(urlencode({'next': '/test/'})), + expected_msg_base_str.format(urlencode({"next": "/test/"})), ) self.assertIn(expected_msg1, get_messages(req1)) @@ -132,69 +139,83 @@ class CASProviderTests(TestCase): req2 = self._get_request() self.provider.add_message_suggest_caslogout( - req2, next_page='/redir/', level=messages.WARNING) + req2, next_page="/redir/", level=messages.WARNING + ) expected_msg2 = Message( messages.WARNING, - expected_msg_base_str.format(urlencode({'next': '/redir/'})), + expected_msg_base_str.format(urlencode({"next": "/redir/"})), ) self.assertIn(expected_msg2, get_messages(req2)) def test_message_suggest_caslogout_on_logout(self): self.assertFalse( - self.provider.message_suggest_caslogout_on_logout(self.request)) + self.provider.message_suggest_caslogout_on_logout(self.request) + ) - with override_settings(SOCIALACCOUNT_PROVIDERS={ - 'theid': {'MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT': True}, - }): + with override_settings( + SOCIALACCOUNT_PROVIDERS={ + "theid": {"MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT": True}, + } + ): self.assertTrue( - self.provider - .message_suggest_caslogout_on_logout(self.request) + self.provider.message_suggest_caslogout_on_logout(self.request) ) - @override_settings(SOCIALACCOUNT_PROVIDERS={ - 'theid': { - 'MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT_LEVEL': messages.WARNING, - }, - }) + @override_settings( + SOCIALACCOUNT_PROVIDERS={ + "theid": { + "MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT_LEVEL": messages.WARNING, + }, + } + ) def test_message_suggest_caslogout_on_logout_level(self): - self.assertEqual(messages.WARNING, ( - self.provider - .message_suggest_caslogout_on_logout_level(self.request) - )) + self.assertEqual( + messages.WARNING, + (self.provider.message_suggest_caslogout_on_logout_level(self.request)), + ) def test_extract_uid(self): - response = 'useRName', {} + response = "useRName", {} uid = self.provider.extract_uid(response) - self.assertEqual('useRName', uid) + self.assertEqual("useRName", uid) def test_extract_common_fields(self): - response = 'useRName', {} + response = "useRName", {} common_fields = self.provider.extract_common_fields(response) - self.assertDictEqual(common_fields, { - 'username': 'useRName', - 'first_name': None, - 'last_name': None, - 'name': None, - 'email': None, - }) + self.assertDictEqual( + common_fields, + { + "username": "useRName", + "first_name": None, + "last_name": None, + "name": None, + "email": None, + }, + ) def test_extract_common_fields_with_extra(self): - response = 'useRName', {'username': 'user', 'email': 'user@mail.net'} + response = "useRName", {"username": "user", "email": "user@mail.net"} common_fields = self.provider.extract_common_fields(response) - self.assertDictEqual(common_fields, { - 'username': 'user', - 'first_name': None, - 'last_name': None, - 'name': None, - 'email': 'user@mail.net', - }) + self.assertDictEqual( + common_fields, + { + "username": "user", + "first_name": None, + "last_name": None, + "name": None, + "email": "user@mail.net", + }, + ) def test_extract_extra_data(self): - response = 'useRName', {'user_attr': 'thevalue', 'another': 'value'} + response = "useRName", {"user_attr": "thevalue", "another": "value"} extra_data = self.provider.extract_extra_data(response) - self.assertDictEqual(extra_data, { - 'user_attr': 'thevalue', - 'another': 'value', - 'uid': 'useRName', - }) + self.assertDictEqual( + extra_data, + { + "user_attr": "thevalue", + "another": "value", + "uid": "useRName", + }, + ) diff --git a/tests/test_testcases.py b/tests/test_testcases.py index c7dacab..98f1673 100644 --- a/tests/test_testcases.py +++ b/tests/test_testcases.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- from django.test import Client, RequestFactory from allauth_cas.test.testcases import CASViewTestCase @@ -8,9 +7,8 @@ from .example.views import ExampleCASAdapter class CASTestCaseTests(CASViewTestCase): - def setUp(self): - self.client.get('/accounts/theid/login/') + self.client.get("/accounts/theid/login/") def test_patch_cas_response_client_version(self): """ @@ -21,20 +19,24 @@ class CASTestCaseTests(CASViewTestCase): """ valid_versions = [ - 1, '1', - 2, '2', - 3, '3', - 'CAS_2_SAML_1_0', + 1, + "1", + 2, + "2", + 3, + "3", + "CAS_2_SAML_1_0", ] invalid_versions = [ - 'not_supported', + "not_supported", ] factory = RequestFactory() - request = factory.get('/path/') + request = factory.get("/path/") request.session = {} for _version in valid_versions + invalid_versions: + class BasicCASAdapter(ExampleCASAdapter): version = _version @@ -47,7 +49,7 @@ class CASTestCaseTests(CASViewTestCase): if _version in valid_versions: raw_client = view(request) - self.patch_cas_response(valid_ticket='__all__') + self.patch_cas_response(valid_ticket="__all__") mocked_client = view(request) self.assertEqual(type(raw_client), type(mocked_client)) @@ -55,58 +57,79 @@ class CASTestCaseTests(CASViewTestCase): # This is a sanity check. self.assertRaises(ValueError, view, request) - self.patch_cas_response(valid_ticket='__all__') + self.patch_cas_response(valid_ticket="__all__") self.assertRaises(ValueError, view, request) def test_patch_cas_response_verify_success(self): - self.patch_cas_response(valid_ticket='123456') - r = self.client.get('/accounts/theid/login/callback/', { - 'ticket': '123456', - }) + self.patch_cas_response(valid_ticket="123456") + r = self.client.get( + "/accounts/theid/login/callback/", + { + "ticket": "123456", + }, + ) self.assertLoginSuccess(r) def test_patch_cas_response_verify_failure(self): - self.patch_cas_response(valid_ticket='123456') - r = self.client.get('/accounts/theid/login/callback/', { - 'ticket': '000000', - }) + self.patch_cas_response(valid_ticket="123456") + r = self.client.get( + "/accounts/theid/login/callback/", + { + "ticket": "000000", + }, + ) self.assertLoginFailure(r) def test_patch_cas_response_accept(self): - self.patch_cas_response(valid_ticket='__all__') - r = self.client.get('/accounts/theid/login/callback/', { - 'ticket': '000000', - }) + self.patch_cas_response(valid_ticket="__all__") + r = self.client.get( + "/accounts/theid/login/callback/", + { + "ticket": "000000", + }, + ) self.assertLoginSuccess(r) def test_patch_cas_response_reject(self): self.patch_cas_response(valid_ticket=None) - r = self.client.get('/accounts/theid/login/callback/', { - 'ticket': '000000', - }) + r = self.client.get( + "/accounts/theid/login/callback/", + { + "ticket": "000000", + }, + ) self.assertLoginFailure(r) def test_patch_cas_reponse_multiple(self): - self.patch_cas_response(valid_ticket='__all__') + self.patch_cas_response(valid_ticket="__all__") client_0 = Client() - client_0.get('/accounts/theid/login/') - r_0 = client_0.get('/accounts/theid/login/callback/', { - 'ticket': '000000', - }) + client_0.get("/accounts/theid/login/") + r_0 = client_0.get( + "/accounts/theid/login/callback/", + { + "ticket": "000000", + }, + ) self.assertLoginSuccess(r_0) self.patch_cas_response(valid_ticket=None) client_1 = Client() - client_1.get('/accounts/theid/login/') - r_1 = client_1.get('/accounts/theid/login/callback/', { - 'ticket': '111111', - }) + client_1.get("/accounts/theid/login/") + r_1 = client_1.get( + "/accounts/theid/login/callback/", + { + "ticket": "111111", + }, + ) self.assertLoginFailure(r_1) def test_assertLoginSuccess(self): - self.patch_cas_response(valid_ticket='__all__') - r = self.client.get('/accounts/theid/login/callback/', { - 'ticket': '000000', - 'next': '/path/', - }) - self.assertLoginSuccess(r, redirect_to='/path/') + self.patch_cas_response(valid_ticket="__all__") + r = self.client.get( + "/accounts/theid/login/callback/", + { + "ticket": "000000", + "next": "/path/", + }, + ) + self.assertLoginSuccess(r, redirect_to="/path/") diff --git a/tests/test_views.py b/tests/test_views.py index e7b04e1..7130645 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -1,11 +1,10 @@ -# -*- coding: utf-8 -*- try: from unittest.mock import patch except ImportError: - from mock import patch + from unittest.mock import patch -import django from django.test import RequestFactory, override_settings +from django.urls import reverse from allauth_cas.exceptions import CASAuthenticationError from allauth_cas.test.testcases import CASTestCase, CASViewTestCase @@ -13,17 +12,11 @@ from allauth_cas.views import CASView from .example.views import ExampleCASAdapter -if django.VERSION >= (1, 10): - from django.urls import reverse -else: - from django.core.urlresolvers import reverse - class CASAdapterTests(CASTestCase): - def setUp(self): factory = RequestFactory() - self.request = factory.get('/path/') + self.request = factory.get("/path/") self.request.session = {} self.adapter = ExampleCASAdapter(self.request) @@ -31,7 +24,7 @@ class CASAdapterTests(CASTestCase): """ Service url (used by CAS client) is the callback url. """ - expected = 'http://testserver/accounts/theid/login/callback/' + expected = "http://testserver/accounts/theid/login/callback/" service_url = self.adapter.get_service_url(self.request) self.assertEqual(expected, service_url) @@ -39,11 +32,9 @@ class CASAdapterTests(CASTestCase): """ Current GET paramater next is appended on service url. """ - expected = ( - 'http://testserver/accounts/theid/login/callback/?next=%2Fnext%2F' - ) + expected = "http://testserver/accounts/theid/login/callback/?next=%2Fnext%2F" factory = RequestFactory() - request = factory.get('/path/', {'next': '/next/'}) + request = factory.get("/path/", {"next": "/next/"}) adapter = ExampleCASAdapter(request) service_url = adapter.get_service_url(request) self.assertEqual(expected, service_url) @@ -67,14 +58,13 @@ class CASAdapterTests(CASTestCase): class CASViewTests(CASViewTestCase): - class BasicCASView(CASView): def dispatch(self, request, *args, **kwargs): return self def setUp(self): factory = RequestFactory() - self.request = factory.get('/path/') + self.request = factory.get("/path/") self.request.session = {} self.cas_view = self.BasicCASView.adapter_view(ExampleCASAdapter) @@ -85,27 +75,34 @@ class CASViewTests(CASViewTestCase): """ view = self.cas_view( self.request, - 'arg1', 'arg2', - kwarg1='kwarg1', kwarg2='kwarg2', + "arg1", + "arg2", + kwarg1="kwarg1", + kwarg2="kwarg2", ) self.assertIsInstance(view, CASView) self.assertEqual(view.request, self.request) - self.assertTupleEqual(view.args, ('arg1', 'arg2')) - self.assertDictEqual(view.kwargs, { - 'kwarg1': 'kwarg1', - 'kwarg2': 'kwarg2', - }) + self.assertTupleEqual(view.args, ("arg1", "arg2")) + self.assertDictEqual( + view.kwargs, + { + "kwarg1": "kwarg1", + "kwarg2": "kwarg2", + }, + ) self.assertIsInstance(view.adapter, ExampleCASAdapter) - @patch('allauth_cas.views.cas.CASClient') - @override_settings(SOCIALACCOUNT_PROVIDERS={ - 'theid': { - 'AUTH_PARAMS': {'key': 'value'}, - }, - }) + @patch("allauth_cas.views.cas.CASClient") + @override_settings( + SOCIALACCOUNT_PROVIDERS={ + "theid": { + "AUTH_PARAMS": {"key": "value"}, + }, + } + ) def test_get_client(self, mock_casclient_class): """ get_client returns a CAS client, configured from settings. @@ -114,11 +111,11 @@ class CASViewTests(CASViewTestCase): view.get_client(self.request) mock_casclient_class.assert_called_once_with( - service_url='http://testserver/accounts/theid/login/callback/', - server_url='https://server.cas', + service_url="http://testserver/accounts/theid/login/callback/", + server_url="https://server.cas", version=2, renew=False, - extra_login_params={'key': 'value'}, + extra_login_params={"key": "value"}, ) def test_render_error_on_failure(self): @@ -126,33 +123,33 @@ class CASViewTests(CASViewTestCase): A common login failure page is rendered if CASAuthenticationError is raised by dispatch. """ + def dispatch_raise(self, request): raise CASAuthenticationError("failure") - with patch.object(self.BasicCASView, 'dispatch', dispatch_raise): + with patch.object(self.BasicCASView, "dispatch", dispatch_raise): resp = self.cas_view(self.request) self.assertLoginFailure(resp) class CASLoginViewTests(CASViewTestCase): - def test_reverse(self): """ Login view name is "{provider_id}_login". """ - url = reverse('theid_login') - self.assertEqual('/accounts/theid/login/', url) + url = reverse("theid_login") + self.assertEqual("/accounts/theid/login/", url) def test_execute(self): """ Login view redirects to the CAS server login url. Service is the callback url, as absolute uri. """ - r = self.client.get('/accounts/theid/login/') + r = self.client.get("/accounts/theid/login/") expected = ( - 'https://server.cas/login?service=http%3A%2F%2Ftestserver%2F' - 'accounts%2Ftheid%2Flogin%2Fcallback%2F' + "https://server.cas/login?service=http%3A%2F%2Ftestserver%2F" + "accounts%2Ftheid%2Flogin%2Fcallback%2F" ) self.assertRedirects(r, expected, fetch_redirect_response=False) @@ -161,54 +158,59 @@ class CASLoginViewTests(CASViewTestCase): """ Current GET parameter 'next' is kept on service url. """ - r = self.client.get('/accounts/theid/login/?next=/path/') + r = self.client.get("/accounts/theid/login/?next=/path/") expected = ( - 'https://server.cas/login?service=http%3A%2F%2Ftestserver%2F' - 'accounts%2Ftheid%2Flogin%2Fcallback%2F%3Fnext%3D%252Fpath%252F' + "https://server.cas/login?service=http%3A%2F%2Ftestserver%2F" + "accounts%2Ftheid%2Flogin%2Fcallback%2F%3Fnext%3D%252Fpath%252F" ) self.assertRedirects(r, expected, fetch_redirect_response=False) class CASCallbackViewTests(CASViewTestCase): - def setUp(self): - self.client.get('/accounts/theid/login/') + self.client.get("/accounts/theid/login/") def test_reverse(self): """ Callback view name is "{provider_id}_callback". """ - url = reverse('theid_callback') - self.assertEqual('/accounts/theid/login/callback/', url) + url = reverse("theid_callback") + self.assertEqual("/accounts/theid/login/callback/", url) def test_ticket_valid(self): """ If ticket is valid, the user is logged in. """ - self.patch_cas_response(username='username', valid_ticket='123456') - r = self.client.get('/accounts/theid/login/callback/', { - 'ticket': '123456', - }) + self.patch_cas_response(username="username", valid_ticket="123456") + r = self.client.get( + "/accounts/theid/login/callback/", + { + "ticket": "123456", + }, + ) self.assertLoginSuccess(r) def test_ticket_invalid(self): """ Login failure page is returned if the ticket is invalid. """ - self.patch_cas_response(username='username', valid_ticket='123456') - r = self.client.get('/accounts/theid/login/callback/', { - 'ticket': '000000', - }) + self.patch_cas_response(username="username", valid_ticket="123456") + r = self.client.get( + "/accounts/theid/login/callback/", + { + "ticket": "000000", + }, + ) self.assertLoginFailure(r) def test_ticket_missing(self): """ Login failure page is returned if request lacks a ticket. """ - self.patch_cas_response(username='username', valid_ticket='123456') - r = self.client.get('/accounts/theid/login/callback/') + self.patch_cas_response(username="username", valid_ticket="123456") + r = self.client.get("/accounts/theid/login/callback/") self.assertLoginFailure(r) def test_attributes_is_none(self): @@ -216,31 +218,33 @@ class CASCallbackViewTests(CASViewTestCase): Without extra attributes, CASClientV2 of python-cas returns None. """ self.patch_cas_response( - username='username', valid_ticket='123456', attributes=None + username="username", valid_ticket="123456", attributes=None + ) + r = self.client.get( + "/accounts/theid/login/callback/", + { + "ticket": "123456", + }, ) - r = self.client.get('/accounts/theid/login/callback/', { - 'ticket': '123456', - }) self.assertLoginSuccess(r) class CASLogoutViewTests(CASViewTestCase): - def test_reverse(self): """ Callback view name is "{provider_id}_logout". """ - url = reverse('theid_logout') - self.assertEqual('/accounts/theid/logout/', url) + url = reverse("theid_logout") + self.assertEqual("/accounts/theid/logout/", url) def test_execute(self): """ Logout view redirects to the CAS server logout url. Service is a url to here, as absolute uri. """ - r = self.client.get('/accounts/theid/logout/') + r = self.client.get("/accounts/theid/logout/") - expected = 'https://server.cas/logout?url=http%3A%2F%2Ftestserver%2F' + expected = "https://server.cas/logout?url=http%3A%2F%2Ftestserver%2F" self.assertRedirects(r, expected, fetch_redirect_response=False) @@ -248,10 +252,8 @@ class CASLogoutViewTests(CASViewTestCase): """ GET parameter 'next' is set as service url. """ - r = self.client.get('/accounts/theid/logout/?next=/path/') + r = self.client.get("/accounts/theid/logout/?next=/path/") - expected = ( - 'https://server.cas/logout?url=http%3A%2F%2Ftestserver%2Fpath%2F' - ) + expected = "https://server.cas/logout?url=http%3A%2F%2Ftestserver%2Fpath%2F" self.assertRedirects(r, expected, fetch_redirect_response=False) diff --git a/tests/urls.py b/tests/urls.py index 1933289..e33212c 100644 --- a/tests/urls.py +++ b/tests/urls.py @@ -1,6 +1,5 @@ -# -*- coding: utf-8 -*- -from django.conf.urls import include, url +from django.urls import include, path urlpatterns = [ - url(r'^accounts/', include('allauth.urls')), + path("accounts/", include("allauth.urls")), ]