add Server.wait_for(), so you can await until getting a matching message
This commit is contained in:
parent
4f61b89012
commit
8148fead51
2 changed files with 66 additions and 3 deletions
49
ircrobots/matching.py
Normal file
49
ircrobots/matching.py
Normal file
|
@ -0,0 +1,49 @@
|
|||
from typing import List, Optional
|
||||
from irctokens import Line
|
||||
|
||||
class ResponseParam(object):
|
||||
def match(self, arg: str) -> bool:
|
||||
return False
|
||||
|
||||
class BaseResponse(object):
|
||||
def match(self, line: Line) -> bool:
|
||||
return False
|
||||
|
||||
class Numerics(BaseResponse):
|
||||
def __init__(self,
|
||||
numerics: List[str]):
|
||||
self._numerics = numerics
|
||||
|
||||
def match(self, line: Line):
|
||||
return line.command in self._numerics
|
||||
|
||||
class Response(BaseResponse):
|
||||
def __init__(self,
|
||||
command: str,
|
||||
params: List[ResponseParam],
|
||||
errors: Optional[List[str]] = None):
|
||||
self._command = command
|
||||
self._params = params
|
||||
self._errors = errors or []
|
||||
|
||||
def match(self, line: Line) -> bool:
|
||||
if line.command == self._command:
|
||||
for i, param in enumerate(self._params):
|
||||
if (i >= len(line.params) or
|
||||
not param.match(line.params[i])):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
elif line.command in self._errors:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
class ParamAny(ResponseParam):
|
||||
def match(self, arg: str) -> bool:
|
||||
return True
|
||||
class Literal(ResponseParam):
|
||||
def __init__(self, value: str):
|
||||
self._value = value
|
||||
def match(self, arg: str) -> bool:
|
||||
return self._value == arg
|
|
@ -1,8 +1,8 @@
|
|||
import asyncio, ssl
|
||||
from asyncio import PriorityQueue
|
||||
from asyncio import Future, PriorityQueue
|
||||
from queue import Queue
|
||||
from typing import Callable, Dict, List, Optional, Set, Tuple
|
||||
from enum import IntEnum
|
||||
from typing import Awaitable, Callable, Dict, List, Optional, Set, Tuple
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass
|
||||
|
||||
from asyncio_throttle import Throttler
|
||||
|
@ -11,6 +11,7 @@ from irctokens import build, Line, tokenise
|
|||
|
||||
from .ircv3 import Capability, CAPS
|
||||
from .interface import ConnectionParams, IServer, PriorityLine, SendPriority
|
||||
from .matching import BaseResponse, Response, Numerics, ParamAny, Literal
|
||||
|
||||
sc = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
|
||||
|
||||
|
@ -33,6 +34,8 @@ class Server(IServer):
|
|||
self._cap_queue: Set[Capability] = set([])
|
||||
self._requested_caps: List[str] = []
|
||||
|
||||
self._wait_for: List[Tuple[BaseResponse, Future]] = []
|
||||
|
||||
async def send_raw(self, line: str, priority=SendPriority.DEFAULT):
|
||||
await self.send(tokenise(line), priority)
|
||||
async def send(self, line: Line, priority=SendPriority.DEFAULT):
|
||||
|
@ -67,6 +70,12 @@ class Server(IServer):
|
|||
await self._cap_ack(emit)
|
||||
|
||||
async def _on_read_line(self, line: Line):
|
||||
for i, (response, future) in enumerate(self._wait_for):
|
||||
if response.match(line):
|
||||
self._wait_for.pop(i)
|
||||
future.set_result(line)
|
||||
break
|
||||
|
||||
if line.command == "PING":
|
||||
await self.send(build("PONG", line.params))
|
||||
|
||||
|
@ -75,6 +84,11 @@ class Server(IServer):
|
|||
lines = self.recv(data)
|
||||
return lines
|
||||
|
||||
def wait_for(self, response: BaseResponse) -> Awaitable[Line]:
|
||||
future: "Future[Line]" = asyncio.Future()
|
||||
self._wait_for.append((response, future))
|
||||
return future
|
||||
|
||||
async def line_written(self, line: Line):
|
||||
pass
|
||||
async def _write_lines(self) -> List[Line]:
|
||||
|
|
Loading…
Add table
Reference in a new issue