diff --git a/ircrobots/params.py b/ircrobots/params.py index 8fd706c..3435f03 100644 --- a/ircrobots/params.py +++ b/ircrobots/params.py @@ -1,11 +1,15 @@ from typing import Optional from dataclasses import dataclass -@dataclass class SASLParams(object): - mechanism: str - username: Optional[str] = None - password: Optional[str] = None + def __init__(self, + mechanism: str, + username: str="", + password: str=""): + self.mechanism = mechanism.upper() + self.username = username + self.password = password + class SASLUserPass(SASLParams): def __init__(self, username: str, password: str): super().__init__("USERPASS", username, password) diff --git a/ircrobots/sasl.py b/ircrobots/sasl.py index fdd4d20..eac79f0 100644 --- a/ircrobots/sasl.py +++ b/ircrobots/sasl.py @@ -5,6 +5,7 @@ from irctokens import build from .matching import Response, Numerics, ParamAny from .contexts import ServerContext +from .params import SASLParams SASL_USERPASS_MECHANISMS = [ "SCRAM-SHA-512", @@ -20,10 +21,20 @@ class SASLResult(Enum): class SASLError(Exception): pass -class SASLUnkownMechanismError(SASLError): +class SASLUnknownMechanismError(SASLError): pass class SASLContext(ServerContext): + async def from_params(self, params: SASLParams) -> SASLResult: + if params.mechanism == "USERPASS": + return await self.userpass(params.username, params.password) + elif params.mechanism == "EXTERNAL": + return await self.external() + else: + raise SASLUnknownMechanismError( + "SASLParams given with unknown mechanism " + f"{params.mechanism!r}") + async def external(self) -> SASLResult: await self.server.send(build("AUTHENTICATE", ["EXTERNAL"])) line = await self.server.wait_for(Response("AUTHENTICATE", @@ -34,7 +45,7 @@ class SASLContext(ServerContext): return SASLResult.ALREADY elif line.command == "908": available = line.params[1].split(",") - raise SASLUnkownMechanismError( + raise SASLUnknownMechanismError( "Server does not support SASL EXTERNAL " f"(it supports {available}") elif line.command == "AUTHENTICATE" and line.params[0] == "+": @@ -53,7 +64,7 @@ class SASLContext(ServerContext): if our_mech in server_mechs: return our_mech else: - raise SASLUnkownMechanismError( + raise SASLUnknownMechanismError( "No matching SASL mechanims. " f"(we have: {SASL_USERPASS_MECHANISMS} " f"server has: {server_mechs})") diff --git a/ircrobots/server.py b/ircrobots/server.py index 080853d..3431443 100644 --- a/ircrobots/server.py +++ b/ircrobots/server.py @@ -126,12 +126,7 @@ class Server(IServer): await self.send(build("CAP", ["REQ", " ".join(matches)])) async def _cap_ack(self, emit: Emit): if not self.params.sasl is None and self.cap_agreed(CAP_SASL): - if self.params.sasl.mechanism == "USERPASS": - await SASLContext(self).userpass( - self.params.sasl.username or "", - self.params.sasl.password or "") - elif self.params.sasl.mechanism == "EXTERNAL": - await SASLContext(self).external() + await SASLContext(self).from_params(self.params.sasl) for cap in (emit.tokens or []): if cap in self._requested_caps: