move SASL logic out in to sasl.py.SASLContext

This commit is contained in:
jesopo 2020-04-02 17:56:44 +01:00
parent 023107385e
commit f43cb75bfa
2 changed files with 85 additions and 92 deletions

View file

@ -1,6 +1,11 @@
from typing import Optional
from enum import Enum
from typing import Optional
from enum import Enum
from base64 import b64encode
from dataclasses import dataclass
from irctokens import build
from .matching import Response, Numerics, ParamAny
from .contexts import ServerContext
SASL_USERPASS_MECHANISMS = [
"SCRAM-SHA-512",
@ -19,14 +24,78 @@ class SASLError(Exception):
class SASLUnkownMechanismError(SASLError):
pass
@dataclass
class SASLParams(object):
mechanism: str
username: Optional[str] = None
password: Optional[str] = None
class SASLUserPass(SASLParams):
def __init__(self, username: str, password: str):
super().__init__("USERPASS", username, password)
class SASLExternal(SASLParams):
def __init__(self):
super().__init__("EXTERNAL")
class SASLContext(ServerContext):
async def external(self) -> SASLResult:
await self.server.send(build("AUTHENTICATE", ["EXTERNAL"]))
line = await self.server.wait_for(Response("AUTHENTICATE",
[ParamAny()], errors=["904", "907", "908"]))
if line.command == "907":
# we've done SASL already. cleanly abort
return SASLResult.ALREADY
elif line.command == "908":
available = line.params[1].split(",")
raise SASLUnkownMechanismError(
"Server does not support SASL EXTERNAL "
f"(it supports {available}")
elif line.command == "AUTHENTICATE" and line.params[0] == "+":
await self.server.send(build("AUTHENTICATE", ["+"]))
line = await self.server.wait_for(Numerics(["903", "904"]))
if line.command == "903":
return SASLResult.SUCCESS
return SASLResult.FAILURE
async def userpass(self, username: str, password: str) -> SASLResult:
# this will, in the future, offer SCRAM support
def _common(server_mechs) -> str:
for our_mech in SASL_USERPASS_MECHANISMS:
if our_mech in server_mechs:
return our_mech
else:
raise SASLUnkownMechanismError(
"No matching SASL mechanims. "
f"(we have: {SASL_USERPASS_MECHANISMS} "
f"server has: {server_mechs})")
if not self.server.available_caps["sasl"] is None:
# CAP v3.2 tells us what mechs it supports
available = self.server.available_caps["sasl"].split(",")
match = _common(available)
else:
# CAP v3.1 does not. pick the pick and wait for 907 to inform us of
# what mechanisms are supported
match = SASL_USERPASS_MECHANISMS[0]
await self.server.send(build("AUTHENTICATE", [match]))
line = await self.server.wait_for(Response("AUTHENTICATE",
[ParamAny()], errors=["904", "907", "908"]))
if line.command == "907":
# we've done SASL already. cleanly abort
return SASLResult.ALREADY
elif line.command == "908":
# prior to CAP v3.2 - ERR telling us which mechs are supported
available = line.params[1].split(",")
match = _common(available)
await self.server.send(build("AUTHENTICATE", [match]))
line = await self.server.wait_for(Response("AUTHENTICATE",
[ParamAny()]))
if line.command == "AUTHENTICATE" and line.params[0] == "+":
auth_text: Optional[str] = None
if match == "PLAIN":
auth_text = f"{username}\0{username}\0{password}"
if not auth_text is None:
auth_b64 = b64encode(auth_text.encode("utf8")
).decode("ascii")
await self.server.send(build("AUTHENTICATE", [auth_b64]))
line = await self.server.wait_for(Numerics(["903", "904"]))
if line.command == "903":
return SASLResult.SUCCESS
return SASLResult.FAILURE

View file

@ -4,7 +4,6 @@ from queue import Queue
from typing import Awaitable, Callable, Dict, List, Optional, Set, Tuple
from enum import Enum
from dataclasses import dataclass
from base64 import b64encode
from asyncio_throttle import Throttler
from ircstates import Emit
@ -13,8 +12,7 @@ from irctokens import build, Line, tokenise
from .ircv3 import Capability, CAPS, CAP_SASL
from .interface import ConnectionParams, IServer, PriorityLine, SendPriority
from .matching import BaseResponse, Response, Numerics, ParamAny, Literal
from .sasl import (SASLResult, SASLUnkownMechanismError,
SASL_USERPASS_MECHANISMS)
from .sasl import SASLContext
sc = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
@ -132,11 +130,11 @@ class Server(IServer):
async def _cap_ack(self, emit: Emit):
if not self.params.sasl is None and self.cap_agreed(CAP_SASL):
if self.params.sasl.mechanism == "USERPASS":
await self.sasl_userpass(
await SASLContext(self).userpass(
self.params.sasl.username or "",
self.params.sasl.password or "")
elif self.params.sasl.mechanism == "EXTERNAL":
await self.sasl_external()
await SASLContext(self).external()
for cap in (emit.tokens or []):
if cap in self._requested_caps:
@ -144,77 +142,3 @@ class Server(IServer):
if not self._requested_caps:
await self.send(build("CAP", ["END"]))
# /CAP-related
async def sasl_external(self) -> SASLResult:
await self.send(build("AUTHENTICATE", ["EXTERNAL"]))
line = await self.wait_for(Response("AUTHENTICATE", [ParamAny()],
errors=["904", "907", "908"]))
if line.command == "907":
# we've done SASL already. cleanly abort
return SASLResult.ALREADY
elif line.command == "908":
available = line.params[1].split(",")
raise SASLUnkownMechanismError(
"Server does not support SASL EXTERNAL "
f"(it supports {available}")
elif line.command == "AUTHENTICATE" and line.params[0] == "+":
await self.send(build("AUTHENTICATE", ["+"]))
line = await self.wait_for(Numerics(["903", "904"]))
if line.command == "903":
return SASLResult.SUCCESS
return SASLResult.FAILURE
async def sasl_userpass(self, username: str, password: str) -> SASLResult:
# this will, in the future, offer SCRAM support
def _common(server_mechs) -> str:
for our_mech in SASL_USERPASS_MECHANISMS:
if our_mech in server_mechs:
return our_mech
else:
raise SASLUnkownMechanismError(
"No matching SASL mechanims. "
f"(we have: {SASL_USERPASS_MECHANISMS} "
f"server has: {server_mechs})")
if not self.available_caps["sasl"] is None:
# CAP v3.2 tells us what mechs it supports
available = self.available_caps["sasl"].split(",")
match = _common(available)
else:
# CAP v3.1 does not. pick the pick and wait for 907 to inform us of
# what mechanisms are supported
match = SASL_USERPASS_MECHANISMS[0]
await self.send(build("AUTHENTICATE", [match]))
line = await self.wait_for(Response("AUTHENTICATE", [ParamAny()],
errors=["904", "907", "908"]))
if line.command == "907":
# we've done SASL already. cleanly abort
return SASLResult.ALREADY
elif line.command == "908":
# prior to CAP v3.2 - ERR telling us which mechs are supported
available = line.params[1].split(",")
match = _common(available)
await self.send(build("AUTHENTICATE", [match]))
line = await self.wait_for(Response("AUTHENTICATE",
[ParamAny()]))
if line.command == "AUTHENTICATE" and line.params[0] == "+":
auth_text: Optional[str] = None
if match == "PLAIN":
auth_text = f"{username}\0{username}\0{password}"
if not auth_text is None:
auth_b64 = b64encode(auth_text.encode("utf8")
).decode("ascii")
await self.send(build("AUTHENTICATE", [auth_b64]))
line = await self.wait_for(Numerics(["903", "904"]))
if line.command == "903":
return SASLResult.SUCCESS
return SASLResult.FAILURE