version 1.0.1. Fixed package for Django 4.2.
This commit is contained in:
parent
6657ec6042
commit
77e02f3796
22 changed files with 548 additions and 453 deletions
8
.idea/.gitignore
vendored
Normal file
8
.idea/.gitignore
vendored
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
# Default ignored files
|
||||||
|
/shelf/
|
||||||
|
/workspace.xml
|
||||||
|
# Editor-based HTTP Client requests
|
||||||
|
/httpRequests/
|
||||||
|
# Datasource local storage ignored files
|
||||||
|
/dataSources/
|
||||||
|
/dataSources.local.xml
|
53
.pre-commit-config.yaml
Normal file
53
.pre-commit-config.yaml
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
exclude: "^docs/|/migrations/"
|
||||||
|
default_stages: [ commit ]
|
||||||
|
fail_fast: true
|
||||||
|
|
||||||
|
repos:
|
||||||
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
|
rev: v4.5.0
|
||||||
|
hooks:
|
||||||
|
- id: trailing-whitespace
|
||||||
|
- id: end-of-file-fixer
|
||||||
|
exclude: ".svg$|.min.*|.idea*"
|
||||||
|
- id: check-yaml
|
||||||
|
- id: check-added-large-files
|
||||||
|
|
||||||
|
- repo: https://github.com/asottile/pyupgrade
|
||||||
|
rev: v3.15.0
|
||||||
|
hooks:
|
||||||
|
- id: pyupgrade
|
||||||
|
args: [ --py39-plus ]
|
||||||
|
|
||||||
|
- repo: https://github.com/psf/black
|
||||||
|
rev: 23.12.1
|
||||||
|
hooks:
|
||||||
|
- id: black
|
||||||
|
|
||||||
|
- repo: https://github.com/PyCQA/isort
|
||||||
|
rev: 5.13.2
|
||||||
|
hooks:
|
||||||
|
- id: isort
|
||||||
|
args: [ "--profile", "black", "-l=88" ]
|
||||||
|
|
||||||
|
- repo: https://github.com/PyCQA/flake8
|
||||||
|
rev: 7.0.0
|
||||||
|
hooks:
|
||||||
|
- id: flake8
|
||||||
|
args: [ "--config=setup.cfg" ]
|
||||||
|
additional_dependencies:
|
||||||
|
- flake8-isort
|
||||||
|
- flake8-bugbear
|
||||||
|
- flake8-print
|
||||||
|
|
||||||
|
- repo: https://github.com/adamchainz/django-upgrade
|
||||||
|
rev: 1.15.0
|
||||||
|
hooks:
|
||||||
|
- id: django-upgrade
|
||||||
|
args: [ --target-version, "4.2" ]
|
||||||
|
|
||||||
|
|
||||||
|
# sets up .pre-commit-ci.yaml to ensure pre-commit dependencies stay up to date
|
||||||
|
ci:
|
||||||
|
autoupdate_schedule: weekly
|
||||||
|
skip: [ ]
|
||||||
|
submodules: false
|
|
@ -1,7 +1,3 @@
|
||||||
# -*- coding: utf-8 -*-
|
__version__ = "1.0.1"
|
||||||
|
|
||||||
__version__ = '1.0.0'
|
CAS_PROVIDER_SESSION_KEY = "allauth_cas__provider_id"
|
||||||
|
|
||||||
default_app_config = 'allauth_cas.apps.CASAccountConfig'
|
|
||||||
|
|
||||||
CAS_PROVIDER_SESSION_KEY = 'allauth_cas__provider_id'
|
|
||||||
|
|
|
@ -1,10 +1,9 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
from django.apps import AppConfig
|
from django.apps import AppConfig
|
||||||
from django.utils.translation import ugettext_lazy as _
|
from django.utils.translation import gettext_lazy as _
|
||||||
|
|
||||||
|
|
||||||
class CASAccountConfig(AppConfig):
|
class CASAccountConfig(AppConfig):
|
||||||
name = 'allauth_cas'
|
name = "allauth_cas"
|
||||||
verbose_name = _("CAS Accounts")
|
verbose_name = _("CAS Accounts")
|
||||||
|
|
||||||
def ready(self):
|
def ready(self):
|
||||||
|
|
|
@ -1,6 +1,3 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
|
|
||||||
|
|
||||||
class CASAuthenticationError(Exception):
|
class CASAuthenticationError(Exception):
|
||||||
"""
|
"""
|
||||||
Base exception to signal CAS authentication failure.
|
Base exception to signal CAS authentication failure.
|
||||||
|
|
|
@ -1,26 +1,18 @@
|
||||||
# -*- coding: utf-8 -*-
|
from urllib.parse import parse_qsl
|
||||||
from six.moves.urllib.parse import parse_qsl
|
|
||||||
|
|
||||||
import django
|
from allauth.socialaccount.providers.base import Provider
|
||||||
from django.contrib import messages
|
from django.contrib import messages
|
||||||
from django.template.loader import render_to_string
|
from django.template.loader import render_to_string
|
||||||
|
from django.urls import reverse
|
||||||
from django.utils.http import urlencode
|
from django.utils.http import urlencode
|
||||||
from django.utils.safestring import mark_safe
|
from django.utils.safestring import mark_safe
|
||||||
|
|
||||||
from allauth.socialaccount.providers.base import Provider
|
|
||||||
|
|
||||||
if django.VERSION >= (1, 10):
|
|
||||||
from django.urls import reverse
|
|
||||||
else:
|
|
||||||
from django.core.urlresolvers import reverse
|
|
||||||
|
|
||||||
|
|
||||||
class CASProvider(Provider):
|
class CASProvider(Provider):
|
||||||
|
|
||||||
def get_auth_params(self, request, action):
|
def get_auth_params(self, request, action):
|
||||||
settings = self.get_settings()
|
settings = self.get_settings()
|
||||||
ret = dict(settings.get('AUTH_PARAMS', {}))
|
ret = dict(settings.get("AUTH_PARAMS", {}))
|
||||||
dynamic_auth_params = request.GET.get('auth_params')
|
dynamic_auth_params = request.GET.get("auth_params")
|
||||||
if dynamic_auth_params:
|
if dynamic_auth_params:
|
||||||
ret.update(dict(parse_qsl(dynamic_auth_params)))
|
ret.update(dict(parse_qsl(dynamic_auth_params)))
|
||||||
return ret
|
return ret
|
||||||
|
@ -68,11 +60,11 @@ class CASProvider(Provider):
|
||||||
"""
|
"""
|
||||||
uid, extra = data
|
uid, extra = data
|
||||||
return {
|
return {
|
||||||
'username': extra.get('username', uid),
|
"username": extra.get("username", uid),
|
||||||
'email': extra.get('email'),
|
"email": extra.get("email"),
|
||||||
'first_name': extra.get('first_name'),
|
"first_name": extra.get("first_name"),
|
||||||
'last_name': extra.get('last_name'),
|
"last_name": extra.get("last_name"),
|
||||||
'name': extra.get('name'),
|
"name": extra.get("name"),
|
||||||
}
|
}
|
||||||
|
|
||||||
def extract_email_addresses(self, data):
|
def extract_email_addresses(self, data):
|
||||||
|
@ -99,7 +91,7 @@ class CASProvider(Provider):
|
||||||
]
|
]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return super(CASProvider, self).extract_email_addresses(data)
|
return super().extract_email_addresses(data)
|
||||||
|
|
||||||
def extract_extra_data(self, data):
|
def extract_extra_data(self, data):
|
||||||
"""Extract the data to save to `SocialAccount.extra_data`.
|
"""Extract the data to save to `SocialAccount.extra_data`.
|
||||||
|
@ -119,7 +111,10 @@ class CASProvider(Provider):
|
||||||
##
|
##
|
||||||
|
|
||||||
def add_message_suggest_caslogout(
|
def add_message_suggest_caslogout(
|
||||||
self, request, next_page=None, level=None,
|
self,
|
||||||
|
request,
|
||||||
|
next_page=None,
|
||||||
|
level=None,
|
||||||
):
|
):
|
||||||
"""Add a message with a link for the user to logout of the CAS server.
|
"""Add a message with a link for the user to logout of the CAS server.
|
||||||
|
|
||||||
|
@ -144,14 +139,15 @@ class CASProvider(Provider):
|
||||||
# DefaultAccountAdapter.add_message is unusable because it always
|
# DefaultAccountAdapter.add_message is unusable because it always
|
||||||
# escape the message content.
|
# escape the message content.
|
||||||
|
|
||||||
template = 'socialaccount/messages/suggest_caslogout.html'
|
template = "socialaccount/messages/suggest_caslogout.html"
|
||||||
context = {
|
context = {
|
||||||
'provider': self,
|
"provider": self,
|
||||||
'logout_url': logout_url,
|
"logout_url": logout_url,
|
||||||
}
|
}
|
||||||
|
|
||||||
messages.add_message(
|
messages.add_message(
|
||||||
request, level,
|
request,
|
||||||
|
level,
|
||||||
mark_safe(render_to_string(template, context).strip()),
|
mark_safe(render_to_string(template, context).strip()),
|
||||||
fail_silently=True,
|
fail_silently=True,
|
||||||
)
|
)
|
||||||
|
@ -168,10 +164,7 @@ class CASProvider(Provider):
|
||||||
signal ``user_logged_out``.
|
signal ``user_logged_out``.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return (
|
return self.get_settings().get("MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT", False)
|
||||||
self.get_settings()
|
|
||||||
.get('MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT', False)
|
|
||||||
)
|
|
||||||
|
|
||||||
def message_suggest_caslogout_on_logout_level(self, request):
|
def message_suggest_caslogout_on_logout_level(self, request):
|
||||||
"""Level of the logout message issued on user logout.
|
"""Level of the logout message issued on user logout.
|
||||||
|
@ -185,9 +178,8 @@ class CASProvider(Provider):
|
||||||
signal ``user_logged_out``.
|
signal ``user_logged_out``.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return (
|
return self.get_settings().get(
|
||||||
self.get_settings()
|
"MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT_LEVEL", messages.INFO
|
||||||
.get('MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT_LEVEL', messages.INFO)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
##
|
##
|
||||||
|
@ -195,19 +187,19 @@ class CASProvider(Provider):
|
||||||
##
|
##
|
||||||
|
|
||||||
def get_login_url(self, request, **kwargs):
|
def get_login_url(self, request, **kwargs):
|
||||||
url = reverse(self.id + '_login')
|
url = reverse(self.id + "_login")
|
||||||
if kwargs:
|
if kwargs:
|
||||||
url += '?' + urlencode(kwargs)
|
url += "?" + urlencode(kwargs)
|
||||||
return url
|
return url
|
||||||
|
|
||||||
def get_callback_url(self, request, **kwargs):
|
def get_callback_url(self, request, **kwargs):
|
||||||
url = reverse(self.id + '_callback')
|
url = reverse(self.id + "_callback")
|
||||||
if kwargs:
|
if kwargs:
|
||||||
url += '?' + urlencode(kwargs)
|
url += "?" + urlencode(kwargs)
|
||||||
return url
|
return url
|
||||||
|
|
||||||
def get_logout_url(self, request, **kwargs):
|
def get_logout_url(self, request, **kwargs):
|
||||||
url = reverse(self.id + '_logout')
|
url = reverse(self.id + "_logout")
|
||||||
if kwargs:
|
if kwargs:
|
||||||
url += '?' + urlencode(kwargs)
|
url += "?" + urlencode(kwargs)
|
||||||
return url
|
return url
|
||||||
|
|
|
@ -1,10 +1,8 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
from django.contrib.auth.signals import user_logged_out
|
|
||||||
from django.dispatch import receiver
|
|
||||||
|
|
||||||
from allauth.account.adapter import get_adapter
|
from allauth.account.adapter import get_adapter
|
||||||
from allauth.account.utils import get_next_redirect_url
|
from allauth.account.utils import get_next_redirect_url
|
||||||
from allauth.socialaccount import providers
|
from allauth.socialaccount import providers
|
||||||
|
from django.contrib.auth.signals import user_logged_out
|
||||||
|
from django.dispatch import receiver
|
||||||
|
|
||||||
from . import CAS_PROVIDER_SESSION_KEY
|
from . import CAS_PROVIDER_SESSION_KEY
|
||||||
|
|
||||||
|
@ -21,12 +19,12 @@ def cas_account_logout(sender, request, **kwargs):
|
||||||
if not provider.message_suggest_caslogout_on_logout(request):
|
if not provider.message_suggest_caslogout_on_logout(request):
|
||||||
return
|
return
|
||||||
|
|
||||||
next_page = (
|
next_page = get_next_redirect_url(request) or get_adapter(
|
||||||
get_next_redirect_url(request) or
|
request
|
||||||
get_adapter(request).get_logout_redirect_url(request)
|
).get_logout_redirect_url(request)
|
||||||
)
|
|
||||||
|
|
||||||
provider.add_message_suggest_caslogout(
|
provider.add_message_suggest_caslogout(
|
||||||
request, next_page=next_page,
|
request,
|
||||||
|
next_page=next_page,
|
||||||
level=provider.message_suggest_caslogout_on_logout_level(request),
|
level=provider.message_suggest_caslogout_on_logout_level(request),
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,29 +1,20 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
try:
|
try:
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import django
|
|
||||||
from django.conf import settings
|
|
||||||
from django.test import TestCase
|
|
||||||
|
|
||||||
import cas
|
import cas
|
||||||
|
from django.conf import settings
|
||||||
|
from django.test import TestCase
|
||||||
|
from django.urls import reverse
|
||||||
|
|
||||||
from allauth_cas import CAS_PROVIDER_SESSION_KEY
|
from allauth_cas import CAS_PROVIDER_SESSION_KEY
|
||||||
|
|
||||||
if django.VERSION >= (1, 10):
|
|
||||||
from django.urls import reverse
|
|
||||||
else:
|
|
||||||
from django.core.urlresolvers import reverse
|
|
||||||
|
|
||||||
|
|
||||||
class CASTestCase(TestCase):
|
class CASTestCase(TestCase):
|
||||||
|
|
||||||
def client_cas_login(
|
def client_cas_login(
|
||||||
self,
|
self, client, provider_id=None, username=None, attributes=None
|
||||||
client, provider_id='theid',
|
):
|
||||||
username=None, attributes={}):
|
|
||||||
"""
|
"""
|
||||||
Authenticate client through provider_id.
|
Authenticate client through provider_id.
|
||||||
|
|
||||||
|
@ -32,20 +23,22 @@ class CASTestCase(TestCase):
|
||||||
username and attributes control the CAS server response when ticket is
|
username and attributes control the CAS server response when ticket is
|
||||||
checked.
|
checked.
|
||||||
"""
|
"""
|
||||||
client.get(reverse('{id}_login'.format(id=provider_id)))
|
if attributes is None:
|
||||||
|
attributes = {}
|
||||||
|
if provider_id is None:
|
||||||
|
provider_id = "theid"
|
||||||
|
client.get(reverse(f"{provider_id}_login"))
|
||||||
self.patch_cas_response(
|
self.patch_cas_response(
|
||||||
valid_ticket='__all__',
|
valid_ticket="__all__",
|
||||||
username=username, attributes=attributes,
|
username=username,
|
||||||
|
attributes=attributes,
|
||||||
)
|
)
|
||||||
callback_url = reverse('{id}_callback'.format(id=provider_id))
|
callback_url = reverse(f"{provider_id}_callback")
|
||||||
r = client.get(callback_url, {'ticket': 'fake-ticket'})
|
r = client.get(callback_url, {"ticket": "fake-ticket"})
|
||||||
self.patch_cas_response_stop()
|
self.patch_cas_response_stop()
|
||||||
return r
|
return r
|
||||||
|
|
||||||
def patch_cas_response(
|
def patch_cas_response(self, valid_ticket, username=None, attributes=None):
|
||||||
self,
|
|
||||||
valid_ticket,
|
|
||||||
username=None, attributes={}):
|
|
||||||
"""
|
"""
|
||||||
Patch the CASClient class used by views of CAS providers.
|
Patch the CASClient class used by views of CAS providers.
|
||||||
|
|
||||||
|
@ -68,35 +61,38 @@ class CASTestCase(TestCase):
|
||||||
Note that valid_ticket sould be a string (which is the type of the
|
Note that valid_ticket sould be a string (which is the type of the
|
||||||
ticket retrieved from GET parameter on request on the callback view).
|
ticket retrieved from GET parameter on request on the callback view).
|
||||||
"""
|
"""
|
||||||
if hasattr(self, '_patch_cas_client'):
|
if attributes is None:
|
||||||
|
attributes = {}
|
||||||
|
if hasattr(self, "_patch_cas_client"):
|
||||||
self.patch_cas_response_stop()
|
self.patch_cas_response_stop()
|
||||||
|
|
||||||
class MockCASClient(object):
|
class MockCASClient:
|
||||||
_username = username
|
_username = username
|
||||||
|
|
||||||
def __new__(self_client, *args, **kwargs):
|
def __new__(self_client, *args, **kwargs):
|
||||||
version = kwargs.pop('version')
|
version = kwargs.pop("version")
|
||||||
if version in (1, '1'):
|
if version in (1, "1"):
|
||||||
client_class = cas.CASClientV1
|
client_class = cas.CASClientV1
|
||||||
elif version in (2, '2'):
|
elif version in (2, "2"):
|
||||||
client_class = cas.CASClientV2
|
client_class = cas.CASClientV2
|
||||||
elif version in (3, '3'):
|
elif version in (3, "3"):
|
||||||
client_class = cas.CASClientV3
|
client_class = cas.CASClientV3
|
||||||
elif version == 'CAS_2_SAML_1_0':
|
elif version == "CAS_2_SAML_1_0":
|
||||||
client_class = cas.CASClientWithSAMLV1
|
client_class = cas.CASClientWithSAMLV1
|
||||||
else:
|
else:
|
||||||
raise ValueError('Unsupported CAS_VERSION %r' % version)
|
raise ValueError("Unsupported CAS_VERSION %r" % version)
|
||||||
|
|
||||||
client_class._username = self_client._username
|
client_class._username = self_client._username
|
||||||
|
|
||||||
def verify_ticket(self, ticket):
|
def verify_ticket(self, ticket):
|
||||||
if valid_ticket == '__all__' or ticket == valid_ticket:
|
if valid_ticket == "__all__" or ticket == valid_ticket:
|
||||||
username = self._username or 'username'
|
username = self._username or "username"
|
||||||
return username, attributes, None
|
return username, attributes, None
|
||||||
return None, {}, None
|
return None, {}, None
|
||||||
|
|
||||||
patcher = patch.object(
|
patcher = patch.object(
|
||||||
client_class, 'verify_ticket',
|
client_class,
|
||||||
|
"verify_ticket",
|
||||||
new=verify_ticket,
|
new=verify_ticket,
|
||||||
)
|
)
|
||||||
patcher.start()
|
patcher.start()
|
||||||
|
@ -104,7 +100,7 @@ class CASTestCase(TestCase):
|
||||||
return client_class(*args, **kwargs)
|
return client_class(*args, **kwargs)
|
||||||
|
|
||||||
self._patch_cas_client = patch(
|
self._patch_cas_client = patch(
|
||||||
'allauth_cas.views.cas.CASClient',
|
"allauth_cas.views.cas.CASClient",
|
||||||
MockCASClient,
|
MockCASClient,
|
||||||
)
|
)
|
||||||
self._patch_cas_client.start()
|
self._patch_cas_client.start()
|
||||||
|
@ -114,12 +110,11 @@ class CASTestCase(TestCase):
|
||||||
del self._patch_cas_client
|
del self._patch_cas_client
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
if hasattr(self, '_patch_cas_client'):
|
if hasattr(self, "_patch_cas_client"):
|
||||||
self.patch_cas_response_stop()
|
self.patch_cas_response_stop()
|
||||||
|
|
||||||
|
|
||||||
class CASViewTestCase(CASTestCase):
|
class CASViewTestCase(CASTestCase):
|
||||||
|
|
||||||
def assertLoginSuccess(self, response, redirect_to=None):
|
def assertLoginSuccess(self, response, redirect_to=None):
|
||||||
"""
|
"""
|
||||||
Asserts response corresponds to a successful login.
|
Asserts response corresponds to a successful login.
|
||||||
|
@ -133,7 +128,8 @@ class CASViewTestCase(CASTestCase):
|
||||||
redirect_to = settings.LOGIN_REDIRECT_URL
|
redirect_to = settings.LOGIN_REDIRECT_URL
|
||||||
|
|
||||||
self.assertRedirects(
|
self.assertRedirects(
|
||||||
response, redirect_to,
|
response,
|
||||||
|
redirect_to,
|
||||||
fetch_redirect_response=False,
|
fetch_redirect_response=False,
|
||||||
)
|
)
|
||||||
self.assertIn(
|
self.assertIn(
|
||||||
|
@ -146,6 +142,6 @@ class CASViewTestCase(CASTestCase):
|
||||||
Asserts response corresponds to a failed login.
|
Asserts response corresponds to a failed login.
|
||||||
"""
|
"""
|
||||||
return self.assertInHTML(
|
return self.assertInHTML(
|
||||||
'<h1>Social Network Login Failure</h1>',
|
"<h1>Social Network Login Failure</h1>",
|
||||||
str(response.content),
|
str(response.content),
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
# -*- coding: utf-8 -*-
|
from django.urls import include, path, re_path
|
||||||
from django.conf.urls import include, url
|
|
||||||
from django.utils.module_loading import import_string
|
from django.utils.module_loading import import_string
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,45 +6,44 @@ def default_urlpatterns(provider):
|
||||||
package = provider.get_package()
|
package = provider.get_package()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
login_view = import_string(package + '.views.login')
|
login_view = import_string(package + ".views.login")
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"The login view for the '{id}' provider is lacking from the "
|
"The login view for the '{id}' provider is lacking from the "
|
||||||
"'views' module of its app.\n"
|
"'views' module of its app.\n"
|
||||||
"You may want to add:\n"
|
"You may want to add:\n"
|
||||||
"from allauth_cas.views import CASLoginView\n\n"
|
"from allauth_cas.views import CASLoginView\n\n"
|
||||||
"login = CASLoginView.adapter_view(<LocalCASAdapter>)"
|
"login = CASLoginView.adapter_view(<LocalCASAdapter>)".format(
|
||||||
.format(id=provider.id)
|
id=provider.id
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
callback_view = import_string(package + '.views.callback')
|
callback_view = import_string(package + ".views.callback")
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"The callback view for the '{id}' provider is lacking from the "
|
"The callback view for the '{id}' provider is lacking from the "
|
||||||
"'views' module of its app.\n"
|
"'views' module of its app.\n"
|
||||||
"You may want to add:\n"
|
"You may want to add:\n"
|
||||||
"from allauth_cas.views import CASCallbackView\n\n"
|
"from allauth_cas.views import CASCallbackView\n\n"
|
||||||
"callback = CASCallbackView.adapter_view(<LocalCASAdapter>)"
|
"callback = CASCallbackView.adapter_view(<LocalCASAdapter>)".format(
|
||||||
.format(id=provider.id)
|
id=provider.id
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logout_view = import_string(package + '.views.logout')
|
logout_view = import_string(package + ".views.logout")
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logout_view = None
|
logout_view = None
|
||||||
|
|
||||||
urlpatterns = [
|
urlpatterns = [
|
||||||
url('^login/$', login_view,
|
path("login/", login_view, name=provider.id + "_login"),
|
||||||
name=provider.id + '_login'),
|
path("login/callback/", callback_view, name=provider.id + "_callback"),
|
||||||
url('^login/callback/$', callback_view,
|
|
||||||
name=provider.id + '_callback'),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
if logout_view is not None:
|
if logout_view is not None:
|
||||||
urlpatterns += [
|
urlpatterns += [
|
||||||
url('^logout/$', logout_view,
|
path("logout/", logout_view, name=provider.id + "_logout"),
|
||||||
name=provider.id + '_logout'),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
return [url('^' + provider.get_slug() + '/', include(urlpatterns))]
|
return [re_path("^" + provider.get_slug() + "/", include(urlpatterns))]
|
||||||
|
|
|
@ -1,28 +1,26 @@
|
||||||
# -*- coding: utf-8 -*-
|
import cas
|
||||||
from django.http import HttpResponseRedirect
|
|
||||||
from django.utils.functional import cached_property
|
|
||||||
|
|
||||||
from allauth.account.adapter import get_adapter
|
from allauth.account.adapter import get_adapter
|
||||||
from allauth.account.utils import get_next_redirect_url
|
from allauth.account.utils import get_next_redirect_url
|
||||||
from allauth.socialaccount import providers
|
from allauth.socialaccount import providers
|
||||||
from allauth.socialaccount.helpers import (
|
from allauth.socialaccount.helpers import (
|
||||||
complete_social_login, render_authentication_error,
|
complete_social_login,
|
||||||
|
render_authentication_error,
|
||||||
)
|
)
|
||||||
from allauth.socialaccount.models import SocialLogin
|
from allauth.socialaccount.models import SocialLogin
|
||||||
|
from django.http import HttpResponseRedirect
|
||||||
import cas
|
from django.utils.functional import cached_property
|
||||||
|
|
||||||
from . import CAS_PROVIDER_SESSION_KEY
|
from . import CAS_PROVIDER_SESSION_KEY
|
||||||
from .exceptions import CASAuthenticationError
|
from .exceptions import CASAuthenticationError
|
||||||
|
|
||||||
|
|
||||||
class AuthAction(object):
|
class AuthAction:
|
||||||
AUTHENTICATE = 'authenticate'
|
AUTHENTICATE = "authenticate"
|
||||||
REAUTHENTICATE = 'reauthenticate'
|
REAUTHENTICATE = "reauthenticate"
|
||||||
DEAUTHENTICATE = 'deauthenticate'
|
DEAUTHENTICATE = "deauthenticate"
|
||||||
|
|
||||||
|
|
||||||
class CASAdapter(object):
|
class CASAdapter:
|
||||||
#: CAS server url.
|
#: CAS server url.
|
||||||
url = None
|
url = None
|
||||||
#: CAS server version.
|
#: CAS server version.
|
||||||
|
@ -92,19 +90,19 @@ class CASAdapter(object):
|
||||||
"""
|
"""
|
||||||
redirect_to = get_next_redirect_url(request)
|
redirect_to = get_next_redirect_url(request)
|
||||||
|
|
||||||
callback_kwargs = {'next': redirect_to} if redirect_to else {}
|
callback_kwargs = {"next": redirect_to} if redirect_to else {}
|
||||||
callback_url = (
|
callback_url = self.provider.get_callback_url(request, **callback_kwargs)
|
||||||
self.provider.get_callback_url(request, **callback_kwargs))
|
|
||||||
|
|
||||||
service_url = request.build_absolute_uri(callback_url)
|
service_url = request.build_absolute_uri(callback_url)
|
||||||
|
|
||||||
return service_url
|
return service_url
|
||||||
|
|
||||||
|
|
||||||
class CASView(object):
|
class CASView:
|
||||||
"""
|
"""
|
||||||
Base class for CAS views.
|
Base class for CAS views.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def adapter_view(cls, adapter):
|
def adapter_view(cls, adapter):
|
||||||
"""Transform the view class into a view function.
|
"""Transform the view class into a view function.
|
||||||
|
@ -124,6 +122,7 @@ class CASView(object):
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def view(request, *args, **kwargs):
|
def view(request, *args, **kwargs):
|
||||||
# Prepare the func-view.
|
# Prepare the func-view.
|
||||||
self = cls()
|
self = cls()
|
||||||
|
@ -169,19 +168,17 @@ class CASView(object):
|
||||||
|
|
||||||
|
|
||||||
class CASLoginView(CASView):
|
class CASLoginView(CASView):
|
||||||
|
|
||||||
def dispatch(self, request):
|
def dispatch(self, request):
|
||||||
"""
|
"""
|
||||||
Redirects to the CAS server login page.
|
Redirects to the CAS server login page.
|
||||||
"""
|
"""
|
||||||
action = request.GET.get('action', AuthAction.AUTHENTICATE)
|
action = request.GET.get("action", AuthAction.AUTHENTICATE)
|
||||||
SocialLogin.stash_state(request)
|
SocialLogin.stash_state(request)
|
||||||
client = self.get_client(request, action=action)
|
client = self.get_client(request, action=action)
|
||||||
return HttpResponseRedirect(client.get_login_url())
|
return HttpResponseRedirect(client.get_login_url())
|
||||||
|
|
||||||
|
|
||||||
class CASCallbackView(CASView):
|
class CASCallbackView(CASView):
|
||||||
|
|
||||||
def dispatch(self, request):
|
def dispatch(self, request):
|
||||||
"""
|
"""
|
||||||
The CAS server redirects the user to this view after a successful
|
The CAS server redirects the user to this view after a successful
|
||||||
|
@ -195,11 +192,9 @@ class CASCallbackView(CASView):
|
||||||
|
|
||||||
# CAS server should let a ticket.
|
# CAS server should let a ticket.
|
||||||
try:
|
try:
|
||||||
ticket = request.GET['ticket']
|
ticket = request.GET["ticket"]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise CASAuthenticationError(
|
raise CASAuthenticationError("CAS server didn't respond with a ticket.")
|
||||||
"CAS server didn't respond with a ticket."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check ticket validity.
|
# Check ticket validity.
|
||||||
# Response format on:
|
# Response format on:
|
||||||
|
@ -210,9 +205,7 @@ class CASCallbackView(CASView):
|
||||||
uid, extra, _ = response
|
uid, extra, _ = response
|
||||||
|
|
||||||
if not uid:
|
if not uid:
|
||||||
raise CASAuthenticationError(
|
raise CASAuthenticationError("CAS server doesn't validate the ticket.")
|
||||||
"CAS server doesn't validate the ticket."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Keep tracks of the last used CAS provider.
|
# Keep tracks of the last used CAS provider.
|
||||||
request.session[CAS_PROVIDER_SESSION_KEY] = self.provider.id
|
request.session[CAS_PROVIDER_SESSION_KEY] = self.provider.id
|
||||||
|
@ -226,7 +219,6 @@ class CASCallbackView(CASView):
|
||||||
|
|
||||||
|
|
||||||
class CASLogoutView(CASView):
|
class CASLogoutView(CASView):
|
||||||
|
|
||||||
def dispatch(self, request, next_page=None):
|
def dispatch(self, request, next_page=None):
|
||||||
"""
|
"""
|
||||||
Redirects to the CAS server logout page.
|
Redirects to the CAS server logout page.
|
||||||
|
@ -248,7 +240,6 @@ class CASLogoutView(CASView):
|
||||||
Returns the url to redirect after logout.
|
Returns the url to redirect after logout.
|
||||||
"""
|
"""
|
||||||
request = self.request
|
request = self.request
|
||||||
return (
|
return get_next_redirect_url(request) or get_adapter(
|
||||||
get_next_redirect_url(request) or
|
request
|
||||||
get_adapter(request).get_logout_redirect_url(request)
|
).get_logout_redirect_url(request)
|
||||||
)
|
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
@ -7,10 +6,10 @@ import django
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.test.utils import get_runner
|
from django.test.utils import get_runner
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
os.environ['DJANGO_SETTINGS_MODULE'] = 'tests.settings'
|
os.environ["DJANGO_SETTINGS_MODULE"] = "tests.settings"
|
||||||
django.setup()
|
django.setup()
|
||||||
TestRunner = get_runner(settings)
|
TestRunner = get_runner(settings)
|
||||||
test_runner = TestRunner()
|
test_runner = TestRunner()
|
||||||
failures = test_runner.run_tests(sys.argv[1:] or ['tests'])
|
failures = test_runner.run_tests(sys.argv[1:] or ["tests"])
|
||||||
sys.exit(bool(failures))
|
sys.exit(bool(failures))
|
||||||
|
|
45
setup.cfg
45
setup.cfg
|
@ -1,18 +1,45 @@
|
||||||
[flake8]
|
[flake8]
|
||||||
|
max-line-length = 120
|
||||||
|
exclude = .tox,.git,*/migrations/*,*/static/CACHE/*,docs,node_modules,venv
|
||||||
# E731: lambda expression
|
# E731: lambda expression
|
||||||
ignore = E731
|
ignore = E731
|
||||||
|
|
||||||
|
[pycodestyle]
|
||||||
|
max-line-length = 120
|
||||||
|
exclude = .tox,.git,*/migrations/*,*/static/CACHE/*,docs,node_modules,venv
|
||||||
|
|
||||||
[isort]
|
[isort]
|
||||||
combine_as_imports = True
|
line_length = 88
|
||||||
|
known_first_party = geoip,config
|
||||||
|
multi_line_output = 3
|
||||||
default_section = THIRDPARTY
|
default_section = THIRDPARTY
|
||||||
include_trailing_comma = True
|
skip = venv/
|
||||||
known_allauth = allauth
|
skip_glob = **/migrations/*.py
|
||||||
known_future_library = future,six
|
include_trailing_comma = true
|
||||||
known_django = django
|
force_grid_wrap = 0
|
||||||
known_first_party = allauth_cas
|
use_parentheses = true
|
||||||
multi_line_output = 5
|
|
||||||
not_skip = __init__.py
|
[mypy]
|
||||||
sections = FUTURE,STDLIB,DJANGO,ALLAUTH,THIRDPARTY,FIRSTPARTY,LOCALFOLDER
|
python_version = 3.9
|
||||||
|
check_untyped_defs = True
|
||||||
|
ignore_missing_imports = True
|
||||||
|
warn_unused_ignores = True
|
||||||
|
warn_redundant_casts = True
|
||||||
|
warn_unused_configs = True
|
||||||
|
plugins = mypy_django_plugin.main
|
||||||
|
|
||||||
|
[mypy.plugins.django-stubs]
|
||||||
|
django_settings_module = config.settings.test
|
||||||
|
|
||||||
|
[mypy-*.migrations.*]
|
||||||
|
# Django migrations should not produce any errors:
|
||||||
|
ignore_errors = True
|
||||||
|
|
||||||
|
[coverage:run]
|
||||||
|
include = geoip/*
|
||||||
|
omit = *migrations*, *tests*
|
||||||
|
plugins =
|
||||||
|
django_coverage_plugin
|
||||||
|
|
||||||
[bdist_wheel]
|
[bdist_wheel]
|
||||||
universal = 1
|
universal = 1
|
||||||
|
|
74
setup.py
74
setup.py
|
@ -1,4 +1,3 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from setuptools import find_packages, setup
|
from setuptools import find_packages, setup
|
||||||
|
@ -7,49 +6,56 @@ from allauth_cas import __version__
|
||||||
|
|
||||||
BASE_DIR = os.path.dirname(__file__)
|
BASE_DIR = os.path.dirname(__file__)
|
||||||
|
|
||||||
with open(os.path.join(BASE_DIR, 'README.rst')) as readme:
|
with open(os.path.join(BASE_DIR, "README.rst")) as readme:
|
||||||
README = readme.read()
|
README = readme.read()
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name='django-allauth-cas',
|
name="django-allauth-cas",
|
||||||
version=__version__,
|
version=__version__,
|
||||||
description='CAS support for django-allauth.',
|
description="CAS support for django-allauth.",
|
||||||
author='Aurélien Delobelle',
|
author="Aurélien Delobelle",
|
||||||
author_email='aurelien.delobelle@gmail.com',
|
author_email="aurelien.delobelle@gmail.com",
|
||||||
keywords='django allauth cas authentication',
|
keywords="django allauth cas authentication",
|
||||||
long_description=README,
|
long_description=README,
|
||||||
url='https://github.com/aureplop/django-allauth-cas',
|
url="https://github.com/aureplop/django-allauth-cas",
|
||||||
classifiers=[
|
classifiers=[
|
||||||
'Development Status :: 4 - Beta',
|
"Development Status :: 4 - Beta",
|
||||||
'Environment :: Web Environment',
|
"Environment :: Web Environment",
|
||||||
'Framework :: Django',
|
"Framework :: Django",
|
||||||
'Framework :: Django :: 1.8',
|
"Framework :: Django :: 1.8",
|
||||||
'Framework :: Django :: 1.9',
|
"Framework :: Django :: 1.9",
|
||||||
'Framework :: Django :: 1.10',
|
"Framework :: Django :: 1.10",
|
||||||
'Framework :: Django :: 1.11',
|
"Framework :: Django :: 1.11",
|
||||||
'Framework :: Django :: 2.0',
|
"Framework :: Django :: 2.0",
|
||||||
'Intended Audience :: Developers',
|
"Framework :: Django :: 3.0",
|
||||||
'License :: OSI Approved :: MIT License',
|
"Framework :: Django :: 4.0",
|
||||||
'Operating System :: OS Independent',
|
"Framework :: Django :: 5.0",
|
||||||
'Programming Language :: Python',
|
"Intended Audience :: Developers",
|
||||||
'Programming Language :: Python :: 2',
|
"License :: OSI Approved :: MIT License",
|
||||||
'Programming Language :: Python :: 2.7',
|
"Operating System :: OS Independent",
|
||||||
'Programming Language :: Python :: 3',
|
"Programming Language :: Python",
|
||||||
'Programming Language :: Python :: 3.4',
|
"Programming Language :: Python :: 2",
|
||||||
'Programming Language :: Python :: 3.5',
|
"Programming Language :: Python :: 2.7",
|
||||||
'Programming Language :: Python :: 3.6',
|
"Programming Language :: Python :: 3",
|
||||||
'Topic :: Internet :: WWW/HTTP',
|
"Programming Language :: Python :: 3.4",
|
||||||
|
"Programming Language :: Python :: 3.5",
|
||||||
|
"Programming Language :: Python :: 3.6",
|
||||||
|
"Programming Language :: Python :: 3.7",
|
||||||
|
"Programming Language :: Python :: 3.8",
|
||||||
|
"Programming Language :: Python :: 3.9",
|
||||||
|
"Programming Language :: Python :: 3.11",
|
||||||
|
"Topic :: Internet :: WWW/HTTP",
|
||||||
],
|
],
|
||||||
license='MIT',
|
license="MIT",
|
||||||
packages=find_packages(exclude=['tests']),
|
packages=find_packages(exclude=["tests"]),
|
||||||
include_package_data=True,
|
include_package_data=True,
|
||||||
install_requires=[
|
install_requires=[
|
||||||
'django-allauth',
|
"django-allauth",
|
||||||
'python-cas',
|
"python-cas",
|
||||||
'six',
|
"six",
|
||||||
],
|
],
|
||||||
extras_require={
|
extras_require={
|
||||||
'docs': ['sphinx'],
|
"docs": ["sphinx"],
|
||||||
'tests': ['tox'],
|
"tests": ["tox"],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
from allauth.socialaccount.providers.base import ProviderAccount
|
from allauth.socialaccount.providers.base import ProviderAccount
|
||||||
|
|
||||||
from allauth_cas.providers import CASProvider
|
from allauth_cas.providers import CASProvider
|
||||||
|
@ -9,8 +8,8 @@ class ExampleCASAccount(ProviderAccount):
|
||||||
|
|
||||||
|
|
||||||
class ExampleCASProvider(CASProvider):
|
class ExampleCASProvider(CASProvider):
|
||||||
id = 'theid'
|
id = "theid"
|
||||||
name = 'The Provider'
|
name = "The Provider"
|
||||||
account_class = ExampleCASAccount
|
account_class = ExampleCASAccount
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
from allauth_cas.urls import default_urlpatterns
|
from allauth_cas.urls import default_urlpatterns
|
||||||
|
|
||||||
from .provider import ExampleCASProvider
|
from .provider import ExampleCASProvider
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
from allauth_cas import views
|
from allauth_cas import views
|
||||||
|
|
||||||
from .provider import ExampleCASProvider
|
from .provider import ExampleCASProvider
|
||||||
|
@ -6,7 +5,7 @@ from .provider import ExampleCASProvider
|
||||||
|
|
||||||
class ExampleCASAdapter(views.CASAdapter):
|
class ExampleCASAdapter(views.CASAdapter):
|
||||||
provider_id = ExampleCASProvider.id
|
provider_id = ExampleCASProvider.id
|
||||||
url = 'https://server.cas'
|
url = "https://server.cas"
|
||||||
version = 2
|
version = 2
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,63 +1,54 @@
|
||||||
# -*- coding: utf-8 -*-
|
SECRET_KEY = "iamabird"
|
||||||
import django
|
|
||||||
|
|
||||||
SECRET_KEY = 'iamabird'
|
|
||||||
|
|
||||||
INSTALLED_APPS = [
|
INSTALLED_APPS = [
|
||||||
'django.contrib.admin',
|
"django.contrib.admin",
|
||||||
'django.contrib.auth',
|
"django.contrib.auth",
|
||||||
'django.contrib.contenttypes',
|
"django.contrib.contenttypes",
|
||||||
'django.contrib.messages',
|
"django.contrib.messages",
|
||||||
'django.contrib.sessions',
|
"django.contrib.sessions",
|
||||||
'django.contrib.sites',
|
"django.contrib.sites",
|
||||||
'django.contrib.staticfiles',
|
"django.contrib.staticfiles",
|
||||||
|
"allauth",
|
||||||
'allauth',
|
"allauth.account",
|
||||||
'allauth.account',
|
"allauth.socialaccount",
|
||||||
'allauth.socialaccount',
|
"allauth_cas",
|
||||||
|
"tests.example", # Dummy CAS provider app
|
||||||
'allauth_cas',
|
|
||||||
|
|
||||||
'tests.example', # Dummy CAS provider app
|
|
||||||
]
|
]
|
||||||
|
|
||||||
DATABASES = {
|
DATABASES = {
|
||||||
'default': {
|
"default": {
|
||||||
'ENGINE': 'django.db.backends.sqlite3',
|
"ENGINE": "django.db.backends.sqlite3",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
AUTHENTICATION_BACKENDS = [
|
AUTHENTICATION_BACKENDS = [
|
||||||
'allauth.account.auth_backends.AuthenticationBackend',
|
"allauth.account.auth_backends.AuthenticationBackend",
|
||||||
]
|
]
|
||||||
|
|
||||||
_MIDDLEWARES = [
|
_MIDDLEWARES = [
|
||||||
'django.middleware.common.CommonMiddleware',
|
"django.middleware.common.CommonMiddleware",
|
||||||
'django.contrib.sessions.middleware.SessionMiddleware',
|
"django.contrib.sessions.middleware.SessionMiddleware",
|
||||||
'django.middleware.csrf.CsrfViewMiddleware',
|
"django.middleware.csrf.CsrfViewMiddleware",
|
||||||
'django.contrib.auth.middleware.AuthenticationMiddleware',
|
"django.contrib.auth.middleware.AuthenticationMiddleware",
|
||||||
'django.contrib.messages.middleware.MessageMiddleware',
|
"django.contrib.messages.middleware.MessageMiddleware",
|
||||||
]
|
]
|
||||||
|
|
||||||
if django.VERSION >= (1, 10):
|
MIDDLEWARE = _MIDDLEWARES
|
||||||
MIDDLEWARE = _MIDDLEWARES
|
|
||||||
else:
|
|
||||||
MIDDLEWARE_CLASSES = _MIDDLEWARES
|
|
||||||
|
|
||||||
TEMPLATES = [
|
TEMPLATES = [
|
||||||
{
|
{
|
||||||
'BACKEND': 'django.template.backends.django.DjangoTemplates',
|
"BACKEND": "django.template.backends.django.DjangoTemplates",
|
||||||
'DIRS': [],
|
"DIRS": [],
|
||||||
'APP_DIRS': True,
|
"APP_DIRS": True,
|
||||||
'OPTIONS': {
|
"OPTIONS": {
|
||||||
'context_processors': [
|
"context_processors": [
|
||||||
'django.template.context_processors.debug',
|
"django.template.context_processors.debug",
|
||||||
'django.template.context_processors.request',
|
"django.template.context_processors.request",
|
||||||
'django.contrib.auth.context_processors.auth',
|
"django.contrib.auth.context_processors.auth",
|
||||||
'django.contrib.messages.context_processors.messages',
|
"django.contrib.messages.context_processors.messages",
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
ROOT_URLCONF = 'tests.urls'
|
ROOT_URLCONF = "tests.urls"
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
from django.contrib import messages
|
from django.contrib import messages
|
||||||
from django.contrib.auth import get_user_model
|
from django.contrib.auth import get_user_model
|
||||||
from django.contrib.messages.api import get_messages
|
from django.contrib.messages.api import get_messages
|
||||||
|
@ -13,7 +12,7 @@ User = get_user_model()
|
||||||
class LogoutFlowTests(CASTestCase):
|
class LogoutFlowTests(CASTestCase):
|
||||||
expected_msg_str = (
|
expected_msg_str = (
|
||||||
"To logout of The Provider, please close your browser, or visit this "
|
"To logout of The Provider, please close your browser, or visit this "
|
||||||
"<a href=\"/accounts/theid/logout/?next=%2Fredir%2F\">"
|
'<a href="/accounts/theid/logout/?next=%2Fredir%2F">'
|
||||||
"link</a>."
|
"link</a>."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -28,42 +27,45 @@ class LogoutFlowTests(CASTestCase):
|
||||||
)
|
)
|
||||||
self.assertTemplateNotUsed(
|
self.assertTemplateNotUsed(
|
||||||
response,
|
response,
|
||||||
'socialaccount/messages/suggest_caslogout.html',
|
"socialaccount/messages/suggest_caslogout.html",
|
||||||
)
|
)
|
||||||
|
|
||||||
@override_settings(SOCIALACCOUNT_PROVIDERS={
|
@override_settings(
|
||||||
'theid': {
|
SOCIALACCOUNT_PROVIDERS={
|
||||||
'MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT': True,
|
"theid": {
|
||||||
'MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT_LEVEL': messages.WARNING,
|
"MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT": True,
|
||||||
},
|
"MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT_LEVEL": messages.WARNING,
|
||||||
})
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
def test_message_on_logout(self):
|
def test_message_on_logout(self):
|
||||||
"""
|
"""
|
||||||
Message is sent to propose user to logout of CAS.
|
Message is sent to propose user to logout of CAS.
|
||||||
"""
|
"""
|
||||||
r = self.client.post('/accounts/logout/?next=/redir/')
|
r = self.client.post("/accounts/logout/?next=/redir/")
|
||||||
r_messages = get_messages(r.wsgi_request)
|
r_messages = get_messages(r.wsgi_request)
|
||||||
|
|
||||||
expected_msg = Message(messages.WARNING, self.expected_msg_str)
|
expected_msg = Message(messages.WARNING, self.expected_msg_str)
|
||||||
|
|
||||||
self.assertIn(expected_msg, r_messages)
|
self.assertIn(expected_msg, r_messages)
|
||||||
self.assertTemplateUsed(
|
self.assertTemplateUsed(r, "socialaccount/messages/suggest_caslogout.html")
|
||||||
r, 'socialaccount/messages/suggest_caslogout.html')
|
|
||||||
|
|
||||||
def test_message_on_logout_disabled(self):
|
def test_message_on_logout_disabled(self):
|
||||||
r = self.client.post('/accounts/logout/')
|
r = self.client.post("/accounts/logout/")
|
||||||
self.assertCASLogoutNotInMessages(r)
|
self.assertCASLogoutNotInMessages(r)
|
||||||
|
|
||||||
@override_settings(SOCIALACCOUNT_PROVIDERS={
|
@override_settings(
|
||||||
'theid': {'MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT': True},
|
SOCIALACCOUNT_PROVIDERS={
|
||||||
})
|
"theid": {"MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT": True},
|
||||||
|
}
|
||||||
|
)
|
||||||
def test_other_logout(self):
|
def test_other_logout(self):
|
||||||
"""
|
"""
|
||||||
The CAS logout message doesn't appear with other login methods.
|
The CAS logout message doesn't appear with other login methods.
|
||||||
"""
|
"""
|
||||||
User.objects.create_user('user', '', 'user')
|
User.objects.create_user("user", "", "user")
|
||||||
client = Client()
|
client = Client()
|
||||||
client.login(username='user', password='user')
|
client.login(username="user", password="user")
|
||||||
|
|
||||||
r = client.post('/accounts/logout/')
|
r = client.post("/accounts/logout/")
|
||||||
self.assertCASLogoutNotInMessages(r)
|
self.assertCASLogoutNotInMessages(r)
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# -*- coding: utf-8 -*-
|
from urllib.parse import urlencode
|
||||||
from six.moves.urllib.parse import urlencode
|
|
||||||
|
|
||||||
|
from allauth.socialaccount.providers import registry
|
||||||
from django.contrib import messages
|
from django.contrib import messages
|
||||||
from django.contrib.messages.api import get_messages
|
from django.contrib.messages.api import get_messages
|
||||||
from django.contrib.messages.middleware import MessageMiddleware
|
from django.contrib.messages.middleware import MessageMiddleware
|
||||||
|
@ -8,21 +8,18 @@ from django.contrib.messages.storage.base import Message
|
||||||
from django.contrib.sessions.middleware import SessionMiddleware
|
from django.contrib.sessions.middleware import SessionMiddleware
|
||||||
from django.test import RequestFactory, TestCase, override_settings
|
from django.test import RequestFactory, TestCase, override_settings
|
||||||
|
|
||||||
from allauth.socialaccount.providers import registry
|
|
||||||
|
|
||||||
from allauth_cas.views import AuthAction
|
from allauth_cas.views import AuthAction
|
||||||
|
|
||||||
from .example.provider import ExampleCASProvider
|
from .example.provider import ExampleCASProvider
|
||||||
|
|
||||||
|
|
||||||
class CASProviderTests(TestCase):
|
class CASProviderTests(TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.request = self._get_request()
|
self.request = self._get_request()
|
||||||
self.provider = ExampleCASProvider(self.request)
|
self.provider = ExampleCASProvider(self.request)
|
||||||
|
|
||||||
def _get_request(self):
|
def _get_request(self):
|
||||||
request = RequestFactory().get('/test/')
|
request = RequestFactory().get("/test/")
|
||||||
SessionMiddleware().process_request(request)
|
SessionMiddleware().process_request(request)
|
||||||
MessageMiddleware().process_request(request)
|
MessageMiddleware().process_request(request)
|
||||||
return request
|
return request
|
||||||
|
@ -31,74 +28,81 @@ class CASProviderTests(TestCase):
|
||||||
"""
|
"""
|
||||||
Example CAS provider is registered as social account provider.
|
Example CAS provider is registered as social account provider.
|
||||||
"""
|
"""
|
||||||
self.assertIsInstance(registry.by_id('theid'), ExampleCASProvider)
|
self.assertIsInstance(registry.by_id("theid"), ExampleCASProvider)
|
||||||
|
|
||||||
def test_get_login_url(self):
|
def test_get_login_url(self):
|
||||||
url = self.provider.get_login_url(self.request)
|
url = self.provider.get_login_url(self.request)
|
||||||
self.assertEqual('/accounts/theid/login/', url)
|
self.assertEqual("/accounts/theid/login/", url)
|
||||||
|
|
||||||
url_with_qs = self.provider.get_login_url(
|
url_with_qs = self.provider.get_login_url(
|
||||||
self.request,
|
self.request,
|
||||||
next='/path?quéry=string&two=whoam%C3%AF',
|
next="/path?quéry=string&two=whoam%C3%AF",
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
url_with_qs,
|
url_with_qs,
|
||||||
'/accounts/theid/login/?next=%2Fpath%3Fqu%C3%A9ry%3Dstring%26two%3'
|
"/accounts/theid/login/?next=%2Fpath%3Fqu%C3%A9ry%3Dstring%26two%3"
|
||||||
'Dwhoam%25C3%25AF'
|
"Dwhoam%25C3%25AF",
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_get_callback_url(self):
|
def test_get_callback_url(self):
|
||||||
url = self.provider.get_callback_url(self.request)
|
url = self.provider.get_callback_url(self.request)
|
||||||
self.assertEqual('/accounts/theid/login/callback/', url)
|
self.assertEqual("/accounts/theid/login/callback/", url)
|
||||||
|
|
||||||
url_with_qs = self.provider.get_callback_url(
|
url_with_qs = self.provider.get_callback_url(
|
||||||
self.request,
|
self.request,
|
||||||
next='/path?quéry=string&two=whoam%C3%AF',
|
next="/path?quéry=string&two=whoam%C3%AF",
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
url_with_qs,
|
url_with_qs,
|
||||||
'/accounts/theid/login/callback/?next=%2Fpath%3Fqu%C3%A9ry%3Dstrin'
|
"/accounts/theid/login/callback/?next=%2Fpath%3Fqu%C3%A9ry%3Dstrin"
|
||||||
'g%26two%3Dwhoam%25C3%25AF'
|
"g%26two%3Dwhoam%25C3%25AF",
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_get_logout_url(self):
|
def test_get_logout_url(self):
|
||||||
url = self.provider.get_logout_url(self.request)
|
url = self.provider.get_logout_url(self.request)
|
||||||
self.assertEqual('/accounts/theid/logout/', url)
|
self.assertEqual("/accounts/theid/logout/", url)
|
||||||
|
|
||||||
url_with_qs = self.provider.get_logout_url(
|
url_with_qs = self.provider.get_logout_url(
|
||||||
self.request,
|
self.request,
|
||||||
next='/path?quéry=string&two=whoam%C3%AF',
|
next="/path?quéry=string&two=whoam%C3%AF",
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
url_with_qs,
|
url_with_qs,
|
||||||
'/accounts/theid/logout/?next=%2Fpath%3Fqu%C3%A9ry%3Dstring%26two%'
|
"/accounts/theid/logout/?next=%2Fpath%3Fqu%C3%A9ry%3Dstring%26two%"
|
||||||
'3Dwhoam%25C3%25AF'
|
"3Dwhoam%25C3%25AF",
|
||||||
)
|
)
|
||||||
|
|
||||||
@override_settings(SOCIALACCOUNT_PROVIDERS={
|
@override_settings(
|
||||||
'theid': {
|
SOCIALACCOUNT_PROVIDERS={
|
||||||
'AUTH_PARAMS': {'key': 'value'},
|
"theid": {
|
||||||
},
|
"AUTH_PARAMS": {"key": "value"},
|
||||||
})
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
def test_get_auth_params(self):
|
def test_get_auth_params(self):
|
||||||
action = AuthAction.AUTHENTICATE
|
action = AuthAction.AUTHENTICATE
|
||||||
|
|
||||||
auth_params = self.provider.get_auth_params(self.request, action)
|
auth_params = self.provider.get_auth_params(self.request, action)
|
||||||
|
|
||||||
self.assertDictEqual(auth_params, {
|
self.assertDictEqual(
|
||||||
'key': 'value',
|
auth_params,
|
||||||
})
|
{
|
||||||
|
"key": "value",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
@override_settings(SOCIALACCOUNT_PROVIDERS={
|
@override_settings(
|
||||||
'theid': {
|
SOCIALACCOUNT_PROVIDERS={
|
||||||
'AUTH_PARAMS': {'key': 'value'},
|
"theid": {
|
||||||
},
|
"AUTH_PARAMS": {"key": "value"},
|
||||||
})
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
def test_get_auth_params_with_dynamic(self):
|
def test_get_auth_params_with_dynamic(self):
|
||||||
factory = RequestFactory()
|
factory = RequestFactory()
|
||||||
request = factory.get(
|
request = factory.get(
|
||||||
'/test/?auth_params=next%3Dtwo%253Dwhoam%2525C3%2525AF%2526qu%2525'
|
"/test/?auth_params=next%3Dtwo%253Dwhoam%2525C3%2525AF%2526qu%2525"
|
||||||
'C3%2525A9ry%253Dstring'
|
"C3%2525A9ry%253Dstring"
|
||||||
)
|
)
|
||||||
request.session = {}
|
request.session = {}
|
||||||
|
|
||||||
|
@ -106,15 +110,18 @@ class CASProviderTests(TestCase):
|
||||||
|
|
||||||
auth_params = self.provider.get_auth_params(request, action)
|
auth_params = self.provider.get_auth_params(request, action)
|
||||||
|
|
||||||
self.assertDictEqual(auth_params, {
|
self.assertDictEqual(
|
||||||
'key': 'value',
|
auth_params,
|
||||||
'next': 'two=whoam%C3%AF&qu%C3%A9ry=string',
|
{
|
||||||
})
|
"key": "value",
|
||||||
|
"next": "two=whoam%C3%AF&qu%C3%A9ry=string",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def test_add_message_suggest_caslogout(self):
|
def test_add_message_suggest_caslogout(self):
|
||||||
expected_msg_base_str = (
|
expected_msg_base_str = (
|
||||||
"To logout of The Provider, please close your browser, or visit "
|
"To logout of The Provider, please close your browser, or visit "
|
||||||
"this <a href=\"/accounts/theid/logout/?{}\">link</a>."
|
'this <a href="/accounts/theid/logout/?{}">link</a>.'
|
||||||
)
|
)
|
||||||
|
|
||||||
# Defaults.
|
# Defaults.
|
||||||
|
@ -124,7 +131,7 @@ class CASProviderTests(TestCase):
|
||||||
|
|
||||||
expected_msg1 = Message(
|
expected_msg1 = Message(
|
||||||
messages.INFO,
|
messages.INFO,
|
||||||
expected_msg_base_str.format(urlencode({'next': '/test/'})),
|
expected_msg_base_str.format(urlencode({"next": "/test/"})),
|
||||||
)
|
)
|
||||||
self.assertIn(expected_msg1, get_messages(req1))
|
self.assertIn(expected_msg1, get_messages(req1))
|
||||||
|
|
||||||
|
@ -132,69 +139,83 @@ class CASProviderTests(TestCase):
|
||||||
req2 = self._get_request()
|
req2 = self._get_request()
|
||||||
|
|
||||||
self.provider.add_message_suggest_caslogout(
|
self.provider.add_message_suggest_caslogout(
|
||||||
req2, next_page='/redir/', level=messages.WARNING)
|
req2, next_page="/redir/", level=messages.WARNING
|
||||||
|
)
|
||||||
|
|
||||||
expected_msg2 = Message(
|
expected_msg2 = Message(
|
||||||
messages.WARNING,
|
messages.WARNING,
|
||||||
expected_msg_base_str.format(urlencode({'next': '/redir/'})),
|
expected_msg_base_str.format(urlencode({"next": "/redir/"})),
|
||||||
)
|
)
|
||||||
self.assertIn(expected_msg2, get_messages(req2))
|
self.assertIn(expected_msg2, get_messages(req2))
|
||||||
|
|
||||||
def test_message_suggest_caslogout_on_logout(self):
|
def test_message_suggest_caslogout_on_logout(self):
|
||||||
self.assertFalse(
|
self.assertFalse(
|
||||||
self.provider.message_suggest_caslogout_on_logout(self.request))
|
self.provider.message_suggest_caslogout_on_logout(self.request)
|
||||||
|
)
|
||||||
|
|
||||||
with override_settings(SOCIALACCOUNT_PROVIDERS={
|
with override_settings(
|
||||||
'theid': {'MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT': True},
|
SOCIALACCOUNT_PROVIDERS={
|
||||||
}):
|
"theid": {"MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT": True},
|
||||||
|
}
|
||||||
|
):
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
self.provider
|
self.provider.message_suggest_caslogout_on_logout(self.request)
|
||||||
.message_suggest_caslogout_on_logout(self.request)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@override_settings(SOCIALACCOUNT_PROVIDERS={
|
@override_settings(
|
||||||
'theid': {
|
SOCIALACCOUNT_PROVIDERS={
|
||||||
'MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT_LEVEL': messages.WARNING,
|
"theid": {
|
||||||
},
|
"MESSAGE_SUGGEST_CASLOGOUT_ON_LOGOUT_LEVEL": messages.WARNING,
|
||||||
})
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
def test_message_suggest_caslogout_on_logout_level(self):
|
def test_message_suggest_caslogout_on_logout_level(self):
|
||||||
self.assertEqual(messages.WARNING, (
|
self.assertEqual(
|
||||||
self.provider
|
messages.WARNING,
|
||||||
.message_suggest_caslogout_on_logout_level(self.request)
|
(self.provider.message_suggest_caslogout_on_logout_level(self.request)),
|
||||||
))
|
)
|
||||||
|
|
||||||
def test_extract_uid(self):
|
def test_extract_uid(self):
|
||||||
response = 'useRName', {}
|
response = "useRName", {}
|
||||||
uid = self.provider.extract_uid(response)
|
uid = self.provider.extract_uid(response)
|
||||||
self.assertEqual('useRName', uid)
|
self.assertEqual("useRName", uid)
|
||||||
|
|
||||||
def test_extract_common_fields(self):
|
def test_extract_common_fields(self):
|
||||||
response = 'useRName', {}
|
response = "useRName", {}
|
||||||
common_fields = self.provider.extract_common_fields(response)
|
common_fields = self.provider.extract_common_fields(response)
|
||||||
self.assertDictEqual(common_fields, {
|
self.assertDictEqual(
|
||||||
'username': 'useRName',
|
common_fields,
|
||||||
'first_name': None,
|
{
|
||||||
'last_name': None,
|
"username": "useRName",
|
||||||
'name': None,
|
"first_name": None,
|
||||||
'email': None,
|
"last_name": None,
|
||||||
})
|
"name": None,
|
||||||
|
"email": None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def test_extract_common_fields_with_extra(self):
|
def test_extract_common_fields_with_extra(self):
|
||||||
response = 'useRName', {'username': 'user', 'email': 'user@mail.net'}
|
response = "useRName", {"username": "user", "email": "user@mail.net"}
|
||||||
common_fields = self.provider.extract_common_fields(response)
|
common_fields = self.provider.extract_common_fields(response)
|
||||||
self.assertDictEqual(common_fields, {
|
self.assertDictEqual(
|
||||||
'username': 'user',
|
common_fields,
|
||||||
'first_name': None,
|
{
|
||||||
'last_name': None,
|
"username": "user",
|
||||||
'name': None,
|
"first_name": None,
|
||||||
'email': 'user@mail.net',
|
"last_name": None,
|
||||||
})
|
"name": None,
|
||||||
|
"email": "user@mail.net",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def test_extract_extra_data(self):
|
def test_extract_extra_data(self):
|
||||||
response = 'useRName', {'user_attr': 'thevalue', 'another': 'value'}
|
response = "useRName", {"user_attr": "thevalue", "another": "value"}
|
||||||
extra_data = self.provider.extract_extra_data(response)
|
extra_data = self.provider.extract_extra_data(response)
|
||||||
self.assertDictEqual(extra_data, {
|
self.assertDictEqual(
|
||||||
'user_attr': 'thevalue',
|
extra_data,
|
||||||
'another': 'value',
|
{
|
||||||
'uid': 'useRName',
|
"user_attr": "thevalue",
|
||||||
})
|
"another": "value",
|
||||||
|
"uid": "useRName",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
from django.test import Client, RequestFactory
|
from django.test import Client, RequestFactory
|
||||||
|
|
||||||
from allauth_cas.test.testcases import CASViewTestCase
|
from allauth_cas.test.testcases import CASViewTestCase
|
||||||
|
@ -8,9 +7,8 @@ from .example.views import ExampleCASAdapter
|
||||||
|
|
||||||
|
|
||||||
class CASTestCaseTests(CASViewTestCase):
|
class CASTestCaseTests(CASViewTestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.client.get('/accounts/theid/login/')
|
self.client.get("/accounts/theid/login/")
|
||||||
|
|
||||||
def test_patch_cas_response_client_version(self):
|
def test_patch_cas_response_client_version(self):
|
||||||
"""
|
"""
|
||||||
|
@ -21,20 +19,24 @@ class CASTestCaseTests(CASViewTestCase):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
valid_versions = [
|
valid_versions = [
|
||||||
1, '1',
|
1,
|
||||||
2, '2',
|
"1",
|
||||||
3, '3',
|
2,
|
||||||
'CAS_2_SAML_1_0',
|
"2",
|
||||||
|
3,
|
||||||
|
"3",
|
||||||
|
"CAS_2_SAML_1_0",
|
||||||
]
|
]
|
||||||
invalid_versions = [
|
invalid_versions = [
|
||||||
'not_supported',
|
"not_supported",
|
||||||
]
|
]
|
||||||
|
|
||||||
factory = RequestFactory()
|
factory = RequestFactory()
|
||||||
request = factory.get('/path/')
|
request = factory.get("/path/")
|
||||||
request.session = {}
|
request.session = {}
|
||||||
|
|
||||||
for _version in valid_versions + invalid_versions:
|
for _version in valid_versions + invalid_versions:
|
||||||
|
|
||||||
class BasicCASAdapter(ExampleCASAdapter):
|
class BasicCASAdapter(ExampleCASAdapter):
|
||||||
version = _version
|
version = _version
|
||||||
|
|
||||||
|
@ -47,7 +49,7 @@ class CASTestCaseTests(CASViewTestCase):
|
||||||
if _version in valid_versions:
|
if _version in valid_versions:
|
||||||
raw_client = view(request)
|
raw_client = view(request)
|
||||||
|
|
||||||
self.patch_cas_response(valid_ticket='__all__')
|
self.patch_cas_response(valid_ticket="__all__")
|
||||||
mocked_client = view(request)
|
mocked_client = view(request)
|
||||||
|
|
||||||
self.assertEqual(type(raw_client), type(mocked_client))
|
self.assertEqual(type(raw_client), type(mocked_client))
|
||||||
|
@ -55,58 +57,79 @@ class CASTestCaseTests(CASViewTestCase):
|
||||||
# This is a sanity check.
|
# This is a sanity check.
|
||||||
self.assertRaises(ValueError, view, request)
|
self.assertRaises(ValueError, view, request)
|
||||||
|
|
||||||
self.patch_cas_response(valid_ticket='__all__')
|
self.patch_cas_response(valid_ticket="__all__")
|
||||||
self.assertRaises(ValueError, view, request)
|
self.assertRaises(ValueError, view, request)
|
||||||
|
|
||||||
def test_patch_cas_response_verify_success(self):
|
def test_patch_cas_response_verify_success(self):
|
||||||
self.patch_cas_response(valid_ticket='123456')
|
self.patch_cas_response(valid_ticket="123456")
|
||||||
r = self.client.get('/accounts/theid/login/callback/', {
|
r = self.client.get(
|
||||||
'ticket': '123456',
|
"/accounts/theid/login/callback/",
|
||||||
})
|
{
|
||||||
|
"ticket": "123456",
|
||||||
|
},
|
||||||
|
)
|
||||||
self.assertLoginSuccess(r)
|
self.assertLoginSuccess(r)
|
||||||
|
|
||||||
def test_patch_cas_response_verify_failure(self):
|
def test_patch_cas_response_verify_failure(self):
|
||||||
self.patch_cas_response(valid_ticket='123456')
|
self.patch_cas_response(valid_ticket="123456")
|
||||||
r = self.client.get('/accounts/theid/login/callback/', {
|
r = self.client.get(
|
||||||
'ticket': '000000',
|
"/accounts/theid/login/callback/",
|
||||||
})
|
{
|
||||||
|
"ticket": "000000",
|
||||||
|
},
|
||||||
|
)
|
||||||
self.assertLoginFailure(r)
|
self.assertLoginFailure(r)
|
||||||
|
|
||||||
def test_patch_cas_response_accept(self):
|
def test_patch_cas_response_accept(self):
|
||||||
self.patch_cas_response(valid_ticket='__all__')
|
self.patch_cas_response(valid_ticket="__all__")
|
||||||
r = self.client.get('/accounts/theid/login/callback/', {
|
r = self.client.get(
|
||||||
'ticket': '000000',
|
"/accounts/theid/login/callback/",
|
||||||
})
|
{
|
||||||
|
"ticket": "000000",
|
||||||
|
},
|
||||||
|
)
|
||||||
self.assertLoginSuccess(r)
|
self.assertLoginSuccess(r)
|
||||||
|
|
||||||
def test_patch_cas_response_reject(self):
|
def test_patch_cas_response_reject(self):
|
||||||
self.patch_cas_response(valid_ticket=None)
|
self.patch_cas_response(valid_ticket=None)
|
||||||
r = self.client.get('/accounts/theid/login/callback/', {
|
r = self.client.get(
|
||||||
'ticket': '000000',
|
"/accounts/theid/login/callback/",
|
||||||
})
|
{
|
||||||
|
"ticket": "000000",
|
||||||
|
},
|
||||||
|
)
|
||||||
self.assertLoginFailure(r)
|
self.assertLoginFailure(r)
|
||||||
|
|
||||||
def test_patch_cas_reponse_multiple(self):
|
def test_patch_cas_reponse_multiple(self):
|
||||||
self.patch_cas_response(valid_ticket='__all__')
|
self.patch_cas_response(valid_ticket="__all__")
|
||||||
client_0 = Client()
|
client_0 = Client()
|
||||||
client_0.get('/accounts/theid/login/')
|
client_0.get("/accounts/theid/login/")
|
||||||
r_0 = client_0.get('/accounts/theid/login/callback/', {
|
r_0 = client_0.get(
|
||||||
'ticket': '000000',
|
"/accounts/theid/login/callback/",
|
||||||
})
|
{
|
||||||
|
"ticket": "000000",
|
||||||
|
},
|
||||||
|
)
|
||||||
self.assertLoginSuccess(r_0)
|
self.assertLoginSuccess(r_0)
|
||||||
|
|
||||||
self.patch_cas_response(valid_ticket=None)
|
self.patch_cas_response(valid_ticket=None)
|
||||||
client_1 = Client()
|
client_1 = Client()
|
||||||
client_1.get('/accounts/theid/login/')
|
client_1.get("/accounts/theid/login/")
|
||||||
r_1 = client_1.get('/accounts/theid/login/callback/', {
|
r_1 = client_1.get(
|
||||||
'ticket': '111111',
|
"/accounts/theid/login/callback/",
|
||||||
})
|
{
|
||||||
|
"ticket": "111111",
|
||||||
|
},
|
||||||
|
)
|
||||||
self.assertLoginFailure(r_1)
|
self.assertLoginFailure(r_1)
|
||||||
|
|
||||||
def test_assertLoginSuccess(self):
|
def test_assertLoginSuccess(self):
|
||||||
self.patch_cas_response(valid_ticket='__all__')
|
self.patch_cas_response(valid_ticket="__all__")
|
||||||
r = self.client.get('/accounts/theid/login/callback/', {
|
r = self.client.get(
|
||||||
'ticket': '000000',
|
"/accounts/theid/login/callback/",
|
||||||
'next': '/path/',
|
{
|
||||||
})
|
"ticket": "000000",
|
||||||
self.assertLoginSuccess(r, redirect_to='/path/')
|
"next": "/path/",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.assertLoginSuccess(r, redirect_to="/path/")
|
||||||
|
|
|
@ -1,11 +1,10 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
try:
|
try:
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import django
|
|
||||||
from django.test import RequestFactory, override_settings
|
from django.test import RequestFactory, override_settings
|
||||||
|
from django.urls import reverse
|
||||||
|
|
||||||
from allauth_cas.exceptions import CASAuthenticationError
|
from allauth_cas.exceptions import CASAuthenticationError
|
||||||
from allauth_cas.test.testcases import CASTestCase, CASViewTestCase
|
from allauth_cas.test.testcases import CASTestCase, CASViewTestCase
|
||||||
|
@ -13,17 +12,11 @@ from allauth_cas.views import CASView
|
||||||
|
|
||||||
from .example.views import ExampleCASAdapter
|
from .example.views import ExampleCASAdapter
|
||||||
|
|
||||||
if django.VERSION >= (1, 10):
|
|
||||||
from django.urls import reverse
|
|
||||||
else:
|
|
||||||
from django.core.urlresolvers import reverse
|
|
||||||
|
|
||||||
|
|
||||||
class CASAdapterTests(CASTestCase):
|
class CASAdapterTests(CASTestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
factory = RequestFactory()
|
factory = RequestFactory()
|
||||||
self.request = factory.get('/path/')
|
self.request = factory.get("/path/")
|
||||||
self.request.session = {}
|
self.request.session = {}
|
||||||
self.adapter = ExampleCASAdapter(self.request)
|
self.adapter = ExampleCASAdapter(self.request)
|
||||||
|
|
||||||
|
@ -31,7 +24,7 @@ class CASAdapterTests(CASTestCase):
|
||||||
"""
|
"""
|
||||||
Service url (used by CAS client) is the callback url.
|
Service url (used by CAS client) is the callback url.
|
||||||
"""
|
"""
|
||||||
expected = 'http://testserver/accounts/theid/login/callback/'
|
expected = "http://testserver/accounts/theid/login/callback/"
|
||||||
service_url = self.adapter.get_service_url(self.request)
|
service_url = self.adapter.get_service_url(self.request)
|
||||||
self.assertEqual(expected, service_url)
|
self.assertEqual(expected, service_url)
|
||||||
|
|
||||||
|
@ -39,11 +32,9 @@ class CASAdapterTests(CASTestCase):
|
||||||
"""
|
"""
|
||||||
Current GET paramater next is appended on service url.
|
Current GET paramater next is appended on service url.
|
||||||
"""
|
"""
|
||||||
expected = (
|
expected = "http://testserver/accounts/theid/login/callback/?next=%2Fnext%2F"
|
||||||
'http://testserver/accounts/theid/login/callback/?next=%2Fnext%2F'
|
|
||||||
)
|
|
||||||
factory = RequestFactory()
|
factory = RequestFactory()
|
||||||
request = factory.get('/path/', {'next': '/next/'})
|
request = factory.get("/path/", {"next": "/next/"})
|
||||||
adapter = ExampleCASAdapter(request)
|
adapter = ExampleCASAdapter(request)
|
||||||
service_url = adapter.get_service_url(request)
|
service_url = adapter.get_service_url(request)
|
||||||
self.assertEqual(expected, service_url)
|
self.assertEqual(expected, service_url)
|
||||||
|
@ -67,14 +58,13 @@ class CASAdapterTests(CASTestCase):
|
||||||
|
|
||||||
|
|
||||||
class CASViewTests(CASViewTestCase):
|
class CASViewTests(CASViewTestCase):
|
||||||
|
|
||||||
class BasicCASView(CASView):
|
class BasicCASView(CASView):
|
||||||
def dispatch(self, request, *args, **kwargs):
|
def dispatch(self, request, *args, **kwargs):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
factory = RequestFactory()
|
factory = RequestFactory()
|
||||||
self.request = factory.get('/path/')
|
self.request = factory.get("/path/")
|
||||||
self.request.session = {}
|
self.request.session = {}
|
||||||
|
|
||||||
self.cas_view = self.BasicCASView.adapter_view(ExampleCASAdapter)
|
self.cas_view = self.BasicCASView.adapter_view(ExampleCASAdapter)
|
||||||
|
@ -85,27 +75,34 @@ class CASViewTests(CASViewTestCase):
|
||||||
"""
|
"""
|
||||||
view = self.cas_view(
|
view = self.cas_view(
|
||||||
self.request,
|
self.request,
|
||||||
'arg1', 'arg2',
|
"arg1",
|
||||||
kwarg1='kwarg1', kwarg2='kwarg2',
|
"arg2",
|
||||||
|
kwarg1="kwarg1",
|
||||||
|
kwarg2="kwarg2",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertIsInstance(view, CASView)
|
self.assertIsInstance(view, CASView)
|
||||||
|
|
||||||
self.assertEqual(view.request, self.request)
|
self.assertEqual(view.request, self.request)
|
||||||
self.assertTupleEqual(view.args, ('arg1', 'arg2'))
|
self.assertTupleEqual(view.args, ("arg1", "arg2"))
|
||||||
self.assertDictEqual(view.kwargs, {
|
self.assertDictEqual(
|
||||||
'kwarg1': 'kwarg1',
|
view.kwargs,
|
||||||
'kwarg2': 'kwarg2',
|
{
|
||||||
})
|
"kwarg1": "kwarg1",
|
||||||
|
"kwarg2": "kwarg2",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
self.assertIsInstance(view.adapter, ExampleCASAdapter)
|
self.assertIsInstance(view.adapter, ExampleCASAdapter)
|
||||||
|
|
||||||
@patch('allauth_cas.views.cas.CASClient')
|
@patch("allauth_cas.views.cas.CASClient")
|
||||||
@override_settings(SOCIALACCOUNT_PROVIDERS={
|
@override_settings(
|
||||||
'theid': {
|
SOCIALACCOUNT_PROVIDERS={
|
||||||
'AUTH_PARAMS': {'key': 'value'},
|
"theid": {
|
||||||
},
|
"AUTH_PARAMS": {"key": "value"},
|
||||||
})
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
def test_get_client(self, mock_casclient_class):
|
def test_get_client(self, mock_casclient_class):
|
||||||
"""
|
"""
|
||||||
get_client returns a CAS client, configured from settings.
|
get_client returns a CAS client, configured from settings.
|
||||||
|
@ -114,11 +111,11 @@ class CASViewTests(CASViewTestCase):
|
||||||
view.get_client(self.request)
|
view.get_client(self.request)
|
||||||
|
|
||||||
mock_casclient_class.assert_called_once_with(
|
mock_casclient_class.assert_called_once_with(
|
||||||
service_url='http://testserver/accounts/theid/login/callback/',
|
service_url="http://testserver/accounts/theid/login/callback/",
|
||||||
server_url='https://server.cas',
|
server_url="https://server.cas",
|
||||||
version=2,
|
version=2,
|
||||||
renew=False,
|
renew=False,
|
||||||
extra_login_params={'key': 'value'},
|
extra_login_params={"key": "value"},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_render_error_on_failure(self):
|
def test_render_error_on_failure(self):
|
||||||
|
@ -126,33 +123,33 @@ class CASViewTests(CASViewTestCase):
|
||||||
A common login failure page is rendered if CASAuthenticationError is
|
A common login failure page is rendered if CASAuthenticationError is
|
||||||
raised by dispatch.
|
raised by dispatch.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def dispatch_raise(self, request):
|
def dispatch_raise(self, request):
|
||||||
raise CASAuthenticationError("failure")
|
raise CASAuthenticationError("failure")
|
||||||
|
|
||||||
with patch.object(self.BasicCASView, 'dispatch', dispatch_raise):
|
with patch.object(self.BasicCASView, "dispatch", dispatch_raise):
|
||||||
resp = self.cas_view(self.request)
|
resp = self.cas_view(self.request)
|
||||||
self.assertLoginFailure(resp)
|
self.assertLoginFailure(resp)
|
||||||
|
|
||||||
|
|
||||||
class CASLoginViewTests(CASViewTestCase):
|
class CASLoginViewTests(CASViewTestCase):
|
||||||
|
|
||||||
def test_reverse(self):
|
def test_reverse(self):
|
||||||
"""
|
"""
|
||||||
Login view name is "{provider_id}_login".
|
Login view name is "{provider_id}_login".
|
||||||
"""
|
"""
|
||||||
url = reverse('theid_login')
|
url = reverse("theid_login")
|
||||||
self.assertEqual('/accounts/theid/login/', url)
|
self.assertEqual("/accounts/theid/login/", url)
|
||||||
|
|
||||||
def test_execute(self):
|
def test_execute(self):
|
||||||
"""
|
"""
|
||||||
Login view redirects to the CAS server login url.
|
Login view redirects to the CAS server login url.
|
||||||
Service is the callback url, as absolute uri.
|
Service is the callback url, as absolute uri.
|
||||||
"""
|
"""
|
||||||
r = self.client.get('/accounts/theid/login/')
|
r = self.client.get("/accounts/theid/login/")
|
||||||
|
|
||||||
expected = (
|
expected = (
|
||||||
'https://server.cas/login?service=http%3A%2F%2Ftestserver%2F'
|
"https://server.cas/login?service=http%3A%2F%2Ftestserver%2F"
|
||||||
'accounts%2Ftheid%2Flogin%2Fcallback%2F'
|
"accounts%2Ftheid%2Flogin%2Fcallback%2F"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertRedirects(r, expected, fetch_redirect_response=False)
|
self.assertRedirects(r, expected, fetch_redirect_response=False)
|
||||||
|
@ -161,54 +158,59 @@ class CASLoginViewTests(CASViewTestCase):
|
||||||
"""
|
"""
|
||||||
Current GET parameter 'next' is kept on service url.
|
Current GET parameter 'next' is kept on service url.
|
||||||
"""
|
"""
|
||||||
r = self.client.get('/accounts/theid/login/?next=/path/')
|
r = self.client.get("/accounts/theid/login/?next=/path/")
|
||||||
|
|
||||||
expected = (
|
expected = (
|
||||||
'https://server.cas/login?service=http%3A%2F%2Ftestserver%2F'
|
"https://server.cas/login?service=http%3A%2F%2Ftestserver%2F"
|
||||||
'accounts%2Ftheid%2Flogin%2Fcallback%2F%3Fnext%3D%252Fpath%252F'
|
"accounts%2Ftheid%2Flogin%2Fcallback%2F%3Fnext%3D%252Fpath%252F"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertRedirects(r, expected, fetch_redirect_response=False)
|
self.assertRedirects(r, expected, fetch_redirect_response=False)
|
||||||
|
|
||||||
|
|
||||||
class CASCallbackViewTests(CASViewTestCase):
|
class CASCallbackViewTests(CASViewTestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.client.get('/accounts/theid/login/')
|
self.client.get("/accounts/theid/login/")
|
||||||
|
|
||||||
def test_reverse(self):
|
def test_reverse(self):
|
||||||
"""
|
"""
|
||||||
Callback view name is "{provider_id}_callback".
|
Callback view name is "{provider_id}_callback".
|
||||||
"""
|
"""
|
||||||
url = reverse('theid_callback')
|
url = reverse("theid_callback")
|
||||||
self.assertEqual('/accounts/theid/login/callback/', url)
|
self.assertEqual("/accounts/theid/login/callback/", url)
|
||||||
|
|
||||||
def test_ticket_valid(self):
|
def test_ticket_valid(self):
|
||||||
"""
|
"""
|
||||||
If ticket is valid, the user is logged in.
|
If ticket is valid, the user is logged in.
|
||||||
"""
|
"""
|
||||||
self.patch_cas_response(username='username', valid_ticket='123456')
|
self.patch_cas_response(username="username", valid_ticket="123456")
|
||||||
r = self.client.get('/accounts/theid/login/callback/', {
|
r = self.client.get(
|
||||||
'ticket': '123456',
|
"/accounts/theid/login/callback/",
|
||||||
})
|
{
|
||||||
|
"ticket": "123456",
|
||||||
|
},
|
||||||
|
)
|
||||||
self.assertLoginSuccess(r)
|
self.assertLoginSuccess(r)
|
||||||
|
|
||||||
def test_ticket_invalid(self):
|
def test_ticket_invalid(self):
|
||||||
"""
|
"""
|
||||||
Login failure page is returned if the ticket is invalid.
|
Login failure page is returned if the ticket is invalid.
|
||||||
"""
|
"""
|
||||||
self.patch_cas_response(username='username', valid_ticket='123456')
|
self.patch_cas_response(username="username", valid_ticket="123456")
|
||||||
r = self.client.get('/accounts/theid/login/callback/', {
|
r = self.client.get(
|
||||||
'ticket': '000000',
|
"/accounts/theid/login/callback/",
|
||||||
})
|
{
|
||||||
|
"ticket": "000000",
|
||||||
|
},
|
||||||
|
)
|
||||||
self.assertLoginFailure(r)
|
self.assertLoginFailure(r)
|
||||||
|
|
||||||
def test_ticket_missing(self):
|
def test_ticket_missing(self):
|
||||||
"""
|
"""
|
||||||
Login failure page is returned if request lacks a ticket.
|
Login failure page is returned if request lacks a ticket.
|
||||||
"""
|
"""
|
||||||
self.patch_cas_response(username='username', valid_ticket='123456')
|
self.patch_cas_response(username="username", valid_ticket="123456")
|
||||||
r = self.client.get('/accounts/theid/login/callback/')
|
r = self.client.get("/accounts/theid/login/callback/")
|
||||||
self.assertLoginFailure(r)
|
self.assertLoginFailure(r)
|
||||||
|
|
||||||
def test_attributes_is_none(self):
|
def test_attributes_is_none(self):
|
||||||
|
@ -216,31 +218,33 @@ class CASCallbackViewTests(CASViewTestCase):
|
||||||
Without extra attributes, CASClientV2 of python-cas returns None.
|
Without extra attributes, CASClientV2 of python-cas returns None.
|
||||||
"""
|
"""
|
||||||
self.patch_cas_response(
|
self.patch_cas_response(
|
||||||
username='username', valid_ticket='123456', attributes=None
|
username="username", valid_ticket="123456", attributes=None
|
||||||
|
)
|
||||||
|
r = self.client.get(
|
||||||
|
"/accounts/theid/login/callback/",
|
||||||
|
{
|
||||||
|
"ticket": "123456",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
r = self.client.get('/accounts/theid/login/callback/', {
|
|
||||||
'ticket': '123456',
|
|
||||||
})
|
|
||||||
self.assertLoginSuccess(r)
|
self.assertLoginSuccess(r)
|
||||||
|
|
||||||
|
|
||||||
class CASLogoutViewTests(CASViewTestCase):
|
class CASLogoutViewTests(CASViewTestCase):
|
||||||
|
|
||||||
def test_reverse(self):
|
def test_reverse(self):
|
||||||
"""
|
"""
|
||||||
Callback view name is "{provider_id}_logout".
|
Callback view name is "{provider_id}_logout".
|
||||||
"""
|
"""
|
||||||
url = reverse('theid_logout')
|
url = reverse("theid_logout")
|
||||||
self.assertEqual('/accounts/theid/logout/', url)
|
self.assertEqual("/accounts/theid/logout/", url)
|
||||||
|
|
||||||
def test_execute(self):
|
def test_execute(self):
|
||||||
"""
|
"""
|
||||||
Logout view redirects to the CAS server logout url.
|
Logout view redirects to the CAS server logout url.
|
||||||
Service is a url to here, as absolute uri.
|
Service is a url to here, as absolute uri.
|
||||||
"""
|
"""
|
||||||
r = self.client.get('/accounts/theid/logout/')
|
r = self.client.get("/accounts/theid/logout/")
|
||||||
|
|
||||||
expected = 'https://server.cas/logout?url=http%3A%2F%2Ftestserver%2F'
|
expected = "https://server.cas/logout?url=http%3A%2F%2Ftestserver%2F"
|
||||||
|
|
||||||
self.assertRedirects(r, expected, fetch_redirect_response=False)
|
self.assertRedirects(r, expected, fetch_redirect_response=False)
|
||||||
|
|
||||||
|
@ -248,10 +252,8 @@ class CASLogoutViewTests(CASViewTestCase):
|
||||||
"""
|
"""
|
||||||
GET parameter 'next' is set as service url.
|
GET parameter 'next' is set as service url.
|
||||||
"""
|
"""
|
||||||
r = self.client.get('/accounts/theid/logout/?next=/path/')
|
r = self.client.get("/accounts/theid/logout/?next=/path/")
|
||||||
|
|
||||||
expected = (
|
expected = "https://server.cas/logout?url=http%3A%2F%2Ftestserver%2Fpath%2F"
|
||||||
'https://server.cas/logout?url=http%3A%2F%2Ftestserver%2Fpath%2F'
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertRedirects(r, expected, fetch_redirect_response=False)
|
self.assertRedirects(r, expected, fetch_redirect_response=False)
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
# -*- coding: utf-8 -*-
|
from django.urls import include, path
|
||||||
from django.conf.urls import include, url
|
|
||||||
|
|
||||||
urlpatterns = [
|
urlpatterns = [
|
||||||
url(r'^accounts/', include('allauth.urls')),
|
path("accounts/", include("allauth.urls")),
|
||||||
]
|
]
|
||||||
|
|
Loading…
Reference in a new issue