initial
This commit is contained in:
parent
f6a7dbc385
commit
76fd5ca344
25 changed files with 1100 additions and 5 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -1,4 +1,5 @@
|
|||
.coverage
|
||||
.coverage.*
|
||||
.tox/
|
||||
build/
|
||||
dist/
|
||||
|
|
|
@ -4,5 +4,10 @@ django-allauth-cas
|
|||
|
||||
CAS support for django-allauth_.
|
||||
|
||||
Supports:
|
||||
|
||||
- Django 1.8-10 - Python 2.7, 3.4-5
|
||||
- Django 1.11 - Python 2.7, 3.4-6
|
||||
|
||||
|
||||
.. _django-allauth: https://www.intenct.nl/projects/django-allauth/
|
||||
|
|
7
allauth_cas/__init__.py
Normal file
7
allauth_cas/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
__version__ = '0.0.1.dev0'
|
||||
|
||||
default_app_config = 'allauth_cas.apps.CASAccountConfig'
|
||||
|
||||
CAS_PROVIDER_SESSION_KEY = 'allauth_cas__provider_id'
|
11
allauth_cas/apps.py
Normal file
11
allauth_cas/apps.py
Normal file
|
@ -0,0 +1,11 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
from django.apps import AppConfig
|
||||
from django.utils.translation import ugettext_lazy as _
|
||||
|
||||
|
||||
class CASAccountConfig(AppConfig):
|
||||
name = 'allauth_cas'
|
||||
verbose_name = _("CAS Accounts")
|
||||
|
||||
def ready(self):
|
||||
from . import signals # noqa
|
7
allauth_cas/exceptions.py
Normal file
7
allauth_cas/exceptions.py
Normal file
|
@ -0,0 +1,7 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
|
||||
class CASAuthenticationError(Exception):
|
||||
"""
|
||||
Base exception to signal CAS authentication failure.
|
||||
"""
|
55
allauth_cas/providers.py
Normal file
55
allauth_cas/providers.py
Normal file
|
@ -0,0 +1,55 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
from six.moves.urllib.parse import parse_qsl
|
||||
|
||||
import django
|
||||
from django.contrib import messages
|
||||
from django.utils.http import urlencode
|
||||
|
||||
from allauth.socialaccount.providers.base import Provider
|
||||
|
||||
if django.VERSION >= (1, 10):
|
||||
from django.urls import reverse
|
||||
else:
|
||||
from django.core.urlresolvers import reverse
|
||||
|
||||
|
||||
class CASProvider(Provider):
|
||||
|
||||
def get_login_url(self, request, **kwargs):
|
||||
url = reverse(self.id + '_login')
|
||||
if kwargs:
|
||||
url += '?' + urlencode(kwargs)
|
||||
return url
|
||||
|
||||
def get_logout_url(self, request, **kwargs):
|
||||
url = reverse(self.id + '_logout')
|
||||
if kwargs:
|
||||
url += '?' + urlencode(kwargs)
|
||||
return url
|
||||
|
||||
def get_auth_params(self, request, action):
|
||||
settings = self.get_settings()
|
||||
ret = dict(settings.get('AUTH_PARAMS', {}))
|
||||
dynamic_auth_params = request.GET.get('auth_params')
|
||||
if dynamic_auth_params:
|
||||
ret.update(dict(parse_qsl(dynamic_auth_params)))
|
||||
return ret
|
||||
|
||||
def message_on_logout(self, request):
|
||||
return self.get_settings().get('MESSAGE_ON_LOGOUT', True)
|
||||
|
||||
def message_on_logout_level(self, request):
|
||||
return self.get_settings().get('MESSAGE_ON_LOGOUT_LEVEL',
|
||||
messages.INFO)
|
||||
|
||||
def extract_uid(self, data):
|
||||
username, _, _ = data
|
||||
return username
|
||||
|
||||
def extract_common_fields(self, data):
|
||||
username, _, _ = data
|
||||
return {'username': username}
|
||||
|
||||
def extract_extra_data(self, data):
|
||||
_, extra_data, _ = data
|
||||
return extra_data
|
45
allauth_cas/signals.py
Normal file
45
allauth_cas/signals.py
Normal file
|
@ -0,0 +1,45 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
from django.contrib.auth.signals import user_logged_out
|
||||
from django.dispatch import receiver
|
||||
from django.utils.safestring import mark_safe
|
||||
|
||||
from allauth.account.adapter import get_adapter
|
||||
from allauth.account.utils import get_next_redirect_url
|
||||
from allauth.socialaccount import providers
|
||||
|
||||
from . import CAS_PROVIDER_SESSION_KEY
|
||||
|
||||
|
||||
@receiver(user_logged_out)
|
||||
def cas_account_logout(sender, request, **kwargs):
|
||||
provider_id = request.session.get(CAS_PROVIDER_SESSION_KEY)
|
||||
|
||||
if not provider_id:
|
||||
return
|
||||
|
||||
provider = providers.registry.by_id(provider_id, request)
|
||||
|
||||
if not provider.message_on_logout(request):
|
||||
return
|
||||
|
||||
adapter = get_adapter(request)
|
||||
|
||||
redirect_url = (
|
||||
get_next_redirect_url(request) or
|
||||
adapter.get_logout_redirect_url(request)
|
||||
)
|
||||
|
||||
logout_kwargs = {'next': redirect_url} if redirect_url else {}
|
||||
logout_url = provider.get_logout_url(request, **logout_kwargs)
|
||||
|
||||
level = provider.message_on_logout_level(request)
|
||||
logout_link = mark_safe('<a href="{}">link</a>'.format(logout_url))
|
||||
|
||||
adapter.add_message(
|
||||
request, level,
|
||||
message_template='cas_account/messages/logged_out.txt',
|
||||
message_context={
|
||||
'logout_url': logout_url,
|
||||
'logout_link': logout_link,
|
||||
}
|
||||
)
|
|
@ -0,0 +1,4 @@
|
|||
{% load i18n %}
|
||||
{% blocktrans %}
|
||||
To logout of CAS, please close your browser, or visit this {{ logout_link }}.
|
||||
{% endblocktrans %}
|
23
allauth_cas/urls.py
Normal file
23
allauth_cas/urls.py
Normal file
|
@ -0,0 +1,23 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
from django.conf.urls import include, url
|
||||
|
||||
from allauth.utils import import_attribute
|
||||
|
||||
|
||||
def default_urlpatterns(provider):
|
||||
package = provider.get_package()
|
||||
|
||||
login_view = import_attribute(package + '.views.login')
|
||||
callback_view = import_attribute(package + '.views.callback')
|
||||
logout_view = import_attribute(package + '.views.logout')
|
||||
|
||||
urlpatterns = [
|
||||
url('^login/$',
|
||||
login_view, name=provider.id + '_login'),
|
||||
url('^login/callback/$',
|
||||
callback_view, name=provider.id + '_callback'),
|
||||
url('^logout/$',
|
||||
logout_view, name=provider.id + '_logout'),
|
||||
]
|
||||
|
||||
return [url('^' + provider.get_slug() + '/', include(urlpatterns))]
|
242
allauth_cas/views.py
Normal file
242
allauth_cas/views.py
Normal file
|
@ -0,0 +1,242 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
import django
|
||||
from django.http import HttpResponseRedirect
|
||||
from django.utils.http import urlencode
|
||||
|
||||
from allauth.account.adapter import get_adapter
|
||||
from allauth.account.utils import get_next_redirect_url
|
||||
from allauth.socialaccount import providers
|
||||
from allauth.socialaccount.helpers import (
|
||||
complete_social_login, render_authentication_error,
|
||||
)
|
||||
|
||||
import cas
|
||||
|
||||
from . import CAS_PROVIDER_SESSION_KEY
|
||||
from .exceptions import CASAuthenticationError
|
||||
|
||||
if django.VERSION >= (1, 10):
|
||||
from django.urls import reverse
|
||||
else:
|
||||
from django.core.urlresolvers import reverse
|
||||
|
||||
|
||||
class AuthAction(object):
|
||||
AUTHENTICATE = 'authenticate'
|
||||
REAUTHENTICATE = 'reauthenticate'
|
||||
DEAUTHENTICATE = 'deauthenticate'
|
||||
|
||||
|
||||
class CASAdapter(object):
|
||||
# CAS client parameters
|
||||
renew = False
|
||||
|
||||
def __init__(self, request):
|
||||
self.request = request
|
||||
|
||||
def get_provider(self):
|
||||
"""
|
||||
Returns a provider instance for the current request.
|
||||
"""
|
||||
return providers.registry.by_id(self.provider_id, self.request)
|
||||
|
||||
def complete_login(self, request, response):
|
||||
"""
|
||||
Executed by the callback view after successful authentication on CAS
|
||||
server.
|
||||
|
||||
Returns the SocialLogin object which represents the state of the
|
||||
current login-session.
|
||||
"""
|
||||
login = (self.get_provider()
|
||||
.sociallogin_from_response(request, response))
|
||||
return login
|
||||
|
||||
def get_service_url(self, request):
|
||||
"""
|
||||
Returns the service url to for a CAS client.
|
||||
|
||||
From CAS specification, the service url is used in order to redirect
|
||||
user after a successful login on CAS server. Also, service_url sent
|
||||
when ticket is verified must be the one for which ticket was issued.
|
||||
|
||||
To conform this, the service url is always the callback url.
|
||||
|
||||
A redirect url is found from the current request and appended as
|
||||
parameter to the service url and is latter used by the callback view to
|
||||
redirect user.
|
||||
"""
|
||||
redirect_to = get_next_redirect_url(request)
|
||||
|
||||
callback_kwargs = {'next': redirect_to} if redirect_to else {}
|
||||
callback_url = self.get_callback_url(request, **callback_kwargs)
|
||||
|
||||
service_url = request.build_absolute_uri(callback_url)
|
||||
|
||||
return service_url
|
||||
|
||||
def get_callback_url(self, request, **kwargs):
|
||||
"""
|
||||
Returns the callback url of the provider.
|
||||
|
||||
Keyword arguments are set as query string.
|
||||
"""
|
||||
url = reverse(self.provider_id + '_callback')
|
||||
if kwargs:
|
||||
url += '?' + urlencode(kwargs)
|
||||
return url
|
||||
|
||||
|
||||
class CASView(object):
|
||||
|
||||
@classmethod
|
||||
def adapter_view(cls, adapter, **kwargs):
|
||||
"""
|
||||
Similar to the Django as_view() method.
|
||||
|
||||
It also setups a few things:
|
||||
- given adapter argument will be used in views internals.
|
||||
- if the view execution raises a CASAuthenticationError, the view
|
||||
renders an authentication error page.
|
||||
|
||||
To use this:
|
||||
|
||||
- subclass CAS adapter as wanted:
|
||||
|
||||
class MyAdapter(CASAdapter):
|
||||
url = 'https://my.cas.url'
|
||||
|
||||
- define views:
|
||||
|
||||
login = views.CASLoginView.adapter_view(MyAdapter)
|
||||
callback = views.CASCallbackView.adapter_view(MyAdapter)
|
||||
logout = views.CASLogoutView.adapter_view(MyAdapter)
|
||||
|
||||
"""
|
||||
def view(request, *args, **kwargs):
|
||||
# Prepare the func-view.
|
||||
self = cls()
|
||||
|
||||
self.request = request
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
# Setup and store adapter as view attribute.
|
||||
self.adapter = adapter(request)
|
||||
|
||||
try:
|
||||
return self.dispatch(request, *args, **kwargs)
|
||||
except CASAuthenticationError:
|
||||
return self.render_error()
|
||||
|
||||
return view
|
||||
|
||||
def get_client(self, request, action=AuthAction.AUTHENTICATE):
|
||||
"""
|
||||
Returns the CAS client to interact with the CAS server.
|
||||
"""
|
||||
provider = self.adapter.get_provider()
|
||||
auth_params = provider.get_auth_params(request, action)
|
||||
|
||||
service_url = self.adapter.get_service_url(request)
|
||||
|
||||
client = cas.CASClient(
|
||||
service_url=service_url,
|
||||
server_url=self.adapter.url,
|
||||
version=self.adapter.version,
|
||||
renew=self.adapter.renew,
|
||||
extra_login_params=auth_params,
|
||||
)
|
||||
|
||||
return client
|
||||
|
||||
def render_error(self):
|
||||
"""
|
||||
Returns an HTTP response in case an authentication failure happens.
|
||||
"""
|
||||
return render_authentication_error(
|
||||
self.request,
|
||||
self.adapter.provider_id,
|
||||
)
|
||||
|
||||
|
||||
class CASLoginView(CASView):
|
||||
|
||||
def dispatch(self, request):
|
||||
"""
|
||||
Redirects to the CAS server login page.
|
||||
"""
|
||||
action = request.GET.get('action', AuthAction.AUTHENTICATE)
|
||||
client = self.get_client(request, action=action)
|
||||
return HttpResponseRedirect(client.get_login_url())
|
||||
|
||||
|
||||
class CASCallbackView(CASView):
|
||||
|
||||
def dispatch(self, request):
|
||||
"""
|
||||
The CAS server redirects the user to this view after a successful
|
||||
authentication.
|
||||
|
||||
On redirect, CAS server should add a ticket whose validity is verified
|
||||
here. If ticket is valid, CAS server may also return extra attributes
|
||||
about user.
|
||||
"""
|
||||
provider = self.adapter.get_provider()
|
||||
client = self.get_client(request)
|
||||
|
||||
# CAS server should let a ticket.
|
||||
try:
|
||||
ticket = request.GET['ticket']
|
||||
except KeyError:
|
||||
raise CASAuthenticationError(
|
||||
"CAS server didn't respond with a ticket."
|
||||
)
|
||||
|
||||
# Check ticket validity.
|
||||
# Response format on:
|
||||
# - success: username, attributes, pgtiou
|
||||
# - error: None, {}, None
|
||||
response = client.verify_ticket(ticket)
|
||||
|
||||
if not response[0]:
|
||||
raise CASAuthenticationError(
|
||||
"CAS server doesn't validate the ticket."
|
||||
)
|
||||
|
||||
# The CAS provider in use is stored to propose to the user to
|
||||
# disconnect from the latter when he logouts.
|
||||
request.session[CAS_PROVIDER_SESSION_KEY] = provider.id
|
||||
|
||||
# Finish the login flow
|
||||
login = self.adapter.complete_login(request, response)
|
||||
return complete_social_login(request, login)
|
||||
|
||||
|
||||
class CASLogoutView(CASView):
|
||||
|
||||
def dispatch(self, request, next_page=None):
|
||||
"""
|
||||
Redirects to the CAS server logout page.
|
||||
|
||||
next_page is used to let the CAS server send back the user. If empty,
|
||||
the redirect url is built on request data.
|
||||
"""
|
||||
action = AuthAction.DEAUTHENTICATE
|
||||
|
||||
redirect_url = next_page or self.get_redirect_url()
|
||||
redirect_to = request.build_absolute_uri(redirect_url)
|
||||
|
||||
client = self.get_client(request, action=action)
|
||||
|
||||
return HttpResponseRedirect(client.get_logout_url(redirect_to))
|
||||
|
||||
def get_redirect_url(self):
|
||||
"""
|
||||
Returns the url to redirect after logout from current request.
|
||||
"""
|
||||
request = self.request
|
||||
return (
|
||||
get_next_redirect_url(request) or
|
||||
get_adapter(request).get_logout_redirect_url(request)
|
||||
)
|
|
@ -1,3 +1,5 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
|
|
@ -6,11 +6,13 @@ ignore = E731
|
|||
combine_as_imports = True
|
||||
default_section = THIRDPARTY
|
||||
include_trailing_comma = True
|
||||
known_allauth = allauth
|
||||
known_future_library = future,six
|
||||
known_django = django
|
||||
known_first_party = allauth_cas
|
||||
multi_line_output = 5
|
||||
not_skip = __init__.py
|
||||
sections = FUTURE,STDLIB,DJANGO,THIRDPARTY,FIRSTPARTY,LOCALFOLDER
|
||||
sections = FUTURE,STDLIB,DJANGO,ALLAUTH,THIRDPARTY,FIRSTPARTY,LOCALFOLDER
|
||||
|
||||
[bdist_wheel]
|
||||
universal = 1
|
||||
|
|
4
setup.py
4
setup.py
|
@ -1,3 +1,4 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
|
@ -15,7 +16,7 @@ setup(
|
|||
description='CAS support for django-allauth.',
|
||||
author='Aurélien Delobelle',
|
||||
author_email='aurelien.delobelle@gmail.com',
|
||||
keyword='django allauth cas authentication',
|
||||
keywords='django allauth cas authentication',
|
||||
long_description=README,
|
||||
url='https://github.com/aureplop/django-allauth-cas',
|
||||
classifiers=[
|
||||
|
@ -44,6 +45,7 @@ setup(
|
|||
install_requires=[
|
||||
'django-allauth',
|
||||
'python-cas',
|
||||
'six',
|
||||
],
|
||||
extras_require={
|
||||
'tests': ['tox'],
|
||||
|
|
37
tests/cas_clients.py
Normal file
37
tests/cas_clients.py
Normal file
|
@ -0,0 +1,37 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
import cas
|
||||
|
||||
|
||||
class MockCASClient(cas.CASClientV2):
|
||||
"""
|
||||
Base class to mock cas.CASClient
|
||||
"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs.pop('version')
|
||||
super(MockCASClient, self).__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class VerifyCASClient(MockCASClient):
|
||||
"""
|
||||
CAS client which verifies ticket is '123456'.
|
||||
"""
|
||||
def verify_ticket(self, ticket):
|
||||
if ticket == '123456':
|
||||
return 'username', {}, None
|
||||
return None, {}, None
|
||||
|
||||
|
||||
class AcceptCASClient(MockCASClient):
|
||||
"""
|
||||
CAS client which accepts all tickets.
|
||||
"""
|
||||
def verify_ticket(self, ticket):
|
||||
return 'username', {}, None
|
||||
|
||||
|
||||
class RejectCASClient(MockCASClient):
|
||||
"""
|
||||
CAS client which rejects all tickets.
|
||||
"""
|
||||
def verify_ticket(self, ticket):
|
||||
return None, {}, None
|
0
tests/example/__init__.py
Normal file
0
tests/example/__init__.py
Normal file
17
tests/example/provider.py
Normal file
17
tests/example/provider.py
Normal file
|
@ -0,0 +1,17 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
from allauth.socialaccount.providers.base import ProviderAccount
|
||||
|
||||
from allauth_cas.providers import CASProvider
|
||||
|
||||
|
||||
class ExampleCASAccount(ProviderAccount):
|
||||
pass
|
||||
|
||||
|
||||
class ExampleCASProvider(CASProvider):
|
||||
id = 'theid'
|
||||
name = 'The Provider'
|
||||
account_class = ExampleCASAccount
|
||||
|
||||
|
||||
provider_classes = [ExampleCASProvider]
|
6
tests/example/urls.py
Normal file
6
tests/example/urls.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
from allauth_cas.urls import default_urlpatterns
|
||||
|
||||
from .provider import ExampleCASProvider
|
||||
|
||||
urlpatterns = default_urlpatterns(ExampleCASProvider)
|
15
tests/example/views.py
Normal file
15
tests/example/views.py
Normal file
|
@ -0,0 +1,15 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
from allauth_cas import views
|
||||
|
||||
from .provider import ExampleCASProvider
|
||||
|
||||
|
||||
class ExampleCASAdapter(views.CASAdapter):
|
||||
provider_id = ExampleCASProvider.id
|
||||
url = 'https://server.cas'
|
||||
version = 2
|
||||
|
||||
|
||||
login = views.CASLoginView.adapter_view(ExampleCASAdapter)
|
||||
callback = views.CASCallbackView.adapter_view(ExampleCASAdapter)
|
||||
logout = views.CASLogoutView.adapter_view(ExampleCASAdapter)
|
63
tests/settings.py
Normal file
63
tests/settings.py
Normal file
|
@ -0,0 +1,63 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
import django
|
||||
|
||||
SECRET_KEY = 'iamabird'
|
||||
|
||||
INSTALLED_APPS = [
|
||||
'django.contrib.admin',
|
||||
'django.contrib.auth',
|
||||
'django.contrib.contenttypes',
|
||||
'django.contrib.messages',
|
||||
'django.contrib.sessions',
|
||||
'django.contrib.sites',
|
||||
'django.contrib.staticfiles',
|
||||
|
||||
'allauth',
|
||||
'allauth.account',
|
||||
'allauth.socialaccount',
|
||||
|
||||
'allauth_cas',
|
||||
|
||||
'tests.example', # Dummy CAS provider app
|
||||
]
|
||||
|
||||
DATABASES = {
|
||||
'default': {
|
||||
'ENGINE': 'django.db.backends.sqlite3',
|
||||
},
|
||||
}
|
||||
|
||||
AUTHENTICATION_BACKENDS = [
|
||||
'allauth.account.auth_backends.AuthenticationBackend',
|
||||
]
|
||||
|
||||
_MIDDLEWARES = [
|
||||
'django.middleware.common.CommonMiddleware',
|
||||
'django.contrib.sessions.middleware.SessionMiddleware',
|
||||
'django.middleware.csrf.CsrfViewMiddleware',
|
||||
'django.contrib.auth.middleware.AuthenticationMiddleware',
|
||||
'django.contrib.messages.middleware.MessageMiddleware',
|
||||
]
|
||||
|
||||
if django.VERSION >= (1, 10):
|
||||
MIDDLEWARE = _MIDDLEWARES
|
||||
else:
|
||||
MIDDLEWARE_CLASSES = _MIDDLEWARES
|
||||
|
||||
TEMPLATES = [
|
||||
{
|
||||
'BACKEND': 'django.template.backends.django.DjangoTemplates',
|
||||
'DIRS': [],
|
||||
'APP_DIRS': True,
|
||||
'OPTIONS': {
|
||||
'context_processors': [
|
||||
'django.template.context_processors.debug',
|
||||
'django.template.context_processors.request',
|
||||
'django.contrib.auth.context_processors.auth',
|
||||
'django.contrib.messages.context_processors.messages',
|
||||
],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
ROOT_URLCONF = 'tests.urls'
|
91
tests/test_flows.py
Normal file
91
tests/test_flows.py
Normal file
|
@ -0,0 +1,91 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
try:
|
||||
from unittest.mock import patch
|
||||
except ImportError:
|
||||
from mock import patch
|
||||
|
||||
from django.contrib import messages
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.contrib.messages.api import get_messages
|
||||
from django.contrib.messages.storage.base import Message
|
||||
from django.test import TestCase, override_settings
|
||||
|
||||
from .cas_clients import AcceptCASClient
|
||||
|
||||
User = get_user_model()
|
||||
|
||||
|
||||
@patch('allauth_cas.views.cas.CASClient', AcceptCASClient)
|
||||
def client_cas_login(client):
|
||||
"""
|
||||
Sign in client through the example CAS provider.
|
||||
Returns the response of callbacK view.
|
||||
"""
|
||||
r = client.get('/accounts/theid/login/callback/', {
|
||||
'ticket': 'fake-ticket',
|
||||
})
|
||||
return r
|
||||
|
||||
|
||||
class LogoutFlowTests(TestCase):
|
||||
expected_msg_str = (
|
||||
"To logout of CAS, please close your browser, or visit this <a "
|
||||
"href=\"/accounts/theid/logout/?next=%2F\">link</a>."
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
client_cas_login(self.client)
|
||||
|
||||
def assertCASLogoutNotInMessages(self, response):
|
||||
r_messages = get_messages(response.wsgi_request)
|
||||
self.assertNotIn(
|
||||
self.expected_msg_str,
|
||||
(str(msg) for msg in r_messages),
|
||||
)
|
||||
self.assertTemplateNotUsed(
|
||||
response,
|
||||
'cas_account/messages/logged_out.txt',
|
||||
)
|
||||
|
||||
@override_settings(SOCIALACCOUNT_PROVIDERS={
|
||||
'theid': {
|
||||
'MESSAGE_ON_LOGOUT': True,
|
||||
'MESSAGE_ON_LOGOUT_LEVEL': messages.WARNING,
|
||||
},
|
||||
})
|
||||
def test_message_on_logout(self):
|
||||
"""
|
||||
Message is sent to propose user to logout of CAS.
|
||||
"""
|
||||
r = self.client.post('/accounts/logout/')
|
||||
r_messages = get_messages(r.wsgi_request)
|
||||
|
||||
expected_msg = Message(messages.WARNING, self.expected_msg_str)
|
||||
|
||||
self.assertIn(expected_msg, r_messages)
|
||||
self.assertTemplateUsed(r, 'cas_account/messages/logged_out.txt')
|
||||
|
||||
@override_settings(SOCIALACCOUNT_PROVIDERS={
|
||||
'theid': {
|
||||
'MESSAGE_ON_LOGOUT': False,
|
||||
},
|
||||
})
|
||||
def test_message_on_logout_disabled(self):
|
||||
"""
|
||||
The logout message can be disabled in settings.
|
||||
"""
|
||||
r = self.client.post('/accounts/logout/')
|
||||
self.assertCASLogoutNotInMessages(r)
|
||||
|
||||
@override_settings(SOCIALACCOUNT_PROVIDERS={
|
||||
'theid': {'MESSAGE_ON_LOGOUT': True},
|
||||
})
|
||||
def test_default_logout(self):
|
||||
"""
|
||||
The CAS logout message doesn't appear with other login methods.
|
||||
"""
|
||||
User.objects.create_user('user', '', 'user')
|
||||
self.client.login(username='user', password='user')
|
||||
|
||||
r = self.client.post('/accounts/logout/')
|
||||
self.assertCASLogoutNotInMessages(r)
|
131
tests/test_providers.py
Normal file
131
tests/test_providers.py
Normal file
|
@ -0,0 +1,131 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
from django.contrib import messages
|
||||
from django.test import RequestFactory, TestCase, override_settings
|
||||
|
||||
from allauth.socialaccount.providers import registry
|
||||
|
||||
from allauth_cas.views import AuthAction
|
||||
|
||||
from .example.provider import ExampleCASProvider
|
||||
|
||||
|
||||
class CASProviderTests(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
factory = RequestFactory()
|
||||
request = factory.get('/test/')
|
||||
request.session = {}
|
||||
self.request = request
|
||||
|
||||
self.provider = ExampleCASProvider(request)
|
||||
|
||||
def test_register(self):
|
||||
"""
|
||||
Example CAS provider is registered as social account provider.
|
||||
"""
|
||||
self.assertIsInstance(registry.by_id('theid'), ExampleCASProvider)
|
||||
|
||||
def test_get_login_url(self):
|
||||
"""
|
||||
get_login_url returns the url to logout of the provider.
|
||||
Keyword arguments are set as query string.
|
||||
"""
|
||||
url = self.provider.get_login_url(self.request)
|
||||
self.assertEqual('/accounts/theid/login/', url)
|
||||
|
||||
url_with_qs = self.provider.get_login_url(
|
||||
self.request,
|
||||
next='/path?quéry=string&two=whoam%C3%AF',
|
||||
)
|
||||
self.assertEqual(
|
||||
url_with_qs,
|
||||
'/accounts/theid/login/?next=%2Fpath%3Fqu%C3%A9ry%3Dstring%26two%3'
|
||||
'Dwhoam%25C3%25AF'
|
||||
)
|
||||
|
||||
def test_get_logout_url(self):
|
||||
"""
|
||||
get_logout_url returns the url to logout of the provider.
|
||||
Keyword arguments are set as query string.
|
||||
"""
|
||||
url = self.provider.get_logout_url(self.request)
|
||||
self.assertEqual('/accounts/theid/logout/', url)
|
||||
|
||||
url_with_qs = self.provider.get_logout_url(
|
||||
self.request,
|
||||
next='/path?quéry=string&two=whoam%C3%AF',
|
||||
)
|
||||
self.assertEqual(
|
||||
url_with_qs,
|
||||
'/accounts/theid/logout/?next=%2Fpath%3Fqu%C3%A9ry%3Dstring%26two%'
|
||||
'3Dwhoam%25C3%25AF'
|
||||
)
|
||||
|
||||
@override_settings(SOCIALACCOUNT_PROVIDERS={
|
||||
'theid': {
|
||||
'AUTH_PARAMS': {'key': 'value'},
|
||||
},
|
||||
})
|
||||
def test_get_auth_params(self):
|
||||
action = AuthAction.AUTHENTICATE
|
||||
|
||||
auth_params = self.provider.get_auth_params(self.request, action)
|
||||
|
||||
self.assertDictEqual(auth_params, {
|
||||
'key': 'value',
|
||||
})
|
||||
|
||||
@override_settings(SOCIALACCOUNT_PROVIDERS={
|
||||
'theid': {
|
||||
'AUTH_PARAMS': {'key': 'value'},
|
||||
},
|
||||
})
|
||||
def test_get_auth_params_with_dynamic(self):
|
||||
factory = RequestFactory()
|
||||
request = factory.get(
|
||||
'/test/?auth_params=next%3Dtwo%253Dwhoam%2525C3%2525AF%2526qu%2525'
|
||||
'C3%2525A9ry%253Dstring'
|
||||
)
|
||||
request.session = {}
|
||||
|
||||
action = AuthAction.AUTHENTICATE
|
||||
|
||||
auth_params = self.provider.get_auth_params(request, action)
|
||||
|
||||
self.assertDictEqual(auth_params, {
|
||||
'key': 'value',
|
||||
'next': 'two=whoam%C3%AF&qu%C3%A9ry=string',
|
||||
})
|
||||
|
||||
@override_settings(SOCIALACCOUNT_PROVIDERS={
|
||||
'theid': {
|
||||
'MESSAGE_ON_LOGOUT_LEVEL': messages.WARNING,
|
||||
},
|
||||
})
|
||||
def test_message_on_logout(self):
|
||||
message_on_logout = self.provider.message_on_logout(self.request)
|
||||
self.assertTrue(message_on_logout)
|
||||
|
||||
message_level = self.provider.message_on_logout_level(self.request)
|
||||
self.assertEqual(messages.WARNING, message_level)
|
||||
|
||||
def test_extract_uid(self):
|
||||
response = 'useRName', {}, None
|
||||
uid = self.provider.extract_uid(response)
|
||||
self.assertEqual('useRName', uid)
|
||||
|
||||
def test_extract_common_fields(self):
|
||||
response = 'useRName', {}, None
|
||||
common_fields = self.provider.extract_common_fields(response)
|
||||
self.assertDictEqual(common_fields, {
|
||||
'username': 'useRName',
|
||||
})
|
||||
|
||||
def test_extract_extra_data(self):
|
||||
attributes = {'user_attr': 'thevalue', 'another': 'value'}
|
||||
response = 'useRName', attributes, None
|
||||
extra_data = self.provider.extract_extra_data(response)
|
||||
self.assertDictEqual(extra_data, {
|
||||
'user_attr': 'thevalue',
|
||||
'another': 'value',
|
||||
})
|
237
tests/test_views.py
Normal file
237
tests/test_views.py
Normal file
|
@ -0,0 +1,237 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
try:
|
||||
from unittest.mock import patch
|
||||
except ImportError:
|
||||
from mock import patch
|
||||
|
||||
import django
|
||||
from django.test import RequestFactory, TestCase, override_settings
|
||||
|
||||
from allauth_cas.exceptions import CASAuthenticationError
|
||||
from allauth_cas.views import CASView
|
||||
|
||||
from .example.views import ExampleCASAdapter
|
||||
from .testcases import CASViewTestCase
|
||||
|
||||
if django.VERSION >= (1, 10):
|
||||
from django.urls import reverse
|
||||
else:
|
||||
from django.core.urlresolvers import reverse
|
||||
|
||||
|
||||
class CASAdapterTests(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
factory = RequestFactory()
|
||||
self.request = factory.get('/path/')
|
||||
self.adapter = ExampleCASAdapter(self.request)
|
||||
|
||||
def test_get_service_url(self):
|
||||
"""
|
||||
Service url (used by CAS client) is the callback url.
|
||||
"""
|
||||
expected = 'http://testserver/accounts/theid/login/callback/'
|
||||
service_url = self.adapter.get_service_url(self.request)
|
||||
self.assertEqual(expected, service_url)
|
||||
|
||||
def test_get_service_url_keep_next(self):
|
||||
"""
|
||||
Current GET paramater next is appended on service url.
|
||||
"""
|
||||
expected = (
|
||||
'http://testserver/accounts/theid/login/callback/?next=%2Fnext%2F'
|
||||
)
|
||||
factory = RequestFactory()
|
||||
request = factory.get('/path/', {'next': '/next/'})
|
||||
adapter = ExampleCASAdapter(request)
|
||||
service_url = adapter.get_service_url(request)
|
||||
self.assertEqual(expected, service_url)
|
||||
|
||||
def test_get_callback_url(self):
|
||||
expected = '/accounts/theid/login/callback/'
|
||||
callback_url = self.adapter.get_callback_url(self.request)
|
||||
self.assertEqual(expected, callback_url)
|
||||
|
||||
def test_get_callback_url_with_kwargs(self):
|
||||
expected = (
|
||||
'/accounts/theid/login/callback/?next=%2Fpath%2F'
|
||||
)
|
||||
callback_url = self.adapter.get_callback_url(self.request, **{
|
||||
'next': '/path/',
|
||||
})
|
||||
self.assertEqual(expected, callback_url)
|
||||
|
||||
|
||||
class CASViewTests(CASViewTestCase):
|
||||
|
||||
class BasicCASView(CASView):
|
||||
def dispatch(self, request, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def setUp(self):
|
||||
factory = RequestFactory()
|
||||
self.request = factory.get('path')
|
||||
|
||||
self.cas_view = self.BasicCASView.adapter_view(ExampleCASAdapter)
|
||||
|
||||
def test_adapter_view(self):
|
||||
"""
|
||||
adapter_view prepares the func view from a class view.
|
||||
"""
|
||||
view = self.cas_view(
|
||||
self.request,
|
||||
'arg1', 'arg2',
|
||||
kwarg1='kwarg1', kwarg2='kwarg2',
|
||||
)
|
||||
|
||||
self.assertIsInstance(view, CASView)
|
||||
|
||||
self.assertEqual(view.request, view.request)
|
||||
self.assertTupleEqual(view.args, ('arg1', 'arg2'))
|
||||
self.assertDictEqual(view.kwargs, {
|
||||
'kwarg1': 'kwarg1',
|
||||
'kwarg2': 'kwarg2',
|
||||
})
|
||||
|
||||
self.assertIsInstance(view.adapter, ExampleCASAdapter)
|
||||
|
||||
@patch('allauth_cas.views.cas.CASClient')
|
||||
@override_settings(SOCIALACCOUNT_PROVIDERS={
|
||||
'theid': {
|
||||
'AUTH_PARAMS': {'key': 'value'},
|
||||
},
|
||||
})
|
||||
def test_get_client(self, mock_casclient_class):
|
||||
"""
|
||||
get_client returns a CAS client, configured from settings.
|
||||
"""
|
||||
view = self.cas_view(self.request)
|
||||
view.get_client(self.request)
|
||||
|
||||
mock_casclient_class.assert_called_once_with(
|
||||
service_url='http://testserver/accounts/theid/login/callback/',
|
||||
server_url='https://server.cas',
|
||||
version=2,
|
||||
renew=False,
|
||||
extra_login_params={'key': 'value'},
|
||||
)
|
||||
|
||||
def test_render_error_on_failure(self):
|
||||
"""
|
||||
A common login failure page is rendered if CASAuthenticationError is
|
||||
raised by dispatch.
|
||||
"""
|
||||
def dispatch_raise(self, request):
|
||||
raise CASAuthenticationError("failure")
|
||||
|
||||
with patch.object(self.BasicCASView, 'dispatch', dispatch_raise):
|
||||
resp = self.cas_view(self.request)
|
||||
self.assertLoginFailure(resp)
|
||||
|
||||
|
||||
class CASLoginViewTests(CASViewTestCase):
|
||||
|
||||
def test_reverse(self):
|
||||
"""
|
||||
Login view name is "{provider_id}_login".
|
||||
"""
|
||||
url = reverse('theid_login')
|
||||
self.assertEqual('/accounts/theid/login/', url)
|
||||
|
||||
def test_execute(self):
|
||||
"""
|
||||
Login view redirects to the CAS server login url.
|
||||
Service is the callback url, as absolute uri.
|
||||
"""
|
||||
r = self.client.get('/accounts/theid/login/')
|
||||
|
||||
expected = (
|
||||
'https://server.cas/login?service=http%3A%2F%2Ftestserver%2F'
|
||||
'accounts%2Ftheid%2Flogin%2Fcallback%2F'
|
||||
)
|
||||
|
||||
self.assertRedirects(r, expected, fetch_redirect_response=False)
|
||||
|
||||
def test_execute_keep_next(self):
|
||||
"""
|
||||
Current GET parameter 'next' is kept on service url.
|
||||
"""
|
||||
r = self.client.get('/accounts/theid/login/?next=/path/')
|
||||
|
||||
expected = (
|
||||
'https://server.cas/login?service=http%3A%2F%2Ftestserver%2F'
|
||||
'accounts%2Ftheid%2Flogin%2Fcallback%2F%3Fnext%3D%252Fpath%252F'
|
||||
)
|
||||
|
||||
self.assertRedirects(r, expected, fetch_redirect_response=False)
|
||||
|
||||
|
||||
class CASCallbackViewTests(CASViewTestCase):
|
||||
|
||||
def test_reverse(self):
|
||||
"""
|
||||
Callback view name is "{provider_id}_callback".
|
||||
"""
|
||||
url = reverse('theid_callback')
|
||||
self.assertEqual('/accounts/theid/login/callback/', url)
|
||||
|
||||
def test_ticket_valid(self):
|
||||
"""
|
||||
If ticket is valid, the user is logged in.
|
||||
"""
|
||||
self.patch_cas_client('verify')
|
||||
r = self.client.get('/accounts/theid/login/callback/', {
|
||||
'ticket': '123456',
|
||||
})
|
||||
self.assertLoginSuccess(r)
|
||||
|
||||
def test_ticket_invalid(self):
|
||||
"""
|
||||
Login failure page is returned if the ticket is invalid.
|
||||
"""
|
||||
self.patch_cas_client('verify')
|
||||
r = self.client.get('/accounts/theid/login/callback/', {
|
||||
'ticket': '000000',
|
||||
})
|
||||
self.assertLoginFailure(r)
|
||||
|
||||
def test_ticket_missing(self):
|
||||
"""
|
||||
Login failure page is returned if request lacks a ticket.
|
||||
"""
|
||||
self.patch_cas_client('verify')
|
||||
r = self.client.get('/accounts/theid/login/callback/')
|
||||
self.assertLoginFailure(r)
|
||||
|
||||
|
||||
class CASLogoutViewTests(CASViewTestCase):
|
||||
|
||||
def test_reverse(self):
|
||||
"""
|
||||
Callback view name is "{provider_id}_logout".
|
||||
"""
|
||||
url = reverse('theid_logout')
|
||||
self.assertEqual('/accounts/theid/logout/', url)
|
||||
|
||||
def test_execute(self):
|
||||
"""
|
||||
Logout view redirects to the CAS server logout url.
|
||||
Service is a url to here, as absolute uri.
|
||||
"""
|
||||
r = self.client.get('/accounts/theid/logout/')
|
||||
|
||||
expected = 'https://server.cas/logout?url=http%3A%2F%2Ftestserver%2F'
|
||||
|
||||
self.assertRedirects(r, expected, fetch_redirect_response=False)
|
||||
|
||||
def test_execute_with_next(self):
|
||||
"""
|
||||
GET parameter 'next' is set as service url.
|
||||
"""
|
||||
r = self.client.get('/accounts/theid/logout/?next=/path/')
|
||||
|
||||
expected = (
|
||||
'https://server.cas/logout?url=http%3A%2F%2Ftestserver%2Fpath%2F'
|
||||
)
|
||||
|
||||
self.assertRedirects(r, expected, fetch_redirect_response=False)
|
76
tests/testcases.py
Normal file
76
tests/testcases.py
Normal file
|
@ -0,0 +1,76 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
try:
|
||||
from unittest.mock import patch
|
||||
except ImportError:
|
||||
from mock import patch
|
||||
|
||||
from django.test import TestCase as DjangoTestCase
|
||||
|
||||
from allauth_cas import CAS_PROVIDER_SESSION_KEY
|
||||
|
||||
from . import cas_clients
|
||||
|
||||
|
||||
class TestCase(DjangoTestCase):
|
||||
|
||||
def patch_cas_client(self, label):
|
||||
"""
|
||||
Patch cas.CASClient in allauth_cs.views module with another CAS client
|
||||
selectable with label argument.
|
||||
|
||||
Patch is stopped at the end of the current test.
|
||||
"""
|
||||
if hasattr(self, '_patch_cas_client'):
|
||||
self.patch_cas_client_stop()
|
||||
|
||||
if label == 'verify':
|
||||
new = cas_clients.VerifyCASClient
|
||||
elif label == 'accept':
|
||||
new = cas_clients.AcceptCASClient
|
||||
elif label == 'reject':
|
||||
new = cas_clients.RejectCASClient
|
||||
|
||||
self._patch_cas_client = patch('allauth_cas.views.cas.CASClient', new)
|
||||
self._patch_cas_client.start()
|
||||
|
||||
def patch_cas_client_stop(self):
|
||||
self._patch_cas_client.stop()
|
||||
|
||||
def tearDown(self):
|
||||
if hasattr(self, '_patch_cas_client'):
|
||||
self.patch_cas_client_stop()
|
||||
|
||||
|
||||
class CASViewTestCase(TestCase):
|
||||
|
||||
def assertLoginSuccess(self, response, redirect_to=None, client=None):
|
||||
"""
|
||||
Asserts response corresponds to a successful login.
|
||||
|
||||
To check this, the response should redirect to redirect_to (default to
|
||||
/accounts/profile/, the default redirect after a successful login).
|
||||
Also CAS_PROVIDER_SESSION_KEY should be set in the client' session. By
|
||||
default, self.client is used.
|
||||
"""
|
||||
if client is None:
|
||||
client = self.client
|
||||
if redirect_to is None:
|
||||
redirect_to = '/accounts/profile/'
|
||||
|
||||
self.assertRedirects(
|
||||
response, redirect_to,
|
||||
fetch_redirect_response=False,
|
||||
)
|
||||
self.assertIn(
|
||||
CAS_PROVIDER_SESSION_KEY,
|
||||
client.session,
|
||||
)
|
||||
|
||||
def assertLoginFailure(self, response):
|
||||
"""
|
||||
Asserts response corresponds to a failed login.
|
||||
"""
|
||||
return self.assertInHTML(
|
||||
'<h1>Social Network Login Failure</h1>',
|
||||
str(response.content),
|
||||
)
|
6
tests/urls.py
Normal file
6
tests/urls.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
from django.conf.urls import include, url
|
||||
|
||||
urlpatterns = [
|
||||
url(r'^accounts/', include('allauth.urls')),
|
||||
]
|
16
tox.ini
16
tox.ini
|
@ -1,7 +1,8 @@
|
|||
[tox]
|
||||
envlist =
|
||||
py{34,35}-django{18,19,110}
|
||||
py{34,35,36}-django111,
|
||||
py{27,34,35}-django{18,19,110}
|
||||
py{27,34,35,36}-django111,
|
||||
cov_combine,
|
||||
flake8,
|
||||
isort
|
||||
|
||||
|
@ -12,12 +13,21 @@ deps =
|
|||
django110: django>=1.10,<1.11
|
||||
django111: django>=1.11,<2.0
|
||||
coverage
|
||||
mock ; python_version < "3.0"
|
||||
usedevelop= True
|
||||
commands =
|
||||
python -V
|
||||
coverage run \
|
||||
--branch \
|
||||
--source=allauth_cas --omit=*migrations* \
|
||||
./runtests.py {posargs}
|
||||
--parallel-mode \
|
||||
runtests.py {posargs}
|
||||
|
||||
[testenv:cov_combine]
|
||||
deps =
|
||||
coverage
|
||||
commands =
|
||||
coverage combine
|
||||
coverage report --show-missing
|
||||
|
||||
[testenv:flake8]
|
||||
|
|
Loading…
Reference in a new issue