add wait_for() hostmask matching functionality

This commit is contained in:
jesopo 2020-04-20 16:53:14 +01:00
parent 079460dd35
commit a79958affd
2 changed files with 29 additions and 5 deletions

View file

@ -3,7 +3,7 @@ from typing import Awaitable, Iterable, List, Optional, Set, Tuple
from enum import IntEnum
from ircstates import Server, Emit
from irctokens import Line
from irctokens import Line, Hostmask
from .params import ConnectionParams, SASLParams, STSPolicy
@ -61,6 +61,9 @@ class IMatchResponse(object):
class IMatchResponseParam(object):
def match(self, server: "IServer", arg: str) -> bool:
pass
class IMatchResponseHostmask(object):
def match(self, server: "IServer", hostmask: Hostmask) -> bool:
pass
class IServer(Server):
bot: "IBot"

View file

@ -1,19 +1,27 @@
from typing import List, Optional
from irctokens import Line
from .interface import IServer, IMatchResponse, IMatchResponseParam
from irctokens import Line, Hostmask
from .interface import (IServer, IMatchResponse, IMatchResponseParam,
IMatchResponseHostmask)
class Responses(IMatchResponse):
def __init__(self,
commands: List[str],
params: List[IMatchResponseParam]=[]):
params: List[IMatchResponseParam]=[],
source: Optional[IMatchResponseHostmask]=None):
self._commands = commands
self._params = params
self._source = source
def __repr__(self) -> str:
return f"Responses({self._commands!r}: {self._params!r})"
def match(self, server: IServer, line: Line) -> bool:
for command in self._commands:
if line.command == command:
if (line.command == command and (
self._source is None or (
line.hostmask is not None and
self._source.match(server, line.hostmask)
))):
for i, param in enumerate(self._params):
if (i >= len(line.params) or
not param.match(server, line.params[i])):
@ -76,3 +84,16 @@ class ParamNot(IMatchResponseParam):
return f"Not({self._param!r})"
def match(self, server: IServer, arg: str) -> bool:
return not self._param.match(server, arg)
class Nickname(IMatchResponseHostmask):
def __init__(self, nickname: str):
self._nickname = nickname
self._folded: Optional[str] = None
def __repr__(self) -> str:
mask = f"{self._nickname}!*@*"
return f"Hostmask({mask!r})"
def match(self, server: IServer, hostmask: Hostmask):
if self._folded is None:
self._folded = server.casefold(self._nickname)
return self._folded == server.casefold(hostmask.nickname)