add numerics.py to translate names, remove Response(errors=)
This commit is contained in:
parent
06a4d20fc8
commit
a4f5d8045f
4 changed files with 48 additions and 23 deletions
|
@ -2,7 +2,8 @@ from typing import Iterable, List, Optional
|
|||
from irctokens import build
|
||||
|
||||
from .contexts import ServerContext
|
||||
from .matching import Response, ResponseOr, ParamAny, ParamNot, ParamLiteral
|
||||
from .matching import (Response, Numerics, ResponseOr, ParamAny, ParamNot,
|
||||
ParamLiteral)
|
||||
from .interface import ICapability
|
||||
|
||||
class Capability(ICapability):
|
||||
|
@ -57,10 +58,13 @@ CAPS: List[ICapability] = [
|
|||
class CAPContext(ServerContext):
|
||||
async def handshake(self) -> bool:
|
||||
# improve this by being able to wait_for Emit objects
|
||||
line = await self.server.wait_for(Response(
|
||||
"CAP",
|
||||
[ParamAny(), ParamLiteral("LS"), ParamNot(ParamLiteral("*"))],
|
||||
errors=["001"]))
|
||||
line = await self.server.wait_for(ResponseOr(
|
||||
Response(
|
||||
"CAP",
|
||||
[ParamAny(), ParamLiteral("LS"), ParamNot(ParamLiteral("*"))]
|
||||
),
|
||||
Numerics(["RPL_WELCOME"])
|
||||
))
|
||||
|
||||
if line.command == "CAP":
|
||||
caps = self.server.collect_caps()
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from typing import List, Optional
|
||||
from typing import List
|
||||
from irctokens import Line
|
||||
from .numerics import NUMERIC_NAMES
|
||||
|
||||
class ResponseParam(object):
|
||||
def match(self, arg: str) -> bool:
|
||||
|
@ -12,7 +13,7 @@ class BaseResponse(object):
|
|||
class Numerics(BaseResponse):
|
||||
def __init__(self,
|
||||
numerics: List[str]):
|
||||
self._numerics = numerics
|
||||
self._numerics = [NUMERIC_NAMES.get(n, n) for n in numerics]
|
||||
|
||||
def match(self, line: Line):
|
||||
return line.command in self._numerics
|
||||
|
@ -20,11 +21,9 @@ class Numerics(BaseResponse):
|
|||
class Response(BaseResponse):
|
||||
def __init__(self,
|
||||
command: str,
|
||||
params: List[ResponseParam],
|
||||
errors: Optional[List[str]] = None):
|
||||
params: List[ResponseParam]):
|
||||
self._command = command
|
||||
self._params = params
|
||||
self._errors = errors or []
|
||||
|
||||
def match(self, line: Line) -> bool:
|
||||
if line.command == self._command:
|
||||
|
@ -34,8 +33,6 @@ class Response(BaseResponse):
|
|||
return False
|
||||
else:
|
||||
return True
|
||||
elif line.command in self._errors:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
|
16
ircrobots/numerics.py
Normal file
16
ircrobots/numerics.py
Normal file
|
@ -0,0 +1,16 @@
|
|||
NUMERIC_NUMBERS = {}
|
||||
NUMERIC_NAMES = {}
|
||||
|
||||
def _numeric(number: str, name: str):
|
||||
NUMERIC_NUMBERS[number] = name
|
||||
NUMERIC_NAMES[name] = number
|
||||
|
||||
_numeric("001", "RPL_WELCOME")
|
||||
_numeric("005", "RPL_ISUPPORT")
|
||||
|
||||
_numeric("903", "RPL_SASLSUCCESS")
|
||||
_numeric("904", "ERR_SASLFAIL")
|
||||
_numeric("905", "ERR_SASLTOOLONG")
|
||||
_numeric("906", "ERR_SASLABORTED")
|
||||
_numeric("907", "ERR_SASLALREADY")
|
||||
_numeric("908", "RPL_SASLMECHS")
|
|
@ -3,7 +3,7 @@ from enum import Enum
|
|||
from base64 import b64encode
|
||||
from irctokens import build
|
||||
|
||||
from .matching import Response, Numerics, ParamAny
|
||||
from .matching import Response, ResponseOr, Numerics, ParamAny
|
||||
from .contexts import ServerContext
|
||||
from .params import SASLParams
|
||||
|
||||
|
@ -25,6 +25,10 @@ class SASLError(Exception):
|
|||
class SASLUnknownMechanismError(SASLError):
|
||||
pass
|
||||
|
||||
NUMERICS_INITIAL = Numerics(
|
||||
["ERR_SASLFAIL", "ERR_SASLALREADY", "RPL_SASLMECHS"])
|
||||
NUMERICS_LAST = Numerics(["RPL_SASLSUCCESS", "ERR_SASLFAIL"])
|
||||
|
||||
class SASLContext(ServerContext):
|
||||
async def from_params(self, params: SASLParams) -> SASLResult:
|
||||
if params.mechanism == "USERPASS":
|
||||
|
@ -38,8 +42,10 @@ class SASLContext(ServerContext):
|
|||
|
||||
async def external(self) -> SASLResult:
|
||||
await self.server.send(build("AUTHENTICATE", ["EXTERNAL"]))
|
||||
line = await self.server.wait_for(Response("AUTHENTICATE",
|
||||
[ParamAny()], errors=["904", "907", "908"]))
|
||||
line = await self.server.wait_for(ResponseOr(
|
||||
Response("AUTHENTICATE", [ParamAny()]),
|
||||
NUMERICS_INITIAL
|
||||
))
|
||||
|
||||
if line.command == "907":
|
||||
# we've done SASL already. cleanly abort
|
||||
|
@ -52,7 +58,7 @@ class SASLContext(ServerContext):
|
|||
elif line.command == "AUTHENTICATE" and line.params[0] == "+":
|
||||
await self.server.send(build("AUTHENTICATE", ["+"]))
|
||||
|
||||
line = await self.server.wait_for(Numerics(["903", "904"]))
|
||||
line = await self.server.wait_for(NUMERICS_LAST)
|
||||
if line.command == "903":
|
||||
return SASLResult.SUCCESS
|
||||
return SASLResult.FAILURE
|
||||
|
@ -80,8 +86,10 @@ class SASLContext(ServerContext):
|
|||
match = SASL_USERPASS_MECHANISMS[0]
|
||||
|
||||
await self.server.send(build("AUTHENTICATE", [match]))
|
||||
line = await self.server.wait_for(Response("AUTHENTICATE",
|
||||
[ParamAny()], errors=["904", "907", "908"]))
|
||||
line = await self.server.wait_for(ResponseOr(
|
||||
Response("AUTHENTICATE", [ParamAny()]),
|
||||
NUMERICS_INITIAL
|
||||
))
|
||||
|
||||
if line.command == "907":
|
||||
# we've done SASL already. cleanly abort
|
||||
|
@ -92,8 +100,8 @@ class SASLContext(ServerContext):
|
|||
match = _common(available)
|
||||
|
||||
await self.server.send(build("AUTHENTICATE", [match]))
|
||||
line = await self.server.wait_for(Response("AUTHENTICATE",
|
||||
[ParamAny()]))
|
||||
line = await self.server.wait_for(
|
||||
Response("AUTHENTICATE", [ParamAny()]))
|
||||
|
||||
if line.command == "AUTHENTICATE" and line.params[0] == "+":
|
||||
auth_text: Optional[str] = None
|
||||
|
@ -101,11 +109,11 @@ class SASLContext(ServerContext):
|
|||
auth_text = f"{username}\0{username}\0{password}"
|
||||
|
||||
if not auth_text is None:
|
||||
auth_b64 = b64encode(auth_text.encode("utf8")
|
||||
).decode("ascii")
|
||||
auth_b64 = b64encode(
|
||||
auth_text.encode("utf8")).decode("ascii")
|
||||
await self.server.send(build("AUTHENTICATE", [auth_b64]))
|
||||
|
||||
line = await self.server.wait_for(Numerics(["903", "904"]))
|
||||
line = await self.server.wait_for(NUMERICS_LAST)
|
||||
if line.command == "903":
|
||||
return SASLResult.SUCCESS
|
||||
return SASLResult.FAILURE
|
||||
|
|
Loading…
Add table
Reference in a new issue