Connect a CAS Account for an authenticated works...

...thanks to using SocialLogin.stash/unstash_state.

Also:
- Some tests are updated to get the stash_state. Requests are denied if
client session doesn't go through stash_state (called in LoginCASView).
- Fix an assert statement in a test.
This commit is contained in:
Aurélien Delobelle 2017-09-16 01:58:09 +02:00
parent 819f50e86d
commit bb8a3b16bf
5 changed files with 22 additions and 13 deletions

View file

@ -32,6 +32,7 @@ class CASTestCase(TestCase):
username and attributes control the CAS server response when ticket is
checked.
"""
client.get('/accounts/theid/login/')
self.patch_cas_response(
valid_ticket='__all__',
username=username, attributes=attributes,

View file

@ -9,6 +9,7 @@ from allauth.socialaccount import providers
from allauth.socialaccount.helpers import (
complete_social_login, render_authentication_error,
)
from allauth.socialaccount.models import SocialLogin
import cas
@ -133,6 +134,7 @@ class CASView(object):
# Setup and store adapter as view attribute.
self.adapter = adapter(request)
self.provider = self.adapter.get_provider()
try:
return self.dispatch(request, *args, **kwargs)
@ -145,8 +147,7 @@ class CASView(object):
"""
Returns the CAS client to interact with the CAS server.
"""
provider = self.adapter.get_provider()
auth_params = provider.get_auth_params(request, action)
auth_params = self.provider.get_auth_params(request, action)
service_url = self.adapter.get_service_url(request)
@ -164,10 +165,7 @@ class CASView(object):
"""
Returns an HTTP response in case an authentication failure happens.
"""
return render_authentication_error(
self.request,
self.adapter.provider_id,
)
return render_authentication_error(self.request, self.provider.id)
class CASLoginView(CASView):
@ -177,6 +175,7 @@ class CASLoginView(CASView):
Redirects to the CAS server login page.
"""
action = request.GET.get('action', AuthAction.AUTHENTICATE)
SocialLogin.stash_state(request)
client = self.get_client(request, action=action)
return HttpResponseRedirect(client.get_login_url())
@ -192,7 +191,6 @@ class CASCallbackView(CASView):
here. If ticket is valid, CAS server may also return extra attributes
about user.
"""
provider = self.adapter.get_provider()
client = self.get_client(request)
# CAS server should let a ticket.
@ -216,10 +214,11 @@ class CASCallbackView(CASView):
# The CAS provider in use is stored to propose to the user to
# disconnect from the latter when he logouts.
request.session[CAS_PROVIDER_SESSION_KEY] = provider.id
request.session[CAS_PROVIDER_SESSION_KEY] = self.provider.id
# Finish the login flow
login = self.adapter.complete_login(request, response)
login.state = SocialLogin.unstash_state(request)
return complete_social_login(request, login)

View file

@ -3,7 +3,7 @@ from django.contrib import messages
from django.contrib.auth import get_user_model
from django.contrib.messages.api import get_messages
from django.contrib.messages.storage.base import Message
from django.test import override_settings
from django.test import Client, override_settings
from allauth_cas.test.testcases import CASTestCase
@ -69,7 +69,8 @@ class LogoutFlowTests(CASTestCase):
The CAS logout message doesn't appear with other login methods.
"""
User.objects.create_user('user', '', 'user')
self.client.login(username='user', password='user')
client = Client()
client.login(username='user', password='user')
r = self.client.post('/accounts/logout/')
r = client.post('/accounts/logout/')
self.assertCASLogoutNotInMessages(r)

View file

@ -9,6 +9,9 @@ from .example.views import ExampleCASAdapter
class CASTestCaseTests(CASViewTestCase):
def setUp(self):
self.client.get('/accounts/theid/login/')
def test_patch_cas_response_client_version(self):
"""
python-cas uses multiple client classes depending on the CAS server
@ -86,6 +89,7 @@ class CASTestCaseTests(CASViewTestCase):
def test_patch_cas_reponse_multiple(self):
self.patch_cas_response(valid_ticket='__all__')
client_0 = Client()
client_0.get('/accounts/theid/login/')
r_0 = client_0.get('/accounts/theid/login/callback/', {
'ticket': '000000',
})
@ -93,6 +97,7 @@ class CASTestCaseTests(CASViewTestCase):
self.patch_cas_response(valid_ticket=None)
client_1 = Client()
client_1.get('/accounts/theid/login/')
r_1 = client_1.get('/accounts/theid/login/callback/', {
'ticket': '111111',
})

View file

@ -105,7 +105,7 @@ class CASViewTests(CASViewTestCase):
self.assertIsInstance(view, CASView)
self.assertEqual(view.request, view.request)
self.assertEqual(view.request, self.request)
self.assertTupleEqual(view.args, ('arg1', 'arg2'))
self.assertDictEqual(view.kwargs, {
'kwarg1': 'kwarg1',
@ -187,6 +187,9 @@ class CASLoginViewTests(CASViewTestCase):
class CASCallbackViewTests(CASViewTestCase):
def setUp(self):
self.client.get('/accounts/theid/login/')
def test_reverse(self):
"""
Callback view name is "{provider_id}_callback".
@ -195,7 +198,7 @@ class CASCallbackViewTests(CASViewTestCase):
self.assertEqual('/accounts/theid/login/callback/', url)
def test_ticket_valid(self):
"""
"""p(
If ticket is valid, the user is logged in.
"""
self.patch_cas_response(username='username', valid_ticket='123456')