Renew on already authenticated + Misc tests
Renew paramater: - By default, CAS client will use renew parameter if user already authenticates via a CAS server. If always False, he can't reauthenticate on a CAS server due to the single sign-on CAS feature (except if he logouts of CAS on his own). Tests: - patch_cas_reponse now returns a correct CAS client taking into account the version attribute of the CAS adapter. - Some moves happens between testcases et al. - Delete old and now unused fake CAS client classes.
This commit is contained in:
parent
b1165d39af
commit
049cf22b42
6 changed files with 147 additions and 77 deletions
|
@ -4,6 +4,7 @@ try:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from mock import patch
|
from mock import patch
|
||||||
|
|
||||||
|
import django
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
|
|
||||||
|
@ -11,9 +12,35 @@ import cas
|
||||||
|
|
||||||
from allauth_cas import CAS_PROVIDER_SESSION_KEY
|
from allauth_cas import CAS_PROVIDER_SESSION_KEY
|
||||||
|
|
||||||
|
if django.VERSION >= (1, 10):
|
||||||
|
from django.urls import reverse
|
||||||
|
else:
|
||||||
|
from django.core.urlresolvers import reverse
|
||||||
|
|
||||||
|
|
||||||
class CASTestCase(TestCase):
|
class CASTestCase(TestCase):
|
||||||
|
|
||||||
|
def client_cas_login(
|
||||||
|
self,
|
||||||
|
client, provider_id='theid',
|
||||||
|
username=None, attributes={}):
|
||||||
|
"""
|
||||||
|
Authenticate client through provider_id.
|
||||||
|
|
||||||
|
Returns the response of the callback view.
|
||||||
|
|
||||||
|
username and attributes control the CAS server response when ticket is
|
||||||
|
checked.
|
||||||
|
"""
|
||||||
|
self.patch_cas_response(
|
||||||
|
valid_ticket='__all__',
|
||||||
|
username=username, attributes=attributes,
|
||||||
|
)
|
||||||
|
callback_url = reverse('{id}_callback'.format(id=provider_id))
|
||||||
|
r = client.get(callback_url, {'ticket': 'fake-ticket'})
|
||||||
|
self.patch_cas_response_stop()
|
||||||
|
return r
|
||||||
|
|
||||||
def patch_cas_response(
|
def patch_cas_response(
|
||||||
self,
|
self,
|
||||||
valid_ticket,
|
valid_ticket,
|
||||||
|
@ -41,20 +68,39 @@ class CASTestCase(TestCase):
|
||||||
ticket retrieved from GET parameter on request on the callback view).
|
ticket retrieved from GET parameter on request on the callback view).
|
||||||
"""
|
"""
|
||||||
if hasattr(self, '_patch_cas_client'):
|
if hasattr(self, '_patch_cas_client'):
|
||||||
self.patch_cas_client_stop()
|
self.patch_cas_response_stop()
|
||||||
|
|
||||||
class MockCASClient(cas.CASClientV2):
|
class MockCASClient(object):
|
||||||
_username = username
|
_username = username
|
||||||
|
|
||||||
def __init__(self_client, *args, **kwargs):
|
def __new__(self_client, *args, **kwargs):
|
||||||
kwargs.pop('version')
|
version = kwargs.pop('version')
|
||||||
super(MockCASClient, self_client).__init__(*args, **kwargs)
|
if version in (1, '1'):
|
||||||
|
client_class = cas.CASClientV1
|
||||||
|
elif version in (2, '2'):
|
||||||
|
client_class = cas.CASClientV2
|
||||||
|
elif version in (3, '3'):
|
||||||
|
client_class = cas.CASClientV3
|
||||||
|
elif version == 'CAS_2_SAML_1_0':
|
||||||
|
client_class = cas.CASClientWithSAMLV1
|
||||||
|
else:
|
||||||
|
raise ValueError('Unsupported CAS_VERSION %r' % version)
|
||||||
|
|
||||||
def verify_ticket(self_client, ticket):
|
client_class._username = self_client._username
|
||||||
if valid_ticket == '__all__' or ticket == valid_ticket:
|
|
||||||
username = self_client._username or 'username'
|
def verify_ticket(self, ticket):
|
||||||
return username, attributes, None
|
if valid_ticket == '__all__' or ticket == valid_ticket:
|
||||||
return None, {}, None
|
username = self._username or 'username'
|
||||||
|
return username, attributes, None
|
||||||
|
return None, {}, None
|
||||||
|
|
||||||
|
patcher = patch.object(
|
||||||
|
client_class, 'verify_ticket',
|
||||||
|
new=verify_ticket,
|
||||||
|
)
|
||||||
|
patcher.start()
|
||||||
|
|
||||||
|
return client_class(*args, **kwargs)
|
||||||
|
|
||||||
self._patch_cas_client = patch(
|
self._patch_cas_client = patch(
|
||||||
'allauth_cas.views.cas.CASClient',
|
'allauth_cas.views.cas.CASClient',
|
||||||
|
@ -62,13 +108,13 @@ class CASTestCase(TestCase):
|
||||||
)
|
)
|
||||||
self._patch_cas_client.start()
|
self._patch_cas_client.start()
|
||||||
|
|
||||||
def patch_cas_client_stop(self):
|
def patch_cas_response_stop(self):
|
||||||
self._patch_cas_client.stop()
|
self._patch_cas_client.stop()
|
||||||
del self._patch_cas_client
|
del self._patch_cas_client
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
if hasattr(self, '_patch_cas_client'):
|
if hasattr(self, '_patch_cas_client'):
|
||||||
self.patch_cas_client_stop()
|
self.patch_cas_response_stop()
|
||||||
|
|
||||||
|
|
||||||
class CASViewTestCase(CASTestCase):
|
class CASViewTestCase(CASTestCase):
|
||||||
|
|
|
@ -28,12 +28,22 @@ class AuthAction(object):
|
||||||
|
|
||||||
|
|
||||||
class CASAdapter(object):
|
class CASAdapter(object):
|
||||||
# CAS client parameters
|
|
||||||
renew = False
|
|
||||||
|
|
||||||
def __init__(self, request):
|
def __init__(self, request):
|
||||||
self.request = request
|
self.request = request
|
||||||
|
|
||||||
|
@property
|
||||||
|
def renew(self):
|
||||||
|
"""
|
||||||
|
If user is already authenticated on Django, he may already been
|
||||||
|
connected to CAS, but still may want to use another CAS account.
|
||||||
|
We set renew to True in this case, as the CAS server won't use the
|
||||||
|
single sign-on.
|
||||||
|
To specifically check, if the current user has used a CAS server,
|
||||||
|
we check if the CAS session key is set.
|
||||||
|
"""
|
||||||
|
return CAS_PROVIDER_SESSION_KEY in self.request.session
|
||||||
|
|
||||||
def get_provider(self):
|
def get_provider(self):
|
||||||
"""
|
"""
|
||||||
Returns a provider instance for the current request.
|
Returns a provider instance for the current request.
|
||||||
|
|
|
@ -1,37 +0,0 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
import cas
|
|
||||||
|
|
||||||
|
|
||||||
class MockCASClient(cas.CASClientV2):
|
|
||||||
"""
|
|
||||||
Base class to mock cas.CASClient
|
|
||||||
"""
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
kwargs.pop('version')
|
|
||||||
super(MockCASClient, self).__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class VerifyCASClient(MockCASClient):
|
|
||||||
"""
|
|
||||||
CAS client which verifies ticket is '123456'.
|
|
||||||
"""
|
|
||||||
def verify_ticket(self, ticket):
|
|
||||||
if ticket == '123456':
|
|
||||||
return 'username', {}, None
|
|
||||||
return None, {}, None
|
|
||||||
|
|
||||||
|
|
||||||
class AcceptCASClient(MockCASClient):
|
|
||||||
"""
|
|
||||||
CAS client which accepts all tickets.
|
|
||||||
"""
|
|
||||||
def verify_ticket(self, ticket):
|
|
||||||
return 'username', {}, None
|
|
||||||
|
|
||||||
|
|
||||||
class RejectCASClient(MockCASClient):
|
|
||||||
"""
|
|
||||||
CAS client which rejects all tickets.
|
|
||||||
"""
|
|
||||||
def verify_ticket(self, ticket):
|
|
||||||
return None, {}, None
|
|
|
@ -1,33 +1,16 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
try:
|
|
||||||
from unittest.mock import patch
|
|
||||||
except ImportError:
|
|
||||||
from mock import patch
|
|
||||||
|
|
||||||
from django.contrib import messages
|
from django.contrib import messages
|
||||||
from django.contrib.auth import get_user_model
|
from django.contrib.auth import get_user_model
|
||||||
from django.contrib.messages.api import get_messages
|
from django.contrib.messages.api import get_messages
|
||||||
from django.contrib.messages.storage.base import Message
|
from django.contrib.messages.storage.base import Message
|
||||||
from django.test import TestCase, override_settings
|
from django.test import override_settings
|
||||||
|
|
||||||
from .cas_clients import AcceptCASClient
|
from allauth_cas.test.testcases import CASTestCase
|
||||||
|
|
||||||
User = get_user_model()
|
User = get_user_model()
|
||||||
|
|
||||||
|
|
||||||
@patch('allauth_cas.views.cas.CASClient', AcceptCASClient)
|
class LogoutFlowTests(CASTestCase):
|
||||||
def client_cas_login(client):
|
|
||||||
"""
|
|
||||||
Sign in client through the example CAS provider.
|
|
||||||
Returns the response of callbacK view.
|
|
||||||
"""
|
|
||||||
r = client.get('/accounts/theid/login/callback/', {
|
|
||||||
'ticket': 'fake-ticket',
|
|
||||||
})
|
|
||||||
return r
|
|
||||||
|
|
||||||
|
|
||||||
class LogoutFlowTests(TestCase):
|
|
||||||
expected_msg_str = (
|
expected_msg_str = (
|
||||||
"To logout of CAS, please close your browser, or visit this "
|
"To logout of CAS, please close your browser, or visit this "
|
||||||
"<a href=\"/accounts/theid/logout/?next=%2Faccounts%2Flogout%2F\">"
|
"<a href=\"/accounts/theid/logout/?next=%2Faccounts%2Flogout%2F\">"
|
||||||
|
@ -35,7 +18,7 @@ class LogoutFlowTests(TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
client_cas_login(self.client)
|
self.client_cas_login(self.client)
|
||||||
|
|
||||||
def assertCASLogoutNotInMessages(self, response):
|
def assertCASLogoutNotInMessages(self, response):
|
||||||
r_messages = get_messages(response.wsgi_request)
|
r_messages = get_messages(response.wsgi_request)
|
||||||
|
|
|
@ -1,11 +1,60 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
from django.test import Client
|
from django.test import Client, RequestFactory
|
||||||
|
|
||||||
from allauth_cas.test.testcases import CASViewTestCase
|
from allauth_cas.test.testcases import CASViewTestCase
|
||||||
|
from allauth_cas.views import CASView
|
||||||
|
|
||||||
|
from .example.views import ExampleCASAdapter
|
||||||
|
|
||||||
|
|
||||||
class CASTestCaseTests(CASViewTestCase):
|
class CASTestCaseTests(CASViewTestCase):
|
||||||
|
|
||||||
|
def test_patch_cas_response_client_version(self):
|
||||||
|
"""
|
||||||
|
python-cas uses multiple client classes depending on the CAS server
|
||||||
|
version.
|
||||||
|
|
||||||
|
patch_cas_response patch must also returns the correct class.
|
||||||
|
|
||||||
|
"""
|
||||||
|
valid_versions = [
|
||||||
|
1, '1',
|
||||||
|
2, '2',
|
||||||
|
3, '3',
|
||||||
|
'CAS_2_SAML_1_0',
|
||||||
|
]
|
||||||
|
invalid_versions = [
|
||||||
|
'not_supported',
|
||||||
|
]
|
||||||
|
|
||||||
|
factory = RequestFactory()
|
||||||
|
request = factory.get('/path/')
|
||||||
|
request.session = {}
|
||||||
|
|
||||||
|
for _version in valid_versions + invalid_versions:
|
||||||
|
class BasicCASAdapter(ExampleCASAdapter):
|
||||||
|
version = _version
|
||||||
|
|
||||||
|
class BasicCASView(CASView):
|
||||||
|
def dispatch(self, request, *args, **kwargs):
|
||||||
|
return self.get_client(request)
|
||||||
|
|
||||||
|
view = BasicCASView.adapter_view(BasicCASAdapter)
|
||||||
|
|
||||||
|
if _version in valid_versions:
|
||||||
|
raw_client = view(request)
|
||||||
|
|
||||||
|
self.patch_cas_response(valid_ticket='__all__')
|
||||||
|
mocked_client = view(request)
|
||||||
|
|
||||||
|
self.assertEqual(type(raw_client), type(mocked_client))
|
||||||
|
else:
|
||||||
|
# This is a sanity check.
|
||||||
|
self.assertRaises(ValueError, view, request)
|
||||||
|
|
||||||
|
self.patch_cas_response(valid_ticket='__all__')
|
||||||
|
self.assertRaises(ValueError, view, request)
|
||||||
|
|
||||||
def test_patch_cas_response_verify_success(self):
|
def test_patch_cas_response_verify_success(self):
|
||||||
self.patch_cas_response(valid_ticket='123456')
|
self.patch_cas_response(valid_ticket='123456')
|
||||||
r = self.client.get('/accounts/theid/login/callback/', {
|
r = self.client.get('/accounts/theid/login/callback/', {
|
||||||
|
|
|
@ -5,10 +5,10 @@ except ImportError:
|
||||||
from mock import patch
|
from mock import patch
|
||||||
|
|
||||||
import django
|
import django
|
||||||
from django.test import RequestFactory, TestCase, override_settings
|
from django.test import RequestFactory, override_settings
|
||||||
|
|
||||||
from allauth_cas.exceptions import CASAuthenticationError
|
from allauth_cas.exceptions import CASAuthenticationError
|
||||||
from allauth_cas.test.testcases import CASViewTestCase
|
from allauth_cas.test.testcases import CASTestCase, CASViewTestCase
|
||||||
from allauth_cas.views import CASView
|
from allauth_cas.views import CASView
|
||||||
|
|
||||||
from .example.views import ExampleCASAdapter
|
from .example.views import ExampleCASAdapter
|
||||||
|
@ -19,11 +19,12 @@ else:
|
||||||
from django.core.urlresolvers import reverse
|
from django.core.urlresolvers import reverse
|
||||||
|
|
||||||
|
|
||||||
class CASAdapterTests(TestCase):
|
class CASAdapterTests(CASTestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
factory = RequestFactory()
|
factory = RequestFactory()
|
||||||
self.request = factory.get('/path/')
|
self.request = factory.get('/path/')
|
||||||
|
self.request.session = {}
|
||||||
self.adapter = ExampleCASAdapter(self.request)
|
self.adapter = ExampleCASAdapter(self.request)
|
||||||
|
|
||||||
def test_get_service_url(self):
|
def test_get_service_url(self):
|
||||||
|
@ -61,6 +62,23 @@ class CASAdapterTests(TestCase):
|
||||||
})
|
})
|
||||||
self.assertEqual(expected, callback_url)
|
self.assertEqual(expected, callback_url)
|
||||||
|
|
||||||
|
def test_renew(self):
|
||||||
|
"""
|
||||||
|
From an anonymous request, renew is False to let using the single
|
||||||
|
sign-on.
|
||||||
|
"""
|
||||||
|
self.assertFalse(self.adapter.renew)
|
||||||
|
|
||||||
|
def test_renew_authenticated(self):
|
||||||
|
"""
|
||||||
|
If user has been authenticated to the application through CAS, and
|
||||||
|
tries to reauthenticate, renew is set to True to opt-out the single
|
||||||
|
sign-on.
|
||||||
|
"""
|
||||||
|
r = self.client_cas_login(self.client)
|
||||||
|
adapter = ExampleCASAdapter(r.wsgi_request)
|
||||||
|
self.assertTrue(adapter.renew)
|
||||||
|
|
||||||
|
|
||||||
class CASViewTests(CASViewTestCase):
|
class CASViewTests(CASViewTestCase):
|
||||||
|
|
||||||
|
@ -70,7 +88,8 @@ class CASViewTests(CASViewTestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
factory = RequestFactory()
|
factory = RequestFactory()
|
||||||
self.request = factory.get('path')
|
self.request = factory.get('/path/')
|
||||||
|
self.request.session = {}
|
||||||
|
|
||||||
self.cas_view = self.BasicCASView.adapter_view(ExampleCASAdapter)
|
self.cas_view = self.BasicCASView.adapter_view(ExampleCASAdapter)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue