switch _next_lines and _read_lines to generators. taskgroup wait_fors!

This commit is contained in:
jesopo 2020-06-21 16:46:36 +01:00
parent 75c12d83e8
commit 883f09e31c
2 changed files with 41 additions and 37 deletions

View file

@ -39,7 +39,8 @@ class Bot(IBot):
async with anyio.create_task_group() as tg: async with anyio.create_task_group() as tg:
async def _read(): async def _read():
while True: while True:
await server._read_lines() async for line, emit in server._read_lines():
pass
await tg.spawn(_read) await tg.spawn(_read)
await tg.spawn(server._send_lines) await tg.spawn(server._send_lines)

View file

@ -1,10 +1,11 @@
import asyncio import asyncio
from asyncio import Future, PriorityQueue from asyncio import Future, PriorityQueue
from typing import (Awaitable, Deque, Dict, List, Optional, Set, Tuple, from typing import (AsyncIterable, Awaitable, Deque, Dict, Iterable, List,
Union) Optional, Set, Tuple, Union)
from collections import deque from collections import deque
from time import monotonic from time import monotonic
import anyio
from asyncio_throttle import Throttler from asyncio_throttle import Throttler
from async_timeout import timeout from async_timeout import timeout
from ircstates import Emit, Channel, ChannelUser from ircstates import Emit, Channel, ChannelUser
@ -64,8 +65,8 @@ class Server(IServer):
self._read_queue: Deque[Tuple[Line, Optional[Emit]]] = deque() self._read_queue: Deque[Tuple[Line, Optional[Emit]]] = deque()
self._wait_fors: List[Tuple[WaitFor, Optional[Awaitable]]] = [] self._wait_fors: List[WaitFor] = []
self._wait_for_fut: Optional[Future[WaitFor]] = None self._wait_for_fut: Dict[str, Future[bool]] = {}
self._pending_who: Deque[str] = deque() self._pending_who: Deque[str] = deque()
@ -200,7 +201,7 @@ class Server(IServer):
line = await self.wait_for(end) line = await self.wait_for(end)
async def _next_lines(self) -> List[Tuple[Line, Optional[Emit]]]: async def _next_lines(self) -> AsyncIterable[Tuple[Line, Optional[Emit]]]:
ping_sent = False ping_sent = False
while True: while True:
try: try:
@ -223,45 +224,46 @@ class Server(IServer):
self.disconnected = True self.disconnected = True
raise raise
return lines for both in lines:
yield both
async def _line_or_wait(self, async def _line_or_wait(self,
line_aw: Awaitable line_aw: asyncio.Task
) -> Optional[Tuple[WaitFor, Awaitable]]: ) -> Optional[Awaitable]:
wait_for_fut: Future[WaitFor] = Future() task_name = line_aw.get_name()
self._wait_for_fut = wait_for_fut wait_for_fut: Future[bool] = Future()
self._wait_for_fut[task_name] = wait_for_fut
done, pend = await asyncio.wait([line_aw, wait_for_fut], done, pend = await asyncio.wait([line_aw, wait_for_fut],
return_when=asyncio.FIRST_COMPLETED) return_when=asyncio.FIRST_COMPLETED)
self._wait_for_fut = None del self._wait_for_fut[task_name]
if wait_for_fut.done(): if wait_for_fut.done():
new_line_aw = list(pend)[0] new_line_aw = list(pend)[0]
return (await wait_for_fut), new_line_aw return new_line_aw
else: else:
return None return None
async def _read_lines(self) -> List[Tuple[Line, Optional[Emit]]]: async def _read_lines(self) -> AsyncIterable[Tuple[Line, Optional[Emit]]]:
lines = await self._next_lines() async with anyio.create_task_group() as tg:
for line, emit in lines: async for both in self._next_lines():
self.line_preread(line) line, emit = both
self.line_preread(line)
for i, (wait_for, aw) in enumerate(self._wait_fors): for i, wait_for in enumerate(self._wait_fors):
if wait_for.match(self, line): if wait_for.match(self, line):
wait_for.resolve(line) wait_for.resolve(line)
self._wait_fors.pop(i)
break
if aw is not None: line_aw = asyncio.create_task(self._on_read(line, emit))
new_wait_for = await self._line_or_wait(aw) new_wait = await self._line_or_wait(line_aw)
if new_wait_for is not None: if new_wait is not None:
self._wait_fors.append(new_wait_for) async def _aw():
self._wait_fors.pop(i) await new_wait
break await tg.spawn(_aw)
line_aw = self._on_read(line, emit) yield both
new_wait_for = await self._line_or_wait(line_aw)
if new_wait_for is not None:
self._wait_fors.append(new_wait_for)
return lines
async def wait_for(self, async def wait_for(self,
response: Union[IMatchResponse, Set[IMatchResponse]], response: Union[IMatchResponse, Set[IMatchResponse]],
@ -275,13 +277,14 @@ class Server(IServer):
response_obj = response response_obj = response
our_wait_for = WaitFor(response_obj) our_wait_for = WaitFor(response_obj)
self._wait_fors.append(our_wait_for)
wait_for_fut = self._wait_for_fut cur_task = asyncio.current_task()
if wait_for_fut is not None: if cur_task is not None:
self._wait_for_fut = None cur_task_name = cur_task.get_name()
wait_for_fut.set_result(our_wait_for) if cur_task_name in self._wait_for_fut:
else: wait_for_fut = self._wait_for_fut[cur_task_name]
self._wait_fors.append((our_wait_for, None)) wait_for_fut.set_result(True)
if sent_aw is not None: if sent_aw is not None:
sent_line = await sent_aw sent_line = await sent_aw