From 45269a98a94c3533cd961c2ae49a4d627ced2773 Mon Sep 17 00:00:00 2001 From: jesopo Date: Thu, 23 Apr 2020 14:42:42 +0100 Subject: [PATCH] change wait_for to not spin up nested next_line() loops --- ircrobots/server.py | 57 +++++++++++++++++++++++++++++++-------------- requirements.txt | 2 +- 2 files changed, 41 insertions(+), 18 deletions(-) diff --git a/ircrobots/server.py b/ircrobots/server.py index 058a7d4..1593de9 100644 --- a/ircrobots/server.py +++ b/ircrobots/server.py @@ -1,3 +1,4 @@ +import asyncio from asyncio import Future, PriorityQueue from typing import Awaitable, Deque, Dict, List, Optional, Set, Tuple from collections import deque @@ -23,6 +24,8 @@ from .interface import ITCPTransport, ITCPReader, ITCPWriter THROTTLE_RATE = 4 # lines THROTTLE_TIME = 2 # seconds +WAIT_TUP = Tuple[IMatchResponse, "Future[Line]"] + class Server(IServer): _reader: ITCPReader _writer: ITCPWriter @@ -40,12 +43,16 @@ class Server(IServer): self.sasl_state = SASLResult.NONE self._sent_count: int = 0 - self._wait_for: List[Tuple["Future[Line]", IMatchResponse]] = [] self._write_queue: PriorityQueue[SentLine] = PriorityQueue() self.desired_caps: Set[ICapability] = set([]) self._read_queue: Deque[Tuple[Line, Optional[Emit]]] = deque() + self._wait_for: List[ + Tuple[Awaitable, IMatchResponse, "Future[Line]"]] = [] + + self._wait_for_fut: Optional["Future[WAIT_TUP]"] = None + def hostmask(self) -> str: hostmask = self.nickname if not self.username is None: @@ -149,6 +156,18 @@ class Server(IServer): if line.command == "PING": await self.send(build("PONG", line.params)) + async def _line_or_wait(self, line_fut: Awaitable): + wait_for_fut: Future[WAIT_TUP] = Future() + self._wait_for_fut = wait_for_fut + + done, pend = await asyncio.wait([line_fut, wait_for_fut], + return_when=asyncio.FIRST_COMPLETED) + + if wait_for_fut.done(): + response, fut = await wait_for_fut + new_line_fut = list(pend)[0] + self._wait_for.append((new_line_fut, response, fut)) + async def next_line(self) -> Tuple[Line, Optional[Emit]]: if self._read_queue: both = self._read_queue.popleft() @@ -167,27 +186,31 @@ class Server(IServer): break line, emit = both - if emit is not None: - await self._on_read_emit(line, emit) - await self._on_read_line(line) - await self.line_read(line) + async def _line(): + if emit is not None: + await self._on_read_emit(line, emit) + await self._on_read_line(line) + await self.line_read(line) + + await self._line_or_wait(_line()) + for i, (aw, response, fut) in enumerate(self._wait_for): + if response.match(self, line): + fut.set_result(line) + self._wait_for.pop(i) + await self._line_or_wait(aw) + break return both async def wait_for(self, response: IMatchResponse) -> Line: - our_fut: "Future[Line]" = Future() - self._wait_for.append((our_fut, response)) - while self._wait_for: - both = await self.next_line() - line, emit = both + wait_for_fut = self._wait_for_fut + if wait_for_fut is not None: + self._wait_for_fut = None - for i, (fut, waiting) in enumerate(self._wait_for): - if waiting.match(self, line): - fut.set_result(line) - self._wait_for.pop(i) - break - - return await our_fut + our_fut: "Future[Line]" = Future() + wait_for_fut.set_result((response, our_fut)) + return await our_fut + raise Exception() async def _on_write_line(self, line: Line): if (line.command == "PRIVMSG" and diff --git a/requirements.txt b/requirements.txt index d4ebe13..c7b6c7c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ anyio ==1.3.0 asyncio-throttle ==1.0.1 dataclasses ==0.6 -ircstates ==0.9.5 +ircstates ==0.9.6 async_stagger ==0.3.0