diff --git a/ircrobots/ircv3.py b/ircrobots/ircv3.py index d35dd9e..f70b0d4 100644 --- a/ircrobots/ircv3.py +++ b/ircrobots/ircv3.py @@ -2,7 +2,8 @@ from typing import Iterable, List, Optional from irctokens import build from .contexts import ServerContext -from .matching import Response, ResponseOr, ParamAny, ParamNot, ParamLiteral +from .matching import (Response, Numerics, ResponseOr, ParamAny, ParamNot, + ParamLiteral) from .interface import ICapability class Capability(ICapability): @@ -57,10 +58,13 @@ CAPS: List[ICapability] = [ class CAPContext(ServerContext): async def handshake(self) -> bool: # improve this by being able to wait_for Emit objects - line = await self.server.wait_for(Response( - "CAP", - [ParamAny(), ParamLiteral("LS"), ParamNot(ParamLiteral("*"))], - errors=["001"])) + line = await self.server.wait_for(ResponseOr( + Response( + "CAP", + [ParamAny(), ParamLiteral("LS"), ParamNot(ParamLiteral("*"))] + ), + Numerics(["RPL_WELCOME"]) + )) if line.command == "CAP": caps = self.server.collect_caps() diff --git a/ircrobots/matching.py b/ircrobots/matching.py index 3fe43cd..2247f18 100644 --- a/ircrobots/matching.py +++ b/ircrobots/matching.py @@ -1,5 +1,6 @@ -from typing import List, Optional +from typing import List from irctokens import Line +from .numerics import NUMERIC_NAMES class ResponseParam(object): def match(self, arg: str) -> bool: @@ -12,7 +13,7 @@ class BaseResponse(object): class Numerics(BaseResponse): def __init__(self, numerics: List[str]): - self._numerics = numerics + self._numerics = [NUMERIC_NAMES.get(n, n) for n in numerics] def match(self, line: Line): return line.command in self._numerics @@ -20,11 +21,9 @@ class Numerics(BaseResponse): class Response(BaseResponse): def __init__(self, command: str, - params: List[ResponseParam], - errors: Optional[List[str]] = None): + params: List[ResponseParam]): self._command = command self._params = params - self._errors = errors or [] def match(self, line: Line) -> bool: if line.command == self._command: @@ -34,8 +33,6 @@ class Response(BaseResponse): return False else: return True - elif line.command in self._errors: - return True else: return False diff --git a/ircrobots/numerics.py b/ircrobots/numerics.py new file mode 100644 index 0000000..29ea54d --- /dev/null +++ b/ircrobots/numerics.py @@ -0,0 +1,16 @@ +NUMERIC_NUMBERS = {} +NUMERIC_NAMES = {} + +def _numeric(number: str, name: str): + NUMERIC_NUMBERS[number] = name + NUMERIC_NAMES[name] = number + +_numeric("001", "RPL_WELCOME") +_numeric("005", "RPL_ISUPPORT") + +_numeric("903", "RPL_SASLSUCCESS") +_numeric("904", "ERR_SASLFAIL") +_numeric("905", "ERR_SASLTOOLONG") +_numeric("906", "ERR_SASLABORTED") +_numeric("907", "ERR_SASLALREADY") +_numeric("908", "RPL_SASLMECHS") diff --git a/ircrobots/sasl.py b/ircrobots/sasl.py index ddba15c..c1885ba 100644 --- a/ircrobots/sasl.py +++ b/ircrobots/sasl.py @@ -3,7 +3,7 @@ from enum import Enum from base64 import b64encode from irctokens import build -from .matching import Response, Numerics, ParamAny +from .matching import Response, ResponseOr, Numerics, ParamAny from .contexts import ServerContext from .params import SASLParams @@ -25,6 +25,10 @@ class SASLError(Exception): class SASLUnknownMechanismError(SASLError): pass +NUMERICS_INITIAL = Numerics( + ["ERR_SASLFAIL", "ERR_SASLALREADY", "RPL_SASLMECHS"]) +NUMERICS_LAST = Numerics(["RPL_SASLSUCCESS", "ERR_SASLFAIL"]) + class SASLContext(ServerContext): async def from_params(self, params: SASLParams) -> SASLResult: if params.mechanism == "USERPASS": @@ -38,8 +42,10 @@ class SASLContext(ServerContext): async def external(self) -> SASLResult: await self.server.send(build("AUTHENTICATE", ["EXTERNAL"])) - line = await self.server.wait_for(Response("AUTHENTICATE", - [ParamAny()], errors=["904", "907", "908"])) + line = await self.server.wait_for(ResponseOr( + Response("AUTHENTICATE", [ParamAny()]), + NUMERICS_INITIAL + )) if line.command == "907": # we've done SASL already. cleanly abort @@ -52,7 +58,7 @@ class SASLContext(ServerContext): elif line.command == "AUTHENTICATE" and line.params[0] == "+": await self.server.send(build("AUTHENTICATE", ["+"])) - line = await self.server.wait_for(Numerics(["903", "904"])) + line = await self.server.wait_for(NUMERICS_LAST) if line.command == "903": return SASLResult.SUCCESS return SASLResult.FAILURE @@ -80,8 +86,10 @@ class SASLContext(ServerContext): match = SASL_USERPASS_MECHANISMS[0] await self.server.send(build("AUTHENTICATE", [match])) - line = await self.server.wait_for(Response("AUTHENTICATE", - [ParamAny()], errors=["904", "907", "908"])) + line = await self.server.wait_for(ResponseOr( + Response("AUTHENTICATE", [ParamAny()]), + NUMERICS_INITIAL + )) if line.command == "907": # we've done SASL already. cleanly abort @@ -92,8 +100,8 @@ class SASLContext(ServerContext): match = _common(available) await self.server.send(build("AUTHENTICATE", [match])) - line = await self.server.wait_for(Response("AUTHENTICATE", - [ParamAny()])) + line = await self.server.wait_for( + Response("AUTHENTICATE", [ParamAny()])) if line.command == "AUTHENTICATE" and line.params[0] == "+": auth_text: Optional[str] = None @@ -101,11 +109,11 @@ class SASLContext(ServerContext): auth_text = f"{username}\0{username}\0{password}" if not auth_text is None: - auth_b64 = b64encode(auth_text.encode("utf8") - ).decode("ascii") + auth_b64 = b64encode( + auth_text.encode("utf8")).decode("ascii") await self.server.send(build("AUTHENTICATE", [auth_b64])) - line = await self.server.wait_for(Numerics(["903", "904"])) + line = await self.server.wait_for(NUMERICS_LAST) if line.command == "903": return SASLResult.SUCCESS return SASLResult.FAILURE