diff --git a/authens/backends.py b/authens/backends.py index f606f1d..fb164a7 100644 --- a/authens/backends.py +++ b/authens/backends.py @@ -2,7 +2,7 @@ from django.contrib.auth import get_user_model from django.db import transaction from authens.models import CASAccount, OldCASAccount -from authens.utils import get_cas_client +from authens.utils import get_cas_client, parse_entrance_year UserModel = get_user_model() @@ -11,30 +11,6 @@ class ENSCASError(Exception): pass -def get_entrance_year(attributes): - """Infer the entrance year of a CAS account holder from her home directory.""" - - # The home directory of a user is of the form /users/YEAR/DEPARTMENT/CAS_LOGIN where - # YEAR is a 2-digit number representing the entrance year of the student. We get the - # entrance year from there. - - home_dir = attributes.get("homeDirectory") - if home_dir is None: - raise ENSCASError("Entrance year not available") - - dirs = home_dir.split("/") - if len(dirs) < 3 or not dirs[2].isdecimal(): - raise ENSCASError("Invalid homeDirectory: {}".format(home_dir)) - - # Expand the 2-digit entrance year into 4 digits. - # This will break in 2080. - year = int(dirs[2]) - if year >= 80: - return 1900 + year - else: - return 2000 + year - - class ENSCASBackend: """AuthENS CAS authentication backend. @@ -106,8 +82,10 @@ class ENSCASBackend: - If a matching CAS account exists, retrieve it. """ - entrance_year = get_entrance_year(attributes) - email = attributes.get("email", None) + email = attributes.get("email") + entrance_year = parse_entrance_year(attributes.get("homeDirectory")) + if entrance_year is None: + raise ENSCASError("Entrance year not available") with transaction.atomic(): try: diff --git a/authens/conf.py b/authens/conf.py new file mode 100644 index 0000000..04b5cb0 --- /dev/null +++ b/authens/conf.py @@ -0,0 +1,2 @@ +LDAP_SERVER_URL = "ldaps://ldap.spi.ens.fr:636" +# TODO: CAS_SERVER_URL diff --git a/authens/shortcuts.py b/authens/shortcuts.py new file mode 100644 index 0000000..8a8dd97 --- /dev/null +++ b/authens/shortcuts.py @@ -0,0 +1,69 @@ +"""Helper functions to get CAS metadata and create CAS accounts.""" + +# TODO: make the python-ldap dependency optional +import ldap + +from django.conf import settings +from django.contrib.auth import get_user_model + +from authens import conf as default_conf +from authens.models import CASAccount, OldCASAccount +from authens.utils import parse_entrance_year + +User = get_user_model() + + +def _extract_ldap_info(entry, field): + dn, attrs = entry + return attrs[field][0].decode("utf-8") + + +def fetch_cas_account(cas_login): + """Issue an LDAP connection to retrieve metadata associated to a CAS account.""" + + # Don't trust the user! Only accept alphanumeric account names. + if not cas_login.isalnum(): + raise ValueError("Illegal CAS login: {}".format(cas_login)) + + ldap_url = getattr(settings, "LDAP_SERVER_URL", default_conf.LDAP_SERVER_URL) + ldap_obj = ldap.initialize(ldap_url) + res = ldap_obj.search_s( + "dc=spi,dc=ens,dc=fr", + ldap.SCOPE_SUBTREE, + "(uid={})".format(cas_login), + ["uid", "cn", "homeDirectory"], + ) + if not res: + return None + + if len(res) != 1: + raise RuntimeError("LDAP returned too many results: {}".format(res)) + (res,) = res + + assert _extract_ldap_info(res, "uid") == cas_login + return { + "cn": _extract_ldap_info(res, "cn"), + "entrance_year": parse_entrance_year(_extract_ldap_info(res, "homeDirectory")), + } + + +def register_cas_account(user: User, cas_login: str) -> CASAccount: + """Register a user as a CAS user and return the newly created CASAccount.""" + + if not cas_login: + raise ValueError("cas_login must be non-empty") + if CASAccount.objects.filter(cas_login=cas_login).exists(): + raise ValueError("A CAS account named '{}' exists already".format(cas_login)) + if CASAccount.objects.filter(user=user).exists(): + raise ValueError("User '{}' already has a CAS account".format(user)) + if OldCASAccount.objects.filter(user=user).exists(): + raise ValueError("User '{}' has an old CAS account".format(user)) + + ldap_info = fetch_cas_account(cas_login) + if ldap_info is None: + raise ValueError("There is no LDAP user for id '{}'".format(cas_login)) + + entrance_year = ldap_info["entrance_year"] + return CASAccount.objects.create( + user=user, cas_login=cas_login, entrance_year=entrance_year + ) diff --git a/authens/tests/cas_utils.py b/authens/tests/cas_utils.py index 089ffb7..64cdfc9 100644 --- a/authens/tests/cas_utils.py +++ b/authens/tests/cas_utils.py @@ -17,3 +17,27 @@ class FakeCASClient: ) } return self.cas_login, attributes, None + + +class FakeLDAPObject: + """Fake object to be used in place of the result of `ldap.initialize`. + + Always return the same user information (configured at class initialization). + """ + + def __init__(self, cas_login: str, entrance_year: int): + self.cas_login = cas_login + self.entrance_year = entrance_year + + def search_s(self, base, scope, request, *args): + if request != "(uid={})".format(self.cas_login): + raise ValueError("I don't know how to answer this request!") + + home_dir = "/users/{}/info/{}".format(self.entrance_year % 100, self.cas_login) + dn = "whatever" + attrs = { + "uid": [self.cas_login.encode("utf-8")], + "cn": ["{}'s long name".format(self.cas_login).encode("utf-8")], + "homeDirectory": [home_dir.encode("utf-8")], + } + return [(dn, attrs)] diff --git a/authens/tests/test_shortcuts.py b/authens/tests/test_shortcuts.py new file mode 100644 index 0000000..307eb42 --- /dev/null +++ b/authens/tests/test_shortcuts.py @@ -0,0 +1,50 @@ +from unittest import mock + +from django.contrib.auth import get_user_model +from django.test import TestCase + +from authens.models import CASAccount, OldCASAccount +from authens.shortcuts import register_cas_account +from authens.tests.cas_utils import FakeLDAPObject + +User = get_user_model() + + +class TestRegisterCasAccount(TestCase): + @mock.patch("authens.shortcuts.ldap.initialize") + def test_register(self, mock_ldap_obj): + mock_ldap_obj.return_value = FakeLDAPObject("johndoe", 2019) + + user = User.objects.create_user(username="whatever") + self.assertFalse(hasattr(user, "cas_account")) + + register_cas_account(user, cas_login="johndoe") + user.refresh_from_db() + self.assertTrue(hasattr(user, "cas_account")) + self.assertEqual(user.cas_account.cas_login, "johndoe") + self.assertEqual(user.cas_account.entrance_year, 2019) + + def test_cant_register_twice(self): + john = User.objects.create_user(username="whatever") + CASAccount.objects.create(user=john, cas_login="johndoe", entrance_year=2019) + + janis = User.objects.create_user(username="janisjoplin") + + # John cannot have two CAS accounts + with self.assertRaises(ValueError): + register_cas_account(john, cas_login="joplin") + + # Janis cannot steal John's account + with self.assertRaises(ValueError): + register_cas_account(janis, cas_login="johndoe") + + self.assertEqual(CASAccount.objects.count(), 1) + + def test_cant_register_old_account(self): + user = User.objects.create_user(username="whatever") + OldCASAccount.objects.create(user=user, cas_login="toto", entrance_year=2012) + + with self.assertRaises(ValueError): + register_cas_account(user, cas_login="toto") + + self.assertFalse(CASAccount.objects.exists()) diff --git a/authens/utils.py b/authens/utils.py index a2a5499..a7cbbb6 100644 --- a/authens/utils.py +++ b/authens/utils.py @@ -1,3 +1,5 @@ +"""Internal utility functions used by authens.""" + from cas import CASClient from urllib.parse import urlunparse @@ -11,3 +13,26 @@ def get_cas_client(request): ), server_url="https://cas.eleves.ens.fr/", ) + + +def parse_entrance_year(home_dir): + """Infer the entrance year of a CAS account from their home directory.""" + + # The home directory of a user is of the form /users/YEAR/DEPARTMENT/CAS_LOGIN where + # YEAR is a 2-digit number representing the entrance year of the student. We get the + # entrance year from there. + + if home_dir is None: + return None + + dirs = home_dir.split("/") + if len(dirs) < 3 or not dirs[2].isdecimal() or dirs[1] != "users": + raise ValueError("Invalid home directory: {}".format(home_dir)) + + # Expand the 2-digit entrance year into 4 digits. + # This will break in 2080. + year = int(dirs[2]) + if year >= 80: + return 1900 + year + else: + return 2000 + year