add Server.wait_for(), so you can await until getting a matching message
This commit is contained in:
parent
4f61b89012
commit
8148fead51
2 changed files with 66 additions and 3 deletions
49
ircrobots/matching.py
Normal file
49
ircrobots/matching.py
Normal file
|
@ -0,0 +1,49 @@
|
||||||
|
from typing import List, Optional
|
||||||
|
from irctokens import Line
|
||||||
|
|
||||||
|
class ResponseParam(object):
|
||||||
|
def match(self, arg: str) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
class BaseResponse(object):
|
||||||
|
def match(self, line: Line) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
class Numerics(BaseResponse):
|
||||||
|
def __init__(self,
|
||||||
|
numerics: List[str]):
|
||||||
|
self._numerics = numerics
|
||||||
|
|
||||||
|
def match(self, line: Line):
|
||||||
|
return line.command in self._numerics
|
||||||
|
|
||||||
|
class Response(BaseResponse):
|
||||||
|
def __init__(self,
|
||||||
|
command: str,
|
||||||
|
params: List[ResponseParam],
|
||||||
|
errors: Optional[List[str]] = None):
|
||||||
|
self._command = command
|
||||||
|
self._params = params
|
||||||
|
self._errors = errors or []
|
||||||
|
|
||||||
|
def match(self, line: Line) -> bool:
|
||||||
|
if line.command == self._command:
|
||||||
|
for i, param in enumerate(self._params):
|
||||||
|
if (i >= len(line.params) or
|
||||||
|
not param.match(line.params[i])):
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
elif line.command in self._errors:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
class ParamAny(ResponseParam):
|
||||||
|
def match(self, arg: str) -> bool:
|
||||||
|
return True
|
||||||
|
class Literal(ResponseParam):
|
||||||
|
def __init__(self, value: str):
|
||||||
|
self._value = value
|
||||||
|
def match(self, arg: str) -> bool:
|
||||||
|
return self._value == arg
|
|
@ -1,8 +1,8 @@
|
||||||
import asyncio, ssl
|
import asyncio, ssl
|
||||||
from asyncio import PriorityQueue
|
from asyncio import Future, PriorityQueue
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from typing import Callable, Dict, List, Optional, Set, Tuple
|
from typing import Awaitable, Callable, Dict, List, Optional, Set, Tuple
|
||||||
from enum import IntEnum
|
from enum import Enum
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from asyncio_throttle import Throttler
|
from asyncio_throttle import Throttler
|
||||||
|
@ -11,6 +11,7 @@ from irctokens import build, Line, tokenise
|
||||||
|
|
||||||
from .ircv3 import Capability, CAPS
|
from .ircv3 import Capability, CAPS
|
||||||
from .interface import ConnectionParams, IServer, PriorityLine, SendPriority
|
from .interface import ConnectionParams, IServer, PriorityLine, SendPriority
|
||||||
|
from .matching import BaseResponse, Response, Numerics, ParamAny, Literal
|
||||||
|
|
||||||
sc = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
|
sc = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
|
||||||
|
|
||||||
|
@ -33,6 +34,8 @@ class Server(IServer):
|
||||||
self._cap_queue: Set[Capability] = set([])
|
self._cap_queue: Set[Capability] = set([])
|
||||||
self._requested_caps: List[str] = []
|
self._requested_caps: List[str] = []
|
||||||
|
|
||||||
|
self._wait_for: List[Tuple[BaseResponse, Future]] = []
|
||||||
|
|
||||||
async def send_raw(self, line: str, priority=SendPriority.DEFAULT):
|
async def send_raw(self, line: str, priority=SendPriority.DEFAULT):
|
||||||
await self.send(tokenise(line), priority)
|
await self.send(tokenise(line), priority)
|
||||||
async def send(self, line: Line, priority=SendPriority.DEFAULT):
|
async def send(self, line: Line, priority=SendPriority.DEFAULT):
|
||||||
|
@ -67,6 +70,12 @@ class Server(IServer):
|
||||||
await self._cap_ack(emit)
|
await self._cap_ack(emit)
|
||||||
|
|
||||||
async def _on_read_line(self, line: Line):
|
async def _on_read_line(self, line: Line):
|
||||||
|
for i, (response, future) in enumerate(self._wait_for):
|
||||||
|
if response.match(line):
|
||||||
|
self._wait_for.pop(i)
|
||||||
|
future.set_result(line)
|
||||||
|
break
|
||||||
|
|
||||||
if line.command == "PING":
|
if line.command == "PING":
|
||||||
await self.send(build("PONG", line.params))
|
await self.send(build("PONG", line.params))
|
||||||
|
|
||||||
|
@ -75,6 +84,11 @@ class Server(IServer):
|
||||||
lines = self.recv(data)
|
lines = self.recv(data)
|
||||||
return lines
|
return lines
|
||||||
|
|
||||||
|
def wait_for(self, response: BaseResponse) -> Awaitable[Line]:
|
||||||
|
future: "Future[Line]" = asyncio.Future()
|
||||||
|
self._wait_for.append((response, future))
|
||||||
|
return future
|
||||||
|
|
||||||
async def line_written(self, line: Line):
|
async def line_written(self, line: Line):
|
||||||
pass
|
pass
|
||||||
async def _write_lines(self) -> List[Line]:
|
async def _write_lines(self) -> List[Line]:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue