version 1.0.1. Fixed package for Django 4.2.

This commit is contained in:
João Pires 2024-01-25 14:33:24 +00:00
parent 6657ec6042
commit 77e02f3796
22 changed files with 548 additions and 453 deletions

8
.idea/.gitignore vendored Normal file
View file

@ -0,0 +1,8 @@
# Default ignored files
/shelf/
/workspace.xml
# Editor-based HTTP Client requests
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml

53
.pre-commit-config.yaml Normal file
View file

@ -0,0 +1,53 @@
exclude: "^docs/|/migrations/"
default_stages: [ commit ]
fail_fast: true
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
exclude: ".svg$|.min.*|.idea*"
- id: check-yaml
- id: check-added-large-files
- repo: https://github.com/asottile/pyupgrade
rev: v3.15.0
hooks:
- id: pyupgrade
args: [ --py39-plus ]
- repo: https://github.com/psf/black
rev: 23.12.1
hooks:
- id: black
- repo: https://github.com/PyCQA/isort
rev: 5.13.2
hooks:
- id: isort
args: [ "--profile", "black", "-l=88" ]
- repo: https://github.com/PyCQA/flake8
rev: 7.0.0
hooks:
- id: flake8
args: [ "--config=setup.cfg" ]
additional_dependencies:
- flake8-isort
- flake8-bugbear
- flake8-print
- repo: https://github.com/adamchainz/django-upgrade
rev: 1.15.0
hooks:
- id: django-upgrade
args: [ --target-version, "4.2" ]
# sets up .pre-commit-ci.yaml to ensure pre-commit dependencies stay up to date
ci:
autoupdate_schedule: weekly
skip: [ ]
submodules: false

View file

@ -1,7 +1,3 @@
# -*- coding: utf-8 -*- __version__ = "1.0.1"
__version__ = '1.0.0' CAS_PROVIDER_SESSION_KEY = "allauth_cas__provider_id"
default_app_config = 'allauth_cas.apps.CASAccountConfig'
CAS_PROVIDER_SESSION_KEY = 'allauth_cas__provider_id'

View file

@ -1,10 +1,9 @@
# -*- coding: utf-8 -*-
from django.apps import AppConfig from django.apps import AppConfig
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import gettext_lazy as _
class CASAccountConfig(AppConfig): class CASAccountConfig(AppConfig):
name = 'allauth_cas' name = "allauth_cas"
verbose_name = _("CAS Accounts") verbose_name = _("CAS Accounts")
def ready(self): def ready(self):

View file

@ -1,6 +1,3 @@
# -*- coding: utf-8 -*-
class CASAuthenticationError(Exception): class CASAuthenticationError(Exception):
""" """
Base exception to signal CAS authentication failure. Base exception to signal CAS authentication failure.

View file

@ -1,26 +1,18 @@
# -*- coding: utf-8 -*- from urllib.parse import parse_qsl
from six.moves.urllib.parse import parse_qsl
import django from allauth.socialaccount.providers.base import Provider
from django.contrib import messages from django.contrib import messages
from django.template.loader import render_to_string from django.template.loader import render_to_string
from django.urls import reverse
from django.utils.http import urlencode from django.utils.http import urlencode
from django.utils.safestring import mark_safe from django.utils.safestring import mark_safe
from allauth.socialaccount.providers.base import Provider
if django.VERSION >= (1, 10):
from django.urls import reverse
else:
from django.core.urlresolvers import reverse
class CASProvider(Provider): class CASProvider(Provider):
def get_auth_params(self, request, action): def get_auth_params(self, request, action):
settings = self.get_settings() settings = self.get_settings()
ret = dict(settings.get('AUTH_PARAMS', {})) ret = dict(settings.get("AUTH_PARAMS", {}))
dynamic_auth_params = request.GET.get('auth_params') dynamic_auth_params = request.GET.get("auth_params")
if dynamic_auth_params: if dynamic_auth_params:
ret.update(dict(parse_qsl(dynamic_auth_params))) ret.update(dict(parse_qsl(dynamic_auth_params)))
return ret return ret
@ -68,11 +60,11 @@ class CASProvider(Provider):
""" """
uid, extra = data uid, extra = data
return { return {
'username': extra.get('username', uid), "username": extra.get("username", uid),
'email': extra.get('email'), "email": extra.get("email"),
'first_name': extra.get('first_name'), "first_name": extra.get("first_name"),
'last_name': extra.get('last_name'), "last_name": extra.get("last_name"),
'name': extra.get('name'), "name": extra.get("name"),
} }
def extract_email_addresses(self, data): def extract_email_addresses(self, data):
@ -99,7 +91,7 @@ class CASProvider(Provider):
] ]
""" """
return super(CASProvider, self).extract_email_addresses(data) return super().extract_email_addresses(data)
def extract_extra_data(self, data): def extract_extra_data(self, data):
"""Extract the data to save to `SocialAccount.extra_data`. """Extract the data to save to `SocialAccount.extra_data`.
@ -119,7 +111,10 @@ class CASProvider(Provider):
## ##
def add_message_suggest_caslogout( def add_message_suggest_caslogout(
self, request, next_page=None, level=None, self,
request,
next_page=None,
level=None,
): ):
"""Add a message with a link for the user to logout of the CAS server. """Add a message with a link for the user to logout of the CAS server.
@ -144,14 +139,15 @@ class CASProvider(Provider):
# DefaultAccountAdapter.add_message is unusable because it always # DefaultAccountAdapter.add_message is unusable because it always
# escape the message content. # escape the message content.
template = 'socialaccount/messages/suggest_caslogout.html' template = "socialaccount/messages/suggest_caslogout.html"
context = { context = {
'provider': self, "provider": self,
'logout_url': logout_url, "logout_url": logout_url,
} }
messages.add_message( messages.add_message(
request, level, request,
level,
mark_safe(render_to_string(template, context).strip()), mark_safe(render_to_string(template, context).strip()),
fail_silently=True, fail_silently=True,
) )
@ -168,10 +164,7 @@ class CASProvider(Provider):
signal ``user_logged_out``. signal ``user_logged_out``.
""" """
return ( return self.get_settings().get("MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT", False)
self.get_settings()
.get('MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT', False)
)
def message_suggest_caslogout_on_logout_level(self, request): def message_suggest_caslogout_on_logout_level(self, request):
"""Level of the logout message issued on user logout. """Level of the logout message issued on user logout.
@ -185,9 +178,8 @@ class CASProvider(Provider):
signal ``user_logged_out``. signal ``user_logged_out``.
""" """
return ( return self.get_settings().get(
self.get_settings() "MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT_LEVEL", messages.INFO
.get('MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT_LEVEL', messages.INFO)
) )
## ##
@ -195,19 +187,19 @@ class CASProvider(Provider):
## ##
def get_login_url(self, request, **kwargs): def get_login_url(self, request, **kwargs):
url = reverse(self.id + '_login') url = reverse(self.id + "_login")
if kwargs: if kwargs:
url += '?' + urlencode(kwargs) url += "?" + urlencode(kwargs)
return url return url
def get_callback_url(self, request, **kwargs): def get_callback_url(self, request, **kwargs):
url = reverse(self.id + '_callback') url = reverse(self.id + "_callback")
if kwargs: if kwargs:
url += '?' + urlencode(kwargs) url += "?" + urlencode(kwargs)
return url return url
def get_logout_url(self, request, **kwargs): def get_logout_url(self, request, **kwargs):
url = reverse(self.id + '_logout') url = reverse(self.id + "_logout")
if kwargs: if kwargs:
url += '?' + urlencode(kwargs) url += "?" + urlencode(kwargs)
return url return url

View file

@ -1,10 +1,8 @@
# -*- coding: utf-8 -*-
from django.contrib.auth.signals import user_logged_out
from django.dispatch import receiver
from allauth.account.adapter import get_adapter from allauth.account.adapter import get_adapter
from allauth.account.utils import get_next_redirect_url from allauth.account.utils import get_next_redirect_url
from allauth.socialaccount import providers from allauth.socialaccount import providers
from django.contrib.auth.signals import user_logged_out
from django.dispatch import receiver
from . import CAS_PROVIDER_SESSION_KEY from . import CAS_PROVIDER_SESSION_KEY
@ -21,12 +19,12 @@ def cas_account_logout(sender, request, **kwargs):
if not provider.message_suggest_caslogout_on_logout(request): if not provider.message_suggest_caslogout_on_logout(request):
return return
next_page = ( next_page = get_next_redirect_url(request) or get_adapter(
get_next_redirect_url(request) or request
get_adapter(request).get_logout_redirect_url(request) ).get_logout_redirect_url(request)
)
provider.add_message_suggest_caslogout( provider.add_message_suggest_caslogout(
request, next_page=next_page, request,
next_page=next_page,
level=provider.message_suggest_caslogout_on_logout_level(request), level=provider.message_suggest_caslogout_on_logout_level(request),
) )

View file

@ -1,29 +1,20 @@
# -*- coding: utf-8 -*-
try: try:
from unittest.mock import patch from unittest.mock import patch
except ImportError: except ImportError:
from mock import patch from unittest.mock import patch
import django
from django.conf import settings
from django.test import TestCase
import cas import cas
from django.conf import settings
from django.test import TestCase
from django.urls import reverse
from allauth_cas import CAS_PROVIDER_SESSION_KEY from allauth_cas import CAS_PROVIDER_SESSION_KEY
if django.VERSION >= (1, 10):
from django.urls import reverse
else:
from django.core.urlresolvers import reverse
class CASTestCase(TestCase): class CASTestCase(TestCase):
def client_cas_login( def client_cas_login(
self, self, client, provider_id=None, username=None, attributes=None
client, provider_id='theid', ):
username=None, attributes={}):
""" """
Authenticate client through provider_id. Authenticate client through provider_id.
@ -32,20 +23,22 @@ class CASTestCase(TestCase):
username and attributes control the CAS server response when ticket is username and attributes control the CAS server response when ticket is
checked. checked.
""" """
client.get(reverse('{id}_login'.format(id=provider_id))) if attributes is None:
attributes = {}
if provider_id is None:
provider_id = "theid"
client.get(reverse(f"{provider_id}_login"))
self.patch_cas_response( self.patch_cas_response(
valid_ticket='__all__', valid_ticket="__all__",
username=username, attributes=attributes, username=username,
attributes=attributes,
) )
callback_url = reverse('{id}_callback'.format(id=provider_id)) callback_url = reverse(f"{provider_id}_callback")
r = client.get(callback_url, {'ticket': 'fake-ticket'}) r = client.get(callback_url, {"ticket": "fake-ticket"})
self.patch_cas_response_stop() self.patch_cas_response_stop()
return r return r
def patch_cas_response( def patch_cas_response(self, valid_ticket, username=None, attributes=None):
self,
valid_ticket,
username=None, attributes={}):
""" """
Patch the CASClient class used by views of CAS providers. Patch the CASClient class used by views of CAS providers.
@ -68,35 +61,38 @@ class CASTestCase(TestCase):
Note that valid_ticket sould be a string (which is the type of the Note that valid_ticket sould be a string (which is the type of the
ticket retrieved from GET parameter on request on the callback view). ticket retrieved from GET parameter on request on the callback view).
""" """
if hasattr(self, '_patch_cas_client'): if attributes is None:
attributes = {}
if hasattr(self, "_patch_cas_client"):
self.patch_cas_response_stop() self.patch_cas_response_stop()
class MockCASClient(object): class MockCASClient:
_username = username _username = username
def __new__(self_client, *args, **kwargs): def __new__(self_client, *args, **kwargs):
version = kwargs.pop('version') version = kwargs.pop("version")
if version in (1, '1'): if version in (1, "1"):
client_class = cas.CASClientV1 client_class = cas.CASClientV1
elif version in (2, '2'): elif version in (2, "2"):
client_class = cas.CASClientV2 client_class = cas.CASClientV2
elif version in (3, '3'): elif version in (3, "3"):
client_class = cas.CASClientV3 client_class = cas.CASClientV3
elif version == 'CAS_2_SAML_1_0': elif version == "CAS_2_SAML_1_0":
client_class = cas.CASClientWithSAMLV1 client_class = cas.CASClientWithSAMLV1
else: else:
raise ValueError('Unsupported CAS_VERSION %r' % version) raise ValueError("Unsupported CAS_VERSION %r" % version)
client_class._username = self_client._username client_class._username = self_client._username
def verify_ticket(self, ticket): def verify_ticket(self, ticket):
if valid_ticket == '__all__' or ticket == valid_ticket: if valid_ticket == "__all__" or ticket == valid_ticket:
username = self._username or 'username' username = self._username or "username"
return username, attributes, None return username, attributes, None
return None, {}, None return None, {}, None
patcher = patch.object( patcher = patch.object(
client_class, 'verify_ticket', client_class,
"verify_ticket",
new=verify_ticket, new=verify_ticket,
) )
patcher.start() patcher.start()
@ -104,7 +100,7 @@ class CASTestCase(TestCase):
return client_class(*args, **kwargs) return client_class(*args, **kwargs)
self._patch_cas_client = patch( self._patch_cas_client = patch(
'allauth_cas.views.cas.CASClient', "allauth_cas.views.cas.CASClient",
MockCASClient, MockCASClient,
) )
self._patch_cas_client.start() self._patch_cas_client.start()
@ -114,12 +110,11 @@ class CASTestCase(TestCase):
del self._patch_cas_client del self._patch_cas_client
def tearDown(self): def tearDown(self):
if hasattr(self, '_patch_cas_client'): if hasattr(self, "_patch_cas_client"):
self.patch_cas_response_stop() self.patch_cas_response_stop()
class CASViewTestCase(CASTestCase): class CASViewTestCase(CASTestCase):
def assertLoginSuccess(self, response, redirect_to=None): def assertLoginSuccess(self, response, redirect_to=None):
""" """
Asserts response corresponds to a successful login. Asserts response corresponds to a successful login.
@ -133,7 +128,8 @@ class CASViewTestCase(CASTestCase):
redirect_to = settings.LOGIN_REDIRECT_URL redirect_to = settings.LOGIN_REDIRECT_URL
self.assertRedirects( self.assertRedirects(
response, redirect_to, response,
redirect_to,
fetch_redirect_response=False, fetch_redirect_response=False,
) )
self.assertIn( self.assertIn(
@ -146,6 +142,6 @@ class CASViewTestCase(CASTestCase):
Asserts response corresponds to a failed login. Asserts response corresponds to a failed login.
""" """
return self.assertInHTML( return self.assertInHTML(
'<h1>Social Network Login Failure</h1>', "<h1>Social Network Login Failure</h1>",
str(response.content), str(response.content),
) )

View file

@ -1,5 +1,4 @@
# -*- coding: utf-8 -*- from django.urls import include, path, re_path
from django.conf.urls import include, url
from django.utils.module_loading import import_string from django.utils.module_loading import import_string
@ -7,45 +6,44 @@ def default_urlpatterns(provider):
package = provider.get_package() package = provider.get_package()
try: try:
login_view = import_string(package + '.views.login') login_view = import_string(package + ".views.login")
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"The login view for the '{id}' provider is lacking from the " "The login view for the '{id}' provider is lacking from the "
"'views' module of its app.\n" "'views' module of its app.\n"
"You may want to add:\n" "You may want to add:\n"
"from allauth_cas.views import CASLoginView\n\n" "from allauth_cas.views import CASLoginView\n\n"
"login = CASLoginView.adapter_view(<LocalCASAdapter>)" "login = CASLoginView.adapter_view(<LocalCASAdapter>)".format(
.format(id=provider.id) id=provider.id
)
) )
try: try:
callback_view = import_string(package + '.views.callback') callback_view = import_string(package + ".views.callback")
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"The callback view for the '{id}' provider is lacking from the " "The callback view for the '{id}' provider is lacking from the "
"'views' module of its app.\n" "'views' module of its app.\n"
"You may want to add:\n" "You may want to add:\n"
"from allauth_cas.views import CASCallbackView\n\n" "from allauth_cas.views import CASCallbackView\n\n"
"callback = CASCallbackView.adapter_view(<LocalCASAdapter>)" "callback = CASCallbackView.adapter_view(<LocalCASAdapter>)".format(
.format(id=provider.id) id=provider.id
)
) )
try: try:
logout_view = import_string(package + '.views.logout') logout_view = import_string(package + ".views.logout")
except ImportError: except ImportError:
logout_view = None logout_view = None
urlpatterns = [ urlpatterns = [
url('^login/$', login_view, path("login/", login_view, name=provider.id + "_login"),
name=provider.id + '_login'), path("login/callback/", callback_view, name=provider.id + "_callback"),
url('^login/callback/$', callback_view,
name=provider.id + '_callback'),
] ]
if logout_view is not None: if logout_view is not None:
urlpatterns += [ urlpatterns += [
url('^logout/$', logout_view, path("logout/", logout_view, name=provider.id + "_logout"),
name=provider.id + '_logout'),
] ]
return [url('^' + provider.get_slug() + '/', include(urlpatterns))] return [re_path("^" + provider.get_slug() + "/", include(urlpatterns))]

View file

@ -1,28 +1,26 @@
# -*- coding: utf-8 -*- import cas
from django.http import HttpResponseRedirect
from django.utils.functional import cached_property
from allauth.account.adapter import get_adapter from allauth.account.adapter import get_adapter
from allauth.account.utils import get_next_redirect_url from allauth.account.utils import get_next_redirect_url
from allauth.socialaccount import providers from allauth.socialaccount import providers
from allauth.socialaccount.helpers import ( from allauth.socialaccount.helpers import (
complete_social_login, render_authentication_error, complete_social_login,
render_authentication_error,
) )
from allauth.socialaccount.models import SocialLogin from allauth.socialaccount.models import SocialLogin
from django.http import HttpResponseRedirect
import cas from django.utils.functional import cached_property
from . import CAS_PROVIDER_SESSION_KEY from . import CAS_PROVIDER_SESSION_KEY
from .exceptions import CASAuthenticationError from .exceptions import CASAuthenticationError
class AuthAction(object): class AuthAction:
AUTHENTICATE = 'authenticate' AUTHENTICATE = "authenticate"
REAUTHENTICATE = 'reauthenticate' REAUTHENTICATE = "reauthenticate"
DEAUTHENTICATE = 'deauthenticate' DEAUTHENTICATE = "deauthenticate"
class CASAdapter(object): class CASAdapter:
#: CAS server url. #: CAS server url.
url = None url = None
#: CAS server version. #: CAS server version.
@ -92,19 +90,19 @@ class CASAdapter(object):
""" """
redirect_to = get_next_redirect_url(request) redirect_to = get_next_redirect_url(request)
callback_kwargs = {'next': redirect_to} if redirect_to else {} callback_kwargs = {"next": redirect_to} if redirect_to else {}
callback_url = ( callback_url = self.provider.get_callback_url(request, **callback_kwargs)
self.provider.get_callback_url(request, **callback_kwargs))
service_url = request.build_absolute_uri(callback_url) service_url = request.build_absolute_uri(callback_url)
return service_url return service_url
class CASView(object): class CASView:
""" """
Base class for CAS views. Base class for CAS views.
""" """
@classmethod @classmethod
def adapter_view(cls, adapter): def adapter_view(cls, adapter):
"""Transform the view class into a view function. """Transform the view class into a view function.
@ -124,6 +122,7 @@ class CASView(object):
""" """
def view(request, *args, **kwargs): def view(request, *args, **kwargs):
# Prepare the func-view. # Prepare the func-view.
self = cls() self = cls()
@ -169,19 +168,17 @@ class CASView(object):
class CASLoginView(CASView): class CASLoginView(CASView):
def dispatch(self, request): def dispatch(self, request):
""" """
Redirects to the CAS server login page. Redirects to the CAS server login page.
""" """
action = request.GET.get('action', AuthAction.AUTHENTICATE) action = request.GET.get("action", AuthAction.AUTHENTICATE)
SocialLogin.stash_state(request) SocialLogin.stash_state(request)
client = self.get_client(request, action=action) client = self.get_client(request, action=action)
return HttpResponseRedirect(client.get_login_url()) return HttpResponseRedirect(client.get_login_url())
class CASCallbackView(CASView): class CASCallbackView(CASView):
def dispatch(self, request): def dispatch(self, request):
""" """
The CAS server redirects the user to this view after a successful The CAS server redirects the user to this view after a successful
@ -195,11 +192,9 @@ class CASCallbackView(CASView):
# CAS server should let a ticket. # CAS server should let a ticket.
try: try:
ticket = request.GET['ticket'] ticket = request.GET["ticket"]
except KeyError: except KeyError:
raise CASAuthenticationError( raise CASAuthenticationError("CAS server didn't respond with a ticket.")
"CAS server didn't respond with a ticket."
)
# Check ticket validity. # Check ticket validity.
# Response format on: # Response format on:
@ -210,9 +205,7 @@ class CASCallbackView(CASView):
uid, extra, _ = response uid, extra, _ = response
if not uid: if not uid:
raise CASAuthenticationError( raise CASAuthenticationError("CAS server doesn't validate the ticket.")
"CAS server doesn't validate the ticket."
)
# Keep tracks of the last used CAS provider. # Keep tracks of the last used CAS provider.
request.session[CAS_PROVIDER_SESSION_KEY] = self.provider.id request.session[CAS_PROVIDER_SESSION_KEY] = self.provider.id
@ -226,7 +219,6 @@ class CASCallbackView(CASView):
class CASLogoutView(CASView): class CASLogoutView(CASView):
def dispatch(self, request, next_page=None): def dispatch(self, request, next_page=None):
""" """
Redirects to the CAS server logout page. Redirects to the CAS server logout page.
@ -248,7 +240,6 @@ class CASLogoutView(CASView):
Returns the url to redirect after logout. Returns the url to redirect after logout.
""" """
request = self.request request = self.request
return ( return get_next_redirect_url(request) or get_adapter(
get_next_redirect_url(request) or request
get_adapter(request).get_logout_redirect_url(request) ).get_logout_redirect_url(request)
)

View file

@ -1,5 +1,4 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*-
import os import os
import sys import sys
@ -7,10 +6,10 @@ import django
from django.conf import settings from django.conf import settings
from django.test.utils import get_runner from django.test.utils import get_runner
if __name__ == '__main__': if __name__ == "__main__":
os.environ['DJANGO_SETTINGS_MODULE'] = 'tests.settings' os.environ["DJANGO_SETTINGS_MODULE"] = "tests.settings"
django.setup() django.setup()
TestRunner = get_runner(settings) TestRunner = get_runner(settings)
test_runner = TestRunner() test_runner = TestRunner()
failures = test_runner.run_tests(sys.argv[1:] or ['tests']) failures = test_runner.run_tests(sys.argv[1:] or ["tests"])
sys.exit(bool(failures)) sys.exit(bool(failures))

View file

@ -1,18 +1,45 @@
[flake8] [flake8]
max-line-length = 120
exclude = .tox,.git,*/migrations/*,*/static/CACHE/*,docs,node_modules,venv
# E731: lambda expression # E731: lambda expression
ignore = E731 ignore = E731
[pycodestyle]
max-line-length = 120
exclude = .tox,.git,*/migrations/*,*/static/CACHE/*,docs,node_modules,venv
[isort] [isort]
combine_as_imports = True line_length = 88
known_first_party = geoip,config
multi_line_output = 3
default_section = THIRDPARTY default_section = THIRDPARTY
include_trailing_comma = True skip = venv/
known_allauth = allauth skip_glob = **/migrations/*.py
known_future_library = future,six include_trailing_comma = true
known_django = django force_grid_wrap = 0
known_first_party = allauth_cas use_parentheses = true
multi_line_output = 5
not_skip = __init__.py [mypy]
sections = FUTURE,STDLIB,DJANGO,ALLAUTH,THIRDPARTY,FIRSTPARTY,LOCALFOLDER python_version = 3.9
check_untyped_defs = True
ignore_missing_imports = True
warn_unused_ignores = True
warn_redundant_casts = True
warn_unused_configs = True
plugins = mypy_django_plugin.main
[mypy.plugins.django-stubs]
django_settings_module = config.settings.test
[mypy-*.migrations.*]
# Django migrations should not produce any errors:
ignore_errors = True
[coverage:run]
include = geoip/*
omit = *migrations*, *tests*
plugins =
django_coverage_plugin
[bdist_wheel] [bdist_wheel]
universal = 1 universal = 1

View file

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
import os import os
from setuptools import find_packages, setup from setuptools import find_packages, setup
@ -7,49 +6,56 @@ from allauth_cas import __version__
BASE_DIR = os.path.dirname(__file__) BASE_DIR = os.path.dirname(__file__)
with open(os.path.join(BASE_DIR, 'README.rst')) as readme: with open(os.path.join(BASE_DIR, "README.rst")) as readme:
README = readme.read() README = readme.read()
setup( setup(
name='django-allauth-cas', name="django-allauth-cas",
version=__version__, version=__version__,
description='CAS support for django-allauth.', description="CAS support for django-allauth.",
author='Aurélien Delobelle', author="Aurélien Delobelle",
author_email='aurelien.delobelle@gmail.com', author_email="aurelien.delobelle@gmail.com",
keywords='django allauth cas authentication', keywords="django allauth cas authentication",
long_description=README, long_description=README,
url='https://github.com/aureplop/django-allauth-cas', url="https://github.com/aureplop/django-allauth-cas",
classifiers=[ classifiers=[
'Development Status :: 4 - Beta', "Development Status :: 4 - Beta",
'Environment :: Web Environment', "Environment :: Web Environment",
'Framework :: Django', "Framework :: Django",
'Framework :: Django :: 1.8', "Framework :: Django :: 1.8",
'Framework :: Django :: 1.9', "Framework :: Django :: 1.9",
'Framework :: Django :: 1.10', "Framework :: Django :: 1.10",
'Framework :: Django :: 1.11', "Framework :: Django :: 1.11",
'Framework :: Django :: 2.0', "Framework :: Django :: 2.0",
'Intended Audience :: Developers', "Framework :: Django :: 3.0",
'License :: OSI Approved :: MIT License', "Framework :: Django :: 4.0",
'Operating System :: OS Independent', "Framework :: Django :: 5.0",
'Programming Language :: Python', "Intended Audience :: Developers",
'Programming Language :: Python :: 2', "License :: OSI Approved :: MIT License",
'Programming Language :: Python :: 2.7', "Operating System :: OS Independent",
'Programming Language :: Python :: 3', "Programming Language :: Python",
'Programming Language :: Python :: 3.4', "Programming Language :: Python :: 2",
'Programming Language :: Python :: 3.5', "Programming Language :: Python :: 2.7",
'Programming Language :: Python :: 3.6', "Programming Language :: Python :: 3",
'Topic :: Internet :: WWW/HTTP', "Programming Language :: Python :: 3.4",
"Programming Language :: Python :: 3.5",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.11",
"Topic :: Internet :: WWW/HTTP",
], ],
license='MIT', license="MIT",
packages=find_packages(exclude=['tests']), packages=find_packages(exclude=["tests"]),
include_package_data=True, include_package_data=True,
install_requires=[ install_requires=[
'django-allauth', "django-allauth",
'python-cas', "python-cas",
'six', "six",
], ],
extras_require={ extras_require={
'docs': ['sphinx'], "docs": ["sphinx"],
'tests': ['tox'], "tests": ["tox"],
}, },
) )

View file

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
from allauth.socialaccount.providers.base import ProviderAccount from allauth.socialaccount.providers.base import ProviderAccount
from allauth_cas.providers import CASProvider from allauth_cas.providers import CASProvider
@ -9,8 +8,8 @@ class ExampleCASAccount(ProviderAccount):
class ExampleCASProvider(CASProvider): class ExampleCASProvider(CASProvider):
id = 'theid' id = "theid"
name = 'The Provider' name = "The Provider"
account_class = ExampleCASAccount account_class = ExampleCASAccount

View file

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
from allauth_cas.urls import default_urlpatterns from allauth_cas.urls import default_urlpatterns
from .provider import ExampleCASProvider from .provider import ExampleCASProvider

View file

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
from allauth_cas import views from allauth_cas import views
from .provider import ExampleCASProvider from .provider import ExampleCASProvider
@ -6,7 +5,7 @@ from .provider import ExampleCASProvider
class ExampleCASAdapter(views.CASAdapter): class ExampleCASAdapter(views.CASAdapter):
provider_id = ExampleCASProvider.id provider_id = ExampleCASProvider.id
url = 'https://server.cas' url = "https://server.cas"
version = 2 version = 2

View file

@ -1,63 +1,54 @@
# -*- coding: utf-8 -*- SECRET_KEY = "iamabird"
import django
SECRET_KEY = 'iamabird'
INSTALLED_APPS = [ INSTALLED_APPS = [
'django.contrib.admin', "django.contrib.admin",
'django.contrib.auth', "django.contrib.auth",
'django.contrib.contenttypes', "django.contrib.contenttypes",
'django.contrib.messages', "django.contrib.messages",
'django.contrib.sessions', "django.contrib.sessions",
'django.contrib.sites', "django.contrib.sites",
'django.contrib.staticfiles', "django.contrib.staticfiles",
"allauth",
'allauth', "allauth.account",
'allauth.account', "allauth.socialaccount",
'allauth.socialaccount', "allauth_cas",
"tests.example", # Dummy CAS provider app
'allauth_cas',
'tests.example', # Dummy CAS provider app
] ]
DATABASES = { DATABASES = {
'default': { "default": {
'ENGINE': 'django.db.backends.sqlite3', "ENGINE": "django.db.backends.sqlite3",
}, },
} }
AUTHENTICATION_BACKENDS = [ AUTHENTICATION_BACKENDS = [
'allauth.account.auth_backends.AuthenticationBackend', "allauth.account.auth_backends.AuthenticationBackend",
] ]
_MIDDLEWARES = [ _MIDDLEWARES = [
'django.middleware.common.CommonMiddleware', "django.middleware.common.CommonMiddleware",
'django.contrib.sessions.middleware.SessionMiddleware', "django.contrib.sessions.middleware.SessionMiddleware",
'django.middleware.csrf.CsrfViewMiddleware', "django.middleware.csrf.CsrfViewMiddleware",
'django.contrib.auth.middleware.AuthenticationMiddleware', "django.contrib.auth.middleware.AuthenticationMiddleware",
'django.contrib.messages.middleware.MessageMiddleware', "django.contrib.messages.middleware.MessageMiddleware",
] ]
if django.VERSION >= (1, 10):
MIDDLEWARE = _MIDDLEWARES MIDDLEWARE = _MIDDLEWARES
else:
MIDDLEWARE_CLASSES = _MIDDLEWARES
TEMPLATES = [ TEMPLATES = [
{ {
'BACKEND': 'django.template.backends.django.DjangoTemplates', "BACKEND": "django.template.backends.django.DjangoTemplates",
'DIRS': [], "DIRS": [],
'APP_DIRS': True, "APP_DIRS": True,
'OPTIONS': { "OPTIONS": {
'context_processors': [ "context_processors": [
'django.template.context_processors.debug', "django.template.context_processors.debug",
'django.template.context_processors.request', "django.template.context_processors.request",
'django.contrib.auth.context_processors.auth', "django.contrib.auth.context_processors.auth",
'django.contrib.messages.context_processors.messages', "django.contrib.messages.context_processors.messages",
], ],
}, },
}, },
] ]
ROOT_URLCONF = 'tests.urls' ROOT_URLCONF = "tests.urls"

View file

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
from django.contrib import messages from django.contrib import messages
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from django.contrib.messages.api import get_messages from django.contrib.messages.api import get_messages
@ -13,7 +12,7 @@ User = get_user_model()
class LogoutFlowTests(CASTestCase): class LogoutFlowTests(CASTestCase):
expected_msg_str = ( expected_msg_str = (
"To logout of The Provider, please close your browser, or visit this " "To logout of The Provider, please close your browser, or visit this "
"<a href=\"/accounts/theid/logout/?next=%2Fredir%2F\">" '<a href="/accounts/theid/logout/?next=%2Fredir%2F">'
"link</a>." "link</a>."
) )
@ -28,42 +27,45 @@ class LogoutFlowTests(CASTestCase):
) )
self.assertTemplateNotUsed( self.assertTemplateNotUsed(
response, response,
'socialaccount/messages/suggest_caslogout.html', "socialaccount/messages/suggest_caslogout.html",
) )
@override_settings(SOCIALACCOUNT_PROVIDERS={ @override_settings(
'theid': { SOCIALACCOUNT_PROVIDERS={
'MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT': True, "theid": {
'MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT_LEVEL': messages.WARNING, "MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT": True,
"MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT_LEVEL": messages.WARNING,
}, },
}) }
)
def test_message_on_logout(self): def test_message_on_logout(self):
""" """
Message is sent to propose user to logout of CAS. Message is sent to propose user to logout of CAS.
""" """
r = self.client.post('/accounts/logout/?next=/redir/') r = self.client.post("/accounts/logout/?next=/redir/")
r_messages = get_messages(r.wsgi_request) r_messages = get_messages(r.wsgi_request)
expected_msg = Message(messages.WARNING, self.expected_msg_str) expected_msg = Message(messages.WARNING, self.expected_msg_str)
self.assertIn(expected_msg, r_messages) self.assertIn(expected_msg, r_messages)
self.assertTemplateUsed( self.assertTemplateUsed(r, "socialaccount/messages/suggest_caslogout.html")
r, 'socialaccount/messages/suggest_caslogout.html')
def test_message_on_logout_disabled(self): def test_message_on_logout_disabled(self):
r = self.client.post('/accounts/logout/') r = self.client.post("/accounts/logout/")
self.assertCASLogoutNotInMessages(r) self.assertCASLogoutNotInMessages(r)
@override_settings(SOCIALACCOUNT_PROVIDERS={ @override_settings(
'theid': {'MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT': True}, SOCIALACCOUNT_PROVIDERS={
}) "theid": {"MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT": True},
}
)
def test_other_logout(self): def test_other_logout(self):
""" """
The CAS logout message doesn't appear with other login methods. The CAS logout message doesn't appear with other login methods.
""" """
User.objects.create_user('user', '', 'user') User.objects.create_user("user", "", "user")
client = Client() client = Client()
client.login(username='user', password='user') client.login(username="user", password="user")
r = client.post('/accounts/logout/') r = client.post("/accounts/logout/")
self.assertCASLogoutNotInMessages(r) self.assertCASLogoutNotInMessages(r)

View file

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*- from urllib.parse import urlencode
from six.moves.urllib.parse import urlencode
from allauth.socialaccount.providers import registry
from django.contrib import messages from django.contrib import messages
from django.contrib.messages.api import get_messages from django.contrib.messages.api import get_messages
from django.contrib.messages.middleware import MessageMiddleware from django.contrib.messages.middleware import MessageMiddleware
@ -8,21 +8,18 @@ from django.contrib.messages.storage.base import Message
from django.contrib.sessions.middleware import SessionMiddleware from django.contrib.sessions.middleware import SessionMiddleware
from django.test import RequestFactory, TestCase, override_settings from django.test import RequestFactory, TestCase, override_settings
from allauth.socialaccount.providers import registry
from allauth_cas.views import AuthAction from allauth_cas.views import AuthAction
from .example.provider import ExampleCASProvider from .example.provider import ExampleCASProvider
class CASProviderTests(TestCase): class CASProviderTests(TestCase):
def setUp(self): def setUp(self):
self.request = self._get_request() self.request = self._get_request()
self.provider = ExampleCASProvider(self.request) self.provider = ExampleCASProvider(self.request)
def _get_request(self): def _get_request(self):
request = RequestFactory().get('/test/') request = RequestFactory().get("/test/")
SessionMiddleware().process_request(request) SessionMiddleware().process_request(request)
MessageMiddleware().process_request(request) MessageMiddleware().process_request(request)
return request return request
@ -31,74 +28,81 @@ class CASProviderTests(TestCase):
""" """
Example CAS provider is registered as social account provider. Example CAS provider is registered as social account provider.
""" """
self.assertIsInstance(registry.by_id('theid'), ExampleCASProvider) self.assertIsInstance(registry.by_id("theid"), ExampleCASProvider)
def test_get_login_url(self): def test_get_login_url(self):
url = self.provider.get_login_url(self.request) url = self.provider.get_login_url(self.request)
self.assertEqual('/accounts/theid/login/', url) self.assertEqual("/accounts/theid/login/", url)
url_with_qs = self.provider.get_login_url( url_with_qs = self.provider.get_login_url(
self.request, self.request,
next='/path?quéry=string&two=whoam%C3%AF', next="/path?quéry=string&two=whoam%C3%AF",
) )
self.assertEqual( self.assertEqual(
url_with_qs, url_with_qs,
'/accounts/theid/login/?next=%2Fpath%3Fqu%C3%A9ry%3Dstring%26two%3' "/accounts/theid/login/?next=%2Fpath%3Fqu%C3%A9ry%3Dstring%26two%3"
'Dwhoam%25C3%25AF' "Dwhoam%25C3%25AF",
) )
def test_get_callback_url(self): def test_get_callback_url(self):
url = self.provider.get_callback_url(self.request) url = self.provider.get_callback_url(self.request)
self.assertEqual('/accounts/theid/login/callback/', url) self.assertEqual("/accounts/theid/login/callback/", url)
url_with_qs = self.provider.get_callback_url( url_with_qs = self.provider.get_callback_url(
self.request, self.request,
next='/path?quéry=string&two=whoam%C3%AF', next="/path?quéry=string&two=whoam%C3%AF",
) )
self.assertEqual( self.assertEqual(
url_with_qs, url_with_qs,
'/accounts/theid/login/callback/?next=%2Fpath%3Fqu%C3%A9ry%3Dstrin' "/accounts/theid/login/callback/?next=%2Fpath%3Fqu%C3%A9ry%3Dstrin"
'g%26two%3Dwhoam%25C3%25AF' "g%26two%3Dwhoam%25C3%25AF",
) )
def test_get_logout_url(self): def test_get_logout_url(self):
url = self.provider.get_logout_url(self.request) url = self.provider.get_logout_url(self.request)
self.assertEqual('/accounts/theid/logout/', url) self.assertEqual("/accounts/theid/logout/", url)
url_with_qs = self.provider.get_logout_url( url_with_qs = self.provider.get_logout_url(
self.request, self.request,
next='/path?quéry=string&two=whoam%C3%AF', next="/path?quéry=string&two=whoam%C3%AF",
) )
self.assertEqual( self.assertEqual(
url_with_qs, url_with_qs,
'/accounts/theid/logout/?next=%2Fpath%3Fqu%C3%A9ry%3Dstring%26two%' "/accounts/theid/logout/?next=%2Fpath%3Fqu%C3%A9ry%3Dstring%26two%"
'3Dwhoam%25C3%25AF' "3Dwhoam%25C3%25AF",
) )
@override_settings(SOCIALACCOUNT_PROVIDERS={ @override_settings(
'theid': { SOCIALACCOUNT_PROVIDERS={
'AUTH_PARAMS': {'key': 'value'}, "theid": {
"AUTH_PARAMS": {"key": "value"},
}, },
}) }
)
def test_get_auth_params(self): def test_get_auth_params(self):
action = AuthAction.AUTHENTICATE action = AuthAction.AUTHENTICATE
auth_params = self.provider.get_auth_params(self.request, action) auth_params = self.provider.get_auth_params(self.request, action)
self.assertDictEqual(auth_params, { self.assertDictEqual(
'key': 'value', auth_params,
}) {
"key": "value",
@override_settings(SOCIALACCOUNT_PROVIDERS={
'theid': {
'AUTH_PARAMS': {'key': 'value'},
}, },
}) )
@override_settings(
SOCIALACCOUNT_PROVIDERS={
"theid": {
"AUTH_PARAMS": {"key": "value"},
},
}
)
def test_get_auth_params_with_dynamic(self): def test_get_auth_params_with_dynamic(self):
factory = RequestFactory() factory = RequestFactory()
request = factory.get( request = factory.get(
'/test/?auth_params=next%3Dtwo%253Dwhoam%2525C3%2525AF%2526qu%2525' "/test/?auth_params=next%3Dtwo%253Dwhoam%2525C3%2525AF%2526qu%2525"
'C3%2525A9ry%253Dstring' "C3%2525A9ry%253Dstring"
) )
request.session = {} request.session = {}
@ -106,15 +110,18 @@ class CASProviderTests(TestCase):
auth_params = self.provider.get_auth_params(request, action) auth_params = self.provider.get_auth_params(request, action)
self.assertDictEqual(auth_params, { self.assertDictEqual(
'key': 'value', auth_params,
'next': 'two=whoam%C3%AF&qu%C3%A9ry=string', {
}) "key": "value",
"next": "two=whoam%C3%AF&qu%C3%A9ry=string",
},
)
def test_add_message_suggest_caslogout(self): def test_add_message_suggest_caslogout(self):
expected_msg_base_str = ( expected_msg_base_str = (
"To logout of The Provider, please close your browser, or visit " "To logout of The Provider, please close your browser, or visit "
"this <a href=\"/accounts/theid/logout/?{}\">link</a>." 'this <a href="/accounts/theid/logout/?{}">link</a>.'
) )
# Defaults. # Defaults.
@ -124,7 +131,7 @@ class CASProviderTests(TestCase):
expected_msg1 = Message( expected_msg1 = Message(
messages.INFO, messages.INFO,
expected_msg_base_str.format(urlencode({'next': '/test/'})), expected_msg_base_str.format(urlencode({"next": "/test/"})),
) )
self.assertIn(expected_msg1, get_messages(req1)) self.assertIn(expected_msg1, get_messages(req1))
@ -132,69 +139,83 @@ class CASProviderTests(TestCase):
req2 = self._get_request() req2 = self._get_request()
self.provider.add_message_suggest_caslogout( self.provider.add_message_suggest_caslogout(
req2, next_page='/redir/', level=messages.WARNING) req2, next_page="/redir/", level=messages.WARNING
)
expected_msg2 = Message( expected_msg2 = Message(
messages.WARNING, messages.WARNING,
expected_msg_base_str.format(urlencode({'next': '/redir/'})), expected_msg_base_str.format(urlencode({"next": "/redir/"})),
) )
self.assertIn(expected_msg2, get_messages(req2)) self.assertIn(expected_msg2, get_messages(req2))
def test_message_suggest_caslogout_on_logout(self): def test_message_suggest_caslogout_on_logout(self):
self.assertFalse( self.assertFalse(
self.provider.message_suggest_caslogout_on_logout(self.request)) self.provider.message_suggest_caslogout_on_logout(self.request)
with override_settings(SOCIALACCOUNT_PROVIDERS={
'theid': {'MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT': True},
}):
self.assertTrue(
self.provider
.message_suggest_caslogout_on_logout(self.request)
) )
@override_settings(SOCIALACCOUNT_PROVIDERS={ with override_settings(
'theid': { SOCIALACCOUNT_PROVIDERS={
'MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT_LEVEL': messages.WARNING, "theid": {"MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT": True},
}
):
self.assertTrue(
self.provider.message_suggest_caslogout_on_logout(self.request)
)
@override_settings(
SOCIALACCOUNT_PROVIDERS={
"theid": {
"MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT_LEVEL": messages.WARNING,
}, },
}) }
)
def test_message_suggest_caslogout_on_logout_level(self): def test_message_suggest_caslogout_on_logout_level(self):
self.assertEqual(messages.WARNING, ( self.assertEqual(
self.provider messages.WARNING,
.message_suggest_caslogout_on_logout_level(self.request) (self.provider.message_suggest_caslogout_on_logout_level(self.request)),
)) )
def test_extract_uid(self): def test_extract_uid(self):
response = 'useRName', {} response = "useRName", {}
uid = self.provider.extract_uid(response) uid = self.provider.extract_uid(response)
self.assertEqual('useRName', uid) self.assertEqual("useRName", uid)
def test_extract_common_fields(self): def test_extract_common_fields(self):
response = 'useRName', {} response = "useRName", {}
common_fields = self.provider.extract_common_fields(response) common_fields = self.provider.extract_common_fields(response)
self.assertDictEqual(common_fields, { self.assertDictEqual(
'username': 'useRName', common_fields,
'first_name': None, {
'last_name': None, "username": "useRName",
'name': None, "first_name": None,
'email': None, "last_name": None,
}) "name": None,
"email": None,
},
)
def test_extract_common_fields_with_extra(self): def test_extract_common_fields_with_extra(self):
response = 'useRName', {'username': 'user', 'email': 'user@mail.net'} response = "useRName", {"username": "user", "email": "user@mail.net"}
common_fields = self.provider.extract_common_fields(response) common_fields = self.provider.extract_common_fields(response)
self.assertDictEqual(common_fields, { self.assertDictEqual(
'username': 'user', common_fields,
'first_name': None, {
'last_name': None, "username": "user",
'name': None, "first_name": None,
'email': 'user@mail.net', "last_name": None,
}) "name": None,
"email": "user@mail.net",
},
)
def test_extract_extra_data(self): def test_extract_extra_data(self):
response = 'useRName', {'user_attr': 'thevalue', 'another': 'value'} response = "useRName", {"user_attr": "thevalue", "another": "value"}
extra_data = self.provider.extract_extra_data(response) extra_data = self.provider.extract_extra_data(response)
self.assertDictEqual(extra_data, { self.assertDictEqual(
'user_attr': 'thevalue', extra_data,
'another': 'value', {
'uid': 'useRName', "user_attr": "thevalue",
}) "another": "value",
"uid": "useRName",
},
)

View file

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
from django.test import Client, RequestFactory from django.test import Client, RequestFactory
from allauth_cas.test.testcases import CASViewTestCase from allauth_cas.test.testcases import CASViewTestCase
@ -8,9 +7,8 @@ from .example.views import ExampleCASAdapter
class CASTestCaseTests(CASViewTestCase): class CASTestCaseTests(CASViewTestCase):
def setUp(self): def setUp(self):
self.client.get('/accounts/theid/login/') self.client.get("/accounts/theid/login/")
def test_patch_cas_response_client_version(self): def test_patch_cas_response_client_version(self):
""" """
@ -21,20 +19,24 @@ class CASTestCaseTests(CASViewTestCase):
""" """
valid_versions = [ valid_versions = [
1, '1', 1,
2, '2', "1",
3, '3', 2,
'CAS_2_SAML_1_0', "2",
3,
"3",
"CAS_2_SAML_1_0",
] ]
invalid_versions = [ invalid_versions = [
'not_supported', "not_supported",
] ]
factory = RequestFactory() factory = RequestFactory()
request = factory.get('/path/') request = factory.get("/path/")
request.session = {} request.session = {}
for _version in valid_versions + invalid_versions: for _version in valid_versions + invalid_versions:
class BasicCASAdapter(ExampleCASAdapter): class BasicCASAdapter(ExampleCASAdapter):
version = _version version = _version
@ -47,7 +49,7 @@ class CASTestCaseTests(CASViewTestCase):
if _version in valid_versions: if _version in valid_versions:
raw_client = view(request) raw_client = view(request)
self.patch_cas_response(valid_ticket='__all__') self.patch_cas_response(valid_ticket="__all__")
mocked_client = view(request) mocked_client = view(request)
self.assertEqual(type(raw_client), type(mocked_client)) self.assertEqual(type(raw_client), type(mocked_client))
@ -55,58 +57,79 @@ class CASTestCaseTests(CASViewTestCase):
# This is a sanity check. # This is a sanity check.
self.assertRaises(ValueError, view, request) self.assertRaises(ValueError, view, request)
self.patch_cas_response(valid_ticket='__all__') self.patch_cas_response(valid_ticket="__all__")
self.assertRaises(ValueError, view, request) self.assertRaises(ValueError, view, request)
def test_patch_cas_response_verify_success(self): def test_patch_cas_response_verify_success(self):
self.patch_cas_response(valid_ticket='123456') self.patch_cas_response(valid_ticket="123456")
r = self.client.get('/accounts/theid/login/callback/', { r = self.client.get(
'ticket': '123456', "/accounts/theid/login/callback/",
}) {
"ticket": "123456",
},
)
self.assertLoginSuccess(r) self.assertLoginSuccess(r)
def test_patch_cas_response_verify_failure(self): def test_patch_cas_response_verify_failure(self):
self.patch_cas_response(valid_ticket='123456') self.patch_cas_response(valid_ticket="123456")
r = self.client.get('/accounts/theid/login/callback/', { r = self.client.get(
'ticket': '000000', "/accounts/theid/login/callback/",
}) {
"ticket": "000000",
},
)
self.assertLoginFailure(r) self.assertLoginFailure(r)
def test_patch_cas_response_accept(self): def test_patch_cas_response_accept(self):
self.patch_cas_response(valid_ticket='__all__') self.patch_cas_response(valid_ticket="__all__")
r = self.client.get('/accounts/theid/login/callback/', { r = self.client.get(
'ticket': '000000', "/accounts/theid/login/callback/",
}) {
"ticket": "000000",
},
)
self.assertLoginSuccess(r) self.assertLoginSuccess(r)
def test_patch_cas_response_reject(self): def test_patch_cas_response_reject(self):
self.patch_cas_response(valid_ticket=None) self.patch_cas_response(valid_ticket=None)
r = self.client.get('/accounts/theid/login/callback/', { r = self.client.get(
'ticket': '000000', "/accounts/theid/login/callback/",
}) {
"ticket": "000000",
},
)
self.assertLoginFailure(r) self.assertLoginFailure(r)
def test_patch_cas_reponse_multiple(self): def test_patch_cas_reponse_multiple(self):
self.patch_cas_response(valid_ticket='__all__') self.patch_cas_response(valid_ticket="__all__")
client_0 = Client() client_0 = Client()
client_0.get('/accounts/theid/login/') client_0.get("/accounts/theid/login/")
r_0 = client_0.get('/accounts/theid/login/callback/', { r_0 = client_0.get(
'ticket': '000000', "/accounts/theid/login/callback/",
}) {
"ticket": "000000",
},
)
self.assertLoginSuccess(r_0) self.assertLoginSuccess(r_0)
self.patch_cas_response(valid_ticket=None) self.patch_cas_response(valid_ticket=None)
client_1 = Client() client_1 = Client()
client_1.get('/accounts/theid/login/') client_1.get("/accounts/theid/login/")
r_1 = client_1.get('/accounts/theid/login/callback/', { r_1 = client_1.get(
'ticket': '111111', "/accounts/theid/login/callback/",
}) {
"ticket": "111111",
},
)
self.assertLoginFailure(r_1) self.assertLoginFailure(r_1)
def test_assertLoginSuccess(self): def test_assertLoginSuccess(self):
self.patch_cas_response(valid_ticket='__all__') self.patch_cas_response(valid_ticket="__all__")
r = self.client.get('/accounts/theid/login/callback/', { r = self.client.get(
'ticket': '000000', "/accounts/theid/login/callback/",
'next': '/path/', {
}) "ticket": "000000",
self.assertLoginSuccess(r, redirect_to='/path/') "next": "/path/",
},
)
self.assertLoginSuccess(r, redirect_to="/path/")

View file

@ -1,11 +1,10 @@
# -*- coding: utf-8 -*-
try: try:
from unittest.mock import patch from unittest.mock import patch
except ImportError: except ImportError:
from mock import patch from unittest.mock import patch
import django
from django.test import RequestFactory, override_settings from django.test import RequestFactory, override_settings
from django.urls import reverse
from allauth_cas.exceptions import CASAuthenticationError from allauth_cas.exceptions import CASAuthenticationError
from allauth_cas.test.testcases import CASTestCase, CASViewTestCase from allauth_cas.test.testcases import CASTestCase, CASViewTestCase
@ -13,17 +12,11 @@ from allauth_cas.views import CASView
from .example.views import ExampleCASAdapter from .example.views import ExampleCASAdapter
if django.VERSION >= (1, 10):
from django.urls import reverse
else:
from django.core.urlresolvers import reverse
class CASAdapterTests(CASTestCase): class CASAdapterTests(CASTestCase):
def setUp(self): def setUp(self):
factory = RequestFactory() factory = RequestFactory()
self.request = factory.get('/path/') self.request = factory.get("/path/")
self.request.session = {} self.request.session = {}
self.adapter = ExampleCASAdapter(self.request) self.adapter = ExampleCASAdapter(self.request)
@ -31,7 +24,7 @@ class CASAdapterTests(CASTestCase):
""" """
Service url (used by CAS client) is the callback url. Service url (used by CAS client) is the callback url.
""" """
expected = 'http://testserver/accounts/theid/login/callback/' expected = "http://testserver/accounts/theid/login/callback/"
service_url = self.adapter.get_service_url(self.request) service_url = self.adapter.get_service_url(self.request)
self.assertEqual(expected, service_url) self.assertEqual(expected, service_url)
@ -39,11 +32,9 @@ class CASAdapterTests(CASTestCase):
""" """
Current GET paramater next is appended on service url. Current GET paramater next is appended on service url.
""" """
expected = ( expected = "http://testserver/accounts/theid/login/callback/?next=%2Fnext%2F"
'http://testserver/accounts/theid/login/callback/?next=%2Fnext%2F'
)
factory = RequestFactory() factory = RequestFactory()
request = factory.get('/path/', {'next': '/next/'}) request = factory.get("/path/", {"next": "/next/"})
adapter = ExampleCASAdapter(request) adapter = ExampleCASAdapter(request)
service_url = adapter.get_service_url(request) service_url = adapter.get_service_url(request)
self.assertEqual(expected, service_url) self.assertEqual(expected, service_url)
@ -67,14 +58,13 @@ class CASAdapterTests(CASTestCase):
class CASViewTests(CASViewTestCase): class CASViewTests(CASViewTestCase):
class BasicCASView(CASView): class BasicCASView(CASView):
def dispatch(self, request, *args, **kwargs): def dispatch(self, request, *args, **kwargs):
return self return self
def setUp(self): def setUp(self):
factory = RequestFactory() factory = RequestFactory()
self.request = factory.get('/path/') self.request = factory.get("/path/")
self.request.session = {} self.request.session = {}
self.cas_view = self.BasicCASView.adapter_view(ExampleCASAdapter) self.cas_view = self.BasicCASView.adapter_view(ExampleCASAdapter)
@ -85,27 +75,34 @@ class CASViewTests(CASViewTestCase):
""" """
view = self.cas_view( view = self.cas_view(
self.request, self.request,
'arg1', 'arg2', "arg1",
kwarg1='kwarg1', kwarg2='kwarg2', "arg2",
kwarg1="kwarg1",
kwarg2="kwarg2",
) )
self.assertIsInstance(view, CASView) self.assertIsInstance(view, CASView)
self.assertEqual(view.request, self.request) self.assertEqual(view.request, self.request)
self.assertTupleEqual(view.args, ('arg1', 'arg2')) self.assertTupleEqual(view.args, ("arg1", "arg2"))
self.assertDictEqual(view.kwargs, { self.assertDictEqual(
'kwarg1': 'kwarg1', view.kwargs,
'kwarg2': 'kwarg2', {
}) "kwarg1": "kwarg1",
"kwarg2": "kwarg2",
},
)
self.assertIsInstance(view.adapter, ExampleCASAdapter) self.assertIsInstance(view.adapter, ExampleCASAdapter)
@patch('allauth_cas.views.cas.CASClient') @patch("allauth_cas.views.cas.CASClient")
@override_settings(SOCIALACCOUNT_PROVIDERS={ @override_settings(
'theid': { SOCIALACCOUNT_PROVIDERS={
'AUTH_PARAMS': {'key': 'value'}, "theid": {
"AUTH_PARAMS": {"key": "value"},
}, },
}) }
)
def test_get_client(self, mock_casclient_class): def test_get_client(self, mock_casclient_class):
""" """
get_client returns a CAS client, configured from settings. get_client returns a CAS client, configured from settings.
@ -114,11 +111,11 @@ class CASViewTests(CASViewTestCase):
view.get_client(self.request) view.get_client(self.request)
mock_casclient_class.assert_called_once_with( mock_casclient_class.assert_called_once_with(
service_url='http://testserver/accounts/theid/login/callback/', service_url="http://testserver/accounts/theid/login/callback/",
server_url='https://server.cas', server_url="https://server.cas",
version=2, version=2,
renew=False, renew=False,
extra_login_params={'key': 'value'}, extra_login_params={"key": "value"},
) )
def test_render_error_on_failure(self): def test_render_error_on_failure(self):
@ -126,33 +123,33 @@ class CASViewTests(CASViewTestCase):
A common login failure page is rendered if CASAuthenticationError is A common login failure page is rendered if CASAuthenticationError is
raised by dispatch. raised by dispatch.
""" """
def dispatch_raise(self, request): def dispatch_raise(self, request):
raise CASAuthenticationError("failure") raise CASAuthenticationError("failure")
with patch.object(self.BasicCASView, 'dispatch', dispatch_raise): with patch.object(self.BasicCASView, "dispatch", dispatch_raise):
resp = self.cas_view(self.request) resp = self.cas_view(self.request)
self.assertLoginFailure(resp) self.assertLoginFailure(resp)
class CASLoginViewTests(CASViewTestCase): class CASLoginViewTests(CASViewTestCase):
def test_reverse(self): def test_reverse(self):
""" """
Login view name is "{provider_id}_login". Login view name is "{provider_id}_login".
""" """
url = reverse('theid_login') url = reverse("theid_login")
self.assertEqual('/accounts/theid/login/', url) self.assertEqual("/accounts/theid/login/", url)
def test_execute(self): def test_execute(self):
""" """
Login view redirects to the CAS server login url. Login view redirects to the CAS server login url.
Service is the callback url, as absolute uri. Service is the callback url, as absolute uri.
""" """
r = self.client.get('/accounts/theid/login/') r = self.client.get("/accounts/theid/login/")
expected = ( expected = (
'https://server.cas/login?service=http%3A%2F%2Ftestserver%2F' "https://server.cas/login?service=http%3A%2F%2Ftestserver%2F"
'accounts%2Ftheid%2Flogin%2Fcallback%2F' "accounts%2Ftheid%2Flogin%2Fcallback%2F"
) )
self.assertRedirects(r, expected, fetch_redirect_response=False) self.assertRedirects(r, expected, fetch_redirect_response=False)
@ -161,54 +158,59 @@ class CASLoginViewTests(CASViewTestCase):
""" """
Current GET parameter 'next' is kept on service url. Current GET parameter 'next' is kept on service url.
""" """
r = self.client.get('/accounts/theid/login/?next=/path/') r = self.client.get("/accounts/theid/login/?next=/path/")
expected = ( expected = (
'https://server.cas/login?service=http%3A%2F%2Ftestserver%2F' "https://server.cas/login?service=http%3A%2F%2Ftestserver%2F"
'accounts%2Ftheid%2Flogin%2Fcallback%2F%3Fnext%3D%252Fpath%252F' "accounts%2Ftheid%2Flogin%2Fcallback%2F%3Fnext%3D%252Fpath%252F"
) )
self.assertRedirects(r, expected, fetch_redirect_response=False) self.assertRedirects(r, expected, fetch_redirect_response=False)
class CASCallbackViewTests(CASViewTestCase): class CASCallbackViewTests(CASViewTestCase):
def setUp(self): def setUp(self):
self.client.get('/accounts/theid/login/') self.client.get("/accounts/theid/login/")
def test_reverse(self): def test_reverse(self):
""" """
Callback view name is "{provider_id}_callback". Callback view name is "{provider_id}_callback".
""" """
url = reverse('theid_callback') url = reverse("theid_callback")
self.assertEqual('/accounts/theid/login/callback/', url) self.assertEqual("/accounts/theid/login/callback/", url)
def test_ticket_valid(self): def test_ticket_valid(self):
""" """
If ticket is valid, the user is logged in. If ticket is valid, the user is logged in.
""" """
self.patch_cas_response(username='username', valid_ticket='123456') self.patch_cas_response(username="username", valid_ticket="123456")
r = self.client.get('/accounts/theid/login/callback/', { r = self.client.get(
'ticket': '123456', "/accounts/theid/login/callback/",
}) {
"ticket": "123456",
},
)
self.assertLoginSuccess(r) self.assertLoginSuccess(r)
def test_ticket_invalid(self): def test_ticket_invalid(self):
""" """
Login failure page is returned if the ticket is invalid. Login failure page is returned if the ticket is invalid.
""" """
self.patch_cas_response(username='username', valid_ticket='123456') self.patch_cas_response(username="username", valid_ticket="123456")
r = self.client.get('/accounts/theid/login/callback/', { r = self.client.get(
'ticket': '000000', "/accounts/theid/login/callback/",
}) {
"ticket": "000000",
},
)
self.assertLoginFailure(r) self.assertLoginFailure(r)
def test_ticket_missing(self): def test_ticket_missing(self):
""" """
Login failure page is returned if request lacks a ticket. Login failure page is returned if request lacks a ticket.
""" """
self.patch_cas_response(username='username', valid_ticket='123456') self.patch_cas_response(username="username", valid_ticket="123456")
r = self.client.get('/accounts/theid/login/callback/') r = self.client.get("/accounts/theid/login/callback/")
self.assertLoginFailure(r) self.assertLoginFailure(r)
def test_attributes_is_none(self): def test_attributes_is_none(self):
@ -216,31 +218,33 @@ class CASCallbackViewTests(CASViewTestCase):
Without extra attributes, CASClientV2 of python-cas returns None. Without extra attributes, CASClientV2 of python-cas returns None.
""" """
self.patch_cas_response( self.patch_cas_response(
username='username', valid_ticket='123456', attributes=None username="username", valid_ticket="123456", attributes=None
)
r = self.client.get(
"/accounts/theid/login/callback/",
{
"ticket": "123456",
},
) )
r = self.client.get('/accounts/theid/login/callback/', {
'ticket': '123456',
})
self.assertLoginSuccess(r) self.assertLoginSuccess(r)
class CASLogoutViewTests(CASViewTestCase): class CASLogoutViewTests(CASViewTestCase):
def test_reverse(self): def test_reverse(self):
""" """
Callback view name is "{provider_id}_logout". Callback view name is "{provider_id}_logout".
""" """
url = reverse('theid_logout') url = reverse("theid_logout")
self.assertEqual('/accounts/theid/logout/', url) self.assertEqual("/accounts/theid/logout/", url)
def test_execute(self): def test_execute(self):
""" """
Logout view redirects to the CAS server logout url. Logout view redirects to the CAS server logout url.
Service is a url to here, as absolute uri. Service is a url to here, as absolute uri.
""" """
r = self.client.get('/accounts/theid/logout/') r = self.client.get("/accounts/theid/logout/")
expected = 'https://server.cas/logout?url=http%3A%2F%2Ftestserver%2F' expected = "https://server.cas/logout?url=http%3A%2F%2Ftestserver%2F"
self.assertRedirects(r, expected, fetch_redirect_response=False) self.assertRedirects(r, expected, fetch_redirect_response=False)
@ -248,10 +252,8 @@ class CASLogoutViewTests(CASViewTestCase):
""" """
GET parameter 'next' is set as service url. GET parameter 'next' is set as service url.
""" """
r = self.client.get('/accounts/theid/logout/?next=/path/') r = self.client.get("/accounts/theid/logout/?next=/path/")
expected = ( expected = "https://server.cas/logout?url=http%3A%2F%2Ftestserver%2Fpath%2F"
'https://server.cas/logout?url=http%3A%2F%2Ftestserver%2Fpath%2F'
)
self.assertRedirects(r, expected, fetch_redirect_response=False) self.assertRedirects(r, expected, fetch_redirect_response=False)

View file

@ -1,6 +1,5 @@
# -*- coding: utf-8 -*- from django.urls import include, path
from django.conf.urls import include, url
urlpatterns = [ urlpatterns = [
url(r'^accounts/', include('allauth.urls')), path("accounts/", include("allauth.urls")),
] ]