diff --git a/authens/backends.py b/authens/backends.py index d5d1b74..fcfd57a 100644 --- a/authens/backends.py +++ b/authens/backends.py @@ -57,12 +57,12 @@ class ENSCASBackend: if cas_login is None: # Authentication failed return None - cas_login = self.clean_cas_login(cas_login) - year = get_entrance_year(attributes) + if request: request.session["CASCONNECTED"] = True - return self._get_or_create(cas_login, year) + + return self._get_or_create(cas_login, attributes) def clean_cas_login(self, cas_login): return cas_login.strip().lower() @@ -97,7 +97,7 @@ class ENSCASBackend: i += 1 return radical + str(i) - def _get_or_create(self, cas_login, entrance_year): + def _get_or_create(self, cas_login, attributes): """Handles account retrieval, creation and invalidation as described above. - If no CAS account exists, create one; @@ -106,6 +106,9 @@ class ENSCASBackend: - If a matching CAS account exists, retrieve it. """ + entrance_year = get_entrance_year(attributes) + email = attributes.get("email", None) + with transaction.atomic(): try: user = UserModel.objects.get(cas_account__cas_login=cas_login) @@ -117,7 +120,7 @@ class ENSCASBackend: if user is None: username = self.get_free_username(cas_login) - user = UserModel.objects.create_user(username=username) + user = UserModel.objects.create_user(username=username, email=email) CASAccount.objects.create( user=user, entrance_year=entrance_year, cas_login=cas_login ) diff --git a/authens/tests/test_backend.py b/authens/tests/test_backend.py index 2cf1f39..6c77c6d 100644 --- a/authens/tests/test_backend.py +++ b/authens/tests/test_backend.py @@ -24,6 +24,15 @@ class TestCASBackend(TestCase): self.assertFalse(UserModel.objects.filter(username=username).exists()) UserModel.objects.create(username=username) + def test_email(self): + backend = ENSCASBackend() + attributes = { + "email": "toto@example.com", + "homeDirectory": "/users/19/info/toto", + } + user = backend._get_or_create("toto", attributes) + self.assertEqual(user.email, "toto@example.com") + @mock.patch("authens.backends.get_cas_client") def test_cas_user_creation(self, mock_cas_client): # Make `get_cas_client` return a dummy CAS client for testing purpose.