diff --git a/allauth_cas/test/testcases.py b/allauth_cas/test/testcases.py index d13500e..d98ee11 100644 --- a/allauth_cas/test/testcases.py +++ b/allauth_cas/test/testcases.py @@ -4,6 +4,7 @@ try: except ImportError: from mock import patch +import django from django.conf import settings from django.test import TestCase @@ -11,9 +12,35 @@ import cas from allauth_cas import CAS_PROVIDER_SESSION_KEY +if django.VERSION >= (1, 10): + from django.urls import reverse +else: + from django.core.urlresolvers import reverse + class CASTestCase(TestCase): + def client_cas_login( + self, + client, provider_id='theid', + username=None, attributes={}): + """ + Authenticate client through provider_id. + + Returns the response of the callback view. + + username and attributes control the CAS server response when ticket is + checked. + """ + self.patch_cas_response( + valid_ticket='__all__', + username=username, attributes=attributes, + ) + callback_url = reverse('{id}_callback'.format(id=provider_id)) + r = client.get(callback_url, {'ticket': 'fake-ticket'}) + self.patch_cas_response_stop() + return r + def patch_cas_response( self, valid_ticket, @@ -41,20 +68,39 @@ class CASTestCase(TestCase): ticket retrieved from GET parameter on request on the callback view). """ if hasattr(self, '_patch_cas_client'): - self.patch_cas_client_stop() + self.patch_cas_response_stop() - class MockCASClient(cas.CASClientV2): + class MockCASClient(object): _username = username - def __init__(self_client, *args, **kwargs): - kwargs.pop('version') - super(MockCASClient, self_client).__init__(*args, **kwargs) + def __new__(self_client, *args, **kwargs): + version = kwargs.pop('version') + if version in (1, '1'): + client_class = cas.CASClientV1 + elif version in (2, '2'): + client_class = cas.CASClientV2 + elif version in (3, '3'): + client_class = cas.CASClientV3 + elif version == 'CAS_2_SAML_1_0': + client_class = cas.CASClientWithSAMLV1 + else: + raise ValueError('Unsupported CAS_VERSION %r' % version) - def verify_ticket(self_client, ticket): - if valid_ticket == '__all__' or ticket == valid_ticket: - username = self_client._username or 'username' - return username, attributes, None - return None, {}, None + client_class._username = self_client._username + + def verify_ticket(self, ticket): + if valid_ticket == '__all__' or ticket == valid_ticket: + username = self._username or 'username' + return username, attributes, None + return None, {}, None + + patcher = patch.object( + client_class, 'verify_ticket', + new=verify_ticket, + ) + patcher.start() + + return client_class(*args, **kwargs) self._patch_cas_client = patch( 'allauth_cas.views.cas.CASClient', @@ -62,13 +108,13 @@ class CASTestCase(TestCase): ) self._patch_cas_client.start() - def patch_cas_client_stop(self): + def patch_cas_response_stop(self): self._patch_cas_client.stop() del self._patch_cas_client def tearDown(self): if hasattr(self, '_patch_cas_client'): - self.patch_cas_client_stop() + self.patch_cas_response_stop() class CASViewTestCase(CASTestCase): diff --git a/allauth_cas/views.py b/allauth_cas/views.py index 4c2b6a7..3a3bb83 100644 --- a/allauth_cas/views.py +++ b/allauth_cas/views.py @@ -28,12 +28,22 @@ class AuthAction(object): class CASAdapter(object): - # CAS client parameters - renew = False def __init__(self, request): self.request = request + @property + def renew(self): + """ + If user is already authenticated on Django, he may already been + connected to CAS, but still may want to use another CAS account. + We set renew to True in this case, as the CAS server won't use the + single sign-on. + To specifically check, if the current user has used a CAS server, + we check if the CAS session key is set. + """ + return CAS_PROVIDER_SESSION_KEY in self.request.session + def get_provider(self): """ Returns a provider instance for the current request. diff --git a/tests/cas_clients.py b/tests/cas_clients.py deleted file mode 100644 index 69c025e..0000000 --- a/tests/cas_clients.py +++ /dev/null @@ -1,37 +0,0 @@ -# -*- coding: utf-8 -*- -import cas - - -class MockCASClient(cas.CASClientV2): - """ - Base class to mock cas.CASClient - """ - def __init__(self, *args, **kwargs): - kwargs.pop('version') - super(MockCASClient, self).__init__(*args, **kwargs) - - -class VerifyCASClient(MockCASClient): - """ - CAS client which verifies ticket is '123456'. - """ - def verify_ticket(self, ticket): - if ticket == '123456': - return 'username', {}, None - return None, {}, None - - -class AcceptCASClient(MockCASClient): - """ - CAS client which accepts all tickets. - """ - def verify_ticket(self, ticket): - return 'username', {}, None - - -class RejectCASClient(MockCASClient): - """ - CAS client which rejects all tickets. - """ - def verify_ticket(self, ticket): - return None, {}, None diff --git a/tests/test_flows.py b/tests/test_flows.py index 8a61c87..29b3c93 100644 --- a/tests/test_flows.py +++ b/tests/test_flows.py @@ -1,33 +1,16 @@ # -*- coding: utf-8 -*- -try: - from unittest.mock import patch -except ImportError: - from mock import patch - from django.contrib import messages from django.contrib.auth import get_user_model from django.contrib.messages.api import get_messages from django.contrib.messages.storage.base import Message -from django.test import TestCase, override_settings +from django.test import override_settings -from .cas_clients import AcceptCASClient +from allauth_cas.test.testcases import CASTestCase User = get_user_model() -@patch('allauth_cas.views.cas.CASClient', AcceptCASClient) -def client_cas_login(client): - """ - Sign in client through the example CAS provider. - Returns the response of callbacK view. - """ - r = client.get('/accounts/theid/login/callback/', { - 'ticket': 'fake-ticket', - }) - return r - - -class LogoutFlowTests(TestCase): +class LogoutFlowTests(CASTestCase): expected_msg_str = ( "To logout of CAS, please close your browser, or visit this " "" @@ -35,7 +18,7 @@ class LogoutFlowTests(TestCase): ) def setUp(self): - client_cas_login(self.client) + self.client_cas_login(self.client) def assertCASLogoutNotInMessages(self, response): r_messages = get_messages(response.wsgi_request) diff --git a/tests/test_testcases.py b/tests/test_testcases.py index 993d6c3..ef260bb 100644 --- a/tests/test_testcases.py +++ b/tests/test_testcases.py @@ -1,11 +1,60 @@ # -*- coding: utf-8 -*- -from django.test import Client +from django.test import Client, RequestFactory from allauth_cas.test.testcases import CASViewTestCase +from allauth_cas.views import CASView + +from .example.views import ExampleCASAdapter class CASTestCaseTests(CASViewTestCase): + def test_patch_cas_response_client_version(self): + """ + python-cas uses multiple client classes depending on the CAS server + version. + + patch_cas_response patch must also returns the correct class. + + """ + valid_versions = [ + 1, '1', + 2, '2', + 3, '3', + 'CAS_2_SAML_1_0', + ] + invalid_versions = [ + 'not_supported', + ] + + factory = RequestFactory() + request = factory.get('/path/') + request.session = {} + + for _version in valid_versions + invalid_versions: + class BasicCASAdapter(ExampleCASAdapter): + version = _version + + class BasicCASView(CASView): + def dispatch(self, request, *args, **kwargs): + return self.get_client(request) + + view = BasicCASView.adapter_view(BasicCASAdapter) + + if _version in valid_versions: + raw_client = view(request) + + self.patch_cas_response(valid_ticket='__all__') + mocked_client = view(request) + + self.assertEqual(type(raw_client), type(mocked_client)) + else: + # This is a sanity check. + self.assertRaises(ValueError, view, request) + + self.patch_cas_response(valid_ticket='__all__') + self.assertRaises(ValueError, view, request) + def test_patch_cas_response_verify_success(self): self.patch_cas_response(valid_ticket='123456') r = self.client.get('/accounts/theid/login/callback/', { diff --git a/tests/test_views.py b/tests/test_views.py index ae5a83a..0fa1208 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -5,10 +5,10 @@ except ImportError: from mock import patch import django -from django.test import RequestFactory, TestCase, override_settings +from django.test import RequestFactory, override_settings from allauth_cas.exceptions import CASAuthenticationError -from allauth_cas.test.testcases import CASViewTestCase +from allauth_cas.test.testcases import CASTestCase, CASViewTestCase from allauth_cas.views import CASView from .example.views import ExampleCASAdapter @@ -19,11 +19,12 @@ else: from django.core.urlresolvers import reverse -class CASAdapterTests(TestCase): +class CASAdapterTests(CASTestCase): def setUp(self): factory = RequestFactory() self.request = factory.get('/path/') + self.request.session = {} self.adapter = ExampleCASAdapter(self.request) def test_get_service_url(self): @@ -61,6 +62,23 @@ class CASAdapterTests(TestCase): }) self.assertEqual(expected, callback_url) + def test_renew(self): + """ + From an anonymous request, renew is False to let using the single + sign-on. + """ + self.assertFalse(self.adapter.renew) + + def test_renew_authenticated(self): + """ + If user has been authenticated to the application through CAS, and + tries to reauthenticate, renew is set to True to opt-out the single + sign-on. + """ + r = self.client_cas_login(self.client) + adapter = ExampleCASAdapter(r.wsgi_request) + self.assertTrue(adapter.renew) + class CASViewTests(CASViewTestCase): @@ -70,7 +88,8 @@ class CASViewTests(CASViewTestCase): def setUp(self): factory = RequestFactory() - self.request = factory.get('path') + self.request = factory.get('/path/') + self.request.session = {} self.cas_view = self.BasicCASView.adapter_view(ExampleCASAdapter)