add SALContext().from_params()
This commit is contained in:
parent
0dd7121469
commit
d310dad471
3 changed files with 23 additions and 13 deletions
|
@ -1,11 +1,15 @@
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SASLParams(object):
|
class SASLParams(object):
|
||||||
mechanism: str
|
def __init__(self,
|
||||||
username: Optional[str] = None
|
mechanism: str,
|
||||||
password: Optional[str] = None
|
username: str="",
|
||||||
|
password: str=""):
|
||||||
|
self.mechanism = mechanism.upper()
|
||||||
|
self.username = username
|
||||||
|
self.password = password
|
||||||
|
|
||||||
class SASLUserPass(SASLParams):
|
class SASLUserPass(SASLParams):
|
||||||
def __init__(self, username: str, password: str):
|
def __init__(self, username: str, password: str):
|
||||||
super().__init__("USERPASS", username, password)
|
super().__init__("USERPASS", username, password)
|
||||||
|
|
|
@ -5,6 +5,7 @@ from irctokens import build
|
||||||
|
|
||||||
from .matching import Response, Numerics, ParamAny
|
from .matching import Response, Numerics, ParamAny
|
||||||
from .contexts import ServerContext
|
from .contexts import ServerContext
|
||||||
|
from .params import SASLParams
|
||||||
|
|
||||||
SASL_USERPASS_MECHANISMS = [
|
SASL_USERPASS_MECHANISMS = [
|
||||||
"SCRAM-SHA-512",
|
"SCRAM-SHA-512",
|
||||||
|
@ -20,10 +21,20 @@ class SASLResult(Enum):
|
||||||
|
|
||||||
class SASLError(Exception):
|
class SASLError(Exception):
|
||||||
pass
|
pass
|
||||||
class SASLUnkownMechanismError(SASLError):
|
class SASLUnknownMechanismError(SASLError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class SASLContext(ServerContext):
|
class SASLContext(ServerContext):
|
||||||
|
async def from_params(self, params: SASLParams) -> SASLResult:
|
||||||
|
if params.mechanism == "USERPASS":
|
||||||
|
return await self.userpass(params.username, params.password)
|
||||||
|
elif params.mechanism == "EXTERNAL":
|
||||||
|
return await self.external()
|
||||||
|
else:
|
||||||
|
raise SASLUnknownMechanismError(
|
||||||
|
"SASLParams given with unknown mechanism "
|
||||||
|
f"{params.mechanism!r}")
|
||||||
|
|
||||||
async def external(self) -> SASLResult:
|
async def external(self) -> SASLResult:
|
||||||
await self.server.send(build("AUTHENTICATE", ["EXTERNAL"]))
|
await self.server.send(build("AUTHENTICATE", ["EXTERNAL"]))
|
||||||
line = await self.server.wait_for(Response("AUTHENTICATE",
|
line = await self.server.wait_for(Response("AUTHENTICATE",
|
||||||
|
@ -34,7 +45,7 @@ class SASLContext(ServerContext):
|
||||||
return SASLResult.ALREADY
|
return SASLResult.ALREADY
|
||||||
elif line.command == "908":
|
elif line.command == "908":
|
||||||
available = line.params[1].split(",")
|
available = line.params[1].split(",")
|
||||||
raise SASLUnkownMechanismError(
|
raise SASLUnknownMechanismError(
|
||||||
"Server does not support SASL EXTERNAL "
|
"Server does not support SASL EXTERNAL "
|
||||||
f"(it supports {available}")
|
f"(it supports {available}")
|
||||||
elif line.command == "AUTHENTICATE" and line.params[0] == "+":
|
elif line.command == "AUTHENTICATE" and line.params[0] == "+":
|
||||||
|
@ -53,7 +64,7 @@ class SASLContext(ServerContext):
|
||||||
if our_mech in server_mechs:
|
if our_mech in server_mechs:
|
||||||
return our_mech
|
return our_mech
|
||||||
else:
|
else:
|
||||||
raise SASLUnkownMechanismError(
|
raise SASLUnknownMechanismError(
|
||||||
"No matching SASL mechanims. "
|
"No matching SASL mechanims. "
|
||||||
f"(we have: {SASL_USERPASS_MECHANISMS} "
|
f"(we have: {SASL_USERPASS_MECHANISMS} "
|
||||||
f"server has: {server_mechs})")
|
f"server has: {server_mechs})")
|
||||||
|
|
|
@ -126,12 +126,7 @@ class Server(IServer):
|
||||||
await self.send(build("CAP", ["REQ", " ".join(matches)]))
|
await self.send(build("CAP", ["REQ", " ".join(matches)]))
|
||||||
async def _cap_ack(self, emit: Emit):
|
async def _cap_ack(self, emit: Emit):
|
||||||
if not self.params.sasl is None and self.cap_agreed(CAP_SASL):
|
if not self.params.sasl is None and self.cap_agreed(CAP_SASL):
|
||||||
if self.params.sasl.mechanism == "USERPASS":
|
await SASLContext(self).from_params(self.params.sasl)
|
||||||
await SASLContext(self).userpass(
|
|
||||||
self.params.sasl.username or "",
|
|
||||||
self.params.sasl.password or "")
|
|
||||||
elif self.params.sasl.mechanism == "EXTERNAL":
|
|
||||||
await SASLContext(self).external()
|
|
||||||
|
|
||||||
for cap in (emit.tokens or []):
|
for cap in (emit.tokens or []):
|
||||||
if cap in self._requested_caps:
|
if cap in self._requested_caps:
|
||||||
|
|
Loading…
Reference in a new issue