version 1.0.1. Fixed package for Django 4.2.

This commit is contained in:
João Pires 2024-01-25 14:33:24 +00:00
parent 6657ec6042
commit 77e02f3796
22 changed files with 548 additions and 453 deletions

8
.idea/.gitignore vendored Normal file
View file

@ -0,0 +1,8 @@
# Default ignored files
/shelf/
/workspace.xml
# Editor-based HTTP Client requests
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml

53
.pre-commit-config.yaml Normal file
View file

@ -0,0 +1,53 @@
exclude: "^docs/|/migrations/"
default_stages: [ commit ]
fail_fast: true
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
exclude: ".svg$|.min.*|.idea*"
- id: check-yaml
- id: check-added-large-files
- repo: https://github.com/asottile/pyupgrade
rev: v3.15.0
hooks:
- id: pyupgrade
args: [ --py39-plus ]
- repo: https://github.com/psf/black
rev: 23.12.1
hooks:
- id: black
- repo: https://github.com/PyCQA/isort
rev: 5.13.2
hooks:
- id: isort
args: [ "--profile", "black", "-l=88" ]
- repo: https://github.com/PyCQA/flake8
rev: 7.0.0
hooks:
- id: flake8
args: [ "--config=setup.cfg" ]
additional_dependencies:
- flake8-isort
- flake8-bugbear
- flake8-print
- repo: https://github.com/adamchainz/django-upgrade
rev: 1.15.0
hooks:
- id: django-upgrade
args: [ --target-version, "4.2" ]
# sets up .pre-commit-ci.yaml to ensure pre-commit dependencies stay up to date
ci:
autoupdate_schedule: weekly
skip: [ ]
submodules: false

View file

@ -1,7 +1,3 @@
# -*- coding: utf-8 -*-
__version__ = "1.0.1"
__version__ = '1.0.0'
default_app_config = 'allauth_cas.apps.CASAccountConfig'
CAS_PROVIDER_SESSION_KEY = 'allauth_cas__provider_id'
CAS_PROVIDER_SESSION_KEY = "allauth_cas__provider_id"

View file

@ -1,10 +1,9 @@
# -*- coding: utf-8 -*-
from django.apps import AppConfig
from django.utils.translation import ugettext_lazy as _
from django.utils.translation import gettext_lazy as _
class CASAccountConfig(AppConfig):
name = 'allauth_cas'
name = "allauth_cas"
verbose_name = _("CAS Accounts")
def ready(self):

View file

@ -1,6 +1,3 @@
# -*- coding: utf-8 -*-
class CASAuthenticationError(Exception):
"""
Base exception to signal CAS authentication failure.

View file

@ -1,26 +1,18 @@
# -*- coding: utf-8 -*-
from six.moves.urllib.parse import parse_qsl
from urllib.parse import parse_qsl
import django
from allauth.socialaccount.providers.base import Provider
from django.contrib import messages
from django.template.loader import render_to_string
from django.urls import reverse
from django.utils.http import urlencode
from django.utils.safestring import mark_safe
from allauth.socialaccount.providers.base import Provider
if django.VERSION >= (1, 10):
from django.urls import reverse
else:
from django.core.urlresolvers import reverse
class CASProvider(Provider):
def get_auth_params(self, request, action):
settings = self.get_settings()
ret = dict(settings.get('AUTH_PARAMS', {}))
dynamic_auth_params = request.GET.get('auth_params')
ret = dict(settings.get("AUTH_PARAMS", {}))
dynamic_auth_params = request.GET.get("auth_params")
if dynamic_auth_params:
ret.update(dict(parse_qsl(dynamic_auth_params)))
return ret
@ -68,11 +60,11 @@ class CASProvider(Provider):
"""
uid, extra = data
return {
'username': extra.get('username', uid),
'email': extra.get('email'),
'first_name': extra.get('first_name'),
'last_name': extra.get('last_name'),
'name': extra.get('name'),
"username": extra.get("username", uid),
"email": extra.get("email"),
"first_name": extra.get("first_name"),
"last_name": extra.get("last_name"),
"name": extra.get("name"),
}
def extract_email_addresses(self, data):
@ -99,7 +91,7 @@ class CASProvider(Provider):
]
"""
return super(CASProvider, self).extract_email_addresses(data)
return super().extract_email_addresses(data)
def extract_extra_data(self, data):
"""Extract the data to save to `SocialAccount.extra_data`.
@ -119,7 +111,10 @@ class CASProvider(Provider):
##
def add_message_suggest_caslogout(
self, request, next_page=None, level=None,
self,
request,
next_page=None,
level=None,
):
"""Add a message with a link for the user to logout of the CAS server.
@ -144,14 +139,15 @@ class CASProvider(Provider):
# DefaultAccountAdapter.add_message is unusable because it always
# escape the message content.
template = 'socialaccount/messages/suggest_caslogout.html'
template = "socialaccount/messages/suggest_caslogout.html"
context = {
'provider': self,
'logout_url': logout_url,
"provider": self,
"logout_url": logout_url,
}
messages.add_message(
request, level,
request,
level,
mark_safe(render_to_string(template, context).strip()),
fail_silently=True,
)
@ -168,10 +164,7 @@ class CASProvider(Provider):
signal ``user_logged_out``.
"""
return (
self.get_settings()
.get('MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT', False)
)
return self.get_settings().get("MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT", False)
def message_suggest_caslogout_on_logout_level(self, request):
"""Level of the logout message issued on user logout.
@ -185,9 +178,8 @@ class CASProvider(Provider):
signal ``user_logged_out``.
"""
return (
self.get_settings()
.get('MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT_LEVEL', messages.INFO)
return self.get_settings().get(
"MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT_LEVEL", messages.INFO
)
##
@ -195,19 +187,19 @@ class CASProvider(Provider):
##
def get_login_url(self, request, **kwargs):
url = reverse(self.id + '_login')
url = reverse(self.id + "_login")
if kwargs:
url += '?' + urlencode(kwargs)
url += "?" + urlencode(kwargs)
return url
def get_callback_url(self, request, **kwargs):
url = reverse(self.id + '_callback')
url = reverse(self.id + "_callback")
if kwargs:
url += '?' + urlencode(kwargs)
url += "?" + urlencode(kwargs)
return url
def get_logout_url(self, request, **kwargs):
url = reverse(self.id + '_logout')
url = reverse(self.id + "_logout")
if kwargs:
url += '?' + urlencode(kwargs)
url += "?" + urlencode(kwargs)
return url

View file

@ -1,10 +1,8 @@
# -*- coding: utf-8 -*-
from django.contrib.auth.signals import user_logged_out
from django.dispatch import receiver
from allauth.account.adapter import get_adapter
from allauth.account.utils import get_next_redirect_url
from allauth.socialaccount import providers
from django.contrib.auth.signals import user_logged_out
from django.dispatch import receiver
from . import CAS_PROVIDER_SESSION_KEY
@ -21,12 +19,12 @@ def cas_account_logout(sender, request, **kwargs):
if not provider.message_suggest_caslogout_on_logout(request):
return
next_page = (
get_next_redirect_url(request) or
get_adapter(request).get_logout_redirect_url(request)
)
next_page = get_next_redirect_url(request) or get_adapter(
request
).get_logout_redirect_url(request)
provider.add_message_suggest_caslogout(
request, next_page=next_page,
request,
next_page=next_page,
level=provider.message_suggest_caslogout_on_logout_level(request),
)

View file

@ -1,29 +1,20 @@
# -*- coding: utf-8 -*-
try:
from unittest.mock import patch
except ImportError:
from mock import patch
import django
from django.conf import settings
from django.test import TestCase
from unittest.mock import patch
import cas
from django.conf import settings
from django.test import TestCase
from django.urls import reverse
from allauth_cas import CAS_PROVIDER_SESSION_KEY
if django.VERSION >= (1, 10):
from django.urls import reverse
else:
from django.core.urlresolvers import reverse
class CASTestCase(TestCase):
def client_cas_login(
self,
client, provider_id='theid',
username=None, attributes={}):
self, client, provider_id=None, username=None, attributes=None
):
"""
Authenticate client through provider_id.
@ -32,20 +23,22 @@ class CASTestCase(TestCase):
username and attributes control the CAS server response when ticket is
checked.
"""
client.get(reverse('{id}_login'.format(id=provider_id)))
if attributes is None:
attributes = {}
if provider_id is None:
provider_id = "theid"
client.get(reverse(f"{provider_id}_login"))
self.patch_cas_response(
valid_ticket='__all__',
username=username, attributes=attributes,
valid_ticket="__all__",
username=username,
attributes=attributes,
)
callback_url = reverse('{id}_callback'.format(id=provider_id))
r = client.get(callback_url, {'ticket': 'fake-ticket'})
callback_url = reverse(f"{provider_id}_callback")
r = client.get(callback_url, {"ticket": "fake-ticket"})
self.patch_cas_response_stop()
return r
def patch_cas_response(
self,
valid_ticket,
username=None, attributes={}):
def patch_cas_response(self, valid_ticket, username=None, attributes=None):
"""
Patch the CASClient class used by views of CAS providers.
@ -68,35 +61,38 @@ class CASTestCase(TestCase):
Note that valid_ticket sould be a string (which is the type of the
ticket retrieved from GET parameter on request on the callback view).
"""
if hasattr(self, '_patch_cas_client'):
if attributes is None:
attributes = {}
if hasattr(self, "_patch_cas_client"):
self.patch_cas_response_stop()
class MockCASClient(object):
class MockCASClient:
_username = username
def __new__(self_client, *args, **kwargs):
version = kwargs.pop('version')
if version in (1, '1'):
version = kwargs.pop("version")
if version in (1, "1"):
client_class = cas.CASClientV1
elif version in (2, '2'):
elif version in (2, "2"):
client_class = cas.CASClientV2
elif version in (3, '3'):
elif version in (3, "3"):
client_class = cas.CASClientV3
elif version == 'CAS_2_SAML_1_0':
elif version == "CAS_2_SAML_1_0":
client_class = cas.CASClientWithSAMLV1
else:
raise ValueError('Unsupported CAS_VERSION %r' % version)
raise ValueError("Unsupported CAS_VERSION %r" % version)
client_class._username = self_client._username
def verify_ticket(self, ticket):
if valid_ticket == '__all__' or ticket == valid_ticket:
username = self._username or 'username'
if valid_ticket == "__all__" or ticket == valid_ticket:
username = self._username or "username"
return username, attributes, None
return None, {}, None
patcher = patch.object(
client_class, 'verify_ticket',
client_class,
"verify_ticket",
new=verify_ticket,
)
patcher.start()
@ -104,7 +100,7 @@ class CASTestCase(TestCase):
return client_class(*args, **kwargs)
self._patch_cas_client = patch(
'allauth_cas.views.cas.CASClient',
"allauth_cas.views.cas.CASClient",
MockCASClient,
)
self._patch_cas_client.start()
@ -114,12 +110,11 @@ class CASTestCase(TestCase):
del self._patch_cas_client
def tearDown(self):
if hasattr(self, '_patch_cas_client'):
if hasattr(self, "_patch_cas_client"):
self.patch_cas_response_stop()
class CASViewTestCase(CASTestCase):
def assertLoginSuccess(self, response, redirect_to=None):
"""
Asserts response corresponds to a successful login.
@ -133,7 +128,8 @@ class CASViewTestCase(CASTestCase):
redirect_to = settings.LOGIN_REDIRECT_URL
self.assertRedirects(
response, redirect_to,
response,
redirect_to,
fetch_redirect_response=False,
)
self.assertIn(
@ -146,6 +142,6 @@ class CASViewTestCase(CASTestCase):
Asserts response corresponds to a failed login.
"""
return self.assertInHTML(
'<h1>Social Network Login Failure</h1>',
"<h1>Social Network Login Failure</h1>",
str(response.content),
)

View file

@ -1,5 +1,4 @@
# -*- coding: utf-8 -*-
from django.conf.urls import include, url
from django.urls import include, path, re_path
from django.utils.module_loading import import_string
@ -7,45 +6,44 @@ def default_urlpatterns(provider):
package = provider.get_package()
try:
login_view = import_string(package + '.views.login')
login_view = import_string(package + ".views.login")
except ImportError:
raise ImportError(
"The login view for the '{id}' provider is lacking from the "
"'views' module of its app.\n"
"You may want to add:\n"
"from allauth_cas.views import CASLoginView\n\n"
"login = CASLoginView.adapter_view(<LocalCASAdapter>)"
.format(id=provider.id)
"login = CASLoginView.adapter_view(<LocalCASAdapter>)".format(
id=provider.id
)
)
try:
callback_view = import_string(package + '.views.callback')
callback_view = import_string(package + ".views.callback")
except ImportError:
raise ImportError(
"The callback view for the '{id}' provider is lacking from the "
"'views' module of its app.\n"
"You may want to add:\n"
"from allauth_cas.views import CASCallbackView\n\n"
"callback = CASCallbackView.adapter_view(<LocalCASAdapter>)"
.format(id=provider.id)
"callback = CASCallbackView.adapter_view(<LocalCASAdapter>)".format(
id=provider.id
)
)
try:
logout_view = import_string(package + '.views.logout')
logout_view = import_string(package + ".views.logout")
except ImportError:
logout_view = None
urlpatterns = [
url('^login/$', login_view,
name=provider.id + '_login'),
url('^login/callback/$', callback_view,
name=provider.id + '_callback'),
path("login/", login_view, name=provider.id + "_login"),
path("login/callback/", callback_view, name=provider.id + "_callback"),
]
if logout_view is not None:
urlpatterns += [
url('^logout/$', logout_view,
name=provider.id + '_logout'),
path("logout/", logout_view, name=provider.id + "_logout"),
]
return [url('^' + provider.get_slug() + '/', include(urlpatterns))]
return [re_path("^" + provider.get_slug() + "/", include(urlpatterns))]

