From bb8a3b16bf55542dc525ec28d4acf5cd54e78e68 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Delobelle?= Date: Sat, 16 Sep 2017 01:58:09 +0200 Subject: [PATCH] Connect a CAS Account for an authenticated works... ...thanks to using SocialLogin.stash/unstash_state. Also: - Some tests are updated to get the stash_state. Requests are denied if client session doesn't go through stash_state (called in LoginCASView). - Fix an assert statement in a test. --- allauth_cas/test/testcases.py | 1 + allauth_cas/views.py | 15 +++++++-------- tests/test_flows.py | 7 ++++--- tests/test_testcases.py | 5 +++++ tests/test_views.py | 7 +++++-- 5 files changed, 22 insertions(+), 13 deletions(-) diff --git a/allauth_cas/test/testcases.py b/allauth_cas/test/testcases.py index d98ee11..0c67403 100644 --- a/allauth_cas/test/testcases.py +++ b/allauth_cas/test/testcases.py @@ -32,6 +32,7 @@ class CASTestCase(TestCase): username and attributes control the CAS server response when ticket is checked. """ + client.get('/accounts/theid/login/') self.patch_cas_response( valid_ticket='__all__', username=username, attributes=attributes, diff --git a/allauth_cas/views.py b/allauth_cas/views.py index 3a3bb83..e9a3685 100644 --- a/allauth_cas/views.py +++ b/allauth_cas/views.py @@ -9,6 +9,7 @@ from allauth.socialaccount import providers from allauth.socialaccount.helpers import ( complete_social_login, render_authentication_error, ) +from allauth.socialaccount.models import SocialLogin import cas @@ -133,6 +134,7 @@ class CASView(object): # Setup and store adapter as view attribute. self.adapter = adapter(request) + self.provider = self.adapter.get_provider() try: return self.dispatch(request, *args, **kwargs) @@ -145,8 +147,7 @@ class CASView(object): """ Returns the CAS client to interact with the CAS server. """ - provider = self.adapter.get_provider() - auth_params = provider.get_auth_params(request, action) + auth_params = self.provider.get_auth_params(request, action) service_url = self.adapter.get_service_url(request) @@ -164,10 +165,7 @@ class CASView(object): """ Returns an HTTP response in case an authentication failure happens. """ - return render_authentication_error( - self.request, - self.adapter.provider_id, - ) + return render_authentication_error(self.request, self.provider.id) class CASLoginView(CASView): @@ -177,6 +175,7 @@ class CASLoginView(CASView): Redirects to the CAS server login page. """ action = request.GET.get('action', AuthAction.AUTHENTICATE) + SocialLogin.stash_state(request) client = self.get_client(request, action=action) return HttpResponseRedirect(client.get_login_url()) @@ -192,7 +191,6 @@ class CASCallbackView(CASView): here. If ticket is valid, CAS server may also return extra attributes about user. """ - provider = self.adapter.get_provider() client = self.get_client(request) # CAS server should let a ticket. @@ -216,10 +214,11 @@ class CASCallbackView(CASView): # The CAS provider in use is stored to propose to the user to # disconnect from the latter when he logouts. - request.session[CAS_PROVIDER_SESSION_KEY] = provider.id + request.session[CAS_PROVIDER_SESSION_KEY] = self.provider.id # Finish the login flow login = self.adapter.complete_login(request, response) + login.state = SocialLogin.unstash_state(request) return complete_social_login(request, login) diff --git a/tests/test_flows.py b/tests/test_flows.py index 29b3c93..3c6dc4b 100644 --- a/tests/test_flows.py +++ b/tests/test_flows.py @@ -3,7 +3,7 @@ from django.contrib import messages from django.contrib.auth import get_user_model from django.contrib.messages.api import get_messages from django.contrib.messages.storage.base import Message -from django.test import override_settings +from django.test import Client, override_settings from allauth_cas.test.testcases import CASTestCase @@ -69,7 +69,8 @@ class LogoutFlowTests(CASTestCase): The CAS logout message doesn't appear with other login methods. """ User.objects.create_user('user', '', 'user') - self.client.login(username='user', password='user') + client = Client() + client.login(username='user', password='user') - r = self.client.post('/accounts/logout/') + r = client.post('/accounts/logout/') self.assertCASLogoutNotInMessages(r) diff --git a/tests/test_testcases.py b/tests/test_testcases.py index ef260bb..c7dacab 100644 --- a/tests/test_testcases.py +++ b/tests/test_testcases.py @@ -9,6 +9,9 @@ from .example.views import ExampleCASAdapter class CASTestCaseTests(CASViewTestCase): + def setUp(self): + self.client.get('/accounts/theid/login/') + def test_patch_cas_response_client_version(self): """ python-cas uses multiple client classes depending on the CAS server @@ -86,6 +89,7 @@ class CASTestCaseTests(CASViewTestCase): def test_patch_cas_reponse_multiple(self): self.patch_cas_response(valid_ticket='__all__') client_0 = Client() + client_0.get('/accounts/theid/login/') r_0 = client_0.get('/accounts/theid/login/callback/', { 'ticket': '000000', }) @@ -93,6 +97,7 @@ class CASTestCaseTests(CASViewTestCase): self.patch_cas_response(valid_ticket=None) client_1 = Client() + client_1.get('/accounts/theid/login/') r_1 = client_1.get('/accounts/theid/login/callback/', { 'ticket': '111111', }) diff --git a/tests/test_views.py b/tests/test_views.py index 0fa1208..dcc99af 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -105,7 +105,7 @@ class CASViewTests(CASViewTestCase): self.assertIsInstance(view, CASView) - self.assertEqual(view.request, view.request) + self.assertEqual(view.request, self.request) self.assertTupleEqual(view.args, ('arg1', 'arg2')) self.assertDictEqual(view.kwargs, { 'kwarg1': 'kwarg1', @@ -187,6 +187,9 @@ class CASLoginViewTests(CASViewTestCase): class CASCallbackViewTests(CASViewTestCase): + def setUp(self): + self.client.get('/accounts/theid/login/') + def test_reverse(self): """ Callback view name is "{provider_id}_callback". @@ -195,7 +198,7 @@ class CASCallbackViewTests(CASViewTestCase): self.assertEqual('/accounts/theid/login/callback/', url) def test_ticket_valid(self): - """ + """p( If ticket is valid, the user is logged in. """ self.patch_cas_response(username='username', valid_ticket='123456')