add SALContext().from_params()
This commit is contained in:
parent
0dd7121469
commit
d310dad471
3 changed files with 23 additions and 13 deletions
|
@ -1,11 +1,15 @@
|
|||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class SASLParams(object):
|
||||
mechanism: str
|
||||
username: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
def __init__(self,
|
||||
mechanism: str,
|
||||
username: str="",
|
||||
password: str=""):
|
||||
self.mechanism = mechanism.upper()
|
||||
self.username = username
|
||||
self.password = password
|
||||
|
||||
class SASLUserPass(SASLParams):
|
||||
def __init__(self, username: str, password: str):
|
||||
super().__init__("USERPASS", username, password)
|
||||
|
|
|
@ -5,6 +5,7 @@ from irctokens import build
|
|||
|
||||
from .matching import Response, Numerics, ParamAny
|
||||
from .contexts import ServerContext
|
||||
from .params import SASLParams
|
||||
|
||||
SASL_USERPASS_MECHANISMS = [
|
||||
"SCRAM-SHA-512",
|
||||
|
@ -20,10 +21,20 @@ class SASLResult(Enum):
|
|||
|
||||
class SASLError(Exception):
|
||||
pass
|
||||
class SASLUnkownMechanismError(SASLError):
|
||||
class SASLUnknownMechanismError(SASLError):
|
||||
pass
|
||||
|
||||
class SASLContext(ServerContext):
|
||||
async def from_params(self, params: SASLParams) -> SASLResult:
|
||||
if params.mechanism == "USERPASS":
|
||||
return await self.userpass(params.username, params.password)
|
||||
elif params.mechanism == "EXTERNAL":
|
||||
return await self.external()
|
||||
else:
|
||||
raise SASLUnknownMechanismError(
|
||||
"SASLParams given with unknown mechanism "
|
||||
f"{params.mechanism!r}")
|
||||
|
||||
async def external(self) -> SASLResult:
|
||||
await self.server.send(build("AUTHENTICATE", ["EXTERNAL"]))
|
||||
line = await self.server.wait_for(Response("AUTHENTICATE",
|
||||
|
@ -34,7 +45,7 @@ class SASLContext(ServerContext):
|
|||
return SASLResult.ALREADY
|
||||
elif line.command == "908":
|
||||
available = line.params[1].split(",")
|
||||
raise SASLUnkownMechanismError(
|
||||
raise SASLUnknownMechanismError(
|
||||
"Server does not support SASL EXTERNAL "
|
||||
f"(it supports {available}")
|
||||
elif line.command == "AUTHENTICATE" and line.params[0] == "+":
|
||||
|
@ -53,7 +64,7 @@ class SASLContext(ServerContext):
|
|||
if our_mech in server_mechs:
|
||||
return our_mech
|
||||
else:
|
||||
raise SASLUnkownMechanismError(
|
||||
raise SASLUnknownMechanismError(
|
||||
"No matching SASL mechanims. "
|
||||
f"(we have: {SASL_USERPASS_MECHANISMS} "
|
||||
f"server has: {server_mechs})")
|
||||
|
|
|
@ -126,12 +126,7 @@ class Server(IServer):
|
|||
await self.send(build("CAP", ["REQ", " ".join(matches)]))
|
||||
async def _cap_ack(self, emit: Emit):
|
||||
if not self.params.sasl is None and self.cap_agreed(CAP_SASL):
|
||||
if self.params.sasl.mechanism == "USERPASS":
|
||||
await SASLContext(self).userpass(
|
||||
self.params.sasl.username or "",
|
||||
self.params.sasl.password or "")
|
||||
elif self.params.sasl.mechanism == "EXTERNAL":
|
||||
await SASLContext(self).external()
|
||||
await SASLContext(self).from_params(self.params.sasl)
|
||||
|
||||
for cap in (emit.tokens or []):
|
||||
if cap in self._requested_caps:
|
||||
|
|
Loading…
Reference in a new issue