from django import forms
from django.contrib.auth import authenticate
from django.contrib.auth import forms as auth_forms
from django.contrib.auth import get_user_model
from django.utils.translation import gettext_lazy as _

UserModel = get_user_model()


class ElectionAuthForm(forms.Form):
    """Adapts Django's AuthenticationForm to allow for an election specific login."""

    login = auth_forms.UsernameField(label=_("Identifiant"), max_length=255)
    password = forms.CharField(
        label=_("Mot de passe"),
        strip=False,
        widget=forms.PasswordInput(attrs={"autocomplete": "current-password"}),
    )
    election_id = forms.IntegerField(widget=forms.HiddenInput())

    def __init__(self, request=None, *args, **kwargs):
        self.request = request
        self.user_cache = None
        super().__init__(*args, **kwargs)

    def clean(self):
        login = self.cleaned_data.get("login")
        password = self.cleaned_data.get("password")
        election_id = self.cleaned_data.get("election_id")

        if login is not None and password:
            self.user_cache = authenticate(
                self.request,
                login=login,
                password=password,
                election_id=election_id,
            )
            if self.user_cache is None:
                raise self.get_invalid_login_error()

        return self.cleaned_data

    def get_user(self):
        # Necessary API for LoginView
        return self.user_cache

    def get_invalid_login_error(self):
        return forms.ValidationError(
            _(
                "Aucun·e électeur·ice avec cet identifiant et mot de passe n'existe "
                "pour cette élection. Vérifiez que les informations rentrées sont "
                "correctes, les champs sont sensibles à la casse."
            ),
            code="invalid_login",
        )


class PwdResetForm(auth_forms.PasswordResetForm):
    """Restricts the search for password users, i.e. whose username starts with pwd__."""

    def get_users(self, email):
        users = super().get_users(email)
        return (u for u in users if u.username.split("__")[0] == "pwd")