add SASL support!

This commit is contained in:
jesopo 2020-04-02 17:04:08 +01:00
parent 16e500fd43
commit 7a1373d9b2

View file

@ -4,12 +4,13 @@ from queue import Queue
from typing import Awaitable, Callable, Dict, List, Optional, Set, Tuple from typing import Awaitable, Callable, Dict, List, Optional, Set, Tuple
from enum import Enum from enum import Enum
from dataclasses import dataclass from dataclasses import dataclass
from base64 import b64encode
from asyncio_throttle import Throttler from asyncio_throttle import Throttler
from ircstates import Emit from ircstates import Emit
from irctokens import build, Line, tokenise from irctokens import build, Line, tokenise
from .ircv3 import Capability, CAPS from .ircv3 import Capability, CAPS, CAP_SASL
from .interface import ConnectionParams, IServer, PriorityLine, SendPriority from .interface import ConnectionParams, IServer, PriorityLine, SendPriority
from .matching import BaseResponse, Response, Numerics, ParamAny, Literal from .matching import BaseResponse, Response, Numerics, ParamAny, Literal
@ -18,6 +19,23 @@ sc = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
THROTTLE_RATE = 4 # lines THROTTLE_RATE = 4 # lines
THROTTLE_TIME = 2 # seconds THROTTLE_TIME = 2 # seconds
SASL_USERPASS_MECHANISMS = [
"SCRAM-SHA-512",
"SCRAM-SHA-256",
"SCRAM-SHA-1",
"PLAIN"
]
class SASLResult(Enum):
SUCCESS = 1
FAILURE = 2
ALREADY = 3
class SASLError(Exception):
pass
class SASLUnkownMechanismError(SASLError):
pass
class Server(IServer): class Server(IServer):
_reader: asyncio.StreamReader _reader: asyncio.StreamReader
_writer: asyncio.StreamWriter _writer: asyncio.StreamWriter
@ -108,19 +126,110 @@ class Server(IServer):
# CAP-related # CAP-related
async def queue_capability(self, cap: Capability): async def queue_capability(self, cap: Capability):
self._cap_queue.add(cap) self._cap_queue.add(cap)
def cap_agreed(self, capability: Capability) -> bool:
return bool(self.cap_available(capability))
def cap_available(self, capability: Capability) -> Optional[str]:
return capability.available(self.agreed_caps)
async def _cap_ls_done(self): async def _cap_ls_done(self):
caps = CAPS+list(self._cap_queue) caps = CAPS+list(self._cap_queue)
self._cap_queue.clear() self._cap_queue.clear()
if not self.params.sasl is None:
caps.append(CAP_SASL)
matches = list(filter(bool, matches = list(filter(bool,
(c.available(self.available_caps) for c in caps))) (c.available(self.available_caps) for c in caps)))
if matches: if matches:
self._requested_caps = matches self._requested_caps = matches
await self.send(build("CAP", ["REQ", " ".join(matches)])) await self.send(build("CAP", ["REQ", " ".join(matches)]))
async def _cap_ack(self, emit: Emit): async def _cap_ack(self, emit: Emit):
if not self.params.sasl is None and self.cap_agreed(CAP_SASL):
if self.params.sasl.mechanism == "USERPASS":
await self.sasl_userpass(
self.params.sasl.username or "",
self.params.sasl.password or "")
elif self.params.sasl.mechanism == "EXTERNAL":
await self.sasl_external()
for cap in (emit.tokens or []): for cap in (emit.tokens or []):
if cap in self._requested_caps: if cap in self._requested_caps:
self._requested_caps.remove(cap) self._requested_caps.remove(cap)
if not self._requested_caps: if not self._requested_caps:
await self.send(build("CAP", ["END"])) await self.send(build("CAP", ["END"]))
# /CAP-related # /CAP-related
async def sasl_external(self) -> SASLResult:
await self.send(build("AUTHENTICATE", ["EXTERNAL"]))
line = await self.wait_for(Response("AUTHENTICATE", [ParamAny()],
errors=["904", "907", "908"]))
if line.command == "907":
# we've done SASL already. cleanly abort
return SASLResult.ALREADY
elif line.command == "908":
available = line.params[1].split(",")
raise SASLUnkownMechanismError(
"Server does not support SASL EXTERNAL "
f"(it supports {available}")
elif line.command == "AUTHENTICATE" and line.params[0] == "+":
await self.send(build("AUTHENTICATE", ["+"]))
line = await self.wait_for(Numerics(["903", "904"]))
if line.command == "903":
return SASLResult.SUCCESS
return SASLResult.FAILURE
async def sasl_userpass(self, username: str, password: str) -> SASLResult:
# this will, in the future, offer SCRAM support
def _common(server_mechs) -> str:
for our_mech in SASL_USERPASS_MECHANISMS:
if our_mech in server_mechs:
return our_mech
else:
raise SASLUnkownMechanismError(
"No matching SASL mechanims. "
f"(we have: {SASL_USERPASS_MECHANISMS} "
f"server has: {server_mechs})")
if not self.available_caps["sasl"] is None:
# CAP v3.2 tells us what mechs it supports
available = self.available_caps["sasl"].split(",")
match = _common(available)
else:
# CAP v3.1 does not. pick the pick and wait for 907 to inform us of
# what mechanisms are supported
match = SASL_USERPASS_MECHANISMS[0]
await self.send(build("AUTHENTICATE", [match]))
line = await self.wait_for(Response("AUTHENTICATE", [ParamAny()],
errors=["904", "907", "908"]))
if line.command == "907":
# we've done SASL already. cleanly abort
return SASLResult.ALREADY
elif line.command == "908":
# prior to CAP v3.2 - ERR telling us which mechs are supported
available = line.params[1].split(",")
match = _common(available)
await self.send(build("AUTHENTICATE", [match]))
line = await self.wait_for(Response("AUTHENTICATE",
[ParamAny()]))
if line.command == "AUTHENTICATE" and line.params[0] == "+":
auth_text: Optional[str] = None
if match == "PLAIN":
auth_text = f"{username}\0{username}\0{password}"
if not auth_text is None:
auth_b64 = b64encode(auth_text.encode("utf8")
).decode("ascii")
await self.send(build("AUTHENTICATE", [auth_b64]))
line = await self.wait_for(Numerics(["903", "904"]))
if line.command == "903":
return SASLResult.SUCCESS
return SASLResult.FAILURE