object-ify WaitFor future stuff

This commit is contained in:
jesopo 2020-04-23 15:22:30 +01:00
parent f48aaded5a
commit 955c284282
2 changed files with 26 additions and 17 deletions

View file

@ -1,4 +1,7 @@
from asyncio import Future
from irctokens import Line
from typing import Any, Awaitable, Callable, Generator, Generic, TypeVar from typing import Any, Awaitable, Callable, Generator, Generic, TypeVar
from .matching import IMatchResponse
TEvent = TypeVar("TEvent") TEvent = TypeVar("TEvent")
class MaybeAwait(Generic[TEvent]): class MaybeAwait(Generic[TEvent]):
@ -8,3 +11,14 @@ class MaybeAwait(Generic[TEvent]):
def __await__(self) -> Generator[Any, None, TEvent]: def __await__(self) -> Generator[Any, None, TEvent]:
coro = self._func() coro = self._func()
return coro.__await__() return coro.__await__()
class WaitFor(object):
def __init__(self, response: IMatchResponse):
self.response = response
self._fut: "Future[Line]" = Future()
def __await__(self) -> Generator[Any, None, Line]:
return self._fut.__await__()
def resolve(self, line: Line):
self._fut.set_result(line)

View file

@ -14,7 +14,7 @@ from .ircv3 import (CAPContext, sts_transmute, CAP_ECHO, CAP_SASL,
from .sasl import SASLContext, SASLResult from .sasl import SASLContext, SASLResult
from .join_info import WHOContext from .join_info import WHOContext
from .matching import ResponseOr, Responses, Response, ANY, Folded, Nickname from .matching import ResponseOr, Responses, Response, ANY, Folded, Nickname
from .asyncs import MaybeAwait from .asyncs import MaybeAwait, WaitFor
from .struct import Whois from .struct import Whois
from .params import ConnectionParams, SASLParams, STSPolicy from .params import ConnectionParams, SASLParams, STSPolicy
from .interface import (IBot, ICapability, IServer, SentLine, SendPriority, from .interface import (IBot, ICapability, IServer, SentLine, SendPriority,
@ -24,8 +24,6 @@ from .interface import ITCPTransport, ITCPReader, ITCPWriter
THROTTLE_RATE = 4 # lines THROTTLE_RATE = 4 # lines
THROTTLE_TIME = 2 # seconds THROTTLE_TIME = 2 # seconds
WAIT_TUP = Tuple[IMatchResponse, "Future[Line]"]
class Server(IServer): class Server(IServer):
_reader: ITCPReader _reader: ITCPReader
_writer: ITCPWriter _writer: ITCPWriter
@ -48,10 +46,8 @@ class Server(IServer):
self._read_queue: Deque[Tuple[Line, Optional[Emit]]] = deque() self._read_queue: Deque[Tuple[Line, Optional[Emit]]] = deque()
self._wait_for: List[ self._wait_for: List[Tuple[Awaitable, WaitFor]] = []
Tuple[Awaitable, IMatchResponse, "Future[Line]"]] = [] self._wait_for_fut: Optional["Future[WaitFor]"] = None
self._wait_for_fut: Optional["Future[WAIT_TUP]"] = None
def hostmask(self) -> str: def hostmask(self) -> str:
hostmask = self.nickname hostmask = self.nickname
@ -157,16 +153,15 @@ class Server(IServer):
await self.send(build("PONG", line.params)) await self.send(build("PONG", line.params))
async def _line_or_wait(self, line_fut: Awaitable): async def _line_or_wait(self, line_fut: Awaitable):
wait_for_fut: Future[WAIT_TUP] = Future() wait_for_fut: Future[WaitFor] = Future()
self._wait_for_fut = wait_for_fut self._wait_for_fut = wait_for_fut
done, pend = await asyncio.wait([line_fut, wait_for_fut], done, pend = await asyncio.wait([line_fut, wait_for_fut],
return_when=asyncio.FIRST_COMPLETED) return_when=asyncio.FIRST_COMPLETED)
if wait_for_fut.done(): if wait_for_fut.done():
response, fut = await wait_for_fut
new_line_fut = list(pend)[0] new_line_fut = list(pend)[0]
self._wait_for.append((new_line_fut, response, fut)) self._wait_for.append((new_line_fut, await wait_for_fut))
async def next_line(self) -> Tuple[Line, Optional[Emit]]: async def next_line(self) -> Tuple[Line, Optional[Emit]]:
if self._read_queue: if self._read_queue:
@ -192,9 +187,9 @@ class Server(IServer):
await self._on_read_line(line) await self._on_read_line(line)
await self.line_read(line) await self.line_read(line)
for i, (aw, response, fut) in enumerate(self._wait_for): for i, (aw, wait_for) in enumerate(self._wait_for):
if response.match(self, line): if wait_for.response.match(self, line):
fut.set_result(line) wait_for.resolve(line)
self._wait_for.pop(i) self._wait_for.pop(i)
await self._line_or_wait(aw) await self._line_or_wait(aw)
break break
@ -208,9 +203,9 @@ class Server(IServer):
if wait_for_fut is not None: if wait_for_fut is not None:
self._wait_for_fut = None self._wait_for_fut = None
our_fut: "Future[Line]" = Future() our_wait_for = WaitFor(response)
wait_for_fut.set_result((response, our_fut)) wait_for_fut.set_result(our_wait_for)
return await our_fut return await our_wait_for
raise Exception() raise Exception()
async def _on_write_line(self, line: Line): async def _on_write_line(self, line: Line):