add Server.wait_for(), so you can await until getting a matching message

This commit is contained in:
jesopo 2020-04-02 17:00:50 +01:00
parent 4f61b89012
commit 8148fead51
2 changed files with 66 additions and 3 deletions

49
ircrobots/matching.py Normal file
View file

@ -0,0 +1,49 @@
from typing import List, Optional
from irctokens import Line
class ResponseParam(object):
def match(self, arg: str) -> bool:
return False
class BaseResponse(object):
def match(self, line: Line) -> bool:
return False
class Numerics(BaseResponse):
def __init__(self,
numerics: List[str]):
self._numerics = numerics
def match(self, line: Line):
return line.command in self._numerics
class Response(BaseResponse):
def __init__(self,
command: str,
params: List[ResponseParam],
errors: Optional[List[str]] = None):
self._command = command
self._params = params
self._errors = errors or []
def match(self, line: Line) -> bool:
if line.command == self._command:
for i, param in enumerate(self._params):
if (i >= len(line.params) or
not param.match(line.params[i])):
return False
else:
return True
elif line.command in self._errors:
return True
else:
return False
class ParamAny(ResponseParam):
def match(self, arg: str) -> bool:
return True
class Literal(ResponseParam):
def __init__(self, value: str):
self._value = value
def match(self, arg: str) -> bool:
return self._value == arg

View file

@ -1,8 +1,8 @@
import asyncio, ssl
from asyncio import PriorityQueue
from asyncio import Future, PriorityQueue
from queue import Queue
from typing import Callable, Dict, List, Optional, Set, Tuple
from enum import IntEnum
from typing import Awaitable, Callable, Dict, List, Optional, Set, Tuple
from enum import Enum
from dataclasses import dataclass
from asyncio_throttle import Throttler
@ -11,6 +11,7 @@ from irctokens import build, Line, tokenise
from .ircv3 import Capability, CAPS
from .interface import ConnectionParams, IServer, PriorityLine, SendPriority
from .matching import BaseResponse, Response, Numerics, ParamAny, Literal
sc = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
@ -33,6 +34,8 @@ class Server(IServer):
self._cap_queue: Set[Capability] = set([])
self._requested_caps: List[str] = []
self._wait_for: List[Tuple[BaseResponse, Future]] = []
async def send_raw(self, line: str, priority=SendPriority.DEFAULT):
await self.send(tokenise(line), priority)
async def send(self, line: Line, priority=SendPriority.DEFAULT):
@ -67,6 +70,12 @@ class Server(IServer):
await self._cap_ack(emit)
async def _on_read_line(self, line: Line):
for i, (response, future) in enumerate(self._wait_for):
if response.match(line):
self._wait_for.pop(i)
future.set_result(line)
break
if line.command == "PING":
await self.send(build("PONG", line.params))
@ -75,6 +84,11 @@ class Server(IServer):
lines = self.recv(data)
return lines
def wait_for(self, response: BaseResponse) -> Awaitable[Line]:
future: "Future[Line]" = asyncio.Future()
self._wait_for.append((response, future))
return future
async def line_written(self, line: Line):
pass
async def _write_lines(self) -> List[Line]: