add SALContext().from_params()

This commit is contained in:
jesopo 2020-04-02 18:17:00 +01:00
parent 0dd7121469
commit d310dad471
3 changed files with 23 additions and 13 deletions

View file

@ -1,11 +1,15 @@
from typing import Optional from typing import Optional
from dataclasses import dataclass from dataclasses import dataclass
@dataclass
class SASLParams(object): class SASLParams(object):
mechanism: str def __init__(self,
username: Optional[str] = None mechanism: str,
password: Optional[str] = None username: str="",
password: str=""):
self.mechanism = mechanism.upper()
self.username = username
self.password = password
class SASLUserPass(SASLParams): class SASLUserPass(SASLParams):
def __init__(self, username: str, password: str): def __init__(self, username: str, password: str):
super().__init__("USERPASS", username, password) super().__init__("USERPASS", username, password)

View file

@ -5,6 +5,7 @@ from irctokens import build
from .matching import Response, Numerics, ParamAny from .matching import Response, Numerics, ParamAny
from .contexts import ServerContext from .contexts import ServerContext
from .params import SASLParams
SASL_USERPASS_MECHANISMS = [ SASL_USERPASS_MECHANISMS = [
"SCRAM-SHA-512", "SCRAM-SHA-512",
@ -20,10 +21,20 @@ class SASLResult(Enum):
class SASLError(Exception): class SASLError(Exception):
pass pass
class SASLUnkownMechanismError(SASLError): class SASLUnknownMechanismError(SASLError):
pass pass
class SASLContext(ServerContext): class SASLContext(ServerContext):
async def from_params(self, params: SASLParams) -> SASLResult:
if params.mechanism == "USERPASS":
return await self.userpass(params.username, params.password)
elif params.mechanism == "EXTERNAL":
return await self.external()
else:
raise SASLUnknownMechanismError(
"SASLParams given with unknown mechanism "
f"{params.mechanism!r}")
async def external(self) -> SASLResult: async def external(self) -> SASLResult:
await self.server.send(build("AUTHENTICATE", ["EXTERNAL"])) await self.server.send(build("AUTHENTICATE", ["EXTERNAL"]))
line = await self.server.wait_for(Response("AUTHENTICATE", line = await self.server.wait_for(Response("AUTHENTICATE",
@ -34,7 +45,7 @@ class SASLContext(ServerContext):
return SASLResult.ALREADY return SASLResult.ALREADY
elif line.command == "908": elif line.command == "908":
available = line.params[1].split(",") available = line.params[1].split(",")
raise SASLUnkownMechanismError( raise SASLUnknownMechanismError(
"Server does not support SASL EXTERNAL " "Server does not support SASL EXTERNAL "
f"(it supports {available}") f"(it supports {available}")
elif line.command == "AUTHENTICATE" and line.params[0] == "+": elif line.command == "AUTHENTICATE" and line.params[0] == "+":
@ -53,7 +64,7 @@ class SASLContext(ServerContext):
if our_mech in server_mechs: if our_mech in server_mechs:
return our_mech return our_mech
else: else:
raise SASLUnkownMechanismError( raise SASLUnknownMechanismError(
"No matching SASL mechanims. " "No matching SASL mechanims. "
f"(we have: {SASL_USERPASS_MECHANISMS} " f"(we have: {SASL_USERPASS_MECHANISMS} "
f"server has: {server_mechs})") f"server has: {server_mechs})")

View file

@ -126,12 +126,7 @@ class Server(IServer):
await self.send(build("CAP", ["REQ", " ".join(matches)])) await self.send(build("CAP", ["REQ", " ".join(matches)]))
async def _cap_ack(self, emit: Emit): async def _cap_ack(self, emit: Emit):
if not self.params.sasl is None and self.cap_agreed(CAP_SASL): if not self.params.sasl is None and self.cap_agreed(CAP_SASL):
if self.params.sasl.mechanism == "USERPASS": await SASLContext(self).from_params(self.params.sasl)
await SASLContext(self).userpass(
self.params.sasl.username or "",
self.params.sasl.password or "")
elif self.params.sasl.mechanism == "EXTERNAL":
await SASLContext(self).external()
for cap in (emit.tokens or []): for cap in (emit.tokens or []):
if cap in self._requested_caps: if cap in self._requested_caps: