135 lines
3.9 KiB
Python
135 lines
3.9 KiB
Python
from django.test import Client, RequestFactory
|
|
|
|
from allauth_cas.test.testcases import CASViewTestCase
|
|
from allauth_cas.views import CASView
|
|
|
|
from .example.views import ExampleCASAdapter
|
|
|
|
|
|
class CASTestCaseTests(CASViewTestCase):
|
|
def setUp(self):
|
|
self.client.get("/accounts/theid/login/")
|
|
|
|
def test_patch_cas_response_client_version(self):
|
|
"""
|
|
python-cas uses multiple client classes depending on the CAS server
|
|
version.
|
|
|
|
patch_cas_response patch must also returns the correct class.
|
|
|
|
"""
|
|
valid_versions = [
|
|
1,
|
|
"1",
|
|
2,
|
|
"2",
|
|
3,
|
|
"3",
|
|
"CAS_2_SAML_1_0",
|
|
]
|
|
invalid_versions = [
|
|
"not_supported",
|
|
]
|
|
|
|
factory = RequestFactory()
|
|
request = factory.get("/path/")
|
|
request.session = {}
|
|
|
|
for _version in valid_versions + invalid_versions:
|
|
|
|
class BasicCASAdapter(ExampleCASAdapter):
|
|
version = _version
|
|
|
|
class BasicCASView(CASView):
|
|
def dispatch(self, request, *args, **kwargs):
|
|
return self.get_client(request)
|
|
|
|
view = BasicCASView.adapter_view(BasicCASAdapter)
|
|
|
|
if _version in valid_versions:
|
|
raw_client = view(request)
|
|
|
|
self.patch_cas_response(valid_ticket="__all__")
|
|
mocked_client = view(request)
|
|
|
|
self.assertEqual(type(raw_client), type(mocked_client))
|
|
else:
|
|
# This is a sanity check.
|
|
self.assertRaises(ValueError, view, request)
|
|
|
|
self.patch_cas_response(valid_ticket="__all__")
|
|
self.assertRaises(ValueError, view, request)
|
|
|
|
def test_patch_cas_response_verify_success(self):
|
|
self.patch_cas_response(valid_ticket="123456")
|
|
r = self.client.get(
|
|
"/accounts/theid/login/callback/",
|
|
{
|
|
"ticket": "123456",
|
|
},
|
|
)
|
|
self.assertLoginSuccess(r)
|
|
|
|
def test_patch_cas_response_verify_failure(self):
|
|
self.patch_cas_response(valid_ticket="123456")
|
|
r = self.client.get(
|
|
"/accounts/theid/login/callback/",
|
|
{
|
|
"ticket": "000000",
|
|
},
|
|
)
|
|
self.assertLoginFailure(r)
|
|
|
|
def test_patch_cas_response_accept(self):
|
|
self.patch_cas_response(valid_ticket="__all__")
|
|
r = self.client.get(
|
|
"/accounts/theid/login/callback/",
|
|
{
|
|
"ticket": "000000",
|
|
},
|
|
)
|
|
self.assertLoginSuccess(r)
|
|
|
|
def test_patch_cas_response_reject(self):
|
|
self.patch_cas_response(valid_ticket=None)
|
|
r = self.client.get(
|
|
"/accounts/theid/login/callback/",
|
|
{
|
|
"ticket": "000000",
|
|
},
|
|
)
|
|
self.assertLoginFailure(r)
|
|
|
|
def test_patch_cas_reponse_multiple(self):
|
|
self.patch_cas_response(valid_ticket="__all__")
|
|
client_0 = Client()
|
|
client_0.get("/accounts/theid/login/")
|
|
r_0 = client_0.get(
|
|
"/accounts/theid/login/callback/",
|
|
{
|
|
"ticket": "000000",
|
|
},
|
|
)
|
|
self.assertLoginSuccess(r_0)
|
|
|
|
self.patch_cas_response(valid_ticket=None)
|
|
client_1 = Client()
|
|
client_1.get("/accounts/theid/login/")
|
|
r_1 = client_1.get(
|
|
"/accounts/theid/login/callback/",
|
|
{
|
|
"ticket": "111111",
|
|
},
|
|
)
|
|
self.assertLoginFailure(r_1)
|
|
|
|
def test_assertLoginSuccess(self):
|
|
self.patch_cas_response(valid_ticket="__all__")
|
|
r = self.client.get(
|
|
"/accounts/theid/login/callback/",
|
|
{
|
|
"ticket": "000000",
|
|
"next": "/path/",
|
|
},
|
|
)
|
|
self.assertLoginSuccess(r, redirect_to="/path/")
|