move SASL logic out in to sasl.py.SASLContext
This commit is contained in:
parent
023107385e
commit
f43cb75bfa
2 changed files with 85 additions and 92 deletions
|
@ -1,6 +1,11 @@
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from base64 import b64encode
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from irctokens import build
|
||||||
|
|
||||||
|
from .matching import Response, Numerics, ParamAny
|
||||||
|
from .contexts import ServerContext
|
||||||
|
|
||||||
SASL_USERPASS_MECHANISMS = [
|
SASL_USERPASS_MECHANISMS = [
|
||||||
"SCRAM-SHA-512",
|
"SCRAM-SHA-512",
|
||||||
|
@ -19,14 +24,78 @@ class SASLError(Exception):
|
||||||
class SASLUnkownMechanismError(SASLError):
|
class SASLUnkownMechanismError(SASLError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@dataclass
|
class SASLContext(ServerContext):
|
||||||
class SASLParams(object):
|
async def external(self) -> SASLResult:
|
||||||
mechanism: str
|
await self.server.send(build("AUTHENTICATE", ["EXTERNAL"]))
|
||||||
username: Optional[str] = None
|
line = await self.server.wait_for(Response("AUTHENTICATE",
|
||||||
password: Optional[str] = None
|
[ParamAny()], errors=["904", "907", "908"]))
|
||||||
class SASLUserPass(SASLParams):
|
|
||||||
def __init__(self, username: str, password: str):
|
if line.command == "907":
|
||||||
super().__init__("USERPASS", username, password)
|
# we've done SASL already. cleanly abort
|
||||||
class SASLExternal(SASLParams):
|
return SASLResult.ALREADY
|
||||||
def __init__(self):
|
elif line.command == "908":
|
||||||
super().__init__("EXTERNAL")
|
available = line.params[1].split(",")
|
||||||
|
raise SASLUnkownMechanismError(
|
||||||
|
"Server does not support SASL EXTERNAL "
|
||||||
|
f"(it supports {available}")
|
||||||
|
elif line.command == "AUTHENTICATE" and line.params[0] == "+":
|
||||||
|
await self.server.send(build("AUTHENTICATE", ["+"]))
|
||||||
|
|
||||||
|
line = await self.server.wait_for(Numerics(["903", "904"]))
|
||||||
|
if line.command == "903":
|
||||||
|
return SASLResult.SUCCESS
|
||||||
|
return SASLResult.FAILURE
|
||||||
|
|
||||||
|
async def userpass(self, username: str, password: str) -> SASLResult:
|
||||||
|
# this will, in the future, offer SCRAM support
|
||||||
|
|
||||||
|
def _common(server_mechs) -> str:
|
||||||
|
for our_mech in SASL_USERPASS_MECHANISMS:
|
||||||
|
if our_mech in server_mechs:
|
||||||
|
return our_mech
|
||||||
|
else:
|
||||||
|
raise SASLUnkownMechanismError(
|
||||||
|
"No matching SASL mechanims. "
|
||||||
|
f"(we have: {SASL_USERPASS_MECHANISMS} "
|
||||||
|
f"server has: {server_mechs})")
|
||||||
|
|
||||||
|
if not self.server.available_caps["sasl"] is None:
|
||||||
|
# CAP v3.2 tells us what mechs it supports
|
||||||
|
available = self.server.available_caps["sasl"].split(",")
|
||||||
|
match = _common(available)
|
||||||
|
else:
|
||||||
|
# CAP v3.1 does not. pick the pick and wait for 907 to inform us of
|
||||||
|
# what mechanisms are supported
|
||||||
|
match = SASL_USERPASS_MECHANISMS[0]
|
||||||
|
|
||||||
|
await self.server.send(build("AUTHENTICATE", [match]))
|
||||||
|
line = await self.server.wait_for(Response("AUTHENTICATE",
|
||||||
|
[ParamAny()], errors=["904", "907", "908"]))
|
||||||
|
|
||||||
|
if line.command == "907":
|
||||||
|
# we've done SASL already. cleanly abort
|
||||||
|
return SASLResult.ALREADY
|
||||||
|
elif line.command == "908":
|
||||||
|
# prior to CAP v3.2 - ERR telling us which mechs are supported
|
||||||
|
available = line.params[1].split(",")
|
||||||
|
match = _common(available)
|
||||||
|
|
||||||
|
await self.server.send(build("AUTHENTICATE", [match]))
|
||||||
|
line = await self.server.wait_for(Response("AUTHENTICATE",
|
||||||
|
[ParamAny()]))
|
||||||
|
|
||||||
|
if line.command == "AUTHENTICATE" and line.params[0] == "+":
|
||||||
|
auth_text: Optional[str] = None
|
||||||
|
if match == "PLAIN":
|
||||||
|
auth_text = f"{username}\0{username}\0{password}"
|
||||||
|
|
||||||
|
if not auth_text is None:
|
||||||
|
auth_b64 = b64encode(auth_text.encode("utf8")
|
||||||
|
).decode("ascii")
|
||||||
|
await self.server.send(build("AUTHENTICATE", [auth_b64]))
|
||||||
|
|
||||||
|
line = await self.server.wait_for(Numerics(["903", "904"]))
|
||||||
|
if line.command == "903":
|
||||||
|
return SASLResult.SUCCESS
|
||||||
|
return SASLResult.FAILURE
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,6 @@ from queue import Queue
|
||||||
from typing import Awaitable, Callable, Dict, List, Optional, Set, Tuple
|
from typing import Awaitable, Callable, Dict, List, Optional, Set, Tuple
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from base64 import b64encode
|
|
||||||
|
|
||||||
from asyncio_throttle import Throttler
|
from asyncio_throttle import Throttler
|
||||||
from ircstates import Emit
|
from ircstates import Emit
|
||||||
|
@ -13,8 +12,7 @@ from irctokens import build, Line, tokenise
|
||||||
from .ircv3 import Capability, CAPS, CAP_SASL
|
from .ircv3 import Capability, CAPS, CAP_SASL
|
||||||
from .interface import ConnectionParams, IServer, PriorityLine, SendPriority
|
from .interface import ConnectionParams, IServer, PriorityLine, SendPriority
|
||||||
from .matching import BaseResponse, Response, Numerics, ParamAny, Literal
|
from .matching import BaseResponse, Response, Numerics, ParamAny, Literal
|
||||||
from .sasl import (SASLResult, SASLUnkownMechanismError,
|
from .sasl import SASLContext
|
||||||
SASL_USERPASS_MECHANISMS)
|
|
||||||
|
|
||||||
sc = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
|
sc = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
|
||||||
|
|
||||||
|
@ -132,11 +130,11 @@ class Server(IServer):
|
||||||
async def _cap_ack(self, emit: Emit):
|
async def _cap_ack(self, emit: Emit):
|
||||||
if not self.params.sasl is None and self.cap_agreed(CAP_SASL):
|
if not self.params.sasl is None and self.cap_agreed(CAP_SASL):
|
||||||
if self.params.sasl.mechanism == "USERPASS":
|
if self.params.sasl.mechanism == "USERPASS":
|
||||||
await self.sasl_userpass(
|
await SASLContext(self).userpass(
|
||||||
self.params.sasl.username or "",
|
self.params.sasl.username or "",
|
||||||
self.params.sasl.password or "")
|
self.params.sasl.password or "")
|
||||||
elif self.params.sasl.mechanism == "EXTERNAL":
|
elif self.params.sasl.mechanism == "EXTERNAL":
|
||||||
await self.sasl_external()
|
await SASLContext(self).external()
|
||||||
|
|
||||||
for cap in (emit.tokens or []):
|
for cap in (emit.tokens or []):
|
||||||
if cap in self._requested_caps:
|
if cap in self._requested_caps:
|
||||||
|
@ -144,77 +142,3 @@ class Server(IServer):
|
||||||
if not self._requested_caps:
|
if not self._requested_caps:
|
||||||
await self.send(build("CAP", ["END"]))
|
await self.send(build("CAP", ["END"]))
|
||||||
# /CAP-related
|
# /CAP-related
|
||||||
|
|
||||||
async def sasl_external(self) -> SASLResult:
|
|
||||||
await self.send(build("AUTHENTICATE", ["EXTERNAL"]))
|
|
||||||
line = await self.wait_for(Response("AUTHENTICATE", [ParamAny()],
|
|
||||||
errors=["904", "907", "908"]))
|
|
||||||
|
|
||||||
if line.command == "907":
|
|
||||||
# we've done SASL already. cleanly abort
|
|
||||||
return SASLResult.ALREADY
|
|
||||||
elif line.command == "908":
|
|
||||||
available = line.params[1].split(",")
|
|
||||||
raise SASLUnkownMechanismError(
|
|
||||||
"Server does not support SASL EXTERNAL "
|
|
||||||
f"(it supports {available}")
|
|
||||||
elif line.command == "AUTHENTICATE" and line.params[0] == "+":
|
|
||||||
await self.send(build("AUTHENTICATE", ["+"]))
|
|
||||||
|
|
||||||
line = await self.wait_for(Numerics(["903", "904"]))
|
|
||||||
if line.command == "903":
|
|
||||||
return SASLResult.SUCCESS
|
|
||||||
return SASLResult.FAILURE
|
|
||||||
|
|
||||||
async def sasl_userpass(self, username: str, password: str) -> SASLResult:
|
|
||||||
# this will, in the future, offer SCRAM support
|
|
||||||
|
|
||||||
def _common(server_mechs) -> str:
|
|
||||||
for our_mech in SASL_USERPASS_MECHANISMS:
|
|
||||||
if our_mech in server_mechs:
|
|
||||||
return our_mech
|
|
||||||
else:
|
|
||||||
raise SASLUnkownMechanismError(
|
|
||||||
"No matching SASL mechanims. "
|
|
||||||
f"(we have: {SASL_USERPASS_MECHANISMS} "
|
|
||||||
f"server has: {server_mechs})")
|
|
||||||
|
|
||||||
if not self.available_caps["sasl"] is None:
|
|
||||||
# CAP v3.2 tells us what mechs it supports
|
|
||||||
available = self.available_caps["sasl"].split(",")
|
|
||||||
match = _common(available)
|
|
||||||
else:
|
|
||||||
# CAP v3.1 does not. pick the pick and wait for 907 to inform us of
|
|
||||||
# what mechanisms are supported
|
|
||||||
match = SASL_USERPASS_MECHANISMS[0]
|
|
||||||
|
|
||||||
await self.send(build("AUTHENTICATE", [match]))
|
|
||||||
line = await self.wait_for(Response("AUTHENTICATE", [ParamAny()],
|
|
||||||
errors=["904", "907", "908"]))
|
|
||||||
|
|
||||||
if line.command == "907":
|
|
||||||
# we've done SASL already. cleanly abort
|
|
||||||
return SASLResult.ALREADY
|
|
||||||
elif line.command == "908":
|
|
||||||
# prior to CAP v3.2 - ERR telling us which mechs are supported
|
|
||||||
available = line.params[1].split(",")
|
|
||||||
match = _common(available)
|
|
||||||
|
|
||||||
await self.send(build("AUTHENTICATE", [match]))
|
|
||||||
line = await self.wait_for(Response("AUTHENTICATE",
|
|
||||||
[ParamAny()]))
|
|
||||||
|
|
||||||
if line.command == "AUTHENTICATE" and line.params[0] == "+":
|
|
||||||
auth_text: Optional[str] = None
|
|
||||||
if match == "PLAIN":
|
|
||||||
auth_text = f"{username}\0{username}\0{password}"
|
|
||||||
|
|
||||||
if not auth_text is None:
|
|
||||||
auth_b64 = b64encode(auth_text.encode("utf8")
|
|
||||||
).decode("ascii")
|
|
||||||
await self.send(build("AUTHENTICATE", [auth_b64]))
|
|
||||||
|
|
||||||
line = await self.wait_for(Numerics(["903", "904"]))
|
|
||||||
if line.command == "903":
|
|
||||||
return SASLResult.SUCCESS
|
|
||||||
return SASLResult.FAILURE
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue