Connect a CAS Account for an authenticated works...
...thanks to using SocialLogin.stash/unstash_state. Also: - Some tests are updated to get the stash_state. Requests are denied if client session doesn't go through stash_state (called in LoginCASView). - Fix an assert statement in a test.
This commit is contained in:
parent
819f50e86d
commit
bb8a3b16bf
5 changed files with 22 additions and 13 deletions
|
@ -32,6 +32,7 @@ class CASTestCase(TestCase):
|
||||||
username and attributes control the CAS server response when ticket is
|
username and attributes control the CAS server response when ticket is
|
||||||
checked.
|
checked.
|
||||||
"""
|
"""
|
||||||
|
client.get('/accounts/theid/login/')
|
||||||
self.patch_cas_response(
|
self.patch_cas_response(
|
||||||
valid_ticket='__all__',
|
valid_ticket='__all__',
|
||||||
username=username, attributes=attributes,
|
username=username, attributes=attributes,
|
||||||
|
|
|
@ -9,6 +9,7 @@ from allauth.socialaccount import providers
|
||||||
from allauth.socialaccount.helpers import (
|
from allauth.socialaccount.helpers import (
|
||||||
complete_social_login, render_authentication_error,
|
complete_social_login, render_authentication_error,
|
||||||
)
|
)
|
||||||
|
from allauth.socialaccount.models import SocialLogin
|
||||||
|
|
||||||
import cas
|
import cas
|
||||||
|
|
||||||
|
@ -133,6 +134,7 @@ class CASView(object):
|
||||||
|
|
||||||
# Setup and store adapter as view attribute.
|
# Setup and store adapter as view attribute.
|
||||||
self.adapter = adapter(request)
|
self.adapter = adapter(request)
|
||||||
|
self.provider = self.adapter.get_provider()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return self.dispatch(request, *args, **kwargs)
|
return self.dispatch(request, *args, **kwargs)
|
||||||
|
@ -145,8 +147,7 @@ class CASView(object):
|
||||||
"""
|
"""
|
||||||
Returns the CAS client to interact with the CAS server.
|
Returns the CAS client to interact with the CAS server.
|
||||||
"""
|
"""
|
||||||
provider = self.adapter.get_provider()
|
auth_params = self.provider.get_auth_params(request, action)
|
||||||
auth_params = provider.get_auth_params(request, action)
|
|
||||||
|
|
||||||
service_url = self.adapter.get_service_url(request)
|
service_url = self.adapter.get_service_url(request)
|
||||||
|
|
||||||
|
@ -164,10 +165,7 @@ class CASView(object):
|
||||||
"""
|
"""
|
||||||
Returns an HTTP response in case an authentication failure happens.
|
Returns an HTTP response in case an authentication failure happens.
|
||||||
"""
|
"""
|
||||||
return render_authentication_error(
|
return render_authentication_error(self.request, self.provider.id)
|
||||||
self.request,
|
|
||||||
self.adapter.provider_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class CASLoginView(CASView):
|
class CASLoginView(CASView):
|
||||||
|
@ -177,6 +175,7 @@ class CASLoginView(CASView):
|
||||||
Redirects to the CAS server login page.
|
Redirects to the CAS server login page.
|
||||||
"""
|
"""
|
||||||
action = request.GET.get('action', AuthAction.AUTHENTICATE)
|
action = request.GET.get('action', AuthAction.AUTHENTICATE)
|
||||||
|
SocialLogin.stash_state(request)
|
||||||
client = self.get_client(request, action=action)
|
client = self.get_client(request, action=action)
|
||||||
return HttpResponseRedirect(client.get_login_url())
|
return HttpResponseRedirect(client.get_login_url())
|
||||||
|
|
||||||
|
@ -192,7 +191,6 @@ class CASCallbackView(CASView):
|
||||||
here. If ticket is valid, CAS server may also return extra attributes
|
here. If ticket is valid, CAS server may also return extra attributes
|
||||||
about user.
|
about user.
|
||||||
"""
|
"""
|
||||||
provider = self.adapter.get_provider()
|
|
||||||
client = self.get_client(request)
|
client = self.get_client(request)
|
||||||
|
|
||||||
# CAS server should let a ticket.
|
# CAS server should let a ticket.
|
||||||
|
@ -216,10 +214,11 @@ class CASCallbackView(CASView):
|
||||||
|
|
||||||
# The CAS provider in use is stored to propose to the user to
|
# The CAS provider in use is stored to propose to the user to
|
||||||
# disconnect from the latter when he logouts.
|
# disconnect from the latter when he logouts.
|
||||||
request.session[CAS_PROVIDER_SESSION_KEY] = provider.id
|
request.session[CAS_PROVIDER_SESSION_KEY] = self.provider.id
|
||||||
|
|
||||||
# Finish the login flow
|
# Finish the login flow
|
||||||
login = self.adapter.complete_login(request, response)
|
login = self.adapter.complete_login(request, response)
|
||||||
|
login.state = SocialLogin.unstash_state(request)
|
||||||
return complete_social_login(request, login)
|
return complete_social_login(request, login)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@ from django.contrib import messages
|
||||||
from django.contrib.auth import get_user_model
|
from django.contrib.auth import get_user_model
|
||||||
from django.contrib.messages.api import get_messages
|
from django.contrib.messages.api import get_messages
|
||||||
from django.contrib.messages.storage.base import Message
|
from django.contrib.messages.storage.base import Message
|
||||||
from django.test import override_settings
|
from django.test import Client, override_settings
|
||||||
|
|
||||||
from allauth_cas.test.testcases import CASTestCase
|
from allauth_cas.test.testcases import CASTestCase
|
||||||
|
|
||||||
|
@ -69,7 +69,8 @@ class LogoutFlowTests(CASTestCase):
|
||||||
The CAS logout message doesn't appear with other login methods.
|
The CAS logout message doesn't appear with other login methods.
|
||||||
"""
|
"""
|
||||||
User.objects.create_user('user', '', 'user')
|
User.objects.create_user('user', '', 'user')
|
||||||
self.client.login(username='user', password='user')
|
client = Client()
|
||||||
|
client.login(username='user', password='user')
|
||||||
|
|
||||||
r = self.client.post('/accounts/logout/')
|
r = client.post('/accounts/logout/')
|
||||||
self.assertCASLogoutNotInMessages(r)
|
self.assertCASLogoutNotInMessages(r)
|
||||||
|
|
|
@ -9,6 +9,9 @@ from .example.views import ExampleCASAdapter
|
||||||
|
|
||||||
class CASTestCaseTests(CASViewTestCase):
|
class CASTestCaseTests(CASViewTestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.client.get('/accounts/theid/login/')
|
||||||
|
|
||||||
def test_patch_cas_response_client_version(self):
|
def test_patch_cas_response_client_version(self):
|
||||||
"""
|
"""
|
||||||
python-cas uses multiple client classes depending on the CAS server
|
python-cas uses multiple client classes depending on the CAS server
|
||||||
|
@ -86,6 +89,7 @@ class CASTestCaseTests(CASViewTestCase):
|
||||||
def test_patch_cas_reponse_multiple(self):
|
def test_patch_cas_reponse_multiple(self):
|
||||||
self.patch_cas_response(valid_ticket='__all__')
|
self.patch_cas_response(valid_ticket='__all__')
|
||||||
client_0 = Client()
|
client_0 = Client()
|
||||||
|
client_0.get('/accounts/theid/login/')
|
||||||
r_0 = client_0.get('/accounts/theid/login/callback/', {
|
r_0 = client_0.get('/accounts/theid/login/callback/', {
|
||||||
'ticket': '000000',
|
'ticket': '000000',
|
||||||
})
|
})
|
||||||
|
@ -93,6 +97,7 @@ class CASTestCaseTests(CASViewTestCase):
|
||||||
|
|
||||||
self.patch_cas_response(valid_ticket=None)
|
self.patch_cas_response(valid_ticket=None)
|
||||||
client_1 = Client()
|
client_1 = Client()
|
||||||
|
client_1.get('/accounts/theid/login/')
|
||||||
r_1 = client_1.get('/accounts/theid/login/callback/', {
|
r_1 = client_1.get('/accounts/theid/login/callback/', {
|
||||||
'ticket': '111111',
|
'ticket': '111111',
|
||||||
})
|
})
|
||||||
|
|
|
@ -105,7 +105,7 @@ class CASViewTests(CASViewTestCase):
|
||||||
|
|
||||||
self.assertIsInstance(view, CASView)
|
self.assertIsInstance(view, CASView)
|
||||||
|
|
||||||
self.assertEqual(view.request, view.request)
|
self.assertEqual(view.request, self.request)
|
||||||
self.assertTupleEqual(view.args, ('arg1', 'arg2'))
|
self.assertTupleEqual(view.args, ('arg1', 'arg2'))
|
||||||
self.assertDictEqual(view.kwargs, {
|
self.assertDictEqual(view.kwargs, {
|
||||||
'kwarg1': 'kwarg1',
|
'kwarg1': 'kwarg1',
|
||||||
|
@ -187,6 +187,9 @@ class CASLoginViewTests(CASViewTestCase):
|
||||||
|
|
||||||
class CASCallbackViewTests(CASViewTestCase):
|
class CASCallbackViewTests(CASViewTestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.client.get('/accounts/theid/login/')
|
||||||
|
|
||||||
def test_reverse(self):
|
def test_reverse(self):
|
||||||
"""
|
"""
|
||||||
Callback view name is "{provider_id}_callback".
|
Callback view name is "{provider_id}_callback".
|
||||||
|
@ -195,7 +198,7 @@ class CASCallbackViewTests(CASViewTestCase):
|
||||||
self.assertEqual('/accounts/theid/login/callback/', url)
|
self.assertEqual('/accounts/theid/login/callback/', url)
|
||||||
|
|
||||||
def test_ticket_valid(self):
|
def test_ticket_valid(self):
|
||||||
"""
|
"""p(
|
||||||
If ticket is valid, the user is logged in.
|
If ticket is valid, the user is logged in.
|
||||||
"""
|
"""
|
||||||
self.patch_cas_response(username='username', valid_ticket='123456')
|
self.patch_cas_response(username='username', valid_ticket='123456')
|
||||||
|
|
Loading…
Reference in a new issue