On branche authens
This commit is contained in:
parent
c70fcefa86
commit
6a59163dea
12 changed files with 280 additions and 66 deletions
|
@ -1,62 +1,37 @@
|
|||
from authens.backends import ENSCASBackend
|
||||
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.contrib.auth.backends import ModelBackend
|
||||
|
||||
from .utils import get_cas_client
|
||||
|
||||
UserModel = get_user_model()
|
||||
|
||||
|
||||
class ENSCASBackend:
|
||||
"""ENS CAS authentication backend.
|
||||
|
||||
Implement standard CAS v3 authentication
|
||||
"""
|
||||
|
||||
def authenticate(self, request, ticket=None):
|
||||
cas_client = get_cas_client(request)
|
||||
cas_login, attributes, _ = cas_client.verify_ticket(ticket)
|
||||
|
||||
if cas_login is None:
|
||||
# Authentication failed
|
||||
return None
|
||||
cas_login = self.clean_cas_login(cas_login)
|
||||
|
||||
if request:
|
||||
request.session["CASCONNECTED"] = True
|
||||
|
||||
return self._get_or_create(cas_login, attributes)
|
||||
class CASBackend(ENSCASBackend):
|
||||
"""ENS CAS authentication backend, customized to get the full name at connection."""
|
||||
|
||||
def clean_cas_login(self, cas_login):
|
||||
return cas_login.strip().lower()
|
||||
|
||||
def _get_or_create(self, cas_login, attributes):
|
||||
"""Handles account retrieval and creation for CAS authentication.
|
||||
|
||||
- If no CAS account exists, create one;
|
||||
- If a matching CAS account exists, retrieve it.
|
||||
"""
|
||||
return f"cas__{cas_login.strip().lower()}"
|
||||
|
||||
def create_user(self, username, attributes):
|
||||
email = attributes.get("email")
|
||||
name = attributes.get("name")
|
||||
|
||||
try:
|
||||
user = UserModel.objects.get(username=cas_login)
|
||||
except UserModel.DoesNotExist:
|
||||
user = None
|
||||
return UserModel.objects.create_user(
|
||||
username=username, email=email, full_name=name
|
||||
)
|
||||
|
||||
if user is None:
|
||||
user = UserModel.objects.create_user(
|
||||
username=cas_login, email=email, full_name=name
|
||||
)
|
||||
return user
|
||||
|
||||
# Django boilerplate.
|
||||
def get_user(self, user_id):
|
||||
try:
|
||||
return UserModel.objects.get(pk=user_id)
|
||||
except UserModel.DoesNotExist:
|
||||
class PwdBackend(ModelBackend):
|
||||
"""Password authentication"""
|
||||
|
||||
def authenticate(self, request, username=None, password=None):
|
||||
if username is None or password is None:
|
||||
return None
|
||||
|
||||
return super().authenticate(
|
||||
request, username=f"pwd__{username}", password=password
|
||||
)
|
||||
|
||||
|
||||
class ElectionBackend(ModelBackend):
|
||||
"""Authentication for a specific election.
|
||||
|
@ -70,17 +45,12 @@ class ElectionBackend(ModelBackend):
|
|||
return None
|
||||
|
||||
try:
|
||||
user = UserModel.objects.get(username=f"{election_id}__{login}")
|
||||
user = UserModel.objects.get(
|
||||
username=f"{election_id}__{login}", election=election_id
|
||||
)
|
||||
except UserModel.DoesNotExist:
|
||||
return None
|
||||
|
||||
if user.check_password(password):
|
||||
return user
|
||||
return None
|
||||
|
||||
# Django boilerplate.
|
||||
def get_user(self, user_id):
|
||||
try:
|
||||
return UserModel.objects.get(pk=user_id)
|
||||
except UserModel.DoesNotExist:
|
||||
return None
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
from django import forms
|
||||
from django.contrib.auth import authenticate
|
||||
from django.contrib.auth import forms as auth_forms
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
UserModel = get_user_model()
|
||||
|
||||
|
||||
class ElectionAuthForm(forms.Form):
|
||||
"""Adapts Django's AuthenticationForm to allow for an election specific login."""
|
||||
|
@ -50,3 +53,11 @@ class ElectionAuthForm(forms.Form):
|
|||
),
|
||||
code="invalid_login",
|
||||
)
|
||||
|
||||
|
||||
class PwdResetForm(auth_forms.PasswordResetForm):
|
||||
"""Restricts the search for password users, i.e. whose username starts with pwd__."""
|
||||
|
||||
def get_users(self, email):
|
||||
users = super().get_users(email)
|
||||
return (u for u in users if u.username.split("__")[0] == "pwd")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue