This commit is contained in:
Aurélien Delobelle 2017-07-25 18:31:42 +02:00
parent f6a7dbc385
commit 76fd5ca344
25 changed files with 1100 additions and 5 deletions

1
.gitignore vendored
View file

@ -1,4 +1,5 @@
.coverage
.coverage.*
.tox/
build/
dist/

View file

@ -4,5 +4,10 @@ django-allauth-cas
CAS support for django-allauth_.
Supports:
- Django 1.8-10 - Python 2.7, 3.4-5
- Django 1.11 - Python 2.7, 3.4-6
.. _django-allauth: https://www.intenct.nl/projects/django-allauth/

7
allauth_cas/__init__.py Normal file
View file

@ -0,0 +1,7 @@
# -*- coding: utf-8 -*-
__version__ = '0.0.1.dev0'
default_app_config = 'allauth_cas.apps.CASAccountConfig'
CAS_PROVIDER_SESSION_KEY = 'allauth_cas__provider_id'

11
allauth_cas/apps.py Normal file
View file

@ -0,0 +1,11 @@
# -*- coding: utf-8 -*-
from django.apps import AppConfig
from django.utils.translation import ugettext_lazy as _
class CASAccountConfig(AppConfig):
name = 'allauth_cas'
verbose_name = _("CAS Accounts")
def ready(self):
from . import signals # noqa

View file

@ -0,0 +1,7 @@
# -*- coding: utf-8 -*-
class CASAuthenticationError(Exception):
"""
Base exception to signal CAS authentication failure.
"""

55
allauth_cas/providers.py Normal file
View file

@ -0,0 +1,55 @@
# -*- coding: utf-8 -*-
from six.moves.urllib.parse import parse_qsl
import django
from django.contrib import messages
from django.utils.http import urlencode
from allauth.socialaccount.providers.base import Provider
if django.VERSION >= (1, 10):
from django.urls import reverse
else:
from django.core.urlresolvers import reverse
class CASProvider(Provider):
def get_login_url(self, request, **kwargs):
url = reverse(self.id + '_login')
if kwargs:
url += '?' + urlencode(kwargs)
return url
def get_logout_url(self, request, **kwargs):
url = reverse(self.id + '_logout')
if kwargs:
url += '?' + urlencode(kwargs)
return url
def get_auth_params(self, request, action):
settings = self.get_settings()
ret = dict(settings.get('AUTH_PARAMS', {}))
dynamic_auth_params = request.GET.get('auth_params')
if dynamic_auth_params:
ret.update(dict(parse_qsl(dynamic_auth_params)))
return ret
def message_on_logout(self, request):
return self.get_settings().get('MESSAGE_ON_LOGOUT', True)
def message_on_logout_level(self, request):
return self.get_settings().get('MESSAGE_ON_LOGOUT_LEVEL',
messages.INFO)
def extract_uid(self, data):
username, _, _ = data
return username
def extract_common_fields(self, data):
username, _, _ = data
return {'username': username}
def extract_extra_data(self, data):
_, extra_data, _ = data
return extra_data

45
allauth_cas/signals.py Normal file
View file

@ -0,0 +1,45 @@
# -*- coding: utf-8 -*-
from django.contrib.auth.signals import user_logged_out
from django.dispatch import receiver
from django.utils.safestring import mark_safe
from allauth.account.adapter import get_adapter
from allauth.account.utils import get_next_redirect_url
from allauth.socialaccount import providers
from . import CAS_PROVIDER_SESSION_KEY
@receiver(user_logged_out)
def cas_account_logout(sender, request, **kwargs):
provider_id = request.session.get(CAS_PROVIDER_SESSION_KEY)
if not provider_id:
return
provider = providers.registry.by_id(provider_id, request)
if not provider.message_on_logout(request):
return
adapter = get_adapter(request)
redirect_url = (
get_next_redirect_url(request) or
adapter.get_logout_redirect_url(request)
)
logout_kwargs = {'next': redirect_url} if redirect_url else {}
logout_url = provider.get_logout_url(request, **logout_kwargs)
level = provider.message_on_logout_level(request)
logout_link = mark_safe('<a href="{}">link</a>'.format(logout_url))
adapter.add_message(
request, level,
message_template='cas_account/messages/logged_out.txt',
message_context={
'logout_url': logout_url,
'logout_link': logout_link,
}
)

View file

@ -0,0 +1,4 @@
{% load i18n %}
{% blocktrans %}
To logout of CAS, please close your browser, or visit this {{ logout_link }}.
{% endblocktrans %}

23
allauth_cas/urls.py Normal file
View file

@ -0,0 +1,23 @@
# -*- coding: utf-8 -*-
from django.conf.urls import include, url
from allauth.utils import import_attribute
def default_urlpatterns(provider):
package = provider.get_package()
login_view = import_attribute(package + '.views.login')
callback_view = import_attribute(package + '.views.callback')
logout_view = import_attribute(package + '.views.logout')
urlpatterns = [
url('^login/$',
login_view, name=provider.id + '_login'),
url('^login/callback/$',
callback_view, name=provider.id + '_callback'),
url('^logout/$',
logout_view, name=provider.id + '_logout'),
]
return [url('^' + provider.get_slug() + '/', include(urlpatterns))]

242
allauth_cas/views.py Normal file
View file

@ -0,0 +1,242 @@
# -*- coding: utf-8 -*-
import django
from django.http import HttpResponseRedirect
from django.utils.http import urlencode
from allauth.account.adapter import get_adapter
from allauth.account.utils import get_next_redirect_url
from allauth.socialaccount import providers
from allauth.socialaccount.helpers import (
complete_social_login, render_authentication_error,
)
import cas
from . import CAS_PROVIDER_SESSION_KEY
from .exceptions import CASAuthenticationError
if django.VERSION >= (1, 10):
from django.urls import reverse
else:
from django.core.urlresolvers import reverse
class AuthAction(object):
AUTHENTICATE = 'authenticate'
REAUTHENTICATE = 'reauthenticate'
DEAUTHENTICATE = 'deauthenticate'
class CASAdapter(object):
# CAS client parameters
renew = False
def __init__(self, request):
self.request = request
def get_provider(self):
"""
Returns a provider instance for the current request.
"""
return providers.registry.by_id(self.provider_id, self.request)
def complete_login(self, request, response):
"""
Executed by the callback view after successful authentication on CAS
server.
Returns the SocialLogin object which represents the state of the
current login-session.
"""
login = (self.get_provider()
.sociallogin_from_response(request, response))
return login
def get_service_url(self, request):
"""
Returns the service url to for a CAS client.
From CAS specification, the service url is used in order to redirect
user after a successful login on CAS server. Also, service_url sent
when ticket is verified must be the one for which ticket was issued.
To conform this, the service url is always the callback url.
A redirect url is found from the current request and appended as
parameter to the service url and is latter used by the callback view to
redirect user.
"""
redirect_to = get_next_redirect_url(request)
callback_kwargs = {'next': redirect_to} if redirect_to else {}
callback_url = self.get_callback_url(request, **callback_kwargs)
service_url = request.build_absolute_uri(callback_url)
return service_url
def get_callback_url(self, request, **kwargs):
"""
Returns the callback url of the provider.
Keyword arguments are set as query string.
"""
url = reverse(self.provider_id + '_callback')
if kwargs:
url += '?' + urlencode(kwargs)
return url
class CASView(object):
@classmethod
def adapter_view(cls, adapter, **kwargs):
"""
Similar to the Django as_view() method.
It also setups a few things:
- given adapter argument will be used in views internals.
- if the view execution raises a CASAuthenticationError, the view
renders an authentication error page.
To use this:
- subclass CAS adapter as wanted:
class MyAdapter(CASAdapter):
url = 'https://my.cas.url'
- define views:
login = views.CASLoginView.adapter_view(MyAdapter)
callback = views.CASCallbackView.adapter_view(MyAdapter)
logout = views.CASLogoutView.adapter_view(MyAdapter)
"""
def view(request, *args, **kwargs):
# Prepare the func-view.
self = cls()
self.request = request
self.args = args
self.kwargs = kwargs
# Setup and store adapter as view attribute.
self.adapter = adapter(request)
try:
return self.dispatch(request, *args, **kwargs)
except CASAuthenticationError:
return self.render_error()
return view
def get_client(self, request, action=AuthAction.AUTHENTICATE):
"""
Returns the CAS client to interact with the CAS server.
"""
provider = self.adapter.get_provider()
auth_params = provider.get_auth_params(request, action)
service_url = self.adapter.get_service_url(request)
client = cas.CASClient(
service_url=service_url,
server_url=self.adapter.url,
version=self.adapter.version,
renew=self.adapter.renew,
extra_login_params=auth_params,
)
return client
def render_error(self):
"""
Returns an HTTP response in case an authentication failure happens.
"""
return render_authentication_error(
self.request,
self.adapter.provider_id,
)
class CASLoginView(CASView):
def dispatch(self, request):
"""
Redirects to the CAS server login page.
"""
action = request.GET.get('action', AuthAction.AUTHENTICATE)
client = self.get_client(request, action=action)
return HttpResponseRedirect(client.get_login_url())
class CASCallbackView(CASView):
def dispatch(self, request):
"""
The CAS server redirects the user to this view after a successful
authentication.
On redirect, CAS server should add a ticket whose validity is verified
here. If ticket is valid, CAS server may also return extra attributes
about user.
"""
provider = self.adapter.get_provider()
client = self.get_client(request)
# CAS server should let a ticket.
try:
ticket = request.GET['ticket']
except KeyError:
raise CASAuthenticationError(
"CAS server didn't respond with a ticket."
)
# Check ticket validity.
# Response format on:
# - success: username, attributes, pgtiou
# - error: None, {}, None
response = client.verify_ticket(ticket)
if not response[0]:
raise CASAuthenticationError(
"CAS server doesn't validate the ticket."
)
# The CAS provider in use is stored to propose to the user to
# disconnect from the latter when he logouts.
request.session[CAS_PROVIDER_SESSION_KEY] = provider.id
# Finish the login flow
login = self.adapter.complete_login(request, response)
return complete_social_login(request, login)
class CASLogoutView(CASView):
def dispatch(self, request, next_page=None):
"""
Redirects to the CAS server logout page.
next_page is used to let the CAS server send back the user. If empty,
the redirect url is built on request data.
"""
action = AuthAction.DEAUTHENTICATE
redirect_url = next_page or self.get_redirect_url()
redirect_to = request.build_absolute_uri(redirect_url)
client = self.get_client(request, action=action)
return HttpResponseRedirect(client.get_logout_url(redirect_to))
def get_redirect_url(self):
"""
Returns the url to redirect after logout from current request.
"""
request = self.request
return (
get_next_redirect_url(request) or
get_adapter(request).get_logout_redirect_url(request)
)

