pass IServer to Response so we can have FoldString (match with casefold rules)

This commit is contained in:
jesopo 2020-04-05 13:00:13 +01:00
parent f70932ac44
commit e470d57780
3 changed files with 46 additions and 35 deletions

View file

@ -5,7 +5,6 @@ from enum import IntEnum
from ircstates import Server from ircstates import Server
from irctokens import Line from irctokens import Line
from .matching import BaseResponse
from .params import ConnectionParams, SASLParams from .params import ConnectionParams, SASLParams
class SendPriority(IntEnum): class SendPriority(IntEnum):
@ -32,6 +31,13 @@ class ICapability(object):
def copy(self) -> "ICapability": def copy(self) -> "ICapability":
pass pass
class IMatchResponse(object):
def match(self, server: "IServer", line: Line) -> bool:
pass
class IMatchResponseParam(object):
def match(self, server: "IServer", arg: str) -> bool:
pass
class IServer(Server): class IServer(Server):
params: ConnectionParams params: ConnectionParams
desired_caps: Set[ICapability] desired_caps: Set[ICapability]
@ -41,7 +47,7 @@ class IServer(Server):
async def send(self, line: Line, priority=SendPriority.DEFAULT): async def send(self, line: Line, priority=SendPriority.DEFAULT):
pass pass
def wait_for(self, response: BaseResponse) -> Awaitable[Line]: def wait_for(self, response: IMatchResponse) -> Awaitable[Line]:
pass pass
def set_throttle(self, rate: int, time: float): def set_throttle(self, rate: int, time: float):

View file

@ -1,73 +1,79 @@
from typing import List from typing import List, Optional
from irctokens import Line from irctokens import Line
from .numerics import NUMERIC_NAMES from .numerics import NUMERIC_NAMES
from .interface import IServer, IMatchResponse, IMatchResponseParam
class ResponseParam(object): class Numerics(IMatchResponse):
def match(self, arg: str) -> bool:
return False
class BaseResponse(object):
def match(self, line: Line) -> bool:
return False
class Numerics(BaseResponse):
def __init__(self, def __init__(self,
numerics: List[str]): numerics: List[str]):
self._numerics = [NUMERIC_NAMES.get(n, n) for n in numerics] self._numerics = [NUMERIC_NAMES.get(n, n) for n in numerics]
def __repr__(self) -> str: def __repr__(self) -> str:
return f"Numerics({self._numerics!r})" return f"Numerics({self._numerics!r})"
def match(self, line: Line): def match(self, server: IServer, line: Line):
return line.command in self._numerics return line.command in self._numerics
class Response(BaseResponse): class Response(IMatchResponse):
def __init__(self, def __init__(self,
command: str, command: str,
params: List[ResponseParam]): params: List[IMatchResponseParam]):
self._command = command self._command = command
self._params = params self._params = params
def __repr__(self) -> str: def __repr__(self) -> str:
return f"Response({self._command}: {self._params!r})" return f"Response({self._command}: {self._params!r})"
def match(self, line: Line) -> bool: def match(self, server: IServer, line: Line) -> bool:
if line.command == self._command: if line.command == self._command:
for i, param in enumerate(self._params): for i, param in enumerate(self._params):
if (i >= len(line.params) or if (i >= len(line.params) or
not param.match(line.params[i])): not param.match(server, line.params[i])):
return False return False
else: else:
return True return True
else: else:
return False return False
class ResponseOr(BaseResponse): class ResponseOr(IMatchResponse):
def __init__(self, *responses: BaseResponse): def __init__(self, *responses: IMatchResponse):
self._responses = responses self._responses = responses
def __repr__(self) -> str: def __repr__(self) -> str:
return f"ResponseOr({self._responses!r})" return f"ResponseOr({self._responses!r})"
def match(self, line: Line) -> bool: def match(self, server: IServer, line: Line) -> bool:
for response in self._responses: for response in self._responses:
if response.match(line): if response.match(server, line):
return True return True
else: else:
return False return False
class ParamAny(ResponseParam): class ParamAny(IMatchResponseParam):
def __repr__(self) -> str: def __repr__(self) -> str:
return "Any()" return "Any()"
def match(self, arg: str) -> bool: def match(self, server: IServer, arg: str) -> bool:
return True return True
class ParamLiteral(ResponseParam):
class ParamLiteral(IMatchResponseParam):
def __init__(self, value: str): def __init__(self, value: str):
self._value = value self._value = value
def __repr__(self) -> str: def __repr__(self) -> str:
return f"Literal({self._value!r})" return f"Literal({self._value!r})"
def match(self, arg: str) -> bool: def match(self, server: IServer, arg: str) -> bool:
return self._value == arg return self._value == arg
class ParamNot(ResponseParam):
def __init__(self, param: ResponseParam): class FoldString(IMatchResponseParam):
def __init__(self, value: str):
self._value = value
self._folded: Optional[str] = None
def __repr__(self) -> str:
return f"FoldString({self._value!r})"
def match(self, server: IServer, arg: str) -> bool:
if self._folded is None:
self._folded = server.casefold(self._value)
return self._folded == server.casefold(arg)
class ParamNot(IMatchResponseParam):
def __init__(self, param: IMatchResponseParam):
self._param = param self._param = param
def __repr__(self) -> str: def __repr__(self) -> str:
return f"Not({self._param!r})" return f"Not({self._param!r})"
def match(self, arg: str) -> bool: def match(self, server: IServer, arg: str) -> bool:
return not self._param.match(arg) return not self._param.match(server, arg)

View file

@ -9,8 +9,7 @@ from irctokens import build, Line, tokenise
from .ircv3 import CAPContext, CAP_SASL from .ircv3 import CAPContext, CAP_SASL
from .interface import (ConnectionParams, ICapability, IServer, SentLine, from .interface import (ConnectionParams, ICapability, IServer, SentLine,
SendPriority, SASLParams) SendPriority, SASLParams, IMatchResponse)
from .matching import BaseResponse
from .sasl import SASLContext, SASLResult from .sasl import SASLContext, SASLResult
from .security import ssl_context from .security import ssl_context
@ -104,7 +103,7 @@ class Server(IServer):
line, emits = await self._read_queue.get() line, emits = await self._read_queue.get()
return line return line
async def wait_for(self, response: BaseResponse) -> Line: async def wait_for(self, response: IMatchResponse) -> Line:
while True: while True:
lines = self._wait_for_cache.copy() lines = self._wait_for_cache.copy()
self._wait_for_cache.clear() self._wait_for_cache.clear()
@ -113,7 +112,7 @@ class Server(IServer):
lines += await self._read_lines() lines += await self._read_lines()
for i, (line, emits) in enumerate(lines): for i, (line, emits) in enumerate(lines):
if response.match(line): if response.match(self, line):
self._wait_for_cache = lines[i+1:] self._wait_for_cache = lines[i+1:]
return line return line