diff --git a/.gitignore b/.gitignore index 973b5a9..3d5d4c2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ .coverage +.coverage.* .tox/ build/ dist/ diff --git a/README.rst b/README.rst index a9b90dc..72eb55e 100644 --- a/README.rst +++ b/README.rst @@ -4,5 +4,10 @@ django-allauth-cas CAS support for django-allauth_. +Supports: + +- Django 1.8-10 - Python 2.7, 3.4-5 +- Django 1.11 - Python 2.7, 3.4-6 + .. _django-allauth: https://www.intenct.nl/projects/django-allauth/ diff --git a/allauth_cas/__init__.py b/allauth_cas/__init__.py new file mode 100644 index 0000000..cceb46f --- /dev/null +++ b/allauth_cas/__init__.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- + +__version__ = '0.0.1.dev0' + +default_app_config = 'allauth_cas.apps.CASAccountConfig' + +CAS_PROVIDER_SESSION_KEY = 'allauth_cas__provider_id' diff --git a/allauth_cas/apps.py b/allauth_cas/apps.py new file mode 100644 index 0000000..2ee7f0c --- /dev/null +++ b/allauth_cas/apps.py @@ -0,0 +1,11 @@ +# -*- coding: utf-8 -*- +from django.apps import AppConfig +from django.utils.translation import ugettext_lazy as _ + + +class CASAccountConfig(AppConfig): + name = 'allauth_cas' + verbose_name = _("CAS Accounts") + + def ready(self): + from . import signals # noqa diff --git a/allauth_cas/exceptions.py b/allauth_cas/exceptions.py new file mode 100644 index 0000000..2ed507e --- /dev/null +++ b/allauth_cas/exceptions.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- + + +class CASAuthenticationError(Exception): + """ + Base exception to signal CAS authentication failure. + """ diff --git a/allauth_cas/providers.py b/allauth_cas/providers.py new file mode 100644 index 0000000..cf5ec52 --- /dev/null +++ b/allauth_cas/providers.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- +from six.moves.urllib.parse import parse_qsl + +import django +from django.contrib import messages +from django.utils.http import urlencode + +from allauth.socialaccount.providers.base import Provider + +if django.VERSION >= (1, 10): + from django.urls import reverse +else: + from django.core.urlresolvers import reverse + + +class CASProvider(Provider): + + def get_login_url(self, request, **kwargs): + url = reverse(self.id + '_login') + if kwargs: + url += '?' + urlencode(kwargs) + return url + + def get_logout_url(self, request, **kwargs): + url = reverse(self.id + '_logout') + if kwargs: + url += '?' + urlencode(kwargs) + return url + + def get_auth_params(self, request, action): + settings = self.get_settings() + ret = dict(settings.get('AUTH_PARAMS', {})) + dynamic_auth_params = request.GET.get('auth_params') + if dynamic_auth_params: + ret.update(dict(parse_qsl(dynamic_auth_params))) + return ret + + def message_on_logout(self, request): + return self.get_settings().get('MESSAGE_ON_LOGOUT', True) + + def message_on_logout_level(self, request): + return self.get_settings().get('MESSAGE_ON_LOGOUT_LEVEL', + messages.INFO) + + def extract_uid(self, data): + username, _, _ = data + return username + + def extract_common_fields(self, data): + username, _, _ = data + return {'username': username} + + def extract_extra_data(self, data): + _, extra_data, _ = data + return extra_data diff --git a/allauth_cas/signals.py b/allauth_cas/signals.py new file mode 100644 index 0000000..8b26023 --- /dev/null +++ b/allauth_cas/signals.py @@ -0,0 +1,45 @@ +# -*- coding: utf-8 -*- +from django.contrib.auth.signals import user_logged_out +from django.dispatch import receiver +from django.utils.safestring import mark_safe + +from allauth.account.adapter import get_adapter +from allauth.account.utils import get_next_redirect_url +from allauth.socialaccount import providers + +from . import CAS_PROVIDER_SESSION_KEY + + +@receiver(user_logged_out) +def cas_account_logout(sender, request, **kwargs): + provider_id = request.session.get(CAS_PROVIDER_SESSION_KEY) + + if not provider_id: + return + + provider = providers.registry.by_id(provider_id, request) + + if not provider.message_on_logout(request): + return + + adapter = get_adapter(request) + + redirect_url = ( + get_next_redirect_url(request) or + adapter.get_logout_redirect_url(request) + ) + + logout_kwargs = {'next': redirect_url} if redirect_url else {} + logout_url = provider.get_logout_url(request, **logout_kwargs) + + level = provider.message_on_logout_level(request) + logout_link = mark_safe('link'.format(logout_url)) + + adapter.add_message( + request, level, + message_template='cas_account/messages/logged_out.txt', + message_context={ + 'logout_url': logout_url, + 'logout_link': logout_link, + } + ) diff --git a/allauth_cas/templates/cas_account/messages/logged_out.txt b/allauth_cas/templates/cas_account/messages/logged_out.txt new file mode 100644 index 0000000..b0b88ab --- /dev/null +++ b/allauth_cas/templates/cas_account/messages/logged_out.txt @@ -0,0 +1,4 @@ +{% load i18n %} +{% blocktrans %} +To logout of CAS, please close your browser, or visit this {{ logout_link }}. +{% endblocktrans %} diff --git a/allauth_cas/urls.py b/allauth_cas/urls.py new file mode 100644 index 0000000..be30316 --- /dev/null +++ b/allauth_cas/urls.py @@ -0,0 +1,23 @@ +# -*- coding: utf-8 -*- +from django.conf.urls import include, url + +from allauth.utils import import_attribute + + +def default_urlpatterns(provider): + package = provider.get_package() + + login_view = import_attribute(package + '.views.login') + callback_view = import_attribute(package + '.views.callback') + logout_view = import_attribute(package + '.views.logout') + + urlpatterns = [ + url('^login/$', + login_view, name=provider.id + '_login'), + url('^login/callback/$', + callback_view, name=provider.id + '_callback'), + url('^logout/$', + logout_view, name=provider.id + '_logout'), + ] + + return [url('^' + provider.get_slug() + '/', include(urlpatterns))] diff --git a/allauth_cas/views.py b/allauth_cas/views.py new file mode 100644 index 0000000..4c2b6a7 --- /dev/null +++ b/allauth_cas/views.py @@ -0,0 +1,242 @@ +# -*- coding: utf-8 -*- +import django +from django.http import HttpResponseRedirect +from django.utils.http import urlencode + +from allauth.account.adapter import get_adapter +from allauth.account.utils import get_next_redirect_url +from allauth.socialaccount import providers +from allauth.socialaccount.helpers import ( + complete_social_login, render_authentication_error, +) + +import cas + +from . import CAS_PROVIDER_SESSION_KEY +from .exceptions import CASAuthenticationError + +if django.VERSION >= (1, 10): + from django.urls import reverse +else: + from django.core.urlresolvers import reverse + + +class AuthAction(object): + AUTHENTICATE = 'authenticate' + REAUTHENTICATE = 'reauthenticate' + DEAUTHENTICATE = 'deauthenticate' + + +class CASAdapter(object): + # CAS client parameters + renew = False + + def __init__(self, request): + self.request = request + + def get_provider(self): + """ + Returns a provider instance for the current request. + """ + return providers.registry.by_id(self.provider_id, self.request) + + def complete_login(self, request, response): + """ + Executed by the callback view after successful authentication on CAS + server. + + Returns the SocialLogin object which represents the state of the + current login-session. + """ + login = (self.get_provider() + .sociallogin_from_response(request, response)) + return login + + def get_service_url(self, request): + """ + Returns the service url to for a CAS client. + + From CAS specification, the service url is used in order to redirect + user after a successful login on CAS server. Also, service_url sent + when ticket is verified must be the one for which ticket was issued. + + To conform this, the service url is always the callback url. + + A redirect url is found from the current request and appended as + parameter to the service url and is latter used by the callback view to + redirect user. + """ + redirect_to = get_next_redirect_url(request) + + callback_kwargs = {'next': redirect_to} if redirect_to else {} + callback_url = self.get_callback_url(request, **callback_kwargs) + + service_url = request.build_absolute_uri(callback_url) + + return service_url + + def get_callback_url(self, request, **kwargs): + """ + Returns the callback url of the provider. + + Keyword arguments are set as query string. + """ + url = reverse(self.provider_id + '_callback') + if kwargs: + url += '?' + urlencode(kwargs) + return url + + +class CASView(object): + + @classmethod + def adapter_view(cls, adapter, **kwargs): + """ + Similar to the Django as_view() method. + + It also setups a few things: + - given adapter argument will be used in views internals. + - if the view execution raises a CASAuthenticationError, the view + renders an authentication error page. + + To use this: + + - subclass CAS adapter as wanted: + + class MyAdapter(CASAdapter): + url = 'https://my.cas.url' + + - define views: + + login = views.CASLoginView.adapter_view(MyAdapter) + callback = views.CASCallbackView.adapter_view(MyAdapter) + logout = views.CASLogoutView.adapter_view(MyAdapter) + + """ + def view(request, *args, **kwargs): + # Prepare the func-view. + self = cls() + + self.request = request + self.args = args + self.kwargs = kwargs + + # Setup and store adapter as view attribute. + self.adapter = adapter(request) + + try: + return self.dispatch(request, *args, **kwargs) + except CASAuthenticationError: + return self.render_error() + + return view + + def get_client(self, request, action=AuthAction.AUTHENTICATE): + """ + Returns the CAS client to interact with the CAS server. + """ + provider = self.adapter.get_provider() + auth_params = provider.get_auth_params(request, action) + + service_url = self.adapter.get_service_url(request) + + client = cas.CASClient( + service_url=service_url, + server_url=self.adapter.url, + version=self.adapter.version, + renew=self.adapter.renew, + extra_login_params=auth_params, + ) + + return client + + def render_error(self): + """ + Returns an HTTP response in case an authentication failure happens. + """ + return render_authentication_error( + self.request, + self.adapter.provider_id, + ) + + +class CASLoginView(CASView): + + def dispatch(self, request): + """ + Redirects to the CAS server login page. + """ + action = request.GET.get('action', AuthAction.AUTHENTICATE) + client = self.get_client(request, action=action) + return HttpResponseRedirect(client.get_login_url()) + + +class CASCallbackView(CASView): + + def dispatch(self, request): + """ + The CAS server redirects the user to this view after a successful + authentication. + + On redirect, CAS server should add a ticket whose validity is verified + here. If ticket is valid, CAS server may also return extra attributes + about user. + """ + provider = self.adapter.get_provider() + client = self.get_client(request) + + # CAS server should let a ticket. + try: + ticket = request.GET['ticket'] + except KeyError: + raise CASAuthenticationError( + "CAS server didn't respond with a ticket." + ) + + # Check ticket validity. + # Response format on: + # - success: username, attributes, pgtiou + # - error: None, {}, None + response = client.verify_ticket(ticket) + + if not response[0]: + raise CASAuthenticationError( + "CAS server doesn't validate the ticket." + ) + + # The CAS provider in use is stored to propose to the user to + # disconnect from the latter when he logouts. + request.session[CAS_PROVIDER_SESSION_KEY] = provider.id + + # Finish the login flow + login = self.adapter.complete_login(request, response) + return complete_social_login(request, login) + + +class CASLogoutView(CASView): + + def dispatch(self, request, next_page=None): + """ + Redirects to the CAS server logout page. + + next_page is used to let the CAS server send back the user. If empty, + the redirect url is built on request data. + """ + action = AuthAction.DEAUTHENTICATE + + redirect_url = next_page or self.get_redirect_url() + redirect_to = request.build_absolute_uri(redirect_url) + + client = self.get_client(request, action=action) + + return HttpResponseRedirect(client.get_logout_url(redirect_to)) + + def get_redirect_url(self): + """ + Returns the url to redirect after logout from current request. + """ + request = self.request + return ( + get_next_redirect_url(request) or + get_adapter(request).get_logout_redirect_url(request) + ) diff --git a/runtests.py b/runtests.py index 2e8d4e7..5cea84f 100644 --- a/runtests.py +++ b/runtests.py @@ -1,3 +1,5 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- import os import sys diff --git a/setup.cfg b/setup.cfg index 82951d1..7e7ca4d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -6,11 +6,13 @@ ignore = E731 combine_as_imports = True default_section = THIRDPARTY include_trailing_comma = True +known_allauth = allauth +known_future_library = future,six known_django = django known_first_party = allauth_cas multi_line_output = 5 not_skip = __init__.py -sections = FUTURE,STDLIB,DJANGO,THIRDPARTY,FIRSTPARTY,LOCALFOLDER +sections = FUTURE,STDLIB,DJANGO,ALLAUTH,THIRDPARTY,FIRSTPARTY,LOCALFOLDER [bdist_wheel] universal = 1 diff --git a/setup.py b/setup.py index da30033..bd84887 100644 --- a/setup.py +++ b/setup.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- import os from setuptools import find_packages, setup @@ -15,7 +16,7 @@ setup( description='CAS support for django-allauth.', author='Aurélien Delobelle', author_email='aurelien.delobelle@gmail.com', - keyword='django allauth cas authentication', + keywords='django allauth cas authentication', long_description=README, url='https://github.com/aureplop/django-allauth-cas', classifiers=[ @@ -44,6 +45,7 @@ setup( install_requires=[ 'django-allauth', 'python-cas', + 'six', ], extras_require={ 'tests': ['tox'], diff --git a/tests/cas_clients.py b/tests/cas_clients.py new file mode 100644 index 0000000..69c025e --- /dev/null +++ b/tests/cas_clients.py @@ -0,0 +1,37 @@ +# -*- coding: utf-8 -*- +import cas + + +class MockCASClient(cas.CASClientV2): + """ + Base class to mock cas.CASClient + """ + def __init__(self, *args, **kwargs): + kwargs.pop('version') + super(MockCASClient, self).__init__(*args, **kwargs) + + +class VerifyCASClient(MockCASClient): + """ + CAS client which verifies ticket is '123456'. + """ + def verify_ticket(self, ticket): + if ticket == '123456': + return 'username', {}, None + return None, {}, None + + +class AcceptCASClient(MockCASClient): + """ + CAS client which accepts all tickets. + """ + def verify_ticket(self, ticket): + return 'username', {}, None + + +class RejectCASClient(MockCASClient): + """ + CAS client which rejects all tickets. + """ + def verify_ticket(self, ticket): + return None, {}, None diff --git a/tests/example/__init__.py b/tests/example/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/example/provider.py b/tests/example/provider.py new file mode 100644 index 0000000..c393b70 --- /dev/null +++ b/tests/example/provider.py @@ -0,0 +1,17 @@ +# -*- coding: utf-8 -*- +from allauth.socialaccount.providers.base import ProviderAccount + +from allauth_cas.providers import CASProvider + + +class ExampleCASAccount(ProviderAccount): + pass + + +class ExampleCASProvider(CASProvider): + id = 'theid' + name = 'The Provider' + account_class = ExampleCASAccount + + +provider_classes = [ExampleCASProvider] diff --git a/tests/example/urls.py b/tests/example/urls.py new file mode 100644 index 0000000..335c441 --- /dev/null +++ b/tests/example/urls.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- +from allauth_cas.urls import default_urlpatterns + +from .provider import ExampleCASProvider + +urlpatterns = default_urlpatterns(ExampleCASProvider) diff --git a/tests/example/views.py b/tests/example/views.py new file mode 100644 index 0000000..6aa916a --- /dev/null +++ b/tests/example/views.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- +from allauth_cas import views + +from .provider import ExampleCASProvider + + +class ExampleCASAdapter(views.CASAdapter): + provider_id = ExampleCASProvider.id + url = 'https://server.cas' + version = 2 + + +login = views.CASLoginView.adapter_view(ExampleCASAdapter) +callback = views.CASCallbackView.adapter_view(ExampleCASAdapter) +logout = views.CASLogoutView.adapter_view(ExampleCASAdapter) diff --git a/tests/settings.py b/tests/settings.py new file mode 100644 index 0000000..5acb51d --- /dev/null +++ b/tests/settings.py @@ -0,0 +1,63 @@ +# -*- coding: utf-8 -*- +import django + +SECRET_KEY = 'iamabird' + +INSTALLED_APPS = [ + 'django.contrib.admin', + 'django.contrib.auth', + 'django.contrib.contenttypes', + 'django.contrib.messages', + 'django.contrib.sessions', + 'django.contrib.sites', + 'django.contrib.staticfiles', + + 'allauth', + 'allauth.account', + 'allauth.socialaccount', + + 'allauth_cas', + + 'tests.example', # Dummy CAS provider app +] + +DATABASES = { + 'default': { + 'ENGINE': 'django.db.backends.sqlite3', + }, +} + +AUTHENTICATION_BACKENDS = [ + 'allauth.account.auth_backends.AuthenticationBackend', +] + +_MIDDLEWARES = [ + 'django.middleware.common.CommonMiddleware', + 'django.contrib.sessions.middleware.SessionMiddleware', + 'django.middleware.csrf.CsrfViewMiddleware', + 'django.contrib.auth.middleware.AuthenticationMiddleware', + 'django.contrib.messages.middleware.MessageMiddleware', +] + +if django.VERSION >= (1, 10): + MIDDLEWARE = _MIDDLEWARES +else: + MIDDLEWARE_CLASSES = _MIDDLEWARES + +TEMPLATES = [ + { + 'BACKEND': 'django.template.backends.django.DjangoTemplates', + 'DIRS': [], + 'APP_DIRS': True, + 'OPTIONS': { + 'context_processors': [ + 'django.template.context_processors.debug', + 'django.template.context_processors.request', + 'django.contrib.auth.context_processors.auth', + 'django.contrib.messages.context_processors.messages', + ], + }, + }, +] + +ROOT_URLCONF = 'tests.urls' diff --git a/tests/test_flows.py b/tests/test_flows.py new file mode 100644 index 0000000..9feadfc --- /dev/null +++ b/tests/test_flows.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- +try: + from unittest.mock import patch +except ImportError: + from mock import patch + +from django.contrib import messages +from django.contrib.auth import get_user_model +from django.contrib.messages.api import get_messages +from django.contrib.messages.storage.base import Message +from django.test import TestCase, override_settings + +from .cas_clients import AcceptCASClient + +User = get_user_model() + + +@patch('allauth_cas.views.cas.CASClient', AcceptCASClient) +def client_cas_login(client): + """ + Sign in client through the example CAS provider. + Returns the response of callbacK view. + """ + r = client.get('/accounts/theid/login/callback/', { + 'ticket': 'fake-ticket', + }) + return r + + +class LogoutFlowTests(TestCase): + expected_msg_str = ( + "To logout of CAS, please close your browser, or visit this link." + ) + + def setUp(self): + client_cas_login(self.client) + + def assertCASLogoutNotInMessages(self, response): + r_messages = get_messages(response.wsgi_request) + self.assertNotIn( + self.expected_msg_str, + (str(msg) for msg in r_messages), + ) + self.assertTemplateNotUsed( + response, + 'cas_account/messages/logged_out.txt', + ) + + @override_settings(SOCIALACCOUNT_PROVIDERS={ + 'theid': { + 'MESSAGE_ON_LOGOUT': True, + 'MESSAGE_ON_LOGOUT_LEVEL': messages.WARNING, + }, + }) + def test_message_on_logout(self): + """ + Message is sent to propose user to logout of CAS. + """ + r = self.client.post('/accounts/logout/') + r_messages = get_messages(r.wsgi_request) + + expected_msg = Message(messages.WARNING, self.expected_msg_str) + + self.assertIn(expected_msg, r_messages) + self.assertTemplateUsed(r, 'cas_account/messages/logged_out.txt') + + @override_settings(SOCIALACCOUNT_PROVIDERS={ + 'theid': { + 'MESSAGE_ON_LOGOUT': False, + }, + }) + def test_message_on_logout_disabled(self): + """ + The logout message can be disabled in settings. + """ + r = self.client.post('/accounts/logout/') + self.assertCASLogoutNotInMessages(r) + + @override_settings(SOCIALACCOUNT_PROVIDERS={ + 'theid': {'MESSAGE_ON_LOGOUT': True}, + }) + def test_default_logout(self): + """ + The CAS logout message doesn't appear with other login methods. + """ + User.objects.create_user('user', '', 'user') + self.client.login(username='user', password='user') + + r = self.client.post('/accounts/logout/') + self.assertCASLogoutNotInMessages(r) diff --git a/tests/test_providers.py b/tests/test_providers.py new file mode 100644 index 0000000..abe7bb2 --- /dev/null +++ b/tests/test_providers.py @@ -0,0 +1,131 @@ +# -*- coding: utf-8 -*- +from django.contrib import messages +from django.test import RequestFactory, TestCase, override_settings + +from allauth.socialaccount.providers import registry + +from allauth_cas.views import AuthAction + +from .example.provider import ExampleCASProvider + + +class CASProviderTests(TestCase): + + def setUp(self): + factory = RequestFactory() + request = factory.get('/test/') + request.session = {} + self.request = request + + self.provider = ExampleCASProvider(request) + + def test_register(self): + """ + Example CAS provider is registered as social account provider. + """ + self.assertIsInstance(registry.by_id('theid'), ExampleCASProvider) + + def test_get_login_url(self): + """ + get_login_url returns the url to logout of the provider. + Keyword arguments are set as query string. + """ + url = self.provider.get_login_url(self.request) + self.assertEqual('/accounts/theid/login/', url) + + url_with_qs = self.provider.get_login_url( + self.request, + next='/path?quéry=string&two=whoam%C3%AF', + ) + self.assertEqual( + url_with_qs, + '/accounts/theid/login/?next=%2Fpath%3Fqu%C3%A9ry%3Dstring%26two%3' + 'Dwhoam%25C3%25AF' + ) + + def test_get_logout_url(self): + """ + get_logout_url returns the url to logout of the provider. + Keyword arguments are set as query string. + """ + url = self.provider.get_logout_url(self.request) + self.assertEqual('/accounts/theid/logout/', url) + + url_with_qs = self.provider.get_logout_url( + self.request, + next='/path?quéry=string&two=whoam%C3%AF', + ) + self.assertEqual( + url_with_qs, + '/accounts/theid/logout/?next=%2Fpath%3Fqu%C3%A9ry%3Dstring%26two%' + '3Dwhoam%25C3%25AF' + ) + + @override_settings(SOCIALACCOUNT_PROVIDERS={ + 'theid': { + 'AUTH_PARAMS': {'key': 'value'}, + }, + }) + def test_get_auth_params(self): + action = AuthAction.AUTHENTICATE + + auth_params = self.provider.get_auth_params(self.request, action) + + self.assertDictEqual(auth_params, { + 'key': 'value', + }) + + @override_settings(SOCIALACCOUNT_PROVIDERS={ + 'theid': { + 'AUTH_PARAMS': {'key': 'value'}, + }, + }) + def test_get_auth_params_with_dynamic(self): + factory = RequestFactory() + request = factory.get( + '/test/?auth_params=next%3Dtwo%253Dwhoam%2525C3%2525AF%2526qu%2525' + 'C3%2525A9ry%253Dstring' + ) + request.session = {} + + action = AuthAction.AUTHENTICATE + + auth_params = self.provider.get_auth_params(request, action) + + self.assertDictEqual(auth_params, { + 'key': 'value', + 'next': 'two=whoam%C3%AF&qu%C3%A9ry=string', + }) + + @override_settings(SOCIALACCOUNT_PROVIDERS={ + 'theid': { + 'MESSAGE_ON_LOGOUT_LEVEL': messages.WARNING, + }, + }) + def test_message_on_logout(self): + message_on_logout = self.provider.message_on_logout(self.request) + self.assertTrue(message_on_logout) + + message_level = self.provider.message_on_logout_level(self.request) + self.assertEqual(messages.WARNING, message_level) + + def test_extract_uid(self): + response = 'useRName', {}, None + uid = self.provider.extract_uid(response) + self.assertEqual('useRName', uid) + + def test_extract_common_fields(self): + response = 'useRName', {}, None + common_fields = self.provider.extract_common_fields(response) + self.assertDictEqual(common_fields, { + 'username': 'useRName', + }) + + def test_extract_extra_data(self): + attributes = {'user_attr': 'thevalue', 'another': 'value'} + response = 'useRName', attributes, None + extra_data = self.provider.extract_extra_data(response) + self.assertDictEqual(extra_data, { + 'user_attr': 'thevalue', + 'another': 'value', + }) diff --git a/tests/test_views.py b/tests/test_views.py new file mode 100644 index 0000000..18a985a --- /dev/null +++ b/tests/test_views.py @@ -0,0 +1,237 @@ +# -*- coding: utf-8 -*- +try: + from unittest.mock import patch +except ImportError: + from mock import patch + +import django +from django.test import RequestFactory, TestCase, override_settings + +from allauth_cas.exceptions import CASAuthenticationError +from allauth_cas.views import CASView + +from .example.views import ExampleCASAdapter +from .testcases import CASViewTestCase + +if django.VERSION >= (1, 10): + from django.urls import reverse +else: + from django.core.urlresolvers import reverse + + +class CASAdapterTests(TestCase): + + def setUp(self): + factory = RequestFactory() + self.request = factory.get('/path/') + self.adapter = ExampleCASAdapter(self.request) + + def test_get_service_url(self): + """ + Service url (used by CAS client) is the callback url. + """ + expected = 'http://testserver/accounts/theid/login/callback/' + service_url = self.adapter.get_service_url(self.request) + self.assertEqual(expected, service_url) + + def test_get_service_url_keep_next(self): + """ + Current GET paramater next is appended on service url. + """ + expected = ( + 'http://testserver/accounts/theid/login/callback/?next=%2Fnext%2F' + ) + factory = RequestFactory() + request = factory.get('/path/', {'next': '/next/'}) + adapter = ExampleCASAdapter(request) + service_url = adapter.get_service_url(request) + self.assertEqual(expected, service_url) + + def test_get_callback_url(self): + expected = '/accounts/theid/login/callback/' + callback_url = self.adapter.get_callback_url(self.request) + self.assertEqual(expected, callback_url) + + def test_get_callback_url_with_kwargs(self): + expected = ( + '/accounts/theid/login/callback/?next=%2Fpath%2F' + ) + callback_url = self.adapter.get_callback_url(self.request, **{ + 'next': '/path/', + }) + self.assertEqual(expected, callback_url) + + +class CASViewTests(CASViewTestCase): + + class BasicCASView(CASView): + def dispatch(self, request, *args, **kwargs): + return self + + def setUp(self): + factory = RequestFactory() + self.request = factory.get('path') + + self.cas_view = self.BasicCASView.adapter_view(ExampleCASAdapter) + + def test_adapter_view(self): + """ + adapter_view prepares the func view from a class view. + """ + view = self.cas_view( + self.request, + 'arg1', 'arg2', + kwarg1='kwarg1', kwarg2='kwarg2', + ) + + self.assertIsInstance(view, CASView) + + self.assertEqual(view.request, view.request) + self.assertTupleEqual(view.args, ('arg1', 'arg2')) + self.assertDictEqual(view.kwargs, { + 'kwarg1': 'kwarg1', + 'kwarg2': 'kwarg2', + }) + + self.assertIsInstance(view.adapter, ExampleCASAdapter) + + @patch('allauth_cas.views.cas.CASClient') + @override_settings(SOCIALACCOUNT_PROVIDERS={ + 'theid': { + 'AUTH_PARAMS': {'key': 'value'}, + }, + }) + def test_get_client(self, mock_casclient_class): + """ + get_client returns a CAS client, configured from settings. + """ + view = self.cas_view(self.request) + view.get_client(self.request) + + mock_casclient_class.assert_called_once_with( + service_url='http://testserver/accounts/theid/login/callback/', + server_url='https://server.cas', + version=2, + renew=False, + extra_login_params={'key': 'value'}, + ) + + def test_render_error_on_failure(self): + """ + A common login failure page is rendered if CASAuthenticationError is + raised by dispatch. + """ + def dispatch_raise(self, request): + raise CASAuthenticationError("failure") + + with patch.object(self.BasicCASView, 'dispatch', dispatch_raise): + resp = self.cas_view(self.request) + self.assertLoginFailure(resp) + + +class CASLoginViewTests(CASViewTestCase): + + def test_reverse(self): + """ + Login view name is "{provider_id}_login". + """ + url = reverse('theid_login') + self.assertEqual('/accounts/theid/login/', url) + + def test_execute(self): + """ + Login view redirects to the CAS server login url. + Service is the callback url, as absolute uri. + """ + r = self.client.get('/accounts/theid/login/') + + expected = ( + 'https://server.cas/login?service=http%3A%2F%2Ftestserver%2F' + 'accounts%2Ftheid%2Flogin%2Fcallback%2F' + ) + + self.assertRedirects(r, expected, fetch_redirect_response=False) + + def test_execute_keep_next(self): + """ + Current GET parameter 'next' is kept on service url. + """ + r = self.client.get('/accounts/theid/login/?next=/path/') + + expected = ( + 'https://server.cas/login?service=http%3A%2F%2Ftestserver%2F' + 'accounts%2Ftheid%2Flogin%2Fcallback%2F%3Fnext%3D%252Fpath%252F' + ) + + self.assertRedirects(r, expected, fetch_redirect_response=False) + + +class CASCallbackViewTests(CASViewTestCase): + + def test_reverse(self): + """ + Callback view name is "{provider_id}_callback". + """ + url = reverse('theid_callback') + self.assertEqual('/accounts/theid/login/callback/', url) + + def test_ticket_valid(self): + """ + If ticket is valid, the user is logged in. + """ + self.patch_cas_client('verify') + r = self.client.get('/accounts/theid/login/callback/', { + 'ticket': '123456', + }) + self.assertLoginSuccess(r) + + def test_ticket_invalid(self): + """ + Login failure page is returned if the ticket is invalid. + """ + self.patch_cas_client('verify') + r = self.client.get('/accounts/theid/login/callback/', { + 'ticket': '000000', + }) + self.assertLoginFailure(r) + + def test_ticket_missing(self): + """ + Login failure page is returned if request lacks a ticket. + """ + self.patch_cas_client('verify') + r = self.client.get('/accounts/theid/login/callback/') + self.assertLoginFailure(r) + + +class CASLogoutViewTests(CASViewTestCase): + + def test_reverse(self): + """ + Callback view name is "{provider_id}_logout". + """ + url = reverse('theid_logout') + self.assertEqual('/accounts/theid/logout/', url) + + def test_execute(self): + """ + Logout view redirects to the CAS server logout url. + Service is a url to here, as absolute uri. + """ + r = self.client.get('/accounts/theid/logout/') + + expected = 'https://server.cas/logout?url=http%3A%2F%2Ftestserver%2F' + + self.assertRedirects(r, expected, fetch_redirect_response=False) + + def test_execute_with_next(self): + """ + GET parameter 'next' is set as service url. + """ + r = self.client.get('/accounts/theid/logout/?next=/path/') + + expected = ( + 'https://server.cas/logout?url=http%3A%2F%2Ftestserver%2Fpath%2F' + ) + + self.assertRedirects(r, expected, fetch_redirect_response=False) diff --git a/tests/testcases.py b/tests/testcases.py new file mode 100644 index 0000000..6b039ad --- /dev/null +++ b/tests/testcases.py @@ -0,0 +1,76 @@ +# -*- coding: utf-8 -*- +try: + from unittest.mock import patch +except ImportError: + from mock import patch + +from django.test import TestCase as DjangoTestCase + +from allauth_cas import CAS_PROVIDER_SESSION_KEY + +from . import cas_clients + + +class TestCase(DjangoTestCase): + + def patch_cas_client(self, label): + """ + Patch cas.CASClient in allauth_cs.views module with another CAS client + selectable with label argument. + + Patch is stopped at the end of the current test. + """ + if hasattr(self, '_patch_cas_client'): + self.patch_cas_client_stop() + + if label == 'verify': + new = cas_clients.VerifyCASClient + elif label == 'accept': + new = cas_clients.AcceptCASClient + elif label == 'reject': + new = cas_clients.RejectCASClient + + self._patch_cas_client = patch('allauth_cas.views.cas.CASClient', new) + self._patch_cas_client.start() + + def patch_cas_client_stop(self): + self._patch_cas_client.stop() + + def tearDown(self): + if hasattr(self, '_patch_cas_client'): + self.patch_cas_client_stop() + + +class CASViewTestCase(TestCase): + + def assertLoginSuccess(self, response, redirect_to=None, client=None): + """ + Asserts response corresponds to a successful login. + + To check this, the response should redirect to redirect_to (default to + /accounts/profile/, the default redirect after a successful login). + Also CAS_PROVIDER_SESSION_KEY should be set in the client' session. By + default, self.client is used. + """ + if client is None: + client = self.client + if redirect_to is None: + redirect_to = '/accounts/profile/' + + self.assertRedirects( + response, redirect_to, + fetch_redirect_response=False, + ) + self.assertIn( + CAS_PROVIDER_SESSION_KEY, + client.session, + ) + + def assertLoginFailure(self, response): + """ + Asserts response corresponds to a failed login. + """ + return self.assertInHTML( + '

Social Network Login Failure

', + str(response.content), + ) diff --git a/tests/urls.py b/tests/urls.py new file mode 100644 index 0000000..1933289 --- /dev/null +++ b/tests/urls.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- +from django.conf.urls import include, url + +urlpatterns = [ + url(r'^accounts/', include('allauth.urls')), +] diff --git a/tox.ini b/tox.ini index 1a15627..050912f 100644 --- a/tox.ini +++ b/tox.ini @@ -1,7 +1,8 @@ [tox] envlist = - py{34,35}-django{18,19,110} - py{34,35,36}-django111, + py{27,34,35}-django{18,19,110} + py{27,34,35,36}-django111, + cov_combine, flake8, isort @@ -12,12 +13,21 @@ deps = django110: django>=1.10,<1.11 django111: django>=1.11,<2.0 coverage + mock ; python_version < "3.0" usedevelop= True commands = + python -V coverage run \ --branch \ --source=allauth_cas --omit=*migrations* \ - ./runtests.py {posargs} + --parallel-mode \ + runtests.py {posargs} + +[testenv:cov_combine] +deps = + coverage +commands = + coverage combine coverage report --show-missing [testenv:flake8]