add SASL support!
This commit is contained in:
parent
16e500fd43
commit
7a1373d9b2
1 changed files with 110 additions and 1 deletions
|
@ -4,12 +4,13 @@ from queue import Queue
|
|||
from typing import Awaitable, Callable, Dict, List, Optional, Set, Tuple
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass
|
||||
from base64 import b64encode
|
||||
|
||||
from asyncio_throttle import Throttler
|
||||
from ircstates import Emit
|
||||
from irctokens import build, Line, tokenise
|
||||
|
||||
from .ircv3 import Capability, CAPS
|
||||
from .ircv3 import Capability, CAPS, CAP_SASL
|
||||
from .interface import ConnectionParams, IServer, PriorityLine, SendPriority
|
||||
from .matching import BaseResponse, Response, Numerics, ParamAny, Literal
|
||||
|
||||
|
@ -18,6 +19,23 @@ sc = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
|
|||
THROTTLE_RATE = 4 # lines
|
||||
THROTTLE_TIME = 2 # seconds
|
||||
|
||||
SASL_USERPASS_MECHANISMS = [
|
||||
"SCRAM-SHA-512",
|
||||
"SCRAM-SHA-256",
|
||||
"SCRAM-SHA-1",
|
||||
"PLAIN"
|
||||
]
|
||||
|
||||
class SASLResult(Enum):
|
||||
SUCCESS = 1
|
||||
FAILURE = 2
|
||||
ALREADY = 3
|
||||
|
||||
class SASLError(Exception):
|
||||
pass
|
||||
class SASLUnkownMechanismError(SASLError):
|
||||
pass
|
||||
|
||||
class Server(IServer):
|
||||
_reader: asyncio.StreamReader
|
||||
_writer: asyncio.StreamWriter
|
||||
|
@ -108,19 +126,110 @@ class Server(IServer):
|
|||
# CAP-related
|
||||
async def queue_capability(self, cap: Capability):
|
||||
self._cap_queue.add(cap)
|
||||
|
||||
def cap_agreed(self, capability: Capability) -> bool:
|
||||
return bool(self.cap_available(capability))
|
||||
def cap_available(self, capability: Capability) -> Optional[str]:
|
||||
return capability.available(self.agreed_caps)
|
||||
|
||||
async def _cap_ls_done(self):
|
||||
caps = CAPS+list(self._cap_queue)
|
||||
self._cap_queue.clear()
|
||||
|
||||
if not self.params.sasl is None:
|
||||
caps.append(CAP_SASL)
|
||||
|
||||
matches = list(filter(bool,
|
||||
(c.available(self.available_caps) for c in caps)))
|
||||
if matches:
|
||||
self._requested_caps = matches
|
||||
await self.send(build("CAP", ["REQ", " ".join(matches)]))
|
||||
async def _cap_ack(self, emit: Emit):
|
||||
if not self.params.sasl is None and self.cap_agreed(CAP_SASL):
|
||||
if self.params.sasl.mechanism == "USERPASS":
|
||||
await self.sasl_userpass(
|
||||
self.params.sasl.username or "",
|
||||
self.params.sasl.password or "")
|
||||
elif self.params.sasl.mechanism == "EXTERNAL":
|
||||
await self.sasl_external()
|
||||
|
||||
for cap in (emit.tokens or []):
|
||||
if cap in self._requested_caps:
|
||||
self._requested_caps.remove(cap)
|
||||
if not self._requested_caps:
|
||||
await self.send(build("CAP", ["END"]))
|
||||
# /CAP-related
|
||||
|
||||
async def sasl_external(self) -> SASLResult:
|
||||
await self.send(build("AUTHENTICATE", ["EXTERNAL"]))
|
||||
line = await self.wait_for(Response("AUTHENTICATE", [ParamAny()],
|
||||
errors=["904", "907", "908"]))
|
||||
|
||||
if line.command == "907":
|
||||
# we've done SASL already. cleanly abort
|
||||
return SASLResult.ALREADY
|
||||
elif line.command == "908":
|
||||
available = line.params[1].split(",")
|
||||
raise SASLUnkownMechanismError(
|
||||
"Server does not support SASL EXTERNAL "
|
||||
f"(it supports {available}")
|
||||
elif line.command == "AUTHENTICATE" and line.params[0] == "+":
|
||||
await self.send(build("AUTHENTICATE", ["+"]))
|
||||
|
||||
line = await self.wait_for(Numerics(["903", "904"]))
|
||||
if line.command == "903":
|
||||
return SASLResult.SUCCESS
|
||||
return SASLResult.FAILURE
|
||||
|
||||
async def sasl_userpass(self, username: str, password: str) -> SASLResult:
|
||||
# this will, in the future, offer SCRAM support
|
||||
|
||||
def _common(server_mechs) -> str:
|
||||
for our_mech in SASL_USERPASS_MECHANISMS:
|
||||
if our_mech in server_mechs:
|
||||
return our_mech
|
||||
else:
|
||||
raise SASLUnkownMechanismError(
|
||||
"No matching SASL mechanims. "
|
||||
f"(we have: {SASL_USERPASS_MECHANISMS} "
|
||||
f"server has: {server_mechs})")
|
||||
|
||||
if not self.available_caps["sasl"] is None:
|
||||
# CAP v3.2 tells us what mechs it supports
|
||||
available = self.available_caps["sasl"].split(",")
|
||||
match = _common(available)
|
||||
else:
|
||||
# CAP v3.1 does not. pick the pick and wait for 907 to inform us of
|
||||
# what mechanisms are supported
|
||||
match = SASL_USERPASS_MECHANISMS[0]
|
||||
|
||||
await self.send(build("AUTHENTICATE", [match]))
|
||||
line = await self.wait_for(Response("AUTHENTICATE", [ParamAny()],
|
||||
errors=["904", "907", "908"]))
|
||||
|
||||
if line.command == "907":
|
||||
# we've done SASL already. cleanly abort
|
||||
return SASLResult.ALREADY
|
||||
elif line.command == "908":
|
||||
# prior to CAP v3.2 - ERR telling us which mechs are supported
|
||||
available = line.params[1].split(",")
|
||||
match = _common(available)
|
||||
|
||||
await self.send(build("AUTHENTICATE", [match]))
|
||||
line = await self.wait_for(Response("AUTHENTICATE",
|
||||
[ParamAny()]))
|
||||
|
||||
if line.command == "AUTHENTICATE" and line.params[0] == "+":
|
||||
auth_text: Optional[str] = None
|
||||
if match == "PLAIN":
|
||||
auth_text = f"{username}\0{username}\0{password}"
|
||||
|
||||
if not auth_text is None:
|
||||
auth_b64 = b64encode(auth_text.encode("utf8")
|
||||
).decode("ascii")
|
||||
await self.send(build("AUTHENTICATE", [auth_b64]))
|
||||
|
||||
line = await self.wait_for(Numerics(["903", "904"]))
|
||||
if line.command == "903":
|
||||
return SASLResult.SUCCESS
|
||||
return SASLResult.FAILURE
|
||||
|
|
Loading…
Reference in a new issue