Connect a CAS Account for an authenticated works...

...thanks to using SocialLogin.stash/unstash_state.

Also:
- Some tests are updated to get the stash_state. Requests are denied if
client session doesn't go through stash_state (called in LoginCASView).
- Fix an assert statement in a test.
This commit is contained in:
Aurélien Delobelle 2017-09-16 01:58:09 +02:00
parent 819f50e86d
commit bb8a3b16bf
5 changed files with 22 additions and 13 deletions

View file

@ -32,6 +32,7 @@ class CASTestCase(TestCase):
username and attributes control the CAS server response when ticket is username and attributes control the CAS server response when ticket is
checked. checked.
""" """
client.get('/accounts/theid/login/')
self.patch_cas_response( self.patch_cas_response(
valid_ticket='__all__', valid_ticket='__all__',
username=username, attributes=attributes, username=username, attributes=attributes,

View file

@ -9,6 +9,7 @@ from allauth.socialaccount import providers
from allauth.socialaccount.helpers import ( from allauth.socialaccount.helpers import (
complete_social_login, render_authentication_error, complete_social_login, render_authentication_error,
) )
from allauth.socialaccount.models import SocialLogin
import cas import cas
@ -133,6 +134,7 @@ class CASView(object):
# Setup and store adapter as view attribute. # Setup and store adapter as view attribute.
self.adapter = adapter(request) self.adapter = adapter(request)
self.provider = self.adapter.get_provider()
try: try:
return self.dispatch(request, *args, **kwargs) return self.dispatch(request, *args, **kwargs)
@ -145,8 +147,7 @@ class CASView(object):
""" """
Returns the CAS client to interact with the CAS server. Returns the CAS client to interact with the CAS server.
""" """
provider = self.adapter.get_provider() auth_params = self.provider.get_auth_params(request, action)
auth_params = provider.get_auth_params(request, action)
service_url = self.adapter.get_service_url(request) service_url = self.adapter.get_service_url(request)
@ -164,10 +165,7 @@ class CASView(object):
""" """
Returns an HTTP response in case an authentication failure happens. Returns an HTTP response in case an authentication failure happens.
""" """
return render_authentication_error( return render_authentication_error(self.request, self.provider.id)
self.request,
self.adapter.provider_id,
)
class CASLoginView(CASView): class CASLoginView(CASView):
@ -177,6 +175,7 @@ class CASLoginView(CASView):
Redirects to the CAS server login page. Redirects to the CAS server login page.
""" """
action = request.GET.get('action', AuthAction.AUTHENTICATE) action = request.GET.get('action', AuthAction.AUTHENTICATE)
SocialLogin.stash_state(request)
client = self.get_client(request, action=action) client = self.get_client(request, action=action)
return HttpResponseRedirect(client.get_login_url()) return HttpResponseRedirect(client.get_login_url())
@ -192,7 +191,6 @@ class CASCallbackView(CASView):
here. If ticket is valid, CAS server may also return extra attributes here. If ticket is valid, CAS server may also return extra attributes
about user. about user.
""" """
provider = self.adapter.get_provider()
client = self.get_client(request) client = self.get_client(request)
# CAS server should let a ticket. # CAS server should let a ticket.
@ -216,10 +214,11 @@ class CASCallbackView(CASView):
# The CAS provider in use is stored to propose to the user to # The CAS provider in use is stored to propose to the user to
# disconnect from the latter when he logouts. # disconnect from the latter when he logouts.
request.session[CAS_PROVIDER_SESSION_KEY] = provider.id request.session[CAS_PROVIDER_SESSION_KEY] = self.provider.id
# Finish the login flow # Finish the login flow
login = self.adapter.complete_login(request, response) login = self.adapter.complete_login(request, response)
login.state = SocialLogin.unstash_state(request)
return complete_social_login(request, login) return complete_social_login(request, login)

View file

@ -3,7 +3,7 @@ from django.contrib import messages
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from django.contrib.messages.api import get_messages from django.contrib.messages.api import get_messages
from django.contrib.messages.storage.base import Message from django.contrib.messages.storage.base import Message
from django.test import override_settings from django.test import Client, override_settings
from allauth_cas.test.testcases import CASTestCase from allauth_cas.test.testcases import CASTestCase
@ -69,7 +69,8 @@ class LogoutFlowTests(CASTestCase):
The CAS logout message doesn't appear with other login methods. The CAS logout message doesn't appear with other login methods.
""" """
User.objects.create_user('user', '', 'user') User.objects.create_user('user', '', 'user')
self.client.login(username='user', password='user') client = Client()
client.login(username='user', password='user')
r = self.client.post('/accounts/logout/') r = client.post('/accounts/logout/')
self.assertCASLogoutNotInMessages(r) self.assertCASLogoutNotInMessages(r)

View file

@ -9,6 +9,9 @@ from .example.views import ExampleCASAdapter
class CASTestCaseTests(CASViewTestCase): class CASTestCaseTests(CASViewTestCase):
def setUp(self):
self.client.get('/accounts/theid/login/')
def test_patch_cas_response_client_version(self): def test_patch_cas_response_client_version(self):
""" """
python-cas uses multiple client classes depending on the CAS server python-cas uses multiple client classes depending on the CAS server
@ -86,6 +89,7 @@ class CASTestCaseTests(CASViewTestCase):
def test_patch_cas_reponse_multiple(self): def test_patch_cas_reponse_multiple(self):
self.patch_cas_response(valid_ticket='__all__') self.patch_cas_response(valid_ticket='__all__')
client_0 = Client() client_0 = Client()
client_0.get('/accounts/theid/login/')
r_0 = client_0.get('/accounts/theid/login/callback/', { r_0 = client_0.get('/accounts/theid/login/callback/', {
'ticket': '000000', 'ticket': '000000',
}) })
@ -93,6 +97,7 @@ class CASTestCaseTests(CASViewTestCase):
self.patch_cas_response(valid_ticket=None) self.patch_cas_response(valid_ticket=None)
client_1 = Client() client_1 = Client()
client_1.get('/accounts/theid/login/')
r_1 = client_1.get('/accounts/theid/login/callback/', { r_1 = client_1.get('/accounts/theid/login/callback/', {
'ticket': '111111', 'ticket': '111111',
}) })

View file

@ -105,7 +105,7 @@ class CASViewTests(CASViewTestCase):
self.assertIsInstance(view, CASView) self.assertIsInstance(view, CASView)
self.assertEqual(view.request, view.request) self.assertEqual(view.request, self.request)
self.assertTupleEqual(view.args, ('arg1', 'arg2')) self.assertTupleEqual(view.args, ('arg1', 'arg2'))
self.assertDictEqual(view.kwargs, { self.assertDictEqual(view.kwargs, {
'kwarg1': 'kwarg1', 'kwarg1': 'kwarg1',
@ -187,6 +187,9 @@ class CASLoginViewTests(CASViewTestCase):
class CASCallbackViewTests(CASViewTestCase): class CASCallbackViewTests(CASViewTestCase):
def setUp(self):
self.client.get('/accounts/theid/login/')
def test_reverse(self): def test_reverse(self):
""" """
Callback view name is "{provider_id}_callback". Callback view name is "{provider_id}_callback".
@ -195,7 +198,7 @@ class CASCallbackViewTests(CASViewTestCase):
self.assertEqual('/accounts/theid/login/callback/', url) self.assertEqual('/accounts/theid/login/callback/', url)
def test_ticket_valid(self): def test_ticket_valid(self):
""" """p(
If ticket is valid, the user is logged in. If ticket is valid, the user is logged in.
""" """
self.patch_cas_response(username='username', valid_ticket='123456') self.patch_cas_response(username='username', valid_ticket='123456')