pass IServer to Response so we can have FoldString (match with casefold rules)

This commit is contained in:
jesopo 2020-04-05 13:00:13 +01:00
parent f70932ac44
commit e470d57780
3 changed files with 46 additions and 35 deletions

View file

@ -5,7 +5,6 @@ from enum import IntEnum
from ircstates import Server
from irctokens import Line
from .matching import BaseResponse
from .params import ConnectionParams, SASLParams
class SendPriority(IntEnum):
@ -32,6 +31,13 @@ class ICapability(object):
def copy(self) -> "ICapability":
pass
class IMatchResponse(object):
def match(self, server: "IServer", line: Line) -> bool:
pass
class IMatchResponseParam(object):
def match(self, server: "IServer", arg: str) -> bool:
pass
class IServer(Server):
params: ConnectionParams
desired_caps: Set[ICapability]
@ -41,7 +47,7 @@ class IServer(Server):
async def send(self, line: Line, priority=SendPriority.DEFAULT):
pass
def wait_for(self, response: BaseResponse) -> Awaitable[Line]:
def wait_for(self, response: IMatchResponse) -> Awaitable[Line]:
pass
def set_throttle(self, rate: int, time: float):

View file

@ -1,73 +1,79 @@
from typing import List
from typing import List, Optional
from irctokens import Line
from .numerics import NUMERIC_NAMES
from .interface import IServer, IMatchResponse, IMatchResponseParam
class ResponseParam(object):
def match(self, arg: str) -> bool:
return False
class BaseResponse(object):
def match(self, line: Line) -> bool:
return False
class Numerics(BaseResponse):
class Numerics(IMatchResponse):
def __init__(self,
numerics: List[str]):
self._numerics = [NUMERIC_NAMES.get(n, n) for n in numerics]
def __repr__(self) -> str:
return f"Numerics({self._numerics!r})"
def match(self, line: Line):
def match(self, server: IServer, line: Line):
return line.command in self._numerics
class Response(BaseResponse):
class Response(IMatchResponse):
def __init__(self,
command: str,
params: List[ResponseParam]):
params: List[IMatchResponseParam]):
self._command = command
self._params = params
def __repr__(self) -> str:
return f"Response({self._command}: {self._params!r})"
def match(self, line: Line) -> bool:
def match(self, server: IServer, line: Line) -> bool:
if line.command == self._command:
for i, param in enumerate(self._params):
if (i >= len(line.params) or
not param.match(line.params[i])):
not param.match(server, line.params[i])):
return False
else:
return True
else:
return False
class ResponseOr(BaseResponse):
def __init__(self, *responses: BaseResponse):
class ResponseOr(IMatchResponse):
def __init__(self, *responses: IMatchResponse):
self._responses = responses
def __repr__(self) -> str:
return f"ResponseOr({self._responses!r})"
def match(self, line: Line) -> bool:
def match(self, server: IServer, line: Line) -> bool:
for response in self._responses:
if response.match(line):
if response.match(server, line):
return True
else:
return False
class ParamAny(ResponseParam):
class ParamAny(IMatchResponseParam):
def __repr__(self) -> str:
return "Any()"
def match(self, arg: str) -> bool:
def match(self, server: IServer, arg: str) -> bool:
return True
class ParamLiteral(ResponseParam):
class ParamLiteral(IMatchResponseParam):
def __init__(self, value: str):
self._value = value
def __repr__(self) -> str:
return f"Literal({self._value!r})"
def match(self, arg: str) -> bool:
def match(self, server: IServer, arg: str) -> bool:
return self._value == arg
class ParamNot(ResponseParam):
def __init__(self, param: ResponseParam):
class FoldString(IMatchResponseParam):
def __init__(self, value: str):
self._value = value
self._folded: Optional[str] = None
def __repr__(self) -> str:
return f"FoldString({self._value!r})"
def match(self, server: IServer, arg: str) -> bool:
if self._folded is None:
self._folded = server.casefold(self._value)
return self._folded == server.casefold(arg)
class ParamNot(IMatchResponseParam):
def __init__(self, param: IMatchResponseParam):
self._param = param
def __repr__(self) -> str:
return f"Not({self._param!r})"
def match(self, arg: str) -> bool:
return not self._param.match(arg)
def match(self, server: IServer, arg: str) -> bool:
return not self._param.match(server, arg)

View file

@ -9,8 +9,7 @@ from irctokens import build, Line, tokenise
from .ircv3 import CAPContext, CAP_SASL
from .interface import (ConnectionParams, ICapability, IServer, SentLine,
SendPriority, SASLParams)
from .matching import BaseResponse
SendPriority, SASLParams, IMatchResponse)
from .sasl import SASLContext, SASLResult
from .security import ssl_context
@ -104,7 +103,7 @@ class Server(IServer):
line, emits = await self._read_queue.get()
return line
async def wait_for(self, response: BaseResponse) -> Line:
async def wait_for(self, response: IMatchResponse) -> Line:
while True:
lines = self._wait_for_cache.copy()
self._wait_for_cache.clear()
@ -113,7 +112,7 @@ class Server(IServer):
lines += await self._read_lines()
for i, (line, emits) in enumerate(lines):
if response.match(line):
if response.match(self, line):
self._wait_for_cache = lines[i+1:]
return line