pass IServer to Response so we can have FoldString (match with casefold rules)
This commit is contained in:
parent
f70932ac44
commit
e470d57780
3 changed files with 46 additions and 35 deletions
|
@ -5,7 +5,6 @@ from enum import IntEnum
|
|||
from ircstates import Server
|
||||
from irctokens import Line
|
||||
|
||||
from .matching import BaseResponse
|
||||
from .params import ConnectionParams, SASLParams
|
||||
|
||||
class SendPriority(IntEnum):
|
||||
|
@ -32,6 +31,13 @@ class ICapability(object):
|
|||
def copy(self) -> "ICapability":
|
||||
pass
|
||||
|
||||
class IMatchResponse(object):
|
||||
def match(self, server: "IServer", line: Line) -> bool:
|
||||
pass
|
||||
class IMatchResponseParam(object):
|
||||
def match(self, server: "IServer", arg: str) -> bool:
|
||||
pass
|
||||
|
||||
class IServer(Server):
|
||||
params: ConnectionParams
|
||||
desired_caps: Set[ICapability]
|
||||
|
@ -41,7 +47,7 @@ class IServer(Server):
|
|||
async def send(self, line: Line, priority=SendPriority.DEFAULT):
|
||||
pass
|
||||
|
||||
def wait_for(self, response: BaseResponse) -> Awaitable[Line]:
|
||||
def wait_for(self, response: IMatchResponse) -> Awaitable[Line]:
|
||||
pass
|
||||
|
||||
def set_throttle(self, rate: int, time: float):
|
||||
|
|
|
@ -1,73 +1,79 @@
|
|||
from typing import List
|
||||
from irctokens import Line
|
||||
from .numerics import NUMERIC_NAMES
|
||||
from typing import List, Optional
|
||||
from irctokens import Line
|
||||
from .numerics import NUMERIC_NAMES
|
||||
from .interface import IServer, IMatchResponse, IMatchResponseParam
|
||||
|
||||
class ResponseParam(object):
|
||||
def match(self, arg: str) -> bool:
|
||||
return False
|
||||
|
||||
class BaseResponse(object):
|
||||
def match(self, line: Line) -> bool:
|
||||
return False
|
||||
|
||||
class Numerics(BaseResponse):
|
||||
class Numerics(IMatchResponse):
|
||||
def __init__(self,
|
||||
numerics: List[str]):
|
||||
self._numerics = [NUMERIC_NAMES.get(n, n) for n in numerics]
|
||||
def __repr__(self) -> str:
|
||||
return f"Numerics({self._numerics!r})"
|
||||
|
||||
def match(self, line: Line):
|
||||
def match(self, server: IServer, line: Line):
|
||||
return line.command in self._numerics
|
||||
|
||||
class Response(BaseResponse):
|
||||
class Response(IMatchResponse):
|
||||
def __init__(self,
|
||||
command: str,
|
||||
params: List[ResponseParam]):
|
||||
params: List[IMatchResponseParam]):
|
||||
self._command = command
|
||||
self._params = params
|
||||
def __repr__(self) -> str:
|
||||
return f"Response({self._command}: {self._params!r})"
|
||||
|
||||
def match(self, line: Line) -> bool:
|
||||
def match(self, server: IServer, line: Line) -> bool:
|
||||
if line.command == self._command:
|
||||
for i, param in enumerate(self._params):
|
||||
if (i >= len(line.params) or
|
||||
not param.match(line.params[i])):
|
||||
not param.match(server, line.params[i])):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
class ResponseOr(BaseResponse):
|
||||
def __init__(self, *responses: BaseResponse):
|
||||
class ResponseOr(IMatchResponse):
|
||||
def __init__(self, *responses: IMatchResponse):
|
||||
self._responses = responses
|
||||
def __repr__(self) -> str:
|
||||
return f"ResponseOr({self._responses!r})"
|
||||
def match(self, line: Line) -> bool:
|
||||
def match(self, server: IServer, line: Line) -> bool:
|
||||
for response in self._responses:
|
||||
if response.match(line):
|
||||
if response.match(server, line):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
class ParamAny(ResponseParam):
|
||||
class ParamAny(IMatchResponseParam):
|
||||
def __repr__(self) -> str:
|
||||
return "Any()"
|
||||
def match(self, arg: str) -> bool:
|
||||
def match(self, server: IServer, arg: str) -> bool:
|
||||
return True
|
||||
class ParamLiteral(ResponseParam):
|
||||
|
||||
class ParamLiteral(IMatchResponseParam):
|
||||
def __init__(self, value: str):
|
||||
self._value = value
|
||||
def __repr__(self) -> str:
|
||||
return f"Literal({self._value!r})"
|
||||
def match(self, arg: str) -> bool:
|
||||
def match(self, server: IServer, arg: str) -> bool:
|
||||
return self._value == arg
|
||||
class ParamNot(ResponseParam):
|
||||
def __init__(self, param: ResponseParam):
|
||||
|
||||
class FoldString(IMatchResponseParam):
|
||||
def __init__(self, value: str):
|
||||
self._value = value
|
||||
self._folded: Optional[str] = None
|
||||
def __repr__(self) -> str:
|
||||
return f"FoldString({self._value!r})"
|
||||
def match(self, server: IServer, arg: str) -> bool:
|
||||
if self._folded is None:
|
||||
self._folded = server.casefold(self._value)
|
||||
return self._folded == server.casefold(arg)
|
||||
|
||||
class ParamNot(IMatchResponseParam):
|
||||
def __init__(self, param: IMatchResponseParam):
|
||||
self._param = param
|
||||
def __repr__(self) -> str:
|
||||
return f"Not({self._param!r})"
|
||||
def match(self, arg: str) -> bool:
|
||||
return not self._param.match(arg)
|
||||
def match(self, server: IServer, arg: str) -> bool:
|
||||
return not self._param.match(server, arg)
|
||||
|
|
|
@ -9,8 +9,7 @@ from irctokens import build, Line, tokenise
|
|||
|
||||
from .ircv3 import CAPContext, CAP_SASL
|
||||
from .interface import (ConnectionParams, ICapability, IServer, SentLine,
|
||||
SendPriority, SASLParams)
|
||||
from .matching import BaseResponse
|
||||
SendPriority, SASLParams, IMatchResponse)
|
||||
from .sasl import SASLContext, SASLResult
|
||||
from .security import ssl_context
|
||||
|
||||
|
@ -104,7 +103,7 @@ class Server(IServer):
|
|||
line, emits = await self._read_queue.get()
|
||||
return line
|
||||
|
||||
async def wait_for(self, response: BaseResponse) -> Line:
|
||||
async def wait_for(self, response: IMatchResponse) -> Line:
|
||||
while True:
|
||||
lines = self._wait_for_cache.copy()
|
||||
self._wait_for_cache.clear()
|
||||
|
@ -113,7 +112,7 @@ class Server(IServer):
|
|||
lines += await self._read_lines()
|
||||
|
||||
for i, (line, emits) in enumerate(lines):
|
||||
if response.match(line):
|
||||
if response.match(self, line):
|
||||
self._wait_for_cache = lines[i+1:]
|
||||
return line
|
||||
|
||||
|
|
Loading…
Reference in a new issue