rename some matching Params, restructure matching

This commit is contained in:
jesopo 2020-04-21 21:40:46 +01:00
parent f2ba7c2512
commit 89c7ac15dd
8 changed files with 126 additions and 109 deletions

View file

@ -5,7 +5,7 @@ from irctokens import build
from ircstates.server import ServerDisconnectedException from ircstates.server import ServerDisconnectedException
from .contexts import ServerContext from .contexts import ServerContext
from .matching import Response, ResponseOr, ParamAny from .matching import Response, ResponseOr, ANY
from .interface import ICapability from .interface import ICapability
from .params import ConnectionParams, STSPolicy from .params import ConnectionParams, STSPolicy
@ -101,8 +101,8 @@ class CAPContext(ServerContext):
while cap_names: while cap_names:
line = await self.server.wait_for(ResponseOr( line = await self.server.wait_for(ResponseOr(
Response("CAP", [ParamAny(), "ACK"]), Response("CAP", [ANY, "ACK"]),
Response("CAP", [ParamAny(), "NAK"]) Response("CAP", [ANY, "NAK"])
)) ))
current_caps = line.params[2].split(" ") current_caps = line.params[2].split(" ")

View file

@ -3,7 +3,7 @@ from irctokens import build
from ircstates.numerics import * from ircstates.numerics import *
from .contexts import ServerContext from .contexts import ServerContext
from .matching import Response, ResponseOr, ParamAny, ParamFolded from .matching import Response, ResponseOr, ANY, Folded
class WHOContext(ServerContext): class WHOContext(ServerContext):
async def ensure(self, channel: str): async def ensure(self, channel: str):
@ -15,5 +15,5 @@ class WHOContext(ServerContext):
await self.server.send(build("WHO", [channel])) await self.server.send(build("WHO", [channel]))
line = await self.server.wait_for( line = await self.server.wait_for(
Response(RPL_ENDOFWHO, [ParamAny(), ParamFolded(folded)]) Response(RPL_ENDOFWHO, [ANY, Folded(folded)])
) )

View file

@ -1,98 +0,0 @@
from typing import List, Optional, Union
from irctokens import Line, Hostmask
from .interface import (IServer, IMatchResponse, IMatchResponseParam,
IMatchResponseHostmask)
TYPE_PARAM = Union[str, IMatchResponseParam]
class Responses(IMatchResponse):
def __init__(self,
commands: List[str],
params: List[TYPE_PARAM]=[],
source: Optional[IMatchResponseHostmask]=None):
self._commands = commands
self._params = params
self._source = source
def __repr__(self) -> str:
return f"Responses({self._commands!r}: {self._params!r})"
def match(self, server: IServer, line: Line) -> bool:
for command in self._commands:
if (line.command == command and (
self._source is None or (
line.hostmask is not None and
self._source.match(server, line.hostmask)
))):
for i, param in enumerate(self._params):
if i >= len(line.params):
break
elif (isinstance(param, str) and
not param == line.params[i]):
break
elif (isinstance(param, IMatchResponseParam) and
not param.match(server, line.params[i])):
break
else:
return True
else:
return False
class Response(Responses):
def __init__(self,
command: str,
params: List[TYPE_PARAM]=[],
source: Optional[IMatchResponseHostmask]=None):
super().__init__([command], params, source=source)
def __repr__(self) -> str:
return f"Response({self._commands[0]}: {self._params!r})"
class ResponseOr(IMatchResponse):
def __init__(self, *responses: IMatchResponse):
self._responses = responses
def __repr__(self) -> str:
return f"ResponseOr({self._responses!r})"
def match(self, server: IServer, line: Line) -> bool:
for response in self._responses:
if response.match(server, line):
return True
else:
return False
class ParamAny(IMatchResponseParam):
def __repr__(self) -> str:
return "Any()"
def match(self, server: IServer, arg: str) -> bool:
return True
class ParamFolded(IMatchResponseParam):
def __init__(self, value: str):
self._value = value
self._folded: Optional[str] = None
def __repr__(self) -> str:
return f"FoldString({self._value!r})"
def match(self, server: IServer, arg: str) -> bool:
if self._folded is None:
self._folded = server.casefold(self._value)
return self._folded == server.casefold(arg)
class ParamNot(IMatchResponseParam):
def __init__(self, param: IMatchResponseParam):
self._param = param
def __repr__(self) -> str:
return f"Not({self._param!r})"
def match(self, server: IServer, arg: str) -> bool:
return not self._param.match(server, arg)
class Nickname(IMatchResponseHostmask):
def __init__(self, nickname: str):
self._nickname = nickname
self._folded: Optional[str] = None
def __repr__(self) -> str:
mask = f"{self._nickname}!*@*"
return f"Hostmask({mask!r})"
def match(self, server: IServer, hostmask: Hostmask):
if self._folded is None:
self._folded = server.casefold(self._nickname)
return self._folded == server.casefold(hostmask.nickname)

View file

@ -0,0 +1,3 @@
from .responses import *
from .params import *

View file

