simplify wait_for
This commit is contained in:
parent
90fb4b7bba
commit
6a05370a12
2 changed files with 45 additions and 80 deletions
|
@ -6,6 +6,7 @@ from collections import deque
|
|||
from time import monotonic
|
||||
|
||||
import anyio
|
||||
from asyncio_rlock import RLock
|
||||
from asyncio_throttle import Throttler
|
||||
from async_timeout import timeout as timeout_
|
||||
from ircstates import Emit, Channel, ChannelUser
|
||||
|
@ -63,10 +64,13 @@ class Server(IServer):
|
|||
self._send_queue: PriorityQueue[SentLine] = PriorityQueue()
|
||||
self.desired_caps: Set[ICapability] = set([])
|
||||
|
||||
self._read_queue: Deque[Line] = deque()
|
||||
self._read_queue: Deque[Line] = deque()
|
||||
self._process_queue: Deque[Line] = deque()
|
||||
|
||||
self._wait_for: Optional[Tuple[Awaitable, WaitFor]] = None
|
||||
self._wait_for_fut: Optional[Future[WaitFor]] = None
|
||||
self._read_lguard = RLock()
|
||||
self.read_lock = self._read_lguard
|
||||
self._read_lwork = asyncio.Lock()
|
||||
self._wait_for = asyncio.Event()
|
||||
|
||||
self._pending_who: Deque[str] = deque()
|
||||
self._alt_nicks: List[str] = []
|
||||
|
@ -273,76 +277,42 @@ class Server(IServer):
|
|||
self.last_read = monotonic()
|
||||
lines = self.recv(data)
|
||||
for line in lines:
|
||||
self.line_preread(line)
|
||||
self._read_queue.append(line)
|
||||
|
||||
async def _line_or_wait(self,
|
||||
line_aw: asyncio.Task
|
||||
) -> Optional[Tuple[Awaitable, WaitFor]]:
|
||||
wait_for_fut: Future[WaitFor] = Future()
|
||||
self._wait_for_fut = wait_for_fut
|
||||
|
||||
done, pend = await asyncio.wait([line_aw, wait_for_fut],
|
||||
return_when=asyncio.FIRST_COMPLETED)
|
||||
self._wait_for_fut = None
|
||||
|
||||
if wait_for_fut.done():
|
||||
new_line_aw = list(pend)[0]
|
||||
return (new_line_aw, wait_for_fut.result())
|
||||
else:
|
||||
return None
|
||||
|
||||
async def _read_lines(self):
|
||||
waiting_lines: List[Tuple[Line, Optional[Emit]]] = []
|
||||
sent_ping = False
|
||||
ping_sent = False
|
||||
while True:
|
||||
now = monotonic()
|
||||
timeouts: List[float] = []
|
||||
timeouts.append((self.last_read+PING_TIMEOUT)-now)
|
||||
if self._wait_for is not None:
|
||||
_, wait_for = self._wait_for
|
||||
timeouts.append(wait_for.deadline-now)
|
||||
async with self._read_lguard:
|
||||
pass
|
||||
|
||||
line = await self._read_line(max([0.1, min(timeouts)]))
|
||||
if line is None:
|
||||
now = monotonic()
|
||||
since = now-self.last_read
|
||||
if not self._process_queue:
|
||||
async with self._read_lwork:
|
||||
read_aw = self._read_line(PING_TIMEOUT)
|
||||
dones, notdones = await asyncio.wait(
|
||||
[read_aw, self._wait_for.wait()],
|
||||
return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
self._wait_for.clear()
|
||||
|
||||
if self._wait_for is not None:
|
||||
aw, wait_for = self._wait_for
|
||||
if wait_for.deadline <= now:
|
||||
self._wait_for = None
|
||||
await aw
|
||||
for done in dones:
|
||||
if isinstance(done.result(), Line):
|
||||
line = done.result()
|
||||
self._process_queue.append(line)
|
||||
elif done.result() is None:
|
||||
if ping_sent:
|
||||
await self.send(build("PING", ["hello"]))
|
||||
ping_sent = True
|
||||
else:
|
||||
await self.disconnect()
|
||||
raise ServerDisconnectedException()
|
||||
for notdone in notdones:
|
||||
notdone.cancel()
|
||||
|
||||
if since >= PING_TIMEOUT:
|
||||
if since >= (PING_TIMEOUT*2):
|
||||
raise ServerDisconnectedException()
|
||||
elif not sent_ping:
|
||||
sent_ping = True
|
||||
await self.send(build("PING", ["hello"]))
|
||||
continue
|
||||
else:
|
||||
sent_ping = False
|
||||
line = self._process_queue.popleft()
|
||||
emit = self.parse_tokens(line)
|
||||
|
||||
waiting_lines.append((line, emit))
|
||||
self.line_preread(line)
|
||||
|
||||
if self._wait_for is not None:
|
||||
aw, wait_for = self._wait_for
|
||||
if wait_for.match(self, line):
|
||||
wait_for.resolve(line)
|
||||
self._wait_for = await self._line_or_wait(aw)
|
||||
if self._wait_for is not None:
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
|
||||
for i in range(len(waiting_lines)):
|
||||
line, emit = waiting_lines.pop(0)
|
||||
line_aw = self._on_read(line, emit)
|
||||
self._wait_for = await self._line_or_wait(line_aw)
|
||||
if self._wait_for is not None:
|
||||
break
|
||||
await self._on_read(line, emit)
|
||||
|
||||
async def wait_for(self,
|
||||
response: Union[IMatchResponse, Set[IMatchResponse]],
|
||||
|
@ -356,22 +326,16 @@ class Server(IServer):
|
|||
else:
|
||||
response_obj = response
|
||||
|
||||
deadline = monotonic()+timeout
|
||||
our_wait_for = WaitFor(response_obj, deadline)
|
||||
if self._wait_for_fut is not None:
|
||||
self._wait_for_fut.set_result(our_wait_for)
|
||||
else:
|
||||
cur_task = asyncio.current_task()
|
||||
if cur_task is not None:
|
||||
self._wait_for = (cur_task, our_wait_for)
|
||||
|
||||
if sent_aw is not None:
|
||||
sent_line = await sent_aw
|
||||
label = str(sent_line.id)
|
||||
our_wait_for.with_label(label)
|
||||
|
||||
async with timeout_(timeout):
|
||||
return (await our_wait_for)
|
||||
async with self._read_lguard:
|
||||
self._wait_for.set()
|
||||
async with self._read_lwork:
|
||||
async with timeout_(timeout):
|
||||
while True:
|
||||
line = await self._read_line(timeout)
|
||||
if line:
|
||||
self._process_queue.append(line)
|
||||
if response_obj.match(self, line):
|
||||
return line
|
||||
|
||||
async def _on_send_line(self, line: Line):
|
||||
if (line.command in ["PRIVMSG", "NOTICE", "TAGMSG"] and
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
anyio ~=2.0.2
|
||||
asyncio-rlock ~=0.1.0
|
||||
asyncio-throttle ~=1.0.1
|
||||
dataclasses ~=0.6; python_version<"3.7"
|
||||
ircstates ~=0.11.8
|
||||
|
|
Loading…
Add table
Reference in a new issue