simplify wait_for

This commit is contained in:
jesopo 2021-05-12 10:52:39 +00:00
parent 90fb4b7bba
commit 6a05370a12
2 changed files with 45 additions and 80 deletions

View file

@ -6,6 +6,7 @@ from collections import deque
from time import monotonic from time import monotonic
import anyio import anyio
from asyncio_rlock import RLock
from asyncio_throttle import Throttler from asyncio_throttle import Throttler
from async_timeout import timeout as timeout_ from async_timeout import timeout as timeout_
from ircstates import Emit, Channel, ChannelUser from ircstates import Emit, Channel, ChannelUser
@ -63,10 +64,13 @@ class Server(IServer):
self._send_queue: PriorityQueue[SentLine] = PriorityQueue() self._send_queue: PriorityQueue[SentLine] = PriorityQueue()
self.desired_caps: Set[ICapability] = set([]) self.desired_caps: Set[ICapability] = set([])
self._read_queue: Deque[Line] = deque() self._read_queue: Deque[Line] = deque()
self._process_queue: Deque[Line] = deque()
self._wait_for: Optional[Tuple[Awaitable, WaitFor]] = None self._read_lguard = RLock()
self._wait_for_fut: Optional[Future[WaitFor]] = None self.read_lock = self._read_lguard
self._read_lwork = asyncio.Lock()
self._wait_for = asyncio.Event()
self._pending_who: Deque[str] = deque() self._pending_who: Deque[str] = deque()
self._alt_nicks: List[str] = [] self._alt_nicks: List[str] = []
@ -273,76 +277,42 @@ class Server(IServer):
self.last_read = monotonic() self.last_read = monotonic()
lines = self.recv(data) lines = self.recv(data)
for line in lines: for line in lines:
self.line_preread(line)
self._read_queue.append(line) self._read_queue.append(line)
async def _line_or_wait(self,
line_aw: asyncio.Task
) -> Optional[Tuple[Awaitable, WaitFor]]:
wait_for_fut: Future[WaitFor] = Future()
self._wait_for_fut = wait_for_fut
done, pend = await asyncio.wait([line_aw, wait_for_fut],
return_when=asyncio.FIRST_COMPLETED)
self._wait_for_fut = None
if wait_for_fut.done():
new_line_aw = list(pend)[0]
return (new_line_aw, wait_for_fut.result())
else:
return None
async def _read_lines(self): async def _read_lines(self):
waiting_lines: List[Tuple[Line, Optional[Emit]]] = [] ping_sent = False
sent_ping = False
while True: while True:
now = monotonic() async with self._read_lguard:
timeouts: List[float] = [] pass
timeouts.append((self.last_read+PING_TIMEOUT)-now)
if self._wait_for is not None:
_, wait_for = self._wait_for
timeouts.append(wait_for.deadline-now)
line = await self._read_line(max([0.1, min(timeouts)])) if not self._process_queue:
if line is None: async with self._read_lwork:
now = monotonic() read_aw = self._read_line(PING_TIMEOUT)
since = now-self.last_read dones, notdones = await asyncio.wait(
[read_aw, self._wait_for.wait()],
return_when=asyncio.FIRST_COMPLETED
)
self._wait_for.clear()
if self._wait_for is not None: for done in dones:
aw, wait_for = self._wait_for if isinstance(done.result(), Line):
if wait_for.deadline <= now: line = done.result()
self._wait_for = None self._process_queue.append(line)
await aw elif done.result() is None:
if ping_sent:
await self.send(build("PING", ["hello"]))
ping_sent = True
else:
await self.disconnect()
raise ServerDisconnectedException()
for notdone in notdones:
notdone.cancel()
if since >= PING_TIMEOUT:
if since >= (PING_TIMEOUT*2):
raise ServerDisconnectedException()
elif not sent_ping:
sent_ping = True
await self.send(build("PING", ["hello"]))
continue
else: else:
sent_ping = False line = self._process_queue.popleft()
emit = self.parse_tokens(line) emit = self.parse_tokens(line)
await self._on_read(line, emit)
waiting_lines.append((line, emit))
self.line_preread(line)
if self._wait_for is not None:
aw, wait_for = self._wait_for
if wait_for.match(self, line):
wait_for.resolve(line)
self._wait_for = await self._line_or_wait(aw)
if self._wait_for is not None:
continue
else:
continue
for i in range(len(waiting_lines)):
line, emit = waiting_lines.pop(0)
line_aw = self._on_read(line, emit)
self._wait_for = await self._line_or_wait(line_aw)
if self._wait_for is not None:
break
async def wait_for(self, async def wait_for(self,
response: Union[IMatchResponse, Set[IMatchResponse]], response: Union[IMatchResponse, Set[IMatchResponse]],
@ -356,22 +326,16 @@ class Server(IServer):
else: else:
response_obj = response response_obj = response
deadline = monotonic()+timeout async with self._read_lguard:
our_wait_for = WaitFor(response_obj, deadline) self._wait_for.set()
if self._wait_for_fut is not None: async with self._read_lwork:
self._wait_for_fut.set_result(our_wait_for) async with timeout_(timeout):
else: while True:
cur_task = asyncio.current_task() line = await self._read_line(timeout)
if cur_task is not None: if line:
self._wait_for = (cur_task, our_wait_for) self._process_queue.append(line)
if response_obj.match(self, line):
if sent_aw is not None: return line
sent_line = await sent_aw
label = str(sent_line.id)
our_wait_for.with_label(label)
async with timeout_(timeout):
return (await our_wait_for)
async def _on_send_line(self, line: Line): async def _on_send_line(self, line: Line):
if (line.command in ["PRIVMSG", "NOTICE", "TAGMSG"] and if (line.command in ["PRIVMSG", "NOTICE", "TAGMSG"] and

View file

@ -1,4 +1,5 @@
anyio ~=2.0.2 anyio ~=2.0.2
asyncio-rlock ~=0.1.0
asyncio-throttle ~=1.0.1 asyncio-throttle ~=1.0.1
dataclasses ~=0.6; python_version<"3.7" dataclasses ~=0.6; python_version<"3.7"
ircstates ~=0.11.8 ircstates ~=0.11.8