Renew on already authenticated + Misc tests
Renew paramater: - By default, CAS client will use renew parameter if user already authenticates via a CAS server. If always False, he can't reauthenticate on a CAS server due to the single sign-on CAS feature (except if he logouts of CAS on his own). Tests: - patch_cas_reponse now returns a correct CAS client taking into account the version attribute of the CAS adapter. - Some moves happens between testcases et al. - Delete old and now unused fake CAS client classes.
This commit is contained in:
parent
b1165d39af
commit
049cf22b42
6 changed files with 147 additions and 77 deletions
|
@ -4,6 +4,7 @@ try:
|
|||
except ImportError:
|
||||
from mock import patch
|
||||
|
||||
import django
|
||||
from django.conf import settings
|
||||
from django.test import TestCase
|
||||
|
||||
|
@ -11,9 +12,35 @@ import cas
|
|||
|
||||
from allauth_cas import CAS_PROVIDER_SESSION_KEY
|
||||
|
||||
if django.VERSION >= (1, 10):
|
||||
from django.urls import reverse
|
||||
else:
|
||||
from django.core.urlresolvers import reverse
|
||||
|
||||
|
||||
class CASTestCase(TestCase):
|
||||
|
||||
def client_cas_login(
|
||||
self,
|
||||
client, provider_id='theid',
|
||||
username=None, attributes={}):
|
||||
"""
|
||||
Authenticate client through provider_id.
|
||||
|
||||
Returns the response of the callback view.
|
||||
|
||||
username and attributes control the CAS server response when ticket is
|
||||
checked.
|
||||
"""
|
||||
self.patch_cas_response(
|
||||
valid_ticket='__all__',
|
||||
username=username, attributes=attributes,
|
||||
)
|
||||
callback_url = reverse('{id}_callback'.format(id=provider_id))
|
||||
r = client.get(callback_url, {'ticket': 'fake-ticket'})
|
||||
self.patch_cas_response_stop()
|
||||
return r
|
||||
|
||||
def patch_cas_response(
|
||||
self,
|
||||
valid_ticket,
|
||||
|
@ -41,20 +68,39 @@ class CASTestCase(TestCase):
|
|||
ticket retrieved from GET parameter on request on the callback view).
|
||||
"""
|
||||
if hasattr(self, '_patch_cas_client'):
|
||||
self.patch_cas_client_stop()
|
||||
self.patch_cas_response_stop()
|
||||
|
||||
class MockCASClient(cas.CASClientV2):
|
||||
class MockCASClient(object):
|
||||
_username = username
|
||||
|
||||
def __init__(self_client, *args, **kwargs):
|
||||
kwargs.pop('version')
|
||||
super(MockCASClient, self_client).__init__(*args, **kwargs)
|
||||
def __new__(self_client, *args, **kwargs):
|
||||
version = kwargs.pop('version')
|
||||
if version in (1, '1'):
|
||||
client_class = cas.CASClientV1
|
||||
elif version in (2, '2'):
|
||||
client_class = cas.CASClientV2
|
||||
elif version in (3, '3'):
|
||||
client_class = cas.CASClientV3
|
||||
elif version == 'CAS_2_SAML_1_0':
|
||||
client_class = cas.CASClientWithSAMLV1
|
||||
else:
|
||||
raise ValueError('Unsupported CAS_VERSION %r' % version)
|
||||
|
||||
def verify_ticket(self_client, ticket):
|
||||
if valid_ticket == '__all__' or ticket == valid_ticket:
|
||||
username = self_client._username or 'username'
|
||||
return username, attributes, None
|
||||
return None, {}, None
|
||||
client_class._username = self_client._username
|
||||
|
||||
def verify_ticket(self, ticket):
|
||||
if valid_ticket == '__all__' or ticket == valid_ticket:
|
||||
username = self._username or 'username'
|
||||
return username, attributes, None
|
||||
return None, {}, None
|
||||
|
||||
patcher = patch.object(
|
||||
client_class, 'verify_ticket',
|
||||
new=verify_ticket,
|
||||
)
|
||||
patcher.start()
|
||||
|
||||
return client_class(*args, **kwargs)
|
||||
|
||||
self._patch_cas_client = patch(
|
||||
'allauth_cas.views.cas.CASClient',
|
||||
|
@ -62,13 +108,13 @@ class CASTestCase(TestCase):
|
|||
)
|
||||
self._patch_cas_client.start()
|
||||
|
||||
def patch_cas_client_stop(self):
|
||||
def patch_cas_response_stop(self):
|
||||
self._patch_cas_client.stop()
|
||||
del self._patch_cas_client
|
||||
|
||||
def tearDown(self):
|
||||
if hasattr(self, '_patch_cas_client'):
|
||||
self.patch_cas_client_stop()
|
||||
self.patch_cas_response_stop()
|
||||
|
||||
|
||||
class CASViewTestCase(CASTestCase):
|
||||
|
|
|
@ -28,12 +28,22 @@ class AuthAction(object):
|
|||
|
||||
|
||||
class CASAdapter(object):
|
||||
# CAS client parameters
|
||||
renew = False
|
||||
|
||||
def __init__(self, request):
|
||||
self.request = request
|
||||
|
||||
@property
|
||||
def renew(self):
|
||||
"""
|
||||
If user is already authenticated on Django, he may already been
|
||||
connected to CAS, but still may want to use another CAS account.
|
||||
We set renew to True in this case, as the CAS server won't use the
|
||||
single sign-on.
|
||||
To specifically check, if the current user has used a CAS server,
|
||||
we check if the CAS session key is set.
|
||||
"""
|
||||
return CAS_PROVIDER_SESSION_KEY in self.request.session
|
||||
|
||||
def get_provider(self):
|
||||
"""
|
||||
Returns a provider instance for the current request.
|
||||
|
|
|
@ -1,37 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
import cas
|
||||
|
||||
|
||||
class MockCASClient(cas.CASClientV2):
|
||||
"""
|
||||
Base class to mock cas.CASClient
|
||||
"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs.pop('version')
|
||||
super(MockCASClient, self).__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class VerifyCASClient(MockCASClient):
|
||||
"""
|
||||
CAS client which verifies ticket is '123456'.
|
||||
"""
|
||||
def verify_ticket(self, ticket):
|
||||
if ticket == '123456':
|
||||
return 'username', {}, None
|
||||
return None, {}, None
|
||||
|
||||
|
||||
class AcceptCASClient(MockCASClient):
|
||||
"""
|
||||
CAS client which accepts all tickets.
|
||||
"""
|
||||
def verify_ticket(self, ticket):
|
||||
return 'username', {}, None
|
||||
|
||||
|
||||
class RejectCASClient(MockCASClient):
|
||||
"""
|
||||
CAS client which rejects all tickets.
|
||||
"""
|
||||
def verify_ticket(self, ticket):
|
||||
return None, {}, None
|
|
@ -1,33 +1,16 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
try:
|
||||
from unittest.mock import patch
|
||||
except ImportError:
|
||||
from mock import patch
|
||||
|
||||
from django.contrib import messages
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.contrib.messages.api import get_messages
|
||||
from django.contrib.messages.storage.base import Message
|
||||
from django.test import TestCase, override_settings
|
||||
from django.test import override_settings
|
||||
|
||||
from .cas_clients import AcceptCASClient
|
||||
from allauth_cas.test.testcases import CASTestCase
|
||||
|
||||
User = get_user_model()
|
||||
|
||||
|
||||
@patch('allauth_cas.views.cas.CASClient', AcceptCASClient)
|
||||
def client_cas_login(client):
|
||||
"""
|
||||
Sign in client through the example CAS provider.
|
||||
Returns the response of callbacK view.
|
||||
"""
|
||||
r = client.get('/accounts/theid/login/callback/', {
|
||||
'ticket': 'fake-ticket',
|
||||
})
|
||||
return r
|
||||
|
||||
|
||||
class LogoutFlowTests(TestCase):
|
||||
class LogoutFlowTests(CASTestCase):
|
||||
expected_msg_str = (
|
||||
"To logout of CAS, please close your browser, or visit this "
|
||||
"<a href=\"/accounts/theid/logout/?next=%2Faccounts%2Flogout%2F\">"
|
||||
|
@ -35,7 +18,7 @@ class LogoutFlowTests(TestCase):
|
|||
)
|
||||
|
||||
def setUp(self):
|
||||
client_cas_login(self.client)
|
||||
self.client_cas_login(self.client)
|
||||
|
||||
def assertCASLogoutNotInMessages(self, response):
|
||||
r_messages = get_messages(response.wsgi_request)
|
||||
|
|
|
@ -1,11 +1,60 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
from django.test import Client
|
||||
from django.test import Client, RequestFactory
|
||||
|
||||
from allauth_cas.test.testcases import CASViewTestCase
|
||||
from allauth_cas.views import CASView
|
||||
|
||||
from .example.views import ExampleCASAdapter
|
||||
|
||||
|
||||
class CASTestCaseTests(CASViewTestCase):
|
||||
|
||||
def test_patch_cas_response_client_version(self):
|
||||
"""
|
||||
python-cas uses multiple client classes depending on the CAS server
|
||||
version.
|
||||
|
||||
patch_cas_response patch must also returns the correct class.
|
||||
|
||||
"""
|
||||
valid_versions = [
|
||||
1, '1',
|
||||
2, '2',
|
||||
3, '3',
|
||||
'CAS_2_SAML_1_0',
|
||||
]
|
||||
invalid_versions = [
|
||||
'not_supported',
|
||||
]
|
||||
|
||||
factory = RequestFactory()
|
||||
request = factory.get('/path/')
|
||||
request.session = {}
|
||||
|
||||
for _version in valid_versions + invalid_versions:
|
||||
class BasicCASAdapter(ExampleCASAdapter):
|
||||
version = _version
|
||||
|
||||
class BasicCASView(CASView):
|
||||
def dispatch(self, request, *args, **kwargs):
|
||||
return self.get_client(request)
|
||||
|
||||
view = BasicCASView.adapter_view(BasicCASAdapter)
|
||||
|
||||
if _version in valid_versions:
|
||||
raw_client = view(request)
|
||||
|
||||
self.patch_cas_response(valid_ticket='__all__')
|
||||
mocked_client = view(request)
|
||||
|
||||
self.assertEqual(type(raw_client), type(mocked_client))
|
||||
else:
|
||||
# This is a sanity check.
|
||||
self.assertRaises(ValueError, view, request)
|
||||
|
||||
self.patch_cas_response(valid_ticket='__all__')
|
||||
self.assertRaises(ValueError, view, request)
|
||||
|
||||
def test_patch_cas_response_verify_success(self):
|
||||
self.patch_cas_response(valid_ticket='123456')
|
||||
r = self.client.get('/accounts/theid/login/callback/', {
|
||||
|
|
|
@ -5,10 +5,10 @@ except ImportError:
|
|||
from mock import patch
|
||||
|
||||
import django
|
||||
from django.test import RequestFactory, TestCase, override_settings
|
||||
from django.test import RequestFactory, override_settings
|
||||
|
||||
from allauth_cas.exceptions import CASAuthenticationError
|
||||
from allauth_cas.test.testcases import CASViewTestCase
|
||||
from allauth_cas.test.testcases import CASTestCase, CASViewTestCase
|
||||
from allauth_cas.views import CASView
|
||||
|
||||
from .example.views import ExampleCASAdapter
|
||||
|
@ -19,11 +19,12 @@ else:
|
|||
from django.core.urlresolvers import reverse
|
||||
|
||||
|
||||
class CASAdapterTests(TestCase):
|
||||
class CASAdapterTests(CASTestCase):
|
||||
|
||||
def setUp(self):
|
||||
factory = RequestFactory()
|
||||
self.request = factory.get('/path/')
|
||||
self.request.session = {}
|
||||
self.adapter = ExampleCASAdapter(self.request)
|
||||
|
||||
def test_get_service_url(self):
|
||||
|
@ -61,6 +62,23 @@ class CASAdapterTests(TestCase):
|
|||
})
|
||||
self.assertEqual(expected, callback_url)
|
||||
|
||||
def test_renew(self):
|
||||
"""
|
||||
From an anonymous request, renew is False to let using the single
|
||||
sign-on.
|
||||
"""
|
||||
self.assertFalse(self.adapter.renew)
|
||||
|
||||
def test_renew_authenticated(self):
|
||||
"""
|
||||
If user has been authenticated to the application through CAS, and
|
||||
tries to reauthenticate, renew is set to True to opt-out the single
|
||||
sign-on.
|
||||
"""
|
||||
r = self.client_cas_login(self.client)
|
||||
adapter = ExampleCASAdapter(r.wsgi_request)
|
||||
self.assertTrue(adapter.renew)
|
||||
|
||||
|
||||
class CASViewTests(CASViewTestCase):
|
||||
|
||||
|
@ -70,7 +88,8 @@ class CASViewTests(CASViewTestCase):
|
|||
|
||||
def setUp(self):
|
||||
factory = RequestFactory()
|
||||
self.request = factory.get('path')
|
||||
self.request = factory.get('/path/')
|
||||
self.request.session = {}
|
||||
|
||||
self.cas_view = self.BasicCASView.adapter_view(ExampleCASAdapter)
|
||||
|
||||
|
|
Loading…
Reference in a new issue