View file

@ -1,3 +1,5 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import sys

View file

@ -6,11 +6,13 @@ ignore = E731
combine_as_imports = True
default_section = THIRDPARTY
include_trailing_comma = True
known_allauth = allauth
known_future_library = future,six
known_django = django
known_first_party = allauth_cas
multi_line_output = 5
not_skip = __init__.py
sections = FUTURE,STDLIB,DJANGO,THIRDPARTY,FIRSTPARTY,LOCALFOLDER
sections = FUTURE,STDLIB,DJANGO,ALLAUTH,THIRDPARTY,FIRSTPARTY,LOCALFOLDER
[bdist_wheel]
universal = 1

View file

@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
import os
from setuptools import find_packages, setup
@ -15,7 +16,7 @@ setup(
description='CAS support for django-allauth.',
author='Aurélien Delobelle',
author_email='aurelien.delobelle@gmail.com',
keyword='django allauth cas authentication',
keywords='django allauth cas authentication',
long_description=README,
url='https://github.com/aureplop/django-allauth-cas',
classifiers=[
@ -44,6 +45,7 @@ setup(
install_requires=[
'django-allauth',
'python-cas',
'six',
],
extras_require={
'tests': ['tox'],

37
tests/cas_clients.py Normal file
View file

@ -0,0 +1,37 @@
# -*- coding: utf-8 -*-
import cas
class MockCASClient(cas.CASClientV2):
"""
Base class to mock cas.CASClient
"""
def __init__(self, *args, **kwargs):
kwargs.pop('version')
super(MockCASClient, self).__init__(*args, **kwargs)
class VerifyCASClient(MockCASClient):
"""
CAS client which verifies ticket is '123456'.
"""
def verify_ticket(self, ticket):
if ticket == '123456':
return 'username', {}, None
return None, {}, None
class AcceptCASClient(MockCASClient):
"""
CAS client which accepts all tickets.
"""
def verify_ticket(self, ticket):
return 'username', {}, None
class RejectCASClient(MockCASClient):
"""
CAS client which rejects all tickets.
"""
def verify_ticket(self, ticket):
return None, {}, None

View file

17
tests/example/provider.py Normal file
View file

@ -0,0 +1,17 @@
# -*- coding: utf-8 -*-
from allauth.socialaccount.providers.base import ProviderAccount
from allauth_cas.providers import CASProvider
class ExampleCASAccount(ProviderAccount):
pass
class ExampleCASProvider(CASProvider):
id = 'theid'
name = 'The Provider'
account_class = ExampleCASAccount
provider_classes = [ExampleCASProvider]

6
tests/example/urls.py Normal file
View file

@ -0,0 +1,6 @@
# -*- coding: utf-8 -*-
from allauth_cas.urls import default_urlpatterns
from .provider import ExampleCASProvider
urlpatterns = default_urlpatterns(ExampleCASProvider)

15
tests/example/views.py Normal file
View file

@ -0,0 +1,15 @@
# -*- coding: utf-8 -*-
from allauth_cas import views
from .provider import ExampleCASProvider
class ExampleCASAdapter(views.CASAdapter):
provider_id = ExampleCASProvider.id
url = 'https://server.cas'
version = 2
login = views.CASLoginView.adapter_view(ExampleCASAdapter)
callback = views.CASCallbackView.adapter_view(ExampleCASAdapter)
logout = views.CASLogoutView.adapter_view(ExampleCASAdapter)

63
tests/settings.py Normal file
View file

@ -0,0 +1,63 @@
# -*- coding: utf-8 -*-
import django
SECRET_KEY = 'iamabird'
INSTALLED_APPS = [
'django.contrib.admin',
'django.contrib.auth',
'django.contrib.contenttypes',
'django.contrib.messages',
'django.contrib.sessions',
'django.contrib.sites',
'django.contrib.staticfiles',
'allauth',
'allauth.account',
'allauth.socialaccount',
'allauth_cas',
'tests.example', # Dummy CAS provider app
]
DATABASES = {
'default': {
'ENGINE': 'django.db.backends.sqlite3',
},
}
AUTHENTICATION_BACKENDS = [
'allauth.account.auth_backends.AuthenticationBackend',
]
_MIDDLEWARES = [
'django.middleware.common.CommonMiddleware',
'django.contrib.sessions.middleware.SessionMiddleware',
'django.middleware.csrf.CsrfViewMiddleware',
'django.contrib.auth.middleware.AuthenticationMiddleware',
'django.contrib.messages.middleware.MessageMiddleware',
]
if django.VERSION >= (1, 10):
MIDDLEWARE = _MIDDLEWARES
else:
MIDDLEWARE_CLASSES = _MIDDLEWARES
TEMPLATES = [
{
'BACKEND': 'django.template.backends.django.DjangoTemplates',
'DIRS': [],
'APP_DIRS': True,
'OPTIONS': {
'context_processors': [
'django.template.context_processors.debug',
'django.template.context_processors.request',
'django.contrib.auth.context_processors.auth',
'django.contrib.messages.context_processors.messages',
],
},
},
]
ROOT_URLCONF = 'tests.urls'

91
tests/test_flows.py Normal file
View file

@ -0,0 +1,91 @@
# -*- coding: utf-8 -*-
try:
from unittest.mock import patch
except ImportError:
from mock import patch
from django.contrib import messages
from django.contrib.auth import get_user_model
from django.contrib.messages.api import get_messages
from django.contrib.messages.storage.base import Message
from django.test import TestCase, override_settings
from .cas_clients import AcceptCASClient
User = get_user_model()
@patch('allauth_cas.views.cas.CASClient', AcceptCASClient)
def client_cas_login(client):
"""
Sign in client through the example CAS provider.
Returns the response of callbacK view.
"""
r = client.get('/accounts/theid/login/callback/', {
'ticket': 'fake-ticket',
})
return r
class LogoutFlowTests(TestCase):
expected_msg_str = (
"To logout of CAS, please close your browser, or visit this <a "
"href=\"/accounts/theid/logout/?next=%2F\">link</a>."
)
def setUp(self):
client_cas_login(self.client)
def assertCASLogoutNotInMessages(self, response):
r_messages = get_messages(response.wsgi_request)
self.assertNotIn(
self.expected_msg_str,
(str(msg) for msg in r_messages),
)
self.assertTemplateNotUsed(
response,
'cas_account/messages/logged_out.txt',
)
@override_settings(SOCIALACCOUNT_PROVIDERS={
'theid': {
'MESSAGE_ON_LOGOUT': True,
'MESSAGE_ON_LOGOUT_LEVEL': messages.WARNING,
},
})
def test_message_on_logout(self):
"""
Message is sent to propose user to logout of CAS.
"""
r = self.client.post('/accounts/logout/')
r_messages = get_messages(r.wsgi_request)
expected_msg = Message(messages.WARNING, self.expected_msg_str)
self.assertIn(expected_msg, r_messages)
self.assertTemplateUsed(r, 'cas_account/messages/logged_out.txt')
@override_settings(SOCIALACCOUNT_PROVIDERS={
'theid': {
'MESSAGE_ON_LOGOUT': False,
},
})
def test_message_on_logout_disabled(self):
"""
The logout message can be disabled in settings.
"""
r = self.client.post('/accounts/logout/')
self.assertCASLogoutNotInMessages(r)
@override_settings(SOCIALACCOUNT_PROVIDERS={
'theid': {'MESSAGE_ON_LOGOUT': True},
})
def test_default_logout(self):
"""
The CAS logout message doesn't appear with other login methods.
"""
User.objects.create_user('user', '', 'user')
self.client.login(username='user', password='user')
r = self.client.post('/accounts/logout/')
self.assertCASLogoutNotInMessages(r)

131
tests/test_providers.py Normal file
View file

@ -0,0 +1,131 @@
# -*- coding: utf-8 -*-
from django.contrib import messages
from django.test import RequestFactory, TestCase, override_settings
from allauth.socialaccount.providers import registry
from allauth_cas.views import AuthAction
from .example.provider import ExampleCASProvider
class CASProviderTests(TestCase):
def setUp(self):
factory = RequestFactory()
request = factory.get('/test/')
request.session = {}
self.request = request
self.provider = ExampleCASProvider(request)
def test_register(self):
"""
Example CAS provider is registered as social account provider.
"""
self.assertIsInstance(registry.by_id('theid'), ExampleCASProvider)
def test_get_login_url(self):
"""
get_login_url returns the url to logout of the provider.
Keyword arguments are set as query string.
"""
url = self.provider.get_login_url(self.request)
self.assertEqual('/accounts/theid/login/', url)
url_with_qs = self.provider.get_login_url(
self.request,
next='/path?quéry=string&two=whoam%C3%AF',
)
self.assertEqual(
url_with_qs,
'/accounts/theid/login/?next=%2Fpath%3Fqu%C3%A9ry%3Dstring%26two%3'
'Dwhoam%25C3%25AF'
)
def test_get_logout_url(self):
"""
get_logout_url returns the url to logout of the provider.
Keyword arguments are set as query string.
"""
url = self.provider.get_logout_url(self.request)
self.assertEqual('/accounts/theid/logout/', url)
url_with_qs = self.provider.get_logout_url(
self.request,
next='/path?quéry=string&two=whoam%C3%AF',
)
self.assertEqual(
url_with_qs,
'/accounts/theid/logout/?next=%2Fpath%3Fqu%C3%A9ry%3Dstring%26two%'
'3Dwhoam%25C3%25AF'
)
@override_settings(SOCIALACCOUNT_PROVIDERS={
'theid': {
'AUTH_PARAMS': {'key': 'value'},
},
})
def test_get_auth_params(self):
action = AuthAction.AUTHENTICATE
auth_params = self.provider.get_auth_params(self.request, action)
self.assertDictEqual(auth_params, {
'key': 'value',
})
@override_settings(SOCIALACCOUNT_PROVIDERS={
'theid': {
'AUTH_PARAMS': {'key': 'value'},
},
})
def test_get_auth_params_with_dynamic(self):
factory = RequestFactory()
request = factory.get(
'/test/?auth_params=next%3Dtwo%253Dwhoam%2525C3%2525AF%2526qu%2525'
'C3%2525A9ry%253Dstring'
)
request.session = {}
action = AuthAction.AUTHENTICATE
auth_params = self.provider.get_auth_params(request, action)
self.assertDictEqual(auth_params, {
'key': 'value',
'next': 'two=whoam%C3%AF&qu%C3%A9ry=string',
})
@override_settings(SOCIALACCOUNT_PROVIDERS={
'theid': {
'MESSAGE_ON_LOGOUT_LEVEL': messages.WARNING,
},
})
def test_message_on_logout(self):
message_on_logout = self.provider.message_on_logout(self.request)
self.assertTrue(message_on_logout)
message_level = self.provider.message_on_logout_level(self.request)
self.assertEqual(messages.WARNING, message_level)
def test_extract_uid(self):
response = 'useRName', {}, None
uid = self.provider.extract_uid(response)
self.assertEqual('useRName', uid)
def test_extract_common_fields(self):
response = 'useRName', {}, None
common_fields = self.provider.extract_common_fields(response)
self.assertDictEqual(common_fields, {
'username': 'useRName',
})
def test_extract_extra_data(self):
attributes = {'user_attr': 'thevalue', 'another': 'value'}
response = 'useRName', attributes, None
extra_data = self.provider.extract_extra_data(response)
self.assertDictEqual(extra_data, {
'user_attr': 'thevalue',
'another': 'value',
})

237
tests/test_views.py Normal file
View file

@ -0,0 +1,237 @@
# -*- coding: utf-8 -*-
try:
from unittest.mock import patch
except ImportError:
from mock import patch
import django
from django.test import RequestFactory, TestCase, override_settings
from allauth_cas.exceptions import CASAuthenticationError
from allauth_cas.views import CASView
from .example.views import ExampleCASAdapter
from .testcases import CASViewTestCase
if django.VERSION >= (1, 10):
from django.urls import reverse
else:
from django.core.urlresolvers import reverse
class CASAdapterTests(TestCase):
def setUp(self):
factory = RequestFactory()
self.request = factory.get('/path/')
self.adapter = ExampleCASAdapter(self.request)
def test_get_service_url(self):
"""
Service url (used by CAS client) is the callback url.
"""
expected = 'http://testserver/accounts/theid/login/callback/'
service_url = self.adapter.get_service_url(self.request)
self.assertEqual(expected, service_url)
def test_get_service_url_keep_next(self):
"""
Current GET paramater next is appended on service url.
"""
expected = (
'http://testserver/accounts/theid/login/callback/?next=%2Fnext%2F'
)
factory = RequestFactory()
request = factory.get('/path/', {'next': '/next/'})
adapter = ExampleCASAdapter(request)
service_url = adapter.get_service_url(request)
self.assertEqual(expected, service_url)
def test_get_callback_url(self):
expected = '/accounts/theid/login/callback/'
callback_url = self.adapter.get_callback_url(self.request)
self.assertEqual(expected, callback_url)
def test_get_callback_url_with_kwargs(self):
expected = (
'/accounts/theid/login/callback/?next=%2Fpath%2F'
)
callback_url = self.adapter.get_callback_url(self.request, **{
'next': '/path/',
})
self.assertEqual(expected, callback_url)
class CASViewTests(CASViewTestCase):
class BasicCASView(CASView):
def dispatch(self, request, *args, **kwargs):
return self
def setUp(self):
factory = RequestFactory()
self.request = factory.get('path')
self.cas_view = self.BasicCASView.adapter_view(ExampleCASAdapter)
def test_adapter_view(self):
"""
adapter_view prepares the func view from a class view.
"""
view = self.cas_view(
self.request,
'arg1', 'arg2',
kwarg1='kwarg1', kwarg2='kwarg2',
)
self.assertIsInstance(view, CASView)
self.assertEqual(view.request, view.request)
self.assertTupleEqual(view.args, ('arg1', 'arg2'))
self.assertDictEqual(view.kwargs, {
'kwarg1': 'kwarg1',
'kwarg2': 'kwarg2',
})
self.assertIsInstance(view.adapter, ExampleCASAdapter)
@patch('allauth_cas.views.cas.CASClient')
@override_settings(SOCIALACCOUNT_PROVIDERS={
'theid': {
'AUTH_PARAMS': {'key': 'value'},
},
})
def test_get_client(self, mock_casclient_class):
"""
get_client returns a CAS client, configured from settings.
"""
view = self.cas_view(self.request)
view.get_client(self.request)
mock_casclient_class.assert_called_once_with(
service_url='http://testserver/accounts/theid/login/callback/',
server_url='https://server.cas',
version=2,
renew=False,
extra_login_params={'key': 'value'},
)
def test_render_error_on_failure(self):
"""
A common login failure page is rendered if CASAuthenticationError is
raised by dispatch.
"""
def dispatch_raise(self, request):
raise CASAuthenticationError("failure")
with patch.object(self.BasicCASView, 'dispatch', dispatch_raise):
resp = self.cas_view(self.request)
self.assertLoginFailure(resp)
class CASLoginViewTests(CASViewTestCase):
def test_reverse(self):
"""
Login view name is "{provider_id}_login".
"""
url = reverse('theid_login')
self.assertEqual('/accounts/theid/login/', url)
def test_execute(self):
"""
Login view redirects to the CAS server login url.
Service is the callback url, as absolute uri.
"""
r = self.client.get('/accounts/theid/login/')
expected = (
'https://server.cas/login?service=http%3A%2F%2Ftestserver%2F'
'accounts%2Ftheid%2Flogin%2Fcallback%2F'
)
self.assertRedirects(r, expected, fetch_redirect_response=False)
def test_execute_keep_next(self):
"""
Current GET parameter 'next' is kept on service url.
"""
r = self.client.get('/accounts/theid/login/?next=/path/')
expected = (
'https://server.cas/login?service=http%3A%2F%2Ftestserver%2F'
'accounts%2Ftheid%2Flogin%2Fcallback%2F%3Fnext%3D%252Fpath%252F'
)
self.assertRedirects(r, expected, fetch_redirect_response=False)
class CASCallbackViewTests(CASViewTestCase):
def test_reverse(self):
"""
Callback view name is "{provider_id}_callback".
"""
url = reverse('theid_callback')
self.assertEqual('/accounts/theid/login/callback/', url)
def test_ticket_valid(self):
"""
If ticket is valid, the user is logged in.
"""
self.patch_cas_client('verify')
r = self.client.get('/accounts/theid/login/callback/', {
'ticket': '123456',
})
self.assertLoginSuccess(r)
def test_ticket_invalid(self):
"""
Login failure page is returned if the ticket is invalid.
"""
self.patch_cas_client('verify')
r = self.client.get('/accounts/theid/login/callback/', {
'ticket': '000000',
})
self.assertLoginFailure(r)
def test_ticket_missing(self):
"""
Login failure page is returned if request lacks a ticket.
"""
self.patch_cas_client('verify')
r = self.client.get('/accounts/theid/login/callback/')
self.assertLoginFailure(r)
class CASLogoutViewTests(CASViewTestCase):
def test_reverse(self):
"""
Callback view name is "{provider_id}_logout".
"""
url = reverse('theid_logout')
self.assertEqual('/accounts/theid/logout/', url)
def test_execute(self):
"""
Logout view redirects to the CAS server logout url.
Service is a url to here, as absolute uri.
"""
r = self.client.get('/accounts/theid/logout/')
expected = 'https://server.cas/logout?url=http%3A%2F%2Ftestserver%2F'
self.assertRedirects(r, expected, fetch_redirect_response=False)
def test_execute_with_next(self):
"""
GET parameter 'next' is set as service url.
"""
r = self.client.get('/accounts/theid/logout/?next=/path/')
expected = (
'https://server.cas/logout?url=http%3A%2F%2Ftestserver%2Fpath%2F'
)
self.assertRedirects(r, expected, fetch_redirect_response=False)

76
tests/testcases.py Normal file
View file

@ -0,0 +1,76 @@
# -*- coding: utf-8 -*-
try:
from unittest.mock import patch
except ImportError:
from mock import patch
from django.test import TestCase as DjangoTestCase
from allauth_cas import CAS_PROVIDER_SESSION_KEY
from . import cas_clients
class TestCase(DjangoTestCase):
def patch_cas_client(self, label):
"""
Patch cas.CASClient in allauth_cs.views module with another CAS client
selectable with label argument.
Patch is stopped at the end of the current test.
"""
if hasattr(self, '_patch_cas_client'):
self.patch_cas_client_stop()
if label == 'verify':
new = cas_clients.VerifyCASClient
elif label == 'accept':
new = cas_clients.AcceptCASClient
elif label == 'reject':
new = cas_clients.RejectCASClient
self._patch_cas_client = patch('allauth_cas.views.cas.CASClient', new)
self._patch_cas_client.start()
def patch_cas_client_stop(self):
self._patch_cas_client.stop()
def tearDown(self):
if hasattr(self, '_patch_cas_client'):
self.patch_cas_client_stop()
class CASViewTestCase(TestCase):
def assertLoginSuccess(self, response, redirect_to=None, client=None):
"""
Asserts response corresponds to a successful login.
To check this, the response should redirect to redirect_to (default to
/accounts/profile/, the default redirect after a successful login).
Also CAS_PROVIDER_SESSION_KEY should be set in the client' session. By
default, self.client is used.
"""
if client is None:
client = self.client
if redirect_to is None:
redirect_to = '/accounts/profile/'
self.assertRedirects(
response, redirect_to,
fetch_redirect_response=False,
)
self.assertIn(
CAS_PROVIDER_SESSION_KEY,
client.session,
)
def assertLoginFailure(self, response):
"""
Asserts response corresponds to a failed login.
"""
return self.assertInHTML(
'<h1>Social Network Login Failure</h1>',
str(response.content),
)

6
tests/urls.py Normal file
View file

@ -0,0 +1,6 @@
# -*- coding: utf-8 -*-
from django.conf.urls import include, url
urlpatterns = [
url(r'^accounts/', include('allauth.urls')),
]

16
tox.ini
View file

@ -1,7 +1,8 @@
[tox]
envlist =
py{34,35}-django{18,19,110}
py{34,35,36}-django111,
py{27,34,35}-django{18,19,110}
py{27,34,35,36}-django111,
cov_combine,
flake8,
isort
@ -12,12 +13,21 @@ deps =
django110: django>=1.10,<1.11
django111: django>=1.11,<2.0
coverage
mock ; python_version < "3.0"
usedevelop= True
commands =
python -V
coverage run \
--branch \
--source=allauth_cas --omit=*migrations* \
./runtests.py {posargs}
--parallel-mode \
runtests.py {posargs}
[testenv:cov_combine]
deps =
coverage
commands =
coverage combine
coverage report --show-missing
[testenv:flake8]