View file

@ -1,28 +1,26 @@
# -*- coding: utf-8 -*-
from django.http import HttpResponseRedirect
from django.utils.functional import cached_property
import cas
from allauth.account.adapter import get_adapter
from allauth.account.utils import get_next_redirect_url
from allauth.socialaccount import providers
from allauth.socialaccount.helpers import (
complete_social_login, render_authentication_error,
complete_social_login,
render_authentication_error,
)
from allauth.socialaccount.models import SocialLogin
import cas
from django.http import HttpResponseRedirect
from django.utils.functional import cached_property
from . import CAS_PROVIDER_SESSION_KEY
from .exceptions import CASAuthenticationError
class AuthAction(object):
AUTHENTICATE = 'authenticate'
REAUTHENTICATE = 'reauthenticate'
DEAUTHENTICATE = 'deauthenticate'
class AuthAction:
AUTHENTICATE = "authenticate"
REAUTHENTICATE = "reauthenticate"
DEAUTHENTICATE = "deauthenticate"
class CASAdapter(object):
class CASAdapter:
#: CAS server url.
url = None
#: CAS server version.
@ -92,19 +90,19 @@ class CASAdapter(object):
"""
redirect_to = get_next_redirect_url(request)
callback_kwargs = {'next': redirect_to} if redirect_to else {}
callback_url = (
self.provider.get_callback_url(request, **callback_kwargs))
callback_kwargs = {"next": redirect_to} if redirect_to else {}
callback_url = self.provider.get_callback_url(request, **callback_kwargs)
service_url = request.build_absolute_uri(callback_url)
return service_url
class CASView(object):
class CASView:
"""
Base class for CAS views.
"""
@classmethod
def adapter_view(cls, adapter):
"""Transform the view class into a view function.
@ -124,6 +122,7 @@ class CASView(object):
"""
def view(request, *args, **kwargs):
# Prepare the func-view.
self = cls()
@ -169,19 +168,17 @@ class CASView(object):
class CASLoginView(CASView):
def dispatch(self, request):
"""
Redirects to the CAS server login page.
"""
action = request.GET.get('action', AuthAction.AUTHENTICATE)
action = request.GET.get("action", AuthAction.AUTHENTICATE)
SocialLogin.stash_state(request)
client = self.get_client(request, action=action)
return HttpResponseRedirect(client.get_login_url())
class CASCallbackView(CASView):
def dispatch(self, request):
"""
The CAS server redirects the user to this view after a successful
@ -195,11 +192,9 @@ class CASCallbackView(CASView):
# CAS server should let a ticket.
try:
ticket = request.GET['ticket']
ticket = request.GET["ticket"]
except KeyError:
raise CASAuthenticationError(
"CAS server didn't respond with a ticket."
)
raise CASAuthenticationError("CAS server didn't respond with a ticket.")
# Check ticket validity.
# Response format on:
@ -210,9 +205,7 @@ class CASCallbackView(CASView):
uid, extra, _ = response
if not uid:
raise CASAuthenticationError(
"CAS server doesn't validate the ticket."
)
raise CASAuthenticationError("CAS server doesn't validate the ticket.")
# Keep tracks of the last used CAS provider.
request.session[CAS_PROVIDER_SESSION_KEY] = self.provider.id
@ -226,7 +219,6 @@ class CASCallbackView(CASView):
class CASLogoutView(CASView):
def dispatch(self, request, next_page=None):
"""
Redirects to the CAS server logout page.
@ -248,7 +240,6 @@ class CASLogoutView(CASView):
Returns the url to redirect after logout.
"""
request = self.request
return (
get_next_redirect_url(request) or
get_adapter(request).get_logout_redirect_url(request)
)
return get_next_redirect_url(request) or get_adapter(
request
).get_logout_redirect_url(request)

View file

@ -1,5 +1,4 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import sys
@ -7,10 +6,10 @@ import django
from django.conf import settings
from django.test.utils import get_runner
if __name__ == '__main__':
os.environ['DJANGO_SETTINGS_MODULE'] = 'tests.settings'
if __name__ == "__main__":
os.environ["DJANGO_SETTINGS_MODULE"] = "tests.settings"
django.setup()
TestRunner = get_runner(settings)
test_runner = TestRunner()
failures = test_runner.run_tests(sys.argv[1:] or ['tests'])
failures = test_runner.run_tests(sys.argv[1:] or ["tests"])
sys.exit(bool(failures))

View file

