diff --git a/ircrobots/server.py b/ircrobots/server.py index 05c721a..ffb8617 100644 --- a/ircrobots/server.py +++ b/ircrobots/server.py @@ -4,12 +4,13 @@ from queue import Queue from typing import Awaitable, Callable, Dict, List, Optional, Set, Tuple from enum import Enum from dataclasses import dataclass +from base64 import b64encode from asyncio_throttle import Throttler from ircstates import Emit from irctokens import build, Line, tokenise -from .ircv3 import Capability, CAPS +from .ircv3 import Capability, CAPS, CAP_SASL from .interface import ConnectionParams, IServer, PriorityLine, SendPriority from .matching import BaseResponse, Response, Numerics, ParamAny, Literal @@ -18,6 +19,23 @@ sc = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) THROTTLE_RATE = 4 # lines THROTTLE_TIME = 2 # seconds +SASL_USERPASS_MECHANISMS = [ + "SCRAM-SHA-512", + "SCRAM-SHA-256", + "SCRAM-SHA-1", + "PLAIN" +] + +class SASLResult(Enum): + SUCCESS = 1 + FAILURE = 2 + ALREADY = 3 + +class SASLError(Exception): + pass +class SASLUnkownMechanismError(SASLError): + pass + class Server(IServer): _reader: asyncio.StreamReader _writer: asyncio.StreamWriter @@ -108,19 +126,110 @@ class Server(IServer): # CAP-related async def queue_capability(self, cap: Capability): self._cap_queue.add(cap) + + def cap_agreed(self, capability: Capability) -> bool: + return bool(self.cap_available(capability)) + def cap_available(self, capability: Capability) -> Optional[str]: + return capability.available(self.agreed_caps) + async def _cap_ls_done(self): caps = CAPS+list(self._cap_queue) self._cap_queue.clear() + if not self.params.sasl is None: + caps.append(CAP_SASL) + matches = list(filter(bool, (c.available(self.available_caps) for c in caps))) if matches: self._requested_caps = matches await self.send(build("CAP", ["REQ", " ".join(matches)])) async def _cap_ack(self, emit: Emit): + if not self.params.sasl is None and self.cap_agreed(CAP_SASL): + if self.params.sasl.mechanism == "USERPASS": + await self.sasl_userpass( + self.params.sasl.username or "", + self.params.sasl.password or "") + elif self.params.sasl.mechanism == "EXTERNAL": + await self.sasl_external() + for cap in (emit.tokens or []): if cap in self._requested_caps: self._requested_caps.remove(cap) if not self._requested_caps: await self.send(build("CAP", ["END"])) # /CAP-related + + async def sasl_external(self) -> SASLResult: + await self.send(build("AUTHENTICATE", ["EXTERNAL"])) + line = await self.wait_for(Response("AUTHENTICATE", [ParamAny()], + errors=["904", "907", "908"])) + + if line.command == "907": + # we've done SASL already. cleanly abort + return SASLResult.ALREADY + elif line.command == "908": + available = line.params[1].split(",") + raise SASLUnkownMechanismError( + "Server does not support SASL EXTERNAL " + f"(it supports {available}") + elif line.command == "AUTHENTICATE" and line.params[0] == "+": + await self.send(build("AUTHENTICATE", ["+"])) + + line = await self.wait_for(Numerics(["903", "904"])) + if line.command == "903": + return SASLResult.SUCCESS + return SASLResult.FAILURE + + async def sasl_userpass(self, username: str, password: str) -> SASLResult: + # this will, in the future, offer SCRAM support + + def _common(server_mechs) -> str: + for our_mech in SASL_USERPASS_MECHANISMS: + if our_mech in server_mechs: + return our_mech + else: + raise SASLUnkownMechanismError( + "No matching SASL mechanims. " + f"(we have: {SASL_USERPASS_MECHANISMS} " + f"server has: {server_mechs})") + + if not self.available_caps["sasl"] is None: + # CAP v3.2 tells us what mechs it supports + available = self.available_caps["sasl"].split(",") + match = _common(available) + else: + # CAP v3.1 does not. pick the pick and wait for 907 to inform us of + # what mechanisms are supported + match = SASL_USERPASS_MECHANISMS[0] + + await self.send(build("AUTHENTICATE", [match])) + line = await self.wait_for(Response("AUTHENTICATE", [ParamAny()], + errors=["904", "907", "908"])) + + if line.command == "907": + # we've done SASL already. cleanly abort + return SASLResult.ALREADY + elif line.command == "908": + # prior to CAP v3.2 - ERR telling us which mechs are supported + available = line.params[1].split(",") + match = _common(available) + + await self.send(build("AUTHENTICATE", [match])) + line = await self.wait_for(Response("AUTHENTICATE", + [ParamAny()])) + + if line.command == "AUTHENTICATE" and line.params[0] == "+": + auth_text: Optional[str] = None + if match == "PLAIN": + auth_text = f"{username}\0{username}\0{password}" + + if not auth_text is None: + auth_b64 = b64encode(auth_text.encode("utf8") + ).decode("ascii") + await self.send(build("AUTHENTICATE", [auth_b64])) + + line = await self.wait_for(Numerics(["903", "904"])) + if line.command == "903": + return SASLResult.SUCCESS + return SASLResult.FAILURE