diff --git a/ircrobots/asyncs.py b/ircrobots/asyncs.py index 78825fc..fe1df89 100644 --- a/ircrobots/asyncs.py +++ b/ircrobots/asyncs.py @@ -1,4 +1,7 @@ -from typing import Any, Awaitable, Callable, Generator, Generic, TypeVar +from asyncio import Future +from irctokens import Line +from typing import Any, Awaitable, Callable, Generator, Generic, TypeVar +from .matching import IMatchResponse TEvent = TypeVar("TEvent") class MaybeAwait(Generic[TEvent]): @@ -8,3 +11,14 @@ class MaybeAwait(Generic[TEvent]): def __await__(self) -> Generator[Any, None, TEvent]: coro = self._func() return coro.__await__() + +class WaitFor(object): + def __init__(self, response: IMatchResponse): + self.response = response + self._fut: "Future[Line]" = Future() + + def __await__(self) -> Generator[Any, None, Line]: + return self._fut.__await__() + + def resolve(self, line: Line): + self._fut.set_result(line) diff --git a/ircrobots/server.py b/ircrobots/server.py index 6ad95e6..5fc5244 100644 --- a/ircrobots/server.py +++ b/ircrobots/server.py @@ -14,7 +14,7 @@ from .ircv3 import (CAPContext, sts_transmute, CAP_ECHO, CAP_SASL, from .sasl import SASLContext, SASLResult from .join_info import WHOContext from .matching import ResponseOr, Responses, Response, ANY, Folded, Nickname -from .asyncs import MaybeAwait +from .asyncs import MaybeAwait, WaitFor from .struct import Whois from .params import ConnectionParams, SASLParams, STSPolicy from .interface import (IBot, ICapability, IServer, SentLine, SendPriority, @@ -24,8 +24,6 @@ from .interface import ITCPTransport, ITCPReader, ITCPWriter THROTTLE_RATE = 4 # lines THROTTLE_TIME = 2 # seconds -WAIT_TUP = Tuple[IMatchResponse, "Future[Line]"] - class Server(IServer): _reader: ITCPReader _writer: ITCPWriter @@ -48,10 +46,8 @@ class Server(IServer): self._read_queue: Deque[Tuple[Line, Optional[Emit]]] = deque() - self._wait_for: List[ - Tuple[Awaitable, IMatchResponse, "Future[Line]"]] = [] - - self._wait_for_fut: Optional["Future[WAIT_TUP]"] = None + self._wait_for: List[Tuple[Awaitable, WaitFor]] = [] + self._wait_for_fut: Optional["Future[WaitFor]"] = None def hostmask(self) -> str: hostmask = self.nickname @@ -157,16 +153,15 @@ class Server(IServer): await self.send(build("PONG", line.params)) async def _line_or_wait(self, line_fut: Awaitable): - wait_for_fut: Future[WAIT_TUP] = Future() + wait_for_fut: Future[WaitFor] = Future() self._wait_for_fut = wait_for_fut done, pend = await asyncio.wait([line_fut, wait_for_fut], return_when=asyncio.FIRST_COMPLETED) if wait_for_fut.done(): - response, fut = await wait_for_fut new_line_fut = list(pend)[0] - self._wait_for.append((new_line_fut, response, fut)) + self._wait_for.append((new_line_fut, await wait_for_fut)) async def next_line(self) -> Tuple[Line, Optional[Emit]]: if self._read_queue: @@ -192,9 +187,9 @@ class Server(IServer): await self._on_read_line(line) await self.line_read(line) - for i, (aw, response, fut) in enumerate(self._wait_for): - if response.match(self, line): - fut.set_result(line) + for i, (aw, wait_for) in enumerate(self._wait_for): + if wait_for.response.match(self, line): + wait_for.resolve(line) self._wait_for.pop(i) await self._line_or_wait(aw) break @@ -208,9 +203,9 @@ class Server(IServer): if wait_for_fut is not None: self._wait_for_fut = None - our_fut: "Future[Line]" = Future() - wait_for_fut.set_result((response, our_fut)) - return await our_fut + our_wait_for = WaitFor(response) + wait_for_fut.set_result(our_wait_for) + return await our_wait_for raise Exception() async def _on_write_line(self, line: Line):