change wait_for to not spin up nested next_line() loops

This commit is contained in:
jesopo 2020-04-23 14:42:42 +01:00
parent d51fcf0987
commit 45269a98a9
2 changed files with 41 additions and 18 deletions

View file

@ -1,3 +1,4 @@
import asyncio
from asyncio import Future, PriorityQueue from asyncio import Future, PriorityQueue
from typing import Awaitable, Deque, Dict, List, Optional, Set, Tuple from typing import Awaitable, Deque, Dict, List, Optional, Set, Tuple
from collections import deque from collections import deque
@ -23,6 +24,8 @@ from .interface import ITCPTransport, ITCPReader, ITCPWriter
THROTTLE_RATE = 4 # lines THROTTLE_RATE = 4 # lines
THROTTLE_TIME = 2 # seconds THROTTLE_TIME = 2 # seconds
WAIT_TUP = Tuple[IMatchResponse, "Future[Line]"]
class Server(IServer): class Server(IServer):
_reader: ITCPReader _reader: ITCPReader
_writer: ITCPWriter _writer: ITCPWriter
@ -40,12 +43,16 @@ class Server(IServer):
self.sasl_state = SASLResult.NONE self.sasl_state = SASLResult.NONE
self._sent_count: int = 0 self._sent_count: int = 0
self._wait_for: List[Tuple["Future[Line]", IMatchResponse]] = []
self._write_queue: PriorityQueue[SentLine] = PriorityQueue() self._write_queue: PriorityQueue[SentLine] = PriorityQueue()
self.desired_caps: Set[ICapability] = set([]) self.desired_caps: Set[ICapability] = set([])
self._read_queue: Deque[Tuple[Line, Optional[Emit]]] = deque() self._read_queue: Deque[Tuple[Line, Optional[Emit]]] = deque()
self._wait_for: List[
Tuple[Awaitable, IMatchResponse, "Future[Line]"]] = []
self._wait_for_fut: Optional["Future[WAIT_TUP]"] = None
def hostmask(self) -> str: def hostmask(self) -> str:
hostmask = self.nickname hostmask = self.nickname
if not self.username is None: if not self.username is None:
@ -149,6 +156,18 @@ class Server(IServer):
if line.command == "PING": if line.command == "PING":
await self.send(build("PONG", line.params)) await self.send(build("PONG", line.params))
async def _line_or_wait(self, line_fut: Awaitable):
wait_for_fut: Future[WAIT_TUP] = Future()
self._wait_for_fut = wait_for_fut
done, pend = await asyncio.wait([line_fut, wait_for_fut],
return_when=asyncio.FIRST_COMPLETED)
if wait_for_fut.done():
response, fut = await wait_for_fut
new_line_fut = list(pend)[0]
self._wait_for.append((new_line_fut, response, fut))
async def next_line(self) -> Tuple[Line, Optional[Emit]]: async def next_line(self) -> Tuple[Line, Optional[Emit]]:
if self._read_queue: if self._read_queue:
both = self._read_queue.popleft() both = self._read_queue.popleft()
@ -167,27 +186,31 @@ class Server(IServer):
break break
line, emit = both line, emit = both
if emit is not None: async def _line():
await self._on_read_emit(line, emit) if emit is not None:
await self._on_read_line(line) await self._on_read_emit(line, emit)
await self.line_read(line) await self._on_read_line(line)
await self.line_read(line)
await self._line_or_wait(_line())
for i, (aw, response, fut) in enumerate(self._wait_for):
if response.match(self, line):
fut.set_result(line)
self._wait_for.pop(i)
await self._line_or_wait(aw)
break
return both return both
async def wait_for(self, response: IMatchResponse) -> Line: async def wait_for(self, response: IMatchResponse) -> Line:
our_fut: "Future[Line]" = Future() wait_for_fut = self._wait_for_fut
self._wait_for.append((our_fut, response)) if wait_for_fut is not None:
while self._wait_for: self._wait_for_fut = None
both = await self.next_line()
line, emit = both
for i, (fut, waiting) in enumerate(self._wait_for): our_fut: "Future[Line]" = Future()
if waiting.match(self, line): wait_for_fut.set_result((response, our_fut))
fut.set_result(line) return await our_fut
self._wait_for.pop(i) raise Exception()
break
return await our_fut
async def _on_write_line(self, line: Line): async def _on_write_line(self, line: Line):
if (line.command == "PRIVMSG" and if (line.command == "PRIVMSG" and

View file

@ -1,5 +1,5 @@
anyio ==1.3.0 anyio ==1.3.0
asyncio-throttle ==1.0.1 asyncio-throttle ==1.0.1
dataclasses ==0.6 dataclasses ==0.6
ircstates ==0.9.5 ircstates ==0.9.6
async_stagger ==0.3.0 async_stagger ==0.3.0