try: from unittest.mock import patch except ImportError: from unittest.mock import patch import cas from django.conf import settings from django.test import TestCase from django.urls import reverse from allauth_cas import CAS_PROVIDER_SESSION_KEY class CASTestCase(TestCase): def client_cas_login( self, client, provider_id=None, username=None, attributes=None ): """ Authenticate client through provider_id. Returns the response of the callback view. username and attributes control the CAS server response when ticket is checked. """ if attributes is None: attributes = {} if provider_id is None: provider_id = "theid" client.get(reverse(f"{provider_id}_login")) self.patch_cas_response( valid_ticket="__all__", username=username, attributes=attributes, ) callback_url = reverse(f"{provider_id}_callback") r = client.get(callback_url, {"ticket": "fake-ticket"}) self.patch_cas_response_stop() return r def patch_cas_response(self, valid_ticket, username=None, attributes=None): """ Patch the CASClient class used by views of CAS providers. Arguments determines the response of verify_ticket method: - If ticket given as paramater to this method is equal to valid_ticket, its return value corresponds to a successful authentication on CAS server for user whose login is username argument (default: "username") and extra attributes (provided by the server) are attributes argument (default: {}). - If ticket doesn't match valid_ticket, the response corresponds to a reject from CAS server. Special values for valid_ticket: - If valid_ticket is '__all__', a success response is always returned. - If valid_ticket is None, a failure response is always returned. Note that valid_ticket sould be a string (which is the type of the ticket retrieved from GET parameter on request on the callback view). """ if attributes is None: attributes = {} if hasattr(self, "_patch_cas_client"): self.patch_cas_response_stop() class MockCASClient: _username = username def __new__(self_client, *args, **kwargs): version = kwargs.pop("version") if version in (1, "1"): client_class = cas.CASClientV1 elif version in (2, "2"): client_class = cas.CASClientV2 elif version in (3, "3"): client_class = cas.CASClientV3 elif version == "CAS_2_SAML_1_0": client_class = cas.CASClientWithSAMLV1 else: raise ValueError("Unsupported CAS_VERSION %r" % version) client_class._username = self_client._username def verify_ticket(self, ticket): if valid_ticket == "__all__" or ticket == valid_ticket: username = self._username or "username" return username, attributes, None return None, {}, None patcher = patch.object( client_class, "verify_ticket", new=verify_ticket, ) patcher.start() return client_class(*args, **kwargs) self._patch_cas_client = patch( "allauth_cas.views.cas.CASClient", MockCASClient, ) self._patch_cas_client.start() def patch_cas_response_stop(self): self._patch_cas_client.stop() del self._patch_cas_client def tearDown(self): if hasattr(self, "_patch_cas_client"): self.patch_cas_response_stop() class CASViewTestCase(CASTestCase): def assertLoginSuccess(self, response, redirect_to=None): """ Asserts response corresponds to a successful login. To check this, the response should redirect to redirect_to (default to /accounts/profile/, the default redirect after a successful login). Also CAS_PROVIDER_SESSION_KEY should be set in the client' session. By default, self.client is used. """ if redirect_to is None: redirect_to = settings.LOGIN_REDIRECT_URL self.assertRedirects( response, redirect_to, fetch_redirect_response=False, ) self.assertIn( CAS_PROVIDER_SESSION_KEY, response.wsgi_request.session, ) def assertLoginFailure(self, response): """ Asserts response corresponds to a failed login. """ return self.assertInHTML( "

Social Network Login Failure

", str(response.content), )