@ -0,0 +1,50 @@
from typing import Optional
from irctokens import Hostmask
from ..interface import IMatchResponseParam, IMatchResponseHostmask, IServer
class Any(IMatchResponseParam):
def __repr__(self) -> str:
return "Any()"
def match(self, server: IServer, arg: str) -> bool:
return True
ANY = Any()
class Literal(IMatchResponseParam):
def __init__(self, value: str):
self._value = value
def __repr__(self) -> str:
return f"{self._value!r}"
def match(self, server: IServer, arg: str) -> bool:
return arg == self._value
class Folded(IMatchResponseParam):
def __init__(self, value: str):
self._value = value
self._folded: Optional[str] = None
def __repr__(self) -> str:
return f"FoldString({self._value!r})"
def match(self, server: IServer, arg: str) -> bool:
if self._folded is None:
self._folded = server.casefold(self._value)
return self._folded == server.casefold(arg)
class Not(IMatchResponseParam):
def __init__(self, param: IMatchResponseParam):
self._param = param
def __repr__(self) -> str:
return f"Not({self._param!r})"
def match(self, server: IServer, arg: str) -> bool:
return not self._param.match(server, arg)
class Nickname(IMatchResponseHostmask):
def __init__(self, nickname: str):
self._nickname = nickname
self._folded: Optional[str] = None
def __repr__(self) -> str:
mask = f"{self._nickname}!*@*"
return f"Hostmask({mask!r})"
def match(self, server: IServer, hostmask: Hostmask):
if self._folded is None:
self._folded = server.casefold(self._nickname)
return self._folded == server.casefold(hostmask.nickname)

View file

@ -0,0 +1,63 @@
from typing import List, Optional, Union
from irctokens import Line
from ..interface import (IServer, IMatchResponse, IMatchResponseParam,
IMatchResponseHostmask)
from .params import *
TYPE_PARAM = Union[str, IMatchResponseParam]
class Responses(IMatchResponse):
def __init__(self,
commands: List[str],
params: List[TYPE_PARAM]=[],
source: Optional[IMatchResponseHostmask]=None):
self._commands = commands
self._source = source
self._params: List[IMatchResponseParam] = []
for param in params:
if isinstance(param, str):
self._params.append(Literal(param))
elif isinstance(param, IMatchResponseParam):
self._params.append(param)
def __repr__(self) -> str:
return f"Responses({self._commands!r}: {self._params!r})"
def match(self, server: IServer, line: Line) -> bool:
for command in self._commands:
if (line.command == command and (
self._source is None or (
line.hostmask is not None and
self._source.match(server, line.hostmask)
))):
for i, param in enumerate(self._params):
if (i >= len(line.params) or
not param.match(server, line.params[i])):
break
else:
return True
else:
return False
class Response(Responses):
def __init__(self,
command: str,
params: List[TYPE_PARAM]=[],
source: Optional[IMatchResponseHostmask]=None):
super().__init__([command], params, source=source)
def __repr__(self) -> str:
return f"Response({self._commands[0]}: {self._params!r})"
class ResponseOr(IMatchResponse):
def __init__(self, *responses: IMatchResponse):
self._responses = responses
def __repr__(self) -> str:
return f"ResponseOr({self._responses!r})"
def match(self, server: IServer, line: Line) -> bool:
for response in self._responses:
if response.match(server, line):
return True
else:
return False

View file

@ -4,7 +4,7 @@ from base64 import b64decode, b64encode
from irctokens import build from irctokens import build
from ircstates.numerics import * from ircstates.numerics import *
from .matching import ResponseOr, Responses, Response, ParamAny from .matching import ResponseOr, Responses, Response, ANY
from .contexts import ServerContext from .contexts import ServerContext
from .params import SASLParams from .params import SASLParams
from .scram import SCRAMContext from .scram import SCRAMContext
@ -29,7 +29,7 @@ class SASLUnknownMechanismError(SASLError):
AUTH_BYTE_MAX = 400 AUTH_BYTE_MAX = 400
AUTHENTICATE_ANY = Response("AUTHENTICATE", [ParamAny()]) AUTHENTICATE_ANY = Response("AUTHENTICATE", [ANY])
NUMERICS_FAIL = Response(ERR_SASLFAIL) NUMERICS_FAIL = Response(ERR_SASLFAIL)
NUMERICS_INITIAL = Responses([ERR_SASLFAIL, ERR_SASLALREADY, RPL_SASLMECHS]) NUMERICS_INITIAL = Responses([ERR_SASLFAIL, ERR_SASLALREADY, RPL_SASLMECHS])

View file

@ -12,8 +12,7 @@ from .ircv3 import (CAPContext, sts_transmute, CAP_ECHO, CAP_SASL,
CAP_LABEL, LABEL_TAG) CAP_LABEL, LABEL_TAG)
from .sasl import SASLContext, SASLResult from .sasl import SASLContext, SASLResult
from .join_info import WHOContext from .join_info import WHOContext
from .matching import (ResponseOr, Responses, Response, ParamAny, ParamFolded, from .matching import ResponseOr, Responses, Response, ANY, Folded, Nickname
Nickname)
from .asyncs import MaybeAwait from .asyncs import MaybeAwait
from .struct import Whois from .struct import Whois
from .params import ConnectionParams, SASLParams, STSPolicy from .params import ConnectionParams, SASLParams, STSPolicy
@ -285,7 +284,7 @@ class Server(IServer):
while folded_names: while folded_names:
line = await self.wait_for( line = await self.wait_for(
Response(RPL_CHANNELMODEIS, [ParamAny(), ParamAny()]) Response(RPL_CHANNELMODEIS, [ANY, ANY])
) )
folded = self.casefold(line.params[1]) folded = self.casefold(line.params[1])
@ -302,7 +301,7 @@ class Server(IServer):
async def _assure(): async def _assure():
await fut await fut
params = [ParamAny(), ParamFolded(folded)] params = [ANY, Folded(folded)]
obj = Whois() obj = Whois()
while True: while True:
line = await self.wait_for(Responses([ line = await self.wait_for(Responses([