@ -1,18 +1,45 @@
[flake8]
max-line-length = 120
exclude = .tox,.git,*/migrations/*,*/static/CACHE/*,docs,node_modules,venv
# E731: lambda expression
ignore = E731
[pycodestyle]
max-line-length = 120
exclude = .tox,.git,*/migrations/*,*/static/CACHE/*,docs,node_modules,venv
[isort]
combine_as_imports = True
line_length = 88
known_first_party = geoip,config
multi_line_output = 3
default_section = THIRDPARTY
include_trailing_comma = True
known_allauth = allauth
known_future_library = future,six
known_django = django
known_first_party = allauth_cas
multi_line_output = 5
not_skip = __init__.py
sections = FUTURE,STDLIB,DJANGO,ALLAUTH,THIRDPARTY,FIRSTPARTY,LOCALFOLDER
skip = venv/
skip_glob = **/migrations/*.py
include_trailing_comma = true
force_grid_wrap = 0
use_parentheses = true
[mypy]
python_version = 3.9
check_untyped_defs = True
ignore_missing_imports = True
warn_unused_ignores = True
warn_redundant_casts = True
warn_unused_configs = True
plugins = mypy_django_plugin.main
[mypy.plugins.django-stubs]
django_settings_module = config.settings.test
[mypy-*.migrations.*]
# Django migrations should not produce any errors:
ignore_errors = True
[coverage:run]
include = geoip/*
omit = *migrations*, *tests*
plugins =
django_coverage_plugin
[bdist_wheel]
universal = 1

View file

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
import os
from setuptools import find_packages, setup
@ -7,49 +6,56 @@ from allauth_cas import __version__
BASE_DIR = os.path.dirname(__file__)
with open(os.path.join(BASE_DIR, 'README.rst')) as readme:
with open(os.path.join(BASE_DIR, "README.rst")) as readme:
README = readme.read()
setup(
name='django-allauth-cas',
name="django-allauth-cas",
version=__version__,
description='CAS support for django-allauth.',
author='Aurélien Delobelle',
author_email='aurelien.delobelle@gmail.com',
keywords='django allauth cas authentication',
description="CAS support for django-allauth.",
author="Aurélien Delobelle",
author_email="aurelien.delobelle@gmail.com",
keywords="django allauth cas authentication",
long_description=README,
url='https://github.com/aureplop/django-allauth-cas',
url="https://github.com/aureplop/django-allauth-cas",
classifiers=[
'Development Status :: 4 - Beta',
'Environment :: Web Environment',
'Framework :: Django',
'Framework :: Django :: 1.8',
'Framework :: Django :: 1.9',
'Framework :: Django :: 1.10',
'Framework :: Django :: 1.11',
'Framework :: Django :: 2.0',
'Intended Audience :: Developers',
'License :: OSI Approved :: MIT License',
'Operating System :: OS Independent',
'Programming Language :: Python',
'Programming Language :: Python :: 2',
'Programming Language :: Python :: 2.7',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.4',
'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6',
'Topic :: Internet :: WWW/HTTP',
"Development Status :: 4 - Beta",
"Environment :: Web Environment",
"Framework :: Django",
"Framework :: Django :: 1.8",
"Framework :: Django :: 1.9",
"Framework :: Django :: 1.10",
"Framework :: Django :: 1.11",
"Framework :: Django :: 2.0",
"Framework :: Django :: 3.0",
"Framework :: Django :: 4.0",
"Framework :: Django :: 5.0",
"Intended Audience :: Developers",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
"Programming Language :: Python",
"Programming Language :: Python :: 2",
"Programming Language :: Python :: 2.7",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.4",
"Programming Language :: Python :: 3.5",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.11",
"Topic :: Internet :: WWW/HTTP",
],
license='MIT',
packages=find_packages(exclude=['tests']),
license="MIT",
packages=find_packages(exclude=["tests"]),
include_package_data=True,
install_requires=[
'django-allauth',
'python-cas',
'six',
"django-allauth",
"python-cas",
"six",
],
extras_require={
'docs': ['sphinx'],
'tests': ['tox'],
"docs": ["sphinx"],
"tests": ["tox"],
},
)

View file

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
from allauth.socialaccount.providers.base import ProviderAccount
from allauth_cas.providers import CASProvider
@ -9,8 +8,8 @@ class ExampleCASAccount(ProviderAccount):
class ExampleCASProvider(CASProvider):
id = 'theid'
name = 'The Provider'
id = "theid"
name = "The Provider"
account_class = ExampleCASAccount

View file

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
from allauth_cas.urls import default_urlpatterns
from .provider import ExampleCASProvider

View file

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
from allauth_cas import views
from .provider import ExampleCASProvider
@ -6,7 +5,7 @@ from .provider import ExampleCASProvider
class ExampleCASAdapter(views.CASAdapter):
provider_id = ExampleCASProvider.id
url = 'https://server.cas'
url = "https://server.cas"
version = 2

View file

@ -1,63 +1,54 @@
# -*- coding: utf-8 -*-
import django
SECRET_KEY = 'iamabird'
SECRET_KEY = "iamabird"
INSTALLED_APPS = [
'django.contrib.admin',
'django.contrib.auth',
'django.contrib.contenttypes',
'django.contrib.messages',
'django.contrib.sessions',
'django.contrib.sites',
'django.contrib.staticfiles',
'allauth',
'allauth.account',
'allauth.socialaccount',
'allauth_cas',
'tests.example', # Dummy CAS provider app
"django.contrib.admin",
"django.contrib.auth",
"django.contrib.contenttypes",
"django.contrib.messages",
"django.contrib.sessions",
"django.contrib.sites",
"django.contrib.staticfiles",
"allauth",
"allauth.account",
"allauth.socialaccount",
"allauth_cas",
"tests.example", # Dummy CAS provider app
]
DATABASES = {
'default': {
'ENGINE': 'django.db.backends.sqlite3',
"default": {
"ENGINE": "django.db.backends.sqlite3",
},
}
AUTHENTICATION_BACKENDS = [
'allauth.account.auth_backends.AuthenticationBackend',
"allauth.account.auth_backends.AuthenticationBackend",
]
_MIDDLEWARES = [
'django.middleware.common.CommonMiddleware',
'django.contrib.sessions.middleware.SessionMiddleware',
'django.middleware.csrf.CsrfViewMiddleware',
'django.contrib.auth.middleware.AuthenticationMiddleware',
'django.contrib.messages.middleware.MessageMiddleware',
"django.middleware.common.CommonMiddleware",
"django.contrib.sessions.middleware.SessionMiddleware",
"django.middleware.csrf.CsrfViewMiddleware",
"django.contrib.auth.middleware.AuthenticationMiddleware",
"django.contrib.messages.middleware.MessageMiddleware",
]
if django.VERSION >= (1, 10):
MIDDLEWARE = _MIDDLEWARES
else:
MIDDLEWARE_CLASSES = _MIDDLEWARES
TEMPLATES = [
{
'BACKEND': 'django.template.backends.django.DjangoTemplates',
'DIRS': [],
'APP_DIRS': True,
'OPTIONS': {
'context_processors': [
'django.template.context_processors.debug',
'django.template.context_processors.request',
'django.contrib.auth.context_processors.auth',
'django.contrib.messages.context_processors.messages',
"BACKEND": "django.template.backends.django.DjangoTemplates",
"DIRS": [],
"APP_DIRS": True,
"OPTIONS": {
"context_processors": [
"django.template.context_processors.debug",
"django.template.context_processors.request",
"django.contrib.auth.context_processors.auth",
"django.contrib.messages.context_processors.messages",
],
},
},
]
ROOT_URLCONF = 'tests.urls'
ROOT_URLCONF = "tests.urls"

View file

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
from django.contrib import messages
from django.contrib.auth import get_user_model
from django.contrib.messages.api import get_messages
@ -13,7 +12,7 @@ User = get_user_model()
class LogoutFlowTests(CASTestCase):
expected_msg_str = (
"To logout of The Provider, please close your browser, or visit this "
"<a href=\"/accounts/theid/logout/?next=%2Fredir%2F\">"
'<a href="/accounts/theid/logout/?next=%2Fredir%2F">'
"link</a>."
)
@ -28,42 +27,45 @@ class LogoutFlowTests(CASTestCase):
)
self.assertTemplateNotUsed(
response,
'socialaccount/messages/suggest_caslogout.html',
"socialaccount/messages/suggest_caslogout.html",
)
@override_settings(SOCIALACCOUNT_PROVIDERS={
'theid': {
'MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT': True,
'MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT_LEVEL': messages.WARNING,
@override_settings(
SOCIALACCOUNT_PROVIDERS={
"theid": {
"MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT": True,
"MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT_LEVEL": messages.WARNING,
},
})
}
)
def test_message_on_logout(self):
"""
Message is sent to propose user to logout of CAS.
"""
r = self.client.post('/accounts/logout/?next=/redir/')
r = self.client.post("/accounts/logout/?next=/redir/")
r_messages = get_messages(r.wsgi_request)
expected_msg = Message(messages.WARNING, self.expected_msg_str)
self.assertIn(expected_msg, r_messages)
self.assertTemplateUsed(
r, 'socialaccount/messages/suggest_caslogout.html')
self.assertTemplateUsed(r, "socialaccount/messages/suggest_caslogout.html")
def test_message_on_logout_disabled(self):
r = self.client.post('/accounts/logout/')
r = self.client.post("/accounts/logout/")
self.assertCASLogoutNotInMessages(r)
@override_settings(SOCIALACCOUNT_PROVIDERS={
'theid': {'MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT': True},
})
@override_settings(
SOCIALACCOUNT_PROVIDERS={
"theid": {"MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT": True},
}
)
def test_other_logout(self):
"""
The CAS logout message doesn't appear with other login methods.
"""
User.objects.create_user('user', '', 'user')
User.objects.create_user("user", "", "user")
client = Client()
client.login(username='user', password='user')
client.login(username="user", password="user")
r = client.post('/accounts/logout/')
r = client.post("/accounts/logout/")
self.assertCASLogoutNotInMessages(r)

View file

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
from six.moves.urllib.parse import urlencode
from urllib.parse import urlencode
from allauth.socialaccount.providers import registry
from django.contrib import messages
from django.contrib.messages.api import get_messages
from django.contrib.messages.middleware import MessageMiddleware
@ -8,21 +8,18 @@ from django.contrib.messages.storage.base import Message
from django.contrib.sessions.middleware import SessionMiddleware
from django.test import RequestFactory, TestCase, override_settings
from allauth.socialaccount.providers import registry
from allauth_cas.views import AuthAction
from .example.provider import ExampleCASProvider
class CASProviderTests(TestCase):
def setUp(self):
self.request = self._get_request()
self.provider = ExampleCASProvider(self.request)
def _get_request(self):
request = RequestFactory().get('/test/')
request = RequestFactory().get("/test/")
SessionMiddleware().process_request(request)
MessageMiddleware().process_request(request)
return request
@ -31,74 +28,81 @@ class CASProviderTests(TestCase):
"""
Example CAS provider is registered as social account provider.
"""
self.assertIsInstance(registry.by_id('theid'), ExampleCASProvider)
self.assertIsInstance(registry.by_id("theid"), ExampleCASProvider)
def test_get_login_url(self):
url = self.provider.get_login_url(self.request)
self.assertEqual('/accounts/theid/login/', url)
self.assertEqual("/accounts/theid/login/", url)
url_with_qs = self.provider.get_login_url(
self.request,
next='/path?quéry=string&two=whoam%C3%AF',
next="/path?quéry=string&two=whoam%C3%AF",
)
self.assertEqual(
url_with_qs,
'/accounts/theid/login/?next=%2Fpath%3Fqu%C3%A9ry%3Dstring%26two%3'
'Dwhoam%25C3%25AF'
"/accounts/theid/login/?next=%2Fpath%3Fqu%C3%A9ry%3Dstring%26two%3"
"Dwhoam%25C3%25AF",
)
def test_get_callback_url(self):
url = self.provider.get_callback_url(self.request)
self.assertEqual('/accounts/theid/login/callback/', url)
self.assertEqual("/accounts/theid/login/callback/", url)
url_with_qs = self.provider.get_callback_url(
self.request,
next='/path?quéry=string&two=whoam%C3%AF',
next="/path?quéry=string&two=whoam%C3%AF",
)
self.assertEqual(
url_with_qs,
'/accounts/theid/login/callback/?next=%2Fpath%3Fqu%C3%A9ry%3Dstrin'
'g%26two%3Dwhoam%25C3%25AF'
"/accounts/theid/login/callback/?next=%2Fpath%3Fqu%C3%A9ry%3Dstrin"
"g%26two%3Dwhoam%25C3%25AF",
)
def test_get_logout_url(self):
url = self.provider.get_logout_url(self.request)
self.assertEqual('/accounts/theid/logout/', url)
self.assertEqual("/accounts/theid/logout/", url)
url_with_qs = self.provider.get_logout_url(
self.request,
next='/path?quéry=string&two=whoam%C3%AF',
next="/path?quéry=string&two=whoam%C3%AF",
)
self.assertEqual(
url_with_qs,
'/accounts/theid/logout/?next=%2Fpath%3Fqu%C3%A9ry%3Dstring%26two%'
'3Dwhoam%25C3%25AF'
"/accounts/theid/logout/?next=%2Fpath%3Fqu%C3%A9ry%3Dstring%26two%"
"3Dwhoam%25C3%25AF",
)
@override_settings(SOCIALACCOUNT_PROVIDERS={
'theid': {
'AUTH_PARAMS': {'key': 'value'},
@override_settings(
SOCIALACCOUNT_PROVIDERS={
"theid": {
"AUTH_PARAMS": {"key": "value"},
},
})
}
)
def test_get_auth_params(self):
action = AuthAction.AUTHENTICATE
auth_params = self.provider.get_auth_params(self.request, action)
self.assertDictEqual(auth_params, {
'key': 'value',
})
@override_settings(SOCIALACCOUNT_PROVIDERS={
'theid': {
'AUTH_PARAMS': {'key': 'value'},
self.assertDictEqual(
auth_params,
{
"key": "value",
},
})
)
@override_settings(
SOCIALACCOUNT_PROVIDERS={
"theid": {
"AUTH_PARAMS": {"key": "value"},
},
}
)
def test_get_auth_params_with_dynamic(self):
factory = RequestFactory()
request = factory.get(
'/test/?auth_params=next%3Dtwo%253Dwhoam%2525C3%2525AF%2526qu%2525'
'C3%2525A9ry%253Dstring'
"/test/?auth_params=next%3Dtwo%253Dwhoam%2525C3%2525AF%2526qu%2525"
"C3%2525A9ry%253Dstring"
)
request.session = {}
@ -106,15 +110,18 @@ class CASProviderTests(TestCase):
auth_params = self.provider.get_auth_params(request, action)
self.assertDictEqual(auth_params, {
'key': 'value',
'next': 'two=whoam%C3%AF&qu%C3%A9ry=string',
})
self.assertDictEqual(
auth_params,
{
"key": "value",
"next": "two=whoam%C3%AF&qu%C3%A9ry=string",
},
)
def test_add_message_suggest_caslogout(self):
expected_msg_base_str = (
"To logout of The Provider, please close your browser, or visit "
"this <a href=\"/accounts/theid/logout/?{}\">link</a>."
'this <a href="/accounts/theid/logout/?{}">link</a>.'
)
# Defaults.
@ -124,7 +131,7 @@ class CASProviderTests(TestCase):
expected_msg1 = Message(
messages.INFO,
expected_msg_base_str.format(urlencode({'next': '/test/'})),
expected_msg_base_str.format(urlencode({"next": "/test/"})),
)
self.assertIn(expected_msg1, get_messages(req1))
@ -132,69 +139,83 @@ class CASProviderTests(TestCase):
req2 = self._get_request()
self.provider.add_message_suggest_caslogout(
req2, next_page='/redir/', level=messages.WARNING)
req2, next_page="/redir/", level=messages.WARNING
)
expected_msg2 = Message(
messages.WARNING,
expected_msg_base_str.format(urlencode({'next': '/redir/'})),
expected_msg_base_str.format(urlencode({"next": "/redir/"})),
)
self.assertIn(expected_msg2, get_messages(req2))
def test_message_suggest_caslogout_on_logout(self):
self.assertFalse(
self.provider.message_suggest_caslogout_on_logout(self.request))
with override_settings(SOCIALACCOUNT_PROVIDERS={
'theid': {'MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT': True},
}):
self.assertTrue(
self.provider
.message_suggest_caslogout_on_logout(self.request)
self.provider.message_suggest_caslogout_on_logout(self.request)
)
@override_settings(SOCIALACCOUNT_PROVIDERS={
'theid': {
'MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT_LEVEL': messages.WARNING,
with override_settings(
SOCIALACCOUNT_PROVIDERS={
"theid": {"MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT": True},
}
):
self.assertTrue(
self.provider.message_suggest_caslogout_on_logout(self.request)
)
@override_settings(
SOCIALACCOUNT_PROVIDERS={
"theid": {
"MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT_LEVEL": messages.WARNING,
},
})
}
)
def test_message_suggest_caslogout_on_logout_level(self):
self.assertEqual(messages.WARNING, (
self.provider
.message_suggest_caslogout_on_logout_level(self.request)
))
self.assertEqual(
messages.WARNING,
(self.provider.message_suggest_caslogout_on_logout_level(self.request)),
)
def test_extract_uid(self):
response = 'useRName', {}
response = "useRName", {}
uid = self.provider.extract_uid(response)
self.assertEqual('useRName', uid)
self.assertEqual("useRName", uid)
def test_extract_common_fields(self):
response = 'useRName', {}
response = "useRName", {}
common_fields = self.provider.extract_common_fields(response)
self.assertDictEqual(common_fields, {
'username': 'useRName',
'first_name': None,
'last_name': None,
'name': None,
'email': None,
})
self.assertDictEqual(
common_fields,
{
"username": "useRName",
"first_name": None,
"last_name": None,
"name": None,
"email": None,
},
)
def test_extract_common_fields_with_extra(self):
response = 'useRName', {'username': 'user', 'email': 'user@mail.net'}
response = "useRName", {"username": "user", "email": "user@mail.net"}
common_fields = self.provider.extract_common_fields(response)
self.assertDictEqual(common_fields, {
'username': 'user',
'first_name': None,
'last_name': None,
'name': None,
'email': 'user@mail.net',
})
self.assertDictEqual(
common_fields,
{
"username": "user",
"first_name": None,
"last_name": None,
"name": None,
"email": "user@mail.net",
},
)
def test_extract_extra_data(self):
response = 'useRName', {'user_attr': 'thevalue', 'another': 'value'}
response = "useRName", {"user_attr": "thevalue", "another": "value"}
extra_data = self.provider.extract_extra_data(response)
self.assertDictEqual(extra_data, {
'user_attr': 'thevalue',
'another': 'value',
'uid': 'useRName',
})
self.assertDictEqual(
extra_data,
{
"user_attr": "thevalue",
"another": "value",
"uid": "useRName",
},
)

View file

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
from django.test import Client, RequestFactory
from allauth_cas.test.testcases import CASViewTestCase
@ -8,9 +7,8 @@ from .example.views import ExampleCASAdapter
class CASTestCaseTests(CASViewTestCase):
def setUp(self):
self.client.get('/accounts/theid/login/')
self.client.get("/accounts/theid/login/")
def test_patch_cas_response_client_version(self):
"""
@ -21,20 +19,24 @@ class CASTestCaseTests(CASViewTestCase):
"""
valid_versions = [
1, '1',
2, '2',
3, '3',
'CAS_2_SAML_1_0',
1,
"1",
2,
"2",
3,
"3",
"CAS_2_SAML_1_0",
]
invalid_versions = [
'not_supported',
"not_supported",
]
factory = RequestFactory()
request = factory.get('/path/')
request = factory.get("/path/")
request.session = {}
for _version in valid_versions + invalid_versions:
class BasicCASAdapter(ExampleCASAdapter):
version = _version
@ -47,7 +49,7 @@ class CASTestCaseTests(CASViewTestCase):
if _version in valid_versions:
raw_client = view(request)
self.patch_cas_response(valid_ticket='__all__')
self.patch_cas_response(valid_ticket="__all__")
mocked_client = view(request)
self.assertEqual(type(raw_client), type(mocked_client))
@ -55,58 +57,79 @@ class CASTestCaseTests(CASViewTestCase):
# This is a sanity check.
self.assertRaises(ValueError, view, request)
self.patch_cas_response(valid_ticket='__all__')
self.patch_cas_response(valid_ticket="__all__")
self.assertRaises(ValueError, view, request)
def test_patch_cas_response_verify_success(self):
self.patch_cas_response(valid_ticket='123456')
r = self.client.get('/accounts/theid/login/callback/', {
'ticket': '123456',
})
self.patch_cas_response(valid_ticket="123456")
r = self.client.get(
"/accounts/theid/login/callback/",
{
"ticket": "123456",
},
)
self.assertLoginSuccess(r)
def test_patch_cas_response_verify_failure(self):
self.patch_cas_response(valid_ticket='123456')
r = self.client.get('/accounts/theid/login/callback/', {
'ticket': '000000',
})
self.patch_cas_response(valid_ticket="123456")
r = self.client.get(
"/accounts/theid/login/callback/",
{
"ticket": "000000",
},
)
self.assertLoginFailure(r)
def test_patch_cas_response_accept(self):
self.patch_cas_response(valid_ticket='__all__')
r = self.client.get('/accounts/theid/login/callback/', {
'ticket': '000000',
})
self.patch_cas_response(valid_ticket="__all__")
r = self.client.get(
"/accounts/theid/login/callback/",
{
"ticket": "000000",
},
)
self.assertLoginSuccess(r)
def test_patch_cas_response_reject(self):
self.patch_cas_response(valid_ticket=None)
r = self.client.get('/accounts/theid/login/callback/', {
'ticket': '000000',
})
r = self.client.get(
"/accounts/theid/login/callback/",
{
"ticket": "000000",
},
)
self.assertLoginFailure(r)
def test_patch_cas_reponse_multiple(self):
self.patch_cas_response(valid_ticket='__all__')
self.patch_cas_response(valid_ticket="__all__")
client_0 = Client()
client_0.get('/accounts/theid/login/')
r_0 = client_0.get('/accounts/theid/login/callback/', {
'ticket': '000000',
})
client_0.get("/accounts/theid/login/")
r_0 = client_0.get(
"/accounts/theid/login/callback/",
{
"ticket": "000000",
},
)
self.assertLoginSuccess(r_0)
self.patch_cas_response(valid_ticket=None)
client_1 = Client()
client_1.get('/accounts/theid/login/')
r_1 = client_1.get('/accounts/theid/login/callback/', {
'ticket': '111111',
})
client_1.get("/accounts/theid/login/")
r_1 = client_1.get(
"/accounts/theid/login/callback/",
{
"ticket": "111111",
},
)
self.assertLoginFailure(r_1)
def test_assertLoginSuccess(self):
self.patch_cas_response(valid_ticket='__all__')
r = self.client.get('/accounts/theid/login/callback/', {
'ticket': '000000',
'next': '/path/',
})
self.assertLoginSuccess(r, redirect_to='/path/')
self.patch_cas_response(valid_ticket="__all__")
r = self.client.get(
"/accounts/theid/login/callback/",
{
"ticket": "000000",
"next": "/path/",
},
)
self.assertLoginSuccess(r, redirect_to="/path/")

View file

@ -1,11 +1,10 @@
# -*- coding: utf-8 -*-
try:
from unittest.mock import patch
except ImportError:
from mock import patch
from unittest.mock import patch
import django
from django.test import RequestFactory, override_settings
from django.urls import reverse
from allauth_cas.exceptions import CASAuthenticationError
from allauth_cas.test.testcases import CASTestCase, CASViewTestCase
@ -13,17 +12,11 @@ from allauth_cas.views import CASView
from .example.views import ExampleCASAdapter
if django.VERSION >= (1, 10):
from django.urls import reverse
else:
from django.core.urlresolvers import reverse
class CASAdapterTests(CASTestCase):
def setUp(self):
factory = RequestFactory()
self.request = factory.get('/path/')
self.request = factory.get("/path/")
self.request.session = {}
self.adapter = ExampleCASAdapter(self.request)
@ -31,7 +24,7 @@ class CASAdapterTests(CASTestCase):
"""
Service url (used by CAS client) is the callback url.
"""
expected = 'http://testserver/accounts/theid/login/callback/'
expected = "http://testserver/accounts/theid/login/callback/"
service_url = self.adapter.get_service_url(self.request)
self.assertEqual(expected, service_url)
@ -39,11 +32,9 @@ class CASAdapterTests(CASTestCase):
"""
Current GET paramater next is appended on service url.
"""
expected = (
'http://testserver/accounts/theid/login/callback/?next=%2Fnext%2F'
)
expected = "http://testserver/accounts/theid/login/callback/?next=%2Fnext%2F"
factory = RequestFactory()
request = factory.get('/path/', {'next': '/next/'})
request = factory.get("/path/", {"next": "/next/"})
adapter = ExampleCASAdapter(request)
service_url = adapter.get_service_url(request)
self.assertEqual(expected, service_url)
@ -67,14 +58,13 @@ class CASAdapterTests(CASTestCase):
class CASViewTests(CASViewTestCase):
class BasicCASView(CASView):
def dispatch(self, request, *args, **kwargs):
return self
def setUp(self):
factory = RequestFactory()
self.request = factory.get('/path/')
self.request = factory.get("/path/")
self.request.session = {}
self.cas_view = self.BasicCASView.adapter_view(ExampleCASAdapter)
@ -85,27 +75,34 @@ class CASViewTests(CASViewTestCase):
"""
view = self.cas_view(
self.request,
'arg1', 'arg2',
kwarg1='kwarg1', kwarg2='kwarg2',
"arg1",
"arg2",
kwarg1="kwarg1",
kwarg2="kwarg2",
)
self.assertIsInstance(view, CASView)
self.assertEqual(view.request, self.request)
self.assertTupleEqual(view.args, ('arg1', 'arg2'))
self.assertDictEqual(view.kwargs, {
'kwarg1': 'kwarg1',
'kwarg2': 'kwarg2',
})
self.assertTupleEqual(view.args, ("arg1", "arg2"))
self.assertDictEqual(
view.kwargs,
{
"kwarg1": "kwarg1",
"kwarg2": "kwarg2",
},
)
self.assertIsInstance(view.adapter, ExampleCASAdapter)
@patch('allauth_cas.views.cas.CASClient')
@override_settings(SOCIALACCOUNT_PROVIDERS={
'theid': {
'AUTH_PARAMS': {'key': 'value'},
@patch("allauth_cas.views.cas.CASClient")
@override_settings(
SOCIALACCOUNT_PROVIDERS={
"theid": {
"AUTH_PARAMS": {"key": "value"},
},
})
}
)
def test_get_client(self, mock_casclient_class):
"""
get_client returns a CAS client, configured from settings.
@ -114,11 +111,11 @@ class CASViewTests(CASViewTestCase):
view.get_client(self.request)
mock_casclient_class.assert_called_once_with(
service_url='http://testserver/accounts/theid/login/callback/',
server_url='https://server.cas',
service_url="http://testserver/accounts/theid/login/callback/",
server_url="https://server.cas",
version=2,
renew=False,
extra_login_params={'key': 'value'},
extra_login_params={"key": "value"},
)
def test_render_error_on_failure(self):
@ -126,33 +123,33 @@ class CASViewTests(CASViewTestCase):
A common login failure page is rendered if CASAuthenticationError is
raised by dispatch.
"""
def dispatch_raise(self, request):
raise CASAuthenticationError("failure")
with patch.object(self.BasicCASView, 'dispatch', dispatch_raise):
with patch.object(self.BasicCASView, "dispatch", dispatch_raise):
resp = self.cas_view(self.request)
self.assertLoginFailure(resp)
class CASLoginViewTests(CASViewTestCase):
def test_reverse(self):
"""
Login view name is "{provider_id}_login".
"""
url = reverse('theid_login')
self.assertEqual('/accounts/theid/login/', url)
url = reverse("theid_login")
self.assertEqual("/accounts/theid/login/", url)
def test_execute(self):
"""
Login view redirects to the CAS server login url.
Service is the callback url, as absolute uri.
"""
r = self.client.get('/accounts/theid/login/')
r = self.client.get("/accounts/theid/login/")
expected = (
'https://server.cas/login?service=http%3A%2F%2Ftestserver%2F'
'accounts%2Ftheid%2Flogin%2Fcallback%2F'
"https://server.cas/login?service=http%3A%2F%2Ftestserver%2F"
"accounts%2Ftheid%2Flogin%2Fcallback%2F"
)
self.assertRedirects(r, expected, fetch_redirect_response=False)
@ -161,54 +158,59 @@ class CASLoginViewTests(CASViewTestCase):
"""
Current GET parameter 'next' is kept on service url.
"""
r = self.client.get('/accounts/theid/login/?next=/path/')
r = self.client.get("/accounts/theid/login/?next=/path/")
expected = (
'https://server.cas/login?service=http%3A%2F%2Ftestserver%2F'
'accounts%2Ftheid%2Flogin%2Fcallback%2F%3Fnext%3D%252Fpath%252F'
"https://server.cas/login?service=http%3A%2F%2Ftestserver%2F"
"accounts%2Ftheid%2Flogin%2Fcallback%2F%3Fnext%3D%252Fpath%252F"
)
self.assertRedirects(r, expected, fetch_redirect_response=False)
class CASCallbackViewTests(CASViewTestCase):
def setUp(self):
self.client.get('/accounts/theid/login/')
self.client.get("/accounts/theid/login/")
def test_reverse(self):
"""
Callback view name is "{provider_id}_callback".
"""
url = reverse('theid_callback')
self.assertEqual('/accounts/theid/login/callback/', url)
url = reverse("theid_callback")
self.assertEqual("/accounts/theid/login/callback/", url)
def test_ticket_valid(self):
"""
If ticket is valid, the user is logged in.
"""
self.patch_cas_response(username='username', valid_ticket='123456')
r = self.client.get('/accounts/theid/login/callback/', {
'ticket': '123456',
})
self.patch_cas_response(username="username", valid_ticket="123456")
r = self.client.get(
"/accounts/theid/login/callback/",
{
"ticket": "123456",
},
)
self.assertLoginSuccess(r)
def test_ticket_invalid(self):
"""
Login failure page is returned if the ticket is invalid.
"""
self.patch_cas_response(username='username', valid_ticket='123456')
r = self.client.get('/accounts/theid/login/callback/', {
'ticket': '000000',
})
self.patch_cas_response(username="username", valid_ticket="123456")
r = self.client.get(
"/accounts/theid/login/callback/",
{
"ticket": "000000",
},
)
self.assertLoginFailure(r)
def test_ticket_missing(self):
"""
Login failure page is returned if request lacks a ticket.
"""
self.patch_cas_response(username='username', valid_ticket='123456')
r = self.client.get('/accounts/theid/login/callback/')
self.patch_cas_response(username="username", valid_ticket="123456")
r = self.client.get("/accounts/theid/login/callback/")
self.assertLoginFailure(r)
def test_attributes_is_none(self):
@ -216,31 +218,33 @@ class CASCallbackViewTests(CASViewTestCase):
Without extra attributes, CASClientV2 of python-cas returns None.
"""
self.patch_cas_response(
username='username', valid_ticket='123456', attributes=None
username="username", valid_ticket="123456", attributes=None
)
r = self.client.get(
"/accounts/theid/login/callback/",
{
"ticket": "123456",
},
)
r = self.client.get('/accounts/theid/login/callback/', {
'ticket': '123456',
})
self.assertLoginSuccess(r)
class CASLogoutViewTests(CASViewTestCase):
def test_reverse(self):
"""
Callback view name is "{provider_id}_logout".
"""
url = reverse('theid_logout')
self.assertEqual('/accounts/theid/logout/', url)
url = reverse("theid_logout")
self.assertEqual("/accounts/theid/logout/", url)
def test_execute(self):
"""
Logout view redirects to the CAS server logout url.
Service is a url to here, as absolute uri.
"""
r = self.client.get('/accounts/theid/logout/')
r = self.client.get("/accounts/theid/logout/")
expected = 'https://server.cas/logout?url=http%3A%2F%2Ftestserver%2F'
expected = "https://server.cas/logout?url=http%3A%2F%2Ftestserver%2F"
self.assertRedirects(r, expected, fetch_redirect_response=False)
@ -248,10 +252,8 @@ class CASLogoutViewTests(CASViewTestCase):
"""
GET parameter 'next' is set as service url.
"""
r = self.client.get('/accounts/theid/logout/?next=/path/')
r = self.client.get("/accounts/theid/logout/?next=/path/")
expected = (
'https://server.cas/logout?url=http%3A%2F%2Ftestserver%2Fpath%2F'
)
expected = "https://server.cas/logout?url=http%3A%2F%2Ftestserver%2Fpath%2F"
self.assertRedirects(r, expected, fetch_redirect_response=False)

View file

@ -1,6 +1,5 @@
# -*- coding: utf-8 -*-
from django.conf.urls import include, url
from django.urls import include, path
urlpatterns = [
url(r'^accounts/', include('allauth.urls')),
path("accounts/", include("allauth.urls")),
]