diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..13566b8
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 0000000..f7320f1
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,53 @@
+exclude: "^docs/|/migrations/"
+default_stages: [ commit ]
+fail_fast: true
+
+repos:
+ - repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v4.5.0
+ hooks:
+ - id: trailing-whitespace
+ - id: end-of-file-fixer
+ exclude: ".svg$|.min.*|.idea*"
+ - id: check-yaml
+ - id: check-added-large-files
+
+ - repo: https://github.com/asottile/pyupgrade
+ rev: v3.15.0
+ hooks:
+ - id: pyupgrade
+ args: [ --py39-plus ]
+
+ - repo: https://github.com/psf/black
+ rev: 23.12.1
+ hooks:
+ - id: black
+
+ - repo: https://github.com/PyCQA/isort
+ rev: 5.13.2
+ hooks:
+ - id: isort
+ args: [ "--profile", "black", "-l=88" ]
+
+ - repo: https://github.com/PyCQA/flake8
+ rev: 7.0.0
+ hooks:
+ - id: flake8
+ args: [ "--config=setup.cfg" ]
+ additional_dependencies:
+ - flake8-isort
+ - flake8-bugbear
+ - flake8-print
+
+ - repo: https://github.com/adamchainz/django-upgrade
+ rev: 1.15.0
+ hooks:
+ - id: django-upgrade
+ args: [ --target-version, "4.2" ]
+
+
+# sets up .pre-commit-ci.yaml to ensure pre-commit dependencies stay up to date
+ci:
+ autoupdate_schedule: weekly
+ skip: [ ]
+ submodules: false
diff --git a/allauth_cas/__init__.py b/allauth_cas/__init__.py
index 04e5073..4562d67 100644
--- a/allauth_cas/__init__.py
+++ b/allauth_cas/__init__.py
@@ -1,7 +1,3 @@
-# -*- coding: utf-8 -*-
+__version__ = "1.0.1"
-__version__ = '1.0.0'
-
-default_app_config = 'allauth_cas.apps.CASAccountConfig'
-
-CAS_PROVIDER_SESSION_KEY = 'allauth_cas__provider_id'
+CAS_PROVIDER_SESSION_KEY = "allauth_cas__provider_id"
diff --git a/allauth_cas/apps.py b/allauth_cas/apps.py
index 2ee7f0c..55d0dbb 100644
--- a/allauth_cas/apps.py
+++ b/allauth_cas/apps.py
@@ -1,10 +1,9 @@
-# -*- coding: utf-8 -*-
from django.apps import AppConfig
-from django.utils.translation import ugettext_lazy as _
+from django.utils.translation import gettext_lazy as _
class CASAccountConfig(AppConfig):
- name = 'allauth_cas'
+ name = "allauth_cas"
verbose_name = _("CAS Accounts")
def ready(self):
diff --git a/allauth_cas/exceptions.py b/allauth_cas/exceptions.py
index 2ed507e..d096400 100644
--- a/allauth_cas/exceptions.py
+++ b/allauth_cas/exceptions.py
@@ -1,6 +1,3 @@
-# -*- coding: utf-8 -*-
-
-
class CASAuthenticationError(Exception):
"""
Base exception to signal CAS authentication failure.
diff --git a/allauth_cas/providers.py b/allauth_cas/providers.py
index a46b4f1..3acf831 100644
--- a/allauth_cas/providers.py
+++ b/allauth_cas/providers.py
@@ -1,26 +1,18 @@
-# -*- coding: utf-8 -*-
-from six.moves.urllib.parse import parse_qsl
+from urllib.parse import parse_qsl
-import django
+from allauth.socialaccount.providers.base import Provider
from django.contrib import messages
from django.template.loader import render_to_string
+from django.urls import reverse
from django.utils.http import urlencode
from django.utils.safestring import mark_safe
-from allauth.socialaccount.providers.base import Provider
-
-if django.VERSION >= (1, 10):
- from django.urls import reverse
-else:
- from django.core.urlresolvers import reverse
-
class CASProvider(Provider):
-
def get_auth_params(self, request, action):
settings = self.get_settings()
- ret = dict(settings.get('AUTH_PARAMS', {}))
- dynamic_auth_params = request.GET.get('auth_params')
+ ret = dict(settings.get("AUTH_PARAMS", {}))
+ dynamic_auth_params = request.GET.get("auth_params")
if dynamic_auth_params:
ret.update(dict(parse_qsl(dynamic_auth_params)))
return ret
@@ -68,11 +60,11 @@ class CASProvider(Provider):
"""
uid, extra = data
return {
- 'username': extra.get('username', uid),
- 'email': extra.get('email'),
- 'first_name': extra.get('first_name'),
- 'last_name': extra.get('last_name'),
- 'name': extra.get('name'),
+ "username": extra.get("username", uid),
+ "email": extra.get("email"),
+ "first_name": extra.get("first_name"),
+ "last_name": extra.get("last_name"),
+ "name": extra.get("name"),
}
def extract_email_addresses(self, data):
@@ -99,7 +91,7 @@ class CASProvider(Provider):
]
"""
- return super(CASProvider, self).extract_email_addresses(data)
+ return super().extract_email_addresses(data)
def extract_extra_data(self, data):
"""Extract the data to save to `SocialAccount.extra_data`.
@@ -119,7 +111,10 @@ class CASProvider(Provider):
##
def add_message_suggest_caslogout(
- self, request, next_page=None, level=None,
+ self,
+ request,
+ next_page=None,
+ level=None,
):
"""Add a message with a link for the user to logout of the CAS server.
@@ -144,14 +139,15 @@ class CASProvider(Provider):
# DefaultAccountAdapter.add_message is unusable because it always
# escape the message content.
- template = 'socialaccount/messages/suggest_caslogout.html'
+ template = "socialaccount/messages/suggest_caslogout.html"
context = {
- 'provider': self,
- 'logout_url': logout_url,
+ "provider": self,
+ "logout_url": logout_url,
}
messages.add_message(
- request, level,
+ request,
+ level,
mark_safe(render_to_string(template, context).strip()),
fail_silently=True,
)
@@ -168,10 +164,7 @@ class CASProvider(Provider):
signal ``user_logged_out``.
"""
- return (
- self.get_settings()
- .get('MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT', False)
- )
+ return self.get_settings().get("MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT", False)
def message_suggest_caslogout_on_logout_level(self, request):
"""Level of the logout message issued on user logout.
@@ -185,9 +178,8 @@ class CASProvider(Provider):
signal ``user_logged_out``.
"""
- return (
- self.get_settings()
- .get('MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT_LEVEL', messages.INFO)
+ return self.get_settings().get(
+ "MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT_LEVEL", messages.INFO
)
##
@@ -195,19 +187,19 @@ class CASProvider(Provider):
##
def get_login_url(self, request, **kwargs):
- url = reverse(self.id + '_login')
+ url = reverse(self.id + "_login")
if kwargs:
- url += '?' + urlencode(kwargs)
+ url += "?" + urlencode(kwargs)
return url
def get_callback_url(self, request, **kwargs):
- url = reverse(self.id + '_callback')
+ url = reverse(self.id + "_callback")
if kwargs:
- url += '?' + urlencode(kwargs)
+ url += "?" + urlencode(kwargs)
return url
def get_logout_url(self, request, **kwargs):
- url = reverse(self.id + '_logout')
+ url = reverse(self.id + "_logout")
if kwargs:
- url += '?' + urlencode(kwargs)
+ url += "?" + urlencode(kwargs)
return url
diff --git a/allauth_cas/signals.py b/allauth_cas/signals.py
index 927cbfd..36c9b24 100644
--- a/allauth_cas/signals.py
+++ b/allauth_cas/signals.py
@@ -1,10 +1,8 @@
-# -*- coding: utf-8 -*-
-from django.contrib.auth.signals import user_logged_out
-from django.dispatch import receiver
-
from allauth.account.adapter import get_adapter
from allauth.account.utils import get_next_redirect_url
from allauth.socialaccount import providers
+from django.contrib.auth.signals import user_logged_out
+from django.dispatch import receiver
from . import CAS_PROVIDER_SESSION_KEY
@@ -21,12 +19,12 @@ def cas_account_logout(sender, request, **kwargs):
if not provider.message_suggest_caslogout_on_logout(request):
return
- next_page = (
- get_next_redirect_url(request) or
- get_adapter(request).get_logout_redirect_url(request)
- )
+ next_page = get_next_redirect_url(request) or get_adapter(
+ request
+ ).get_logout_redirect_url(request)
provider.add_message_suggest_caslogout(
- request, next_page=next_page,
+ request,
+ next_page=next_page,
level=provider.message_suggest_caslogout_on_logout_level(request),
)
diff --git a/allauth_cas/test/testcases.py b/allauth_cas/test/testcases.py
index d1e93e9..7ec8c07 100644
--- a/allauth_cas/test/testcases.py
+++ b/allauth_cas/test/testcases.py
@@ -1,29 +1,20 @@
-# -*- coding: utf-8 -*-
try:
from unittest.mock import patch
except ImportError:
- from mock import patch
-
-import django
-from django.conf import settings
-from django.test import TestCase
+ from unittest.mock import patch
import cas
+from django.conf import settings
+from django.test import TestCase
+from django.urls import reverse
from allauth_cas import CAS_PROVIDER_SESSION_KEY
-if django.VERSION >= (1, 10):
- from django.urls import reverse
-else:
- from django.core.urlresolvers import reverse
-
class CASTestCase(TestCase):
-
def client_cas_login(
- self,
- client, provider_id='theid',
- username=None, attributes={}):
+ self, client, provider_id=None, username=None, attributes=None
+ ):
"""
Authenticate client through provider_id.
@@ -32,20 +23,22 @@ class CASTestCase(TestCase):
username and attributes control the CAS server response when ticket is
checked.
"""
- client.get(reverse('{id}_login'.format(id=provider_id)))
+ if attributes is None:
+ attributes = {}
+ if provider_id is None:
+ provider_id = "theid"
+ client.get(reverse(f"{provider_id}_login"))
self.patch_cas_response(
- valid_ticket='__all__',
- username=username, attributes=attributes,
+ valid_ticket="__all__",
+ username=username,
+ attributes=attributes,
)
- callback_url = reverse('{id}_callback'.format(id=provider_id))
- r = client.get(callback_url, {'ticket': 'fake-ticket'})
+ callback_url = reverse(f"{provider_id}_callback")
+ r = client.get(callback_url, {"ticket": "fake-ticket"})
self.patch_cas_response_stop()
return r
- def patch_cas_response(
- self,
- valid_ticket,
- username=None, attributes={}):
+ def patch_cas_response(self, valid_ticket, username=None, attributes=None):
"""
Patch the CASClient class used by views of CAS providers.
@@ -68,35 +61,38 @@ class CASTestCase(TestCase):
Note that valid_ticket sould be a string (which is the type of the
ticket retrieved from GET parameter on request on the callback view).
"""
- if hasattr(self, '_patch_cas_client'):
+ if attributes is None:
+ attributes = {}
+ if hasattr(self, "_patch_cas_client"):
self.patch_cas_response_stop()
- class MockCASClient(object):
+ class MockCASClient:
_username = username
def __new__(self_client, *args, **kwargs):
- version = kwargs.pop('version')
- if version in (1, '1'):
+ version = kwargs.pop("version")
+ if version in (1, "1"):
client_class = cas.CASClientV1
- elif version in (2, '2'):
+ elif version in (2, "2"):
client_class = cas.CASClientV2
- elif version in (3, '3'):
+ elif version in (3, "3"):
client_class = cas.CASClientV3
- elif version == 'CAS_2_SAML_1_0':
+ elif version == "CAS_2_SAML_1_0":
client_class = cas.CASClientWithSAMLV1
else:
- raise ValueError('Unsupported CAS_VERSION %r' % version)
+ raise ValueError("Unsupported CAS_VERSION %r" % version)
client_class._username = self_client._username
def verify_ticket(self, ticket):
- if valid_ticket == '__all__' or ticket == valid_ticket:
- username = self._username or 'username'
+ if valid_ticket == "__all__" or ticket == valid_ticket:
+ username = self._username or "username"
return username, attributes, None
return None, {}, None
patcher = patch.object(
- client_class, 'verify_ticket',
+ client_class,
+ "verify_ticket",
new=verify_ticket,
)
patcher.start()
@@ -104,7 +100,7 @@ class CASTestCase(TestCase):
return client_class(*args, **kwargs)
self._patch_cas_client = patch(
- 'allauth_cas.views.cas.CASClient',
+ "allauth_cas.views.cas.CASClient",
MockCASClient,
)
self._patch_cas_client.start()
@@ -114,12 +110,11 @@ class CASTestCase(TestCase):
del self._patch_cas_client
def tearDown(self):
- if hasattr(self, '_patch_cas_client'):
+ if hasattr(self, "_patch_cas_client"):
self.patch_cas_response_stop()
class CASViewTestCase(CASTestCase):
-
def assertLoginSuccess(self, response, redirect_to=None):
"""
Asserts response corresponds to a successful login.
@@ -133,7 +128,8 @@ class CASViewTestCase(CASTestCase):
redirect_to = settings.LOGIN_REDIRECT_URL
self.assertRedirects(
- response, redirect_to,
+ response,
+ redirect_to,
fetch_redirect_response=False,
)
self.assertIn(
@@ -146,6 +142,6 @@ class CASViewTestCase(CASTestCase):
Asserts response corresponds to a failed login.
"""
return self.assertInHTML(
- '
Social Network Login Failure
',
+ "Social Network Login Failure
",
str(response.content),
)
diff --git a/allauth_cas/urls.py b/allauth_cas/urls.py
index 687dca6..292bce7 100644
--- a/allauth_cas/urls.py
+++ b/allauth_cas/urls.py
@@ -1,5 +1,4 @@
-# -*- coding: utf-8 -*-
-from django.conf.urls import include, url
+from django.urls import include, path, re_path
from django.utils.module_loading import import_string
@@ -7,45 +6,44 @@ def default_urlpatterns(provider):
package = provider.get_package()
try:
- login_view = import_string(package + '.views.login')
+ login_view = import_string(package + ".views.login")
except ImportError:
raise ImportError(
"The login view for the '{id}' provider is lacking from the "
"'views' module of its app.\n"
"You may want to add:\n"
"from allauth_cas.views import CASLoginView\n\n"
- "login = CASLoginView.adapter_view()"
- .format(id=provider.id)
+ "login = CASLoginView.adapter_view()".format(
+ id=provider.id
+ )
)
try:
- callback_view = import_string(package + '.views.callback')
+ callback_view = import_string(package + ".views.callback")
except ImportError:
raise ImportError(
"The callback view for the '{id}' provider is lacking from the "
"'views' module of its app.\n"
"You may want to add:\n"
"from allauth_cas.views import CASCallbackView\n\n"
- "callback = CASCallbackView.adapter_view()"
- .format(id=provider.id)
+ "callback = CASCallbackView.adapter_view()".format(
+ id=provider.id
+ )
)
try:
- logout_view = import_string(package + '.views.logout')
+ logout_view = import_string(package + ".views.logout")
except ImportError:
logout_view = None
urlpatterns = [
- url('^login/$', login_view,
- name=provider.id + '_login'),
- url('^login/callback/$', callback_view,
- name=provider.id + '_callback'),
+ path("login/", login_view, name=provider.id + "_login"),
+ path("login/callback/", callback_view, name=provider.id + "_callback"),
]
if logout_view is not None:
urlpatterns += [
- url('^logout/$', logout_view,
- name=provider.id + '_logout'),
+ path("logout/", logout_view, name=provider.id + "_logout"),
]
- return [url('^' + provider.get_slug() + '/', include(urlpatterns))]
+ return [re_path("^" + provider.get_slug() + "/", include(urlpatterns))]
diff --git a/allauth_cas/views.py b/allauth_cas/views.py
index b27b3c7..d08e354 100644
--- a/allauth_cas/views.py
+++ b/allauth_cas/views.py
@@ -1,28 +1,26 @@
-# -*- coding: utf-8 -*-
-from django.http import HttpResponseRedirect
-from django.utils.functional import cached_property
-
+import cas
from allauth.account.adapter import get_adapter
from allauth.account.utils import get_next_redirect_url
from allauth.socialaccount import providers
from allauth.socialaccount.helpers import (
- complete_social_login, render_authentication_error,
+ complete_social_login,
+ render_authentication_error,
)
from allauth.socialaccount.models import SocialLogin
-
-import cas
+from django.http import HttpResponseRedirect
+from django.utils.functional import cached_property
from . import CAS_PROVIDER_SESSION_KEY
from .exceptions import CASAuthenticationError
-class AuthAction(object):
- AUTHENTICATE = 'authenticate'
- REAUTHENTICATE = 'reauthenticate'
- DEAUTHENTICATE = 'deauthenticate'
+class AuthAction:
+ AUTHENTICATE = "authenticate"
+ REAUTHENTICATE = "reauthenticate"
+ DEAUTHENTICATE = "deauthenticate"
-class CASAdapter(object):
+class CASAdapter:
#: CAS server url.
url = None
#: CAS server version.
@@ -92,19 +90,19 @@ class CASAdapter(object):
"""
redirect_to = get_next_redirect_url(request)
- callback_kwargs = {'next': redirect_to} if redirect_to else {}
- callback_url = (
- self.provider.get_callback_url(request, **callback_kwargs))
+ callback_kwargs = {"next": redirect_to} if redirect_to else {}
+ callback_url = self.provider.get_callback_url(request, **callback_kwargs)
service_url = request.build_absolute_uri(callback_url)
return service_url
-class CASView(object):
+class CASView:
"""
Base class for CAS views.
"""
+
@classmethod
def adapter_view(cls, adapter):
"""Transform the view class into a view function.
@@ -124,6 +122,7 @@ class CASView(object):
"""
+
def view(request, *args, **kwargs):
# Prepare the func-view.
self = cls()
@@ -169,19 +168,17 @@ class CASView(object):
class CASLoginView(CASView):
-
def dispatch(self, request):
"""
Redirects to the CAS server login page.
"""
- action = request.GET.get('action', AuthAction.AUTHENTICATE)
+ action = request.GET.get("action", AuthAction.AUTHENTICATE)
SocialLogin.stash_state(request)
client = self.get_client(request, action=action)
return HttpResponseRedirect(client.get_login_url())
class CASCallbackView(CASView):
-
def dispatch(self, request):
"""
The CAS server redirects the user to this view after a successful
@@ -195,11 +192,9 @@ class CASCallbackView(CASView):
# CAS server should let a ticket.
try:
- ticket = request.GET['ticket']
+ ticket = request.GET["ticket"]
except KeyError:
- raise CASAuthenticationError(
- "CAS server didn't respond with a ticket."
- )
+ raise CASAuthenticationError("CAS server didn't respond with a ticket.")
# Check ticket validity.
# Response format on:
@@ -210,9 +205,7 @@ class CASCallbackView(CASView):
uid, extra, _ = response
if not uid:
- raise CASAuthenticationError(
- "CAS server doesn't validate the ticket."
- )
+ raise CASAuthenticationError("CAS server doesn't validate the ticket.")
# Keep tracks of the last used CAS provider.
request.session[CAS_PROVIDER_SESSION_KEY] = self.provider.id
@@ -226,7 +219,6 @@ class CASCallbackView(CASView):
class CASLogoutView(CASView):
-
def dispatch(self, request, next_page=None):
"""
Redirects to the CAS server logout page.
@@ -248,7 +240,6 @@ class CASLogoutView(CASView):
Returns the url to redirect after logout.
"""
request = self.request
- return (
- get_next_redirect_url(request) or
- get_adapter(request).get_logout_redirect_url(request)
- )
+ return get_next_redirect_url(request) or get_adapter(
+ request
+ ).get_logout_redirect_url(request)
diff --git a/runtests.py b/runtests.py
index 5cea84f..a4e0970 100644
--- a/runtests.py
+++ b/runtests.py
@@ -1,5 +1,4 @@
#!/usr/bin/env python
-# -*- coding: utf-8 -*-
import os
import sys
@@ -7,10 +6,10 @@ import django
from django.conf import settings
from django.test.utils import get_runner
-if __name__ == '__main__':
- os.environ['DJANGO_SETTINGS_MODULE'] = 'tests.settings'
+if __name__ == "__main__":
+ os.environ["DJANGO_SETTINGS_MODULE"] = "tests.settings"
django.setup()
TestRunner = get_runner(settings)
test_runner = TestRunner()
- failures = test_runner.run_tests(sys.argv[1:] or ['tests'])
+ failures = test_runner.run_tests(sys.argv[1:] or ["tests"])
sys.exit(bool(failures))
diff --git a/setup.cfg b/setup.cfg
index 7e7ca4d..d4b1006 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -1,18 +1,45 @@
[flake8]
+max-line-length = 120
+exclude = .tox,.git,*/migrations/*,*/static/CACHE/*,docs,node_modules,venv
# E731: lambda expression
ignore = E731
+[pycodestyle]
+max-line-length = 120
+exclude = .tox,.git,*/migrations/*,*/static/CACHE/*,docs,node_modules,venv
+
[isort]
-combine_as_imports = True
+line_length = 88
+known_first_party = geoip,config
+multi_line_output = 3
default_section = THIRDPARTY
-include_trailing_comma = True
-known_allauth = allauth
-known_future_library = future,six
-known_django = django
-known_first_party = allauth_cas
-multi_line_output = 5
-not_skip = __init__.py
-sections = FUTURE,STDLIB,DJANGO,ALLAUTH,THIRDPARTY,FIRSTPARTY,LOCALFOLDER
+skip = venv/
+skip_glob = **/migrations/*.py
+include_trailing_comma = true
+force_grid_wrap = 0
+use_parentheses = true
+
+[mypy]
+python_version = 3.9
+check_untyped_defs = True
+ignore_missing_imports = True
+warn_unused_ignores = True
+warn_redundant_casts = True
+warn_unused_configs = True
+plugins = mypy_django_plugin.main
+
+[mypy.plugins.django-stubs]
+django_settings_module = config.settings.test
+
+[mypy-*.migrations.*]
+# Django migrations should not produce any errors:
+ignore_errors = True
+
+[coverage:run]
+include = geoip/*
+omit = *migrations*, *tests*
+plugins =
+ django_coverage_plugin
[bdist_wheel]
universal = 1
diff --git a/setup.py b/setup.py
index b256fac..fb06ec0 100644
--- a/setup.py
+++ b/setup.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
import os
from setuptools import find_packages, setup
@@ -7,49 +6,56 @@ from allauth_cas import __version__
BASE_DIR = os.path.dirname(__file__)
-with open(os.path.join(BASE_DIR, 'README.rst')) as readme:
+with open(os.path.join(BASE_DIR, "README.rst")) as readme:
README = readme.read()
setup(
- name='django-allauth-cas',
+ name="django-allauth-cas",
version=__version__,
- description='CAS support for django-allauth.',
- author='Aurélien Delobelle',
- author_email='aurelien.delobelle@gmail.com',
- keywords='django allauth cas authentication',
+ description="CAS support for django-allauth.",
+ author="Aurélien Delobelle",
+ author_email="aurelien.delobelle@gmail.com",
+ keywords="django allauth cas authentication",
long_description=README,
- url='https://github.com/aureplop/django-allauth-cas',
+ url="https://github.com/aureplop/django-allauth-cas",
classifiers=[
- 'Development Status :: 4 - Beta',
- 'Environment :: Web Environment',
- 'Framework :: Django',
- 'Framework :: Django :: 1.8',
- 'Framework :: Django :: 1.9',
- 'Framework :: Django :: 1.10',
- 'Framework :: Django :: 1.11',
- 'Framework :: Django :: 2.0',
- 'Intended Audience :: Developers',
- 'License :: OSI Approved :: MIT License',
- 'Operating System :: OS Independent',
- 'Programming Language :: Python',
- 'Programming Language :: Python :: 2',
- 'Programming Language :: Python :: 2.7',
- 'Programming Language :: Python :: 3',
- 'Programming Language :: Python :: 3.4',
- 'Programming Language :: Python :: 3.5',
- 'Programming Language :: Python :: 3.6',
- 'Topic :: Internet :: WWW/HTTP',
+ "Development Status :: 4 - Beta",
+ "Environment :: Web Environment",
+ "Framework :: Django",
+ "Framework :: Django :: 1.8",
+ "Framework :: Django :: 1.9",
+ "Framework :: Django :: 1.10",
+ "Framework :: Django :: 1.11",
+ "Framework :: Django :: 2.0",
+ "Framework :: Django :: 3.0",
+ "Framework :: Django :: 4.0",
+ "Framework :: Django :: 5.0",
+ "Intended Audience :: Developers",
+ "License :: OSI Approved :: MIT License",
+ "Operating System :: OS Independent",
+ "Programming Language :: Python",
+ "Programming Language :: Python :: 2",
+ "Programming Language :: Python :: 2.7",
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3.4",
+ "Programming Language :: Python :: 3.5",
+ "Programming Language :: Python :: 3.6",
+ "Programming Language :: Python :: 3.7",
+ "Programming Language :: Python :: 3.8",
+ "Programming Language :: Python :: 3.9",
+ "Programming Language :: Python :: 3.11",
+ "Topic :: Internet :: WWW/HTTP",
],
- license='MIT',
- packages=find_packages(exclude=['tests']),
+ license="MIT",
+ packages=find_packages(exclude=["tests"]),
include_package_data=True,
install_requires=[
- 'django-allauth',
- 'python-cas',
- 'six',
+ "django-allauth",
+ "python-cas",
+ "six",
],
extras_require={
- 'docs': ['sphinx'],
- 'tests': ['tox'],
+ "docs": ["sphinx"],
+ "tests": ["tox"],
},
)
diff --git a/tests/example/provider.py b/tests/example/provider.py
index c393b70..72d1c45 100644
--- a/tests/example/provider.py
+++ b/tests/example/provider.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
from allauth.socialaccount.providers.base import ProviderAccount
from allauth_cas.providers import CASProvider
@@ -9,8 +8,8 @@ class ExampleCASAccount(ProviderAccount):
class ExampleCASProvider(CASProvider):
- id = 'theid'
- name = 'The Provider'
+ id = "theid"
+ name = "The Provider"
account_class = ExampleCASAccount
diff --git a/tests/example/urls.py b/tests/example/urls.py
index 335c441..795e835 100644
--- a/tests/example/urls.py
+++ b/tests/example/urls.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
from allauth_cas.urls import default_urlpatterns
from .provider import ExampleCASProvider
diff --git a/tests/example/views.py b/tests/example/views.py
index 6aa916a..1f04fb2 100644
--- a/tests/example/views.py
+++ b/tests/example/views.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
from allauth_cas import views
from .provider import ExampleCASProvider
@@ -6,7 +5,7 @@ from .provider import ExampleCASProvider
class ExampleCASAdapter(views.CASAdapter):
provider_id = ExampleCASProvider.id
- url = 'https://server.cas'
+ url = "https://server.cas"
version = 2
diff --git a/tests/settings.py b/tests/settings.py
index 5acb51d..505c325 100644
--- a/tests/settings.py
+++ b/tests/settings.py
@@ -1,63 +1,54 @@
-# -*- coding: utf-8 -*-
-import django
-
-SECRET_KEY = 'iamabird'
+SECRET_KEY = "iamabird"
INSTALLED_APPS = [
- 'django.contrib.admin',
- 'django.contrib.auth',
- 'django.contrib.contenttypes',
- 'django.contrib.messages',
- 'django.contrib.sessions',
- 'django.contrib.sites',
- 'django.contrib.staticfiles',
-
- 'allauth',
- 'allauth.account',
- 'allauth.socialaccount',
-
- 'allauth_cas',
-
- 'tests.example', # Dummy CAS provider app
+ "django.contrib.admin",
+ "django.contrib.auth",
+ "django.contrib.contenttypes",
+ "django.contrib.messages",
+ "django.contrib.sessions",
+ "django.contrib.sites",
+ "django.contrib.staticfiles",
+ "allauth",
+ "allauth.account",
+ "allauth.socialaccount",
+ "allauth_cas",
+ "tests.example", # Dummy CAS provider app
]
DATABASES = {
- 'default': {
- 'ENGINE': 'django.db.backends.sqlite3',
+ "default": {
+ "ENGINE": "django.db.backends.sqlite3",
},
}
AUTHENTICATION_BACKENDS = [
- 'allauth.account.auth_backends.AuthenticationBackend',
+ "allauth.account.auth_backends.AuthenticationBackend",
]
_MIDDLEWARES = [
- 'django.middleware.common.CommonMiddleware',
- 'django.contrib.sessions.middleware.SessionMiddleware',
- 'django.middleware.csrf.CsrfViewMiddleware',
- 'django.contrib.auth.middleware.AuthenticationMiddleware',
- 'django.contrib.messages.middleware.MessageMiddleware',
+ "django.middleware.common.CommonMiddleware",
+ "django.contrib.sessions.middleware.SessionMiddleware",
+ "django.middleware.csrf.CsrfViewMiddleware",
+ "django.contrib.auth.middleware.AuthenticationMiddleware",
+ "django.contrib.messages.middleware.MessageMiddleware",
]
-if django.VERSION >= (1, 10):
- MIDDLEWARE = _MIDDLEWARES
-else:
- MIDDLEWARE_CLASSES = _MIDDLEWARES
+MIDDLEWARE = _MIDDLEWARES
TEMPLATES = [
{
- 'BACKEND': 'django.template.backends.django.DjangoTemplates',
- 'DIRS': [],
- 'APP_DIRS': True,
- 'OPTIONS': {
- 'context_processors': [
- 'django.template.context_processors.debug',
- 'django.template.context_processors.request',
- 'django.contrib.auth.context_processors.auth',
- 'django.contrib.messages.context_processors.messages',
+ "BACKEND": "django.template.backends.django.DjangoTemplates",
+ "DIRS": [],
+ "APP_DIRS": True,
+ "OPTIONS": {
+ "context_processors": [
+ "django.template.context_processors.debug",
+ "django.template.context_processors.request",
+ "django.contrib.auth.context_processors.auth",
+ "django.contrib.messages.context_processors.messages",
],
},
},
]
-ROOT_URLCONF = 'tests.urls'
+ROOT_URLCONF = "tests.urls"
diff --git a/tests/test_flows.py b/tests/test_flows.py
index fb9fdbd..00d2b63 100644
--- a/tests/test_flows.py
+++ b/tests/test_flows.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
from django.contrib import messages
from django.contrib.auth import get_user_model
from django.contrib.messages.api import get_messages
@@ -13,7 +12,7 @@ User = get_user_model()
class LogoutFlowTests(CASTestCase):
expected_msg_str = (
"To logout of The Provider, please close your browser, or visit this "
- ""
+ ''
"link."
)
@@ -28,42 +27,45 @@ class LogoutFlowTests(CASTestCase):
)
self.assertTemplateNotUsed(
response,
- 'socialaccount/messages/suggest_caslogout.html',
+ "socialaccount/messages/suggest_caslogout.html",
)
- @override_settings(SOCIALACCOUNT_PROVIDERS={
- 'theid': {
- 'MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT': True,
- 'MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT_LEVEL': messages.WARNING,
- },
- })
+ @override_settings(
+ SOCIALACCOUNT_PROVIDERS={
+ "theid": {
+ "MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT": True,
+ "MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT_LEVEL": messages.WARNING,
+ },
+ }
+ )
def test_message_on_logout(self):
"""
Message is sent to propose user to logout of CAS.
"""
- r = self.client.post('/accounts/logout/?next=/redir/')
+ r = self.client.post("/accounts/logout/?next=/redir/")
r_messages = get_messages(r.wsgi_request)
expected_msg = Message(messages.WARNING, self.expected_msg_str)
self.assertIn(expected_msg, r_messages)
- self.assertTemplateUsed(
- r, 'socialaccount/messages/suggest_caslogout.html')
+ self.assertTemplateUsed(r, "socialaccount/messages/suggest_caslogout.html")
def test_message_on_logout_disabled(self):
- r = self.client.post('/accounts/logout/')
+ r = self.client.post("/accounts/logout/")
self.assertCASLogoutNotInMessages(r)
- @override_settings(SOCIALACCOUNT_PROVIDERS={
- 'theid': {'MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT': True},
- })
+ @override_settings(
+ SOCIALACCOUNT_PROVIDERS={
+ "theid": {"MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT": True},
+ }
+ )
def test_other_logout(self):
"""
The CAS logout message doesn't appear with other login methods.
"""
- User.objects.create_user('user', '', 'user')
+ User.objects.create_user("user", "", "user")
client = Client()
- client.login(username='user', password='user')
+ client.login(username="user", password="user")
- r = client.post('/accounts/logout/')
+ r = client.post("/accounts/logout/")
self.assertCASLogoutNotInMessages(r)
diff --git a/tests/test_providers.py b/tests/test_providers.py
index ef68bb7..508f5b3 100644
--- a/tests/test_providers.py
+++ b/tests/test_providers.py
@@ -1,6 +1,6 @@
-# -*- coding: utf-8 -*-
-from six.moves.urllib.parse import urlencode
+from urllib.parse import urlencode
+from allauth.socialaccount.providers import registry
from django.contrib import messages
from django.contrib.messages.api import get_messages
from django.contrib.messages.middleware import MessageMiddleware
@@ -8,21 +8,18 @@ from django.contrib.messages.storage.base import Message
from django.contrib.sessions.middleware import SessionMiddleware
from django.test import RequestFactory, TestCase, override_settings
-from allauth.socialaccount.providers import registry
-
from allauth_cas.views import AuthAction
from .example.provider import ExampleCASProvider
class CASProviderTests(TestCase):
-
def setUp(self):
self.request = self._get_request()
self.provider = ExampleCASProvider(self.request)
def _get_request(self):
- request = RequestFactory().get('/test/')
+ request = RequestFactory().get("/test/")
SessionMiddleware().process_request(request)
MessageMiddleware().process_request(request)
return request
@@ -31,74 +28,81 @@ class CASProviderTests(TestCase):
"""
Example CAS provider is registered as social account provider.
"""
- self.assertIsInstance(registry.by_id('theid'), ExampleCASProvider)
+ self.assertIsInstance(registry.by_id("theid"), ExampleCASProvider)
def test_get_login_url(self):
url = self.provider.get_login_url(self.request)
- self.assertEqual('/accounts/theid/login/', url)
+ self.assertEqual("/accounts/theid/login/", url)
url_with_qs = self.provider.get_login_url(
self.request,
- next='/path?quéry=string&two=whoam%C3%AF',
+ next="/path?quéry=string&two=whoam%C3%AF",
)
self.assertEqual(
url_with_qs,
- '/accounts/theid/login/?next=%2Fpath%3Fqu%C3%A9ry%3Dstring%26two%3'
- 'Dwhoam%25C3%25AF'
+ "/accounts/theid/login/?next=%2Fpath%3Fqu%C3%A9ry%3Dstring%26two%3"
+ "Dwhoam%25C3%25AF",
)
def test_get_callback_url(self):
url = self.provider.get_callback_url(self.request)
- self.assertEqual('/accounts/theid/login/callback/', url)
+ self.assertEqual("/accounts/theid/login/callback/", url)
url_with_qs = self.provider.get_callback_url(
self.request,
- next='/path?quéry=string&two=whoam%C3%AF',
+ next="/path?quéry=string&two=whoam%C3%AF",
)
self.assertEqual(
url_with_qs,
- '/accounts/theid/login/callback/?next=%2Fpath%3Fqu%C3%A9ry%3Dstrin'
- 'g%26two%3Dwhoam%25C3%25AF'
+ "/accounts/theid/login/callback/?next=%2Fpath%3Fqu%C3%A9ry%3Dstrin"
+ "g%26two%3Dwhoam%25C3%25AF",
)
def test_get_logout_url(self):
url = self.provider.get_logout_url(self.request)
- self.assertEqual('/accounts/theid/logout/', url)
+ self.assertEqual("/accounts/theid/logout/", url)
url_with_qs = self.provider.get_logout_url(
self.request,
- next='/path?quéry=string&two=whoam%C3%AF',
+ next="/path?quéry=string&two=whoam%C3%AF",
)
self.assertEqual(
url_with_qs,
- '/accounts/theid/logout/?next=%2Fpath%3Fqu%C3%A9ry%3Dstring%26two%'
- '3Dwhoam%25C3%25AF'
+ "/accounts/theid/logout/?next=%2Fpath%3Fqu%C3%A9ry%3Dstring%26two%"
+ "3Dwhoam%25C3%25AF",
)
- @override_settings(SOCIALACCOUNT_PROVIDERS={
- 'theid': {
- 'AUTH_PARAMS': {'key': 'value'},
- },
- })
+ @override_settings(
+ SOCIALACCOUNT_PROVIDERS={
+ "theid": {
+ "AUTH_PARAMS": {"key": "value"},
+ },
+ }
+ )
def test_get_auth_params(self):
action = AuthAction.AUTHENTICATE
auth_params = self.provider.get_auth_params(self.request, action)
- self.assertDictEqual(auth_params, {
- 'key': 'value',
- })
+ self.assertDictEqual(
+ auth_params,
+ {
+ "key": "value",
+ },
+ )
- @override_settings(SOCIALACCOUNT_PROVIDERS={
- 'theid': {
- 'AUTH_PARAMS': {'key': 'value'},
- },
- })
+ @override_settings(
+ SOCIALACCOUNT_PROVIDERS={
+ "theid": {
+ "AUTH_PARAMS": {"key": "value"},
+ },
+ }
+ )
def test_get_auth_params_with_dynamic(self):
factory = RequestFactory()
request = factory.get(
- '/test/?auth_params=next%3Dtwo%253Dwhoam%2525C3%2525AF%2526qu%2525'
- 'C3%2525A9ry%253Dstring'
+ "/test/?auth_params=next%3Dtwo%253Dwhoam%2525C3%2525AF%2526qu%2525"
+ "C3%2525A9ry%253Dstring"
)
request.session = {}
@@ -106,15 +110,18 @@ class CASProviderTests(TestCase):
auth_params = self.provider.get_auth_params(request, action)
- self.assertDictEqual(auth_params, {
- 'key': 'value',
- 'next': 'two=whoam%C3%AF&qu%C3%A9ry=string',
- })
+ self.assertDictEqual(
+ auth_params,
+ {
+ "key": "value",
+ "next": "two=whoam%C3%AF&qu%C3%A9ry=string",
+ },
+ )
def test_add_message_suggest_caslogout(self):
expected_msg_base_str = (
"To logout of The Provider, please close your browser, or visit "
- "this link."
+ 'this link.'
)
# Defaults.
@@ -124,7 +131,7 @@ class CASProviderTests(TestCase):
expected_msg1 = Message(
messages.INFO,
- expected_msg_base_str.format(urlencode({'next': '/test/'})),
+ expected_msg_base_str.format(urlencode({"next": "/test/"})),
)
self.assertIn(expected_msg1, get_messages(req1))
@@ -132,69 +139,83 @@ class CASProviderTests(TestCase):
req2 = self._get_request()
self.provider.add_message_suggest_caslogout(
- req2, next_page='/redir/', level=messages.WARNING)
+ req2, next_page="/redir/", level=messages.WARNING
+ )
expected_msg2 = Message(
messages.WARNING,
- expected_msg_base_str.format(urlencode({'next': '/redir/'})),
+ expected_msg_base_str.format(urlencode({"next": "/redir/"})),
)
self.assertIn(expected_msg2, get_messages(req2))
def test_message_suggest_caslogout_on_logout(self):
self.assertFalse(
- self.provider.message_suggest_caslogout_on_logout(self.request))
+ self.provider.message_suggest_caslogout_on_logout(self.request)
+ )
- with override_settings(SOCIALACCOUNT_PROVIDERS={
- 'theid': {'MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT': True},
- }):
+ with override_settings(
+ SOCIALACCOUNT_PROVIDERS={
+ "theid": {"MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT": True},
+ }
+ ):
self.assertTrue(
- self.provider
- .message_suggest_caslogout_on_logout(self.request)
+ self.provider.message_suggest_caslogout_on_logout(self.request)
)
- @override_settings(SOCIALACCOUNT_PROVIDERS={
- 'theid': {
- 'MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT_LEVEL': messages.WARNING,
- },
- })
+ @override_settings(
+ SOCIALACCOUNT_PROVIDERS={
+ "theid": {
+ "MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT_LEVEL": messages.WARNING,
+ },
+ }
+ )
def test_message_suggest_caslogout_on_logout_level(self):
- self.assertEqual(messages.WARNING, (
- self.provider
- .message_suggest_caslogout_on_logout_level(self.request)
- ))
+ self.assertEqual(
+ messages.WARNING,
+ (self.provider.message_suggest_caslogout_on_logout_level(self.request)),
+ )
def test_extract_uid(self):
- response = 'useRName', {}
+ response = "useRName", {}
uid = self.provider.extract_uid(response)
- self.assertEqual('useRName', uid)
+ self.assertEqual("useRName", uid)
def test_extract_common_fields(self):
- response = 'useRName', {}
+ response = "useRName", {}
common_fields = self.provider.extract_common_fields(response)
- self.assertDictEqual(common_fields, {
- 'username': 'useRName',
- 'first_name': None,
- 'last_name': None,
- 'name': None,
- 'email': None,
- })
+ self.assertDictEqual(
+ common_fields,
+ {
+ "username": "useRName",
+ "first_name": None,
+ "last_name": None,
+ "name": None,
+ "email": None,
+ },
+ )
def test_extract_common_fields_with_extra(self):
- response = 'useRName', {'username': 'user', 'email': 'user@mail.net'}
+ response = "useRName", {"username": "user", "email": "user@mail.net"}
common_fields = self.provider.extract_common_fields(response)
- self.assertDictEqual(common_fields, {
- 'username': 'user',
- 'first_name': None,
- 'last_name': None,
- 'name': None,
- 'email': 'user@mail.net',
- })
+ self.assertDictEqual(
+ common_fields,
+ {
+ "username": "user",
+ "first_name": None,
+ "last_name": None,
+ "name": None,
+ "email": "user@mail.net",
+ },
+ )
def test_extract_extra_data(self):
- response = 'useRName', {'user_attr': 'thevalue', 'another': 'value'}
+ response = "useRName", {"user_attr": "thevalue", "another": "value"}
extra_data = self.provider.extract_extra_data(response)
- self.assertDictEqual(extra_data, {
- 'user_attr': 'thevalue',
- 'another': 'value',
- 'uid': 'useRName',
- })
+ self.assertDictEqual(
+ extra_data,
+ {
+ "user_attr": "thevalue",
+ "another": "value",
+ "uid": "useRName",
+ },
+ )
diff --git a/tests/test_testcases.py b/tests/test_testcases.py
index c7dacab..98f1673 100644
--- a/tests/test_testcases.py
+++ b/tests/test_testcases.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
from django.test import Client, RequestFactory
from allauth_cas.test.testcases import CASViewTestCase
@@ -8,9 +7,8 @@ from .example.views import ExampleCASAdapter
class CASTestCaseTests(CASViewTestCase):
-
def setUp(self):
- self.client.get('/accounts/theid/login/')
+ self.client.get("/accounts/theid/login/")
def test_patch_cas_response_client_version(self):
"""
@@ -21,20 +19,24 @@ class CASTestCaseTests(CASViewTestCase):
"""
valid_versions = [
- 1, '1',
- 2, '2',
- 3, '3',
- 'CAS_2_SAML_1_0',
+ 1,
+ "1",
+ 2,
+ "2",
+ 3,
+ "3",
+ "CAS_2_SAML_1_0",
]
invalid_versions = [
- 'not_supported',
+ "not_supported",
]
factory = RequestFactory()
- request = factory.get('/path/')
+ request = factory.get("/path/")
request.session = {}
for _version in valid_versions + invalid_versions:
+
class BasicCASAdapter(ExampleCASAdapter):
version = _version
@@ -47,7 +49,7 @@ class CASTestCaseTests(CASViewTestCase):
if _version in valid_versions:
raw_client = view(request)
- self.patch_cas_response(valid_ticket='__all__')
+ self.patch_cas_response(valid_ticket="__all__")
mocked_client = view(request)
self.assertEqual(type(raw_client), type(mocked_client))
@@ -55,58 +57,79 @@ class CASTestCaseTests(CASViewTestCase):
# This is a sanity check.
self.assertRaises(ValueError, view, request)
- self.patch_cas_response(valid_ticket='__all__')
+ self.patch_cas_response(valid_ticket="__all__")
self.assertRaises(ValueError, view, request)
def test_patch_cas_response_verify_success(self):
- self.patch_cas_response(valid_ticket='123456')
- r = self.client.get('/accounts/theid/login/callback/', {
- 'ticket': '123456',
- })
+ self.patch_cas_response(valid_ticket="123456")
+ r = self.client.get(
+ "/accounts/theid/login/callback/",
+ {
+ "ticket": "123456",
+ },
+ )
self.assertLoginSuccess(r)
def test_patch_cas_response_verify_failure(self):
- self.patch_cas_response(valid_ticket='123456')
- r = self.client.get('/accounts/theid/login/callback/', {
- 'ticket': '000000',
- })
+ self.patch_cas_response(valid_ticket="123456")
+ r = self.client.get(
+ "/accounts/theid/login/callback/",
+ {
+ "ticket": "000000",
+ },
+ )
self.assertLoginFailure(r)
def test_patch_cas_response_accept(self):
- self.patch_cas_response(valid_ticket='__all__')
- r = self.client.get('/accounts/theid/login/callback/', {
- 'ticket': '000000',
- })
+ self.patch_cas_response(valid_ticket="__all__")
+ r = self.client.get(
+ "/accounts/theid/login/callback/",
+ {
+ "ticket": "000000",
+ },
+ )
self.assertLoginSuccess(r)
def test_patch_cas_response_reject(self):
self.patch_cas_response(valid_ticket=None)
- r = self.client.get('/accounts/theid/login/callback/', {
- 'ticket': '000000',
- })
+ r = self.client.get(
+ "/accounts/theid/login/callback/",
+ {
+ "ticket": "000000",
+ },
+ )
self.assertLoginFailure(r)
def test_patch_cas_reponse_multiple(self):
- self.patch_cas_response(valid_ticket='__all__')
+ self.patch_cas_response(valid_ticket="__all__")
client_0 = Client()
- client_0.get('/accounts/theid/login/')
- r_0 = client_0.get('/accounts/theid/login/callback/', {
- 'ticket': '000000',
- })
+ client_0.get("/accounts/theid/login/")
+ r_0 = client_0.get(
+ "/accounts/theid/login/callback/",
+ {
+ "ticket": "000000",
+ },
+ )
self.assertLoginSuccess(r_0)
self.patch_cas_response(valid_ticket=None)
client_1 = Client()
- client_1.get('/accounts/theid/login/')
- r_1 = client_1.get('/accounts/theid/login/callback/', {
- 'ticket': '111111',
- })
+ client_1.get("/accounts/theid/login/")
+ r_1 = client_1.get(
+ "/accounts/theid/login/callback/",
+ {
+ "ticket": "111111",
+ },
+ )
self.assertLoginFailure(r_1)
def test_assertLoginSuccess(self):
- self.patch_cas_response(valid_ticket='__all__')
- r = self.client.get('/accounts/theid/login/callback/', {
- 'ticket': '000000',
- 'next': '/path/',
- })
- self.assertLoginSuccess(r, redirect_to='/path/')
+ self.patch_cas_response(valid_ticket="__all__")
+ r = self.client.get(
+ "/accounts/theid/login/callback/",
+ {
+ "ticket": "000000",
+ "next": "/path/",
+ },
+ )
+ self.assertLoginSuccess(r, redirect_to="/path/")
diff --git a/tests/test_views.py b/tests/test_views.py
index e7b04e1..7130645 100644
--- a/tests/test_views.py
+++ b/tests/test_views.py
@@ -1,11 +1,10 @@
-# -*- coding: utf-8 -*-
try:
from unittest.mock import patch
except ImportError:
- from mock import patch
+ from unittest.mock import patch
-import django
from django.test import RequestFactory, override_settings
+from django.urls import reverse
from allauth_cas.exceptions import CASAuthenticationError
from allauth_cas.test.testcases import CASTestCase, CASViewTestCase
@@ -13,17 +12,11 @@ from allauth_cas.views import CASView
from .example.views import ExampleCASAdapter
-if django.VERSION >= (1, 10):
- from django.urls import reverse
-else:
- from django.core.urlresolvers import reverse
-
class CASAdapterTests(CASTestCase):
-
def setUp(self):
factory = RequestFactory()
- self.request = factory.get('/path/')
+ self.request = factory.get("/path/")
self.request.session = {}
self.adapter = ExampleCASAdapter(self.request)
@@ -31,7 +24,7 @@ class CASAdapterTests(CASTestCase):
"""
Service url (used by CAS client) is the callback url.
"""
- expected = 'http://testserver/accounts/theid/login/callback/'
+ expected = "http://testserver/accounts/theid/login/callback/"
service_url = self.adapter.get_service_url(self.request)
self.assertEqual(expected, service_url)
@@ -39,11 +32,9 @@ class CASAdapterTests(CASTestCase):
"""
Current GET paramater next is appended on service url.
"""
- expected = (
- 'http://testserver/accounts/theid/login/callback/?next=%2Fnext%2F'
- )
+ expected = "http://testserver/accounts/theid/login/callback/?next=%2Fnext%2F"
factory = RequestFactory()
- request = factory.get('/path/', {'next': '/next/'})
+ request = factory.get("/path/", {"next": "/next/"})
adapter = ExampleCASAdapter(request)
service_url = adapter.get_service_url(request)
self.assertEqual(expected, service_url)
@@ -67,14 +58,13 @@ class CASAdapterTests(CASTestCase):
class CASViewTests(CASViewTestCase):
-
class BasicCASView(CASView):
def dispatch(self, request, *args, **kwargs):
return self
def setUp(self):
factory = RequestFactory()
- self.request = factory.get('/path/')
+ self.request = factory.get("/path/")
self.request.session = {}
self.cas_view = self.BasicCASView.adapter_view(ExampleCASAdapter)
@@ -85,27 +75,34 @@ class CASViewTests(CASViewTestCase):
"""
view = self.cas_view(
self.request,
- 'arg1', 'arg2',
- kwarg1='kwarg1', kwarg2='kwarg2',
+ "arg1",
+ "arg2",
+ kwarg1="kwarg1",
+ kwarg2="kwarg2",
)
self.assertIsInstance(view, CASView)
self.assertEqual(view.request, self.request)
- self.assertTupleEqual(view.args, ('arg1', 'arg2'))
- self.assertDictEqual(view.kwargs, {
- 'kwarg1': 'kwarg1',
- 'kwarg2': 'kwarg2',
- })
+ self.assertTupleEqual(view.args, ("arg1", "arg2"))
+ self.assertDictEqual(
+ view.kwargs,
+ {
+ "kwarg1": "kwarg1",
+ "kwarg2": "kwarg2",
+ },
+ )
self.assertIsInstance(view.adapter, ExampleCASAdapter)
- @patch('allauth_cas.views.cas.CASClient')
- @override_settings(SOCIALACCOUNT_PROVIDERS={
- 'theid': {
- 'AUTH_PARAMS': {'key': 'value'},
- },
- })
+ @patch("allauth_cas.views.cas.CASClient")
+ @override_settings(
+ SOCIALACCOUNT_PROVIDERS={
+ "theid": {
+ "AUTH_PARAMS": {"key": "value"},
+ },
+ }
+ )
def test_get_client(self, mock_casclient_class):
"""
get_client returns a CAS client, configured from settings.
@@ -114,11 +111,11 @@ class CASViewTests(CASViewTestCase):
view.get_client(self.request)
mock_casclient_class.assert_called_once_with(
- service_url='http://testserver/accounts/theid/login/callback/',
- server_url='https://server.cas',
+ service_url="http://testserver/accounts/theid/login/callback/",
+ server_url="https://server.cas",
version=2,
renew=False,
- extra_login_params={'key': 'value'},
+ extra_login_params={"key": "value"},
)
def test_render_error_on_failure(self):
@@ -126,33 +123,33 @@ class CASViewTests(CASViewTestCase):
A common login failure page is rendered if CASAuthenticationError is
raised by dispatch.
"""
+
def dispatch_raise(self, request):
raise CASAuthenticationError("failure")
- with patch.object(self.BasicCASView, 'dispatch', dispatch_raise):
+ with patch.object(self.BasicCASView, "dispatch", dispatch_raise):
resp = self.cas_view(self.request)
self.assertLoginFailure(resp)
class CASLoginViewTests(CASViewTestCase):
-
def test_reverse(self):
"""
Login view name is "{provider_id}_login".
"""
- url = reverse('theid_login')
- self.assertEqual('/accounts/theid/login/', url)
+ url = reverse("theid_login")
+ self.assertEqual("/accounts/theid/login/", url)
def test_execute(self):
"""
Login view redirects to the CAS server login url.
Service is the callback url, as absolute uri.
"""
- r = self.client.get('/accounts/theid/login/')
+ r = self.client.get("/accounts/theid/login/")
expected = (
- 'https://server.cas/login?service=http%3A%2F%2Ftestserver%2F'
- 'accounts%2Ftheid%2Flogin%2Fcallback%2F'
+ "https://server.cas/login?service=http%3A%2F%2Ftestserver%2F"
+ "accounts%2Ftheid%2Flogin%2Fcallback%2F"
)
self.assertRedirects(r, expected, fetch_redirect_response=False)
@@ -161,54 +158,59 @@ class CASLoginViewTests(CASViewTestCase):
"""
Current GET parameter 'next' is kept on service url.
"""
- r = self.client.get('/accounts/theid/login/?next=/path/')
+ r = self.client.get("/accounts/theid/login/?next=/path/")
expected = (
- 'https://server.cas/login?service=http%3A%2F%2Ftestserver%2F'
- 'accounts%2Ftheid%2Flogin%2Fcallback%2F%3Fnext%3D%252Fpath%252F'
+ "https://server.cas/login?service=http%3A%2F%2Ftestserver%2F"
+ "accounts%2Ftheid%2Flogin%2Fcallback%2F%3Fnext%3D%252Fpath%252F"
)
self.assertRedirects(r, expected, fetch_redirect_response=False)
class CASCallbackViewTests(CASViewTestCase):
-
def setUp(self):
- self.client.get('/accounts/theid/login/')
+ self.client.get("/accounts/theid/login/")
def test_reverse(self):
"""
Callback view name is "{provider_id}_callback".
"""
- url = reverse('theid_callback')
- self.assertEqual('/accounts/theid/login/callback/', url)
+ url = reverse("theid_callback")
+ self.assertEqual("/accounts/theid/login/callback/", url)
def test_ticket_valid(self):
"""
If ticket is valid, the user is logged in.
"""
- self.patch_cas_response(username='username', valid_ticket='123456')
- r = self.client.get('/accounts/theid/login/callback/', {
- 'ticket': '123456',
- })
+ self.patch_cas_response(username="username", valid_ticket="123456")
+ r = self.client.get(
+ "/accounts/theid/login/callback/",
+ {
+ "ticket": "123456",
+ },
+ )
self.assertLoginSuccess(r)
def test_ticket_invalid(self):
"""
Login failure page is returned if the ticket is invalid.
"""
- self.patch_cas_response(username='username', valid_ticket='123456')
- r = self.client.get('/accounts/theid/login/callback/', {
- 'ticket': '000000',
- })
+ self.patch_cas_response(username="username", valid_ticket="123456")
+ r = self.client.get(
+ "/accounts/theid/login/callback/",
+ {
+ "ticket": "000000",
+ },
+ )
self.assertLoginFailure(r)
def test_ticket_missing(self):
"""
Login failure page is returned if request lacks a ticket.
"""
- self.patch_cas_response(username='username', valid_ticket='123456')
- r = self.client.get('/accounts/theid/login/callback/')
+ self.patch_cas_response(username="username", valid_ticket="123456")
+ r = self.client.get("/accounts/theid/login/callback/")
self.assertLoginFailure(r)
def test_attributes_is_none(self):
@@ -216,31 +218,33 @@ class CASCallbackViewTests(CASViewTestCase):
Without extra attributes, CASClientV2 of python-cas returns None.
"""
self.patch_cas_response(
- username='username', valid_ticket='123456', attributes=None
+ username="username", valid_ticket="123456", attributes=None
+ )
+ r = self.client.get(
+ "/accounts/theid/login/callback/",
+ {
+ "ticket": "123456",
+ },
)
- r = self.client.get('/accounts/theid/login/callback/', {
- 'ticket': '123456',
- })
self.assertLoginSuccess(r)
class CASLogoutViewTests(CASViewTestCase):
-
def test_reverse(self):
"""
Callback view name is "{provider_id}_logout".
"""
- url = reverse('theid_logout')
- self.assertEqual('/accounts/theid/logout/', url)
+ url = reverse("theid_logout")
+ self.assertEqual("/accounts/theid/logout/", url)
def test_execute(self):
"""
Logout view redirects to the CAS server logout url.
Service is a url to here, as absolute uri.
"""
- r = self.client.get('/accounts/theid/logout/')
+ r = self.client.get("/accounts/theid/logout/")
- expected = 'https://server.cas/logout?url=http%3A%2F%2Ftestserver%2F'
+ expected = "https://server.cas/logout?url=http%3A%2F%2Ftestserver%2F"
self.assertRedirects(r, expected, fetch_redirect_response=False)
@@ -248,10 +252,8 @@ class CASLogoutViewTests(CASViewTestCase):
"""
GET parameter 'next' is set as service url.
"""
- r = self.client.get('/accounts/theid/logout/?next=/path/')
+ r = self.client.get("/accounts/theid/logout/?next=/path/")
- expected = (
- 'https://server.cas/logout?url=http%3A%2F%2Ftestserver%2Fpath%2F'
- )
+ expected = "https://server.cas/logout?url=http%3A%2F%2Ftestserver%2Fpath%2F"
self.assertRedirects(r, expected, fetch_redirect_response=False)
diff --git a/tests/urls.py b/tests/urls.py
index 1933289..e33212c 100644
--- a/tests/urls.py
+++ b/tests/urls.py
@@ -1,6 +1,5 @@
-# -*- coding: utf-8 -*-
-from django.conf.urls import include, url
+from django.urls import include, path
urlpatterns = [
- url(r'^accounts/', include('allauth.urls')),
+ path("accounts/", include("allauth.urls")),
]