From bc70afe04b358eeeb0cdfc6ec07095c2f7840fe1 Mon Sep 17 00:00:00 2001 From: jesopo Date: Thu, 30 Apr 2020 11:22:47 +0100 Subject: [PATCH] move setting wait_for_fut result to WaitFor await in case it isn't awaited --- ircrobots/asyncs.py | 18 +++++++++++++----- ircrobots/server.py | 10 ++++------ 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/ircrobots/asyncs.py b/ircrobots/asyncs.py index ee548cb..9b9009c 100644 --- a/ircrobots/asyncs.py +++ b/ircrobots/asyncs.py @@ -14,15 +14,23 @@ class MaybeAwait(Generic[TEvent]): return coro.__await__() class WaitFor(object): - def __init__(self, response: IMatchResponse): - self.response = response - self._fut: "Future[Line]" = Future() + def __init__(self, + wait_fut: "Future[WaitFor]", + response: IMatchResponse): + self._wait_fut = wait_fut + self.response = response + self.deferred = False + self._our_fut: "Future[Line]" = Future() def __await__(self) -> Generator[Any, None, Line]: - return self._fut.__await__() + self._wait_fut.set_result(self) + return self._our_fut.__await__() + async def defer(self): + self.deferred = True + await self def match(self, server: IServer, line: Line): return self.response.match(server, line) def resolve(self, line: Line): - self._fut.set_result(line) + self._our_fut.set_result(line) diff --git a/ircrobots/server.py b/ircrobots/server.py index 8d7f3a3..84ca721 100644 --- a/ircrobots/server.py +++ b/ircrobots/server.py @@ -258,9 +258,9 @@ class Server(IServer): if wait_for is not None: break - async def wait_for(self, + def wait_for(self, response: Union[IMatchResponse, Set[IMatchResponse]] - ) -> Line: + ) -> Awaitable[Line]: response_obj: IMatchResponse if isinstance(response, set): response_obj = ResponseOr(*response) @@ -270,10 +270,8 @@ class Server(IServer): wait_for_fut = self._wait_for_fut if wait_for_fut is not None: self._wait_for_fut = None - - our_wait_for = WaitFor(response_obj) - wait_for_fut.set_result(our_wait_for) - return await our_wait_for + our_wait_for = WaitFor(wait_for_fut, response_obj) + return our_wait_for raise Exception() async def _on_send_line(self, line: Line):