scrap deferred wait_for, actually catch server disconnection

This commit is contained in:
jesopo 2020-09-24 19:43:03 +00:00
parent eb9888d0c4
commit a264e4e347
3 changed files with 83 additions and 61 deletions

View file

@ -1,7 +1,8 @@
from asyncio import Future from asyncio import Future
from irctokens import Line
from typing import (Any, Awaitable, Callable, Generator, Generic, Optional, from typing import (Any, Awaitable, Callable, Generator, Generic, Optional,
TypeVar) TypeVar)
from irctokens import Line
from .matching import IMatchResponse from .matching import IMatchResponse
from .interface import IServer from .interface import IServer
from .ircv3 import TAG_LABEL from .ircv3 import TAG_LABEL
@ -17,8 +18,10 @@ class MaybeAwait(Generic[TEvent]):
class WaitFor(object): class WaitFor(object):
def __init__(self, def __init__(self,
response: IMatchResponse): response: IMatchResponse,
self.response = response deadline: float):
self.response = response
self.deadline = deadline
self._label: Optional[str] = None self._label: Optional[str] = None
self._our_fut: "Future[Line]" = Future() self._our_fut: "Future[Line]" = Future()

View file

@ -36,13 +36,12 @@ class Bot(IBot):
return server return server
async def _run_server(self, server: Server): async def _run_server(self, server: Server):
async with anyio.create_task_group() as tg: try:
async def _read(): async with anyio.create_task_group() as tg:
while True: await tg.spawn(server._read_lines)
async for line, emit in server._read_lines(): await tg.spawn(server._send_lines)
pass except ServerDisconnectedException:
await tg.spawn(_read) server.disconnected = True
await tg.spawn(server._send_lines)
await self.disconnected(server) await self.disconnected(server)

View file

@ -58,16 +58,16 @@ class Server(IServer):
rate_limit=100, period=THROTTLE_TIME) rate_limit=100, period=THROTTLE_TIME)
self.sasl_state = SASLResult.NONE self.sasl_state = SASLResult.NONE
self.last_read = -1.0 self.last_read = monotonic()
self._sent_count: int = 0 self._sent_count: int = 0
self._send_queue: PriorityQueue[SentLine] = PriorityQueue() self._send_queue: PriorityQueue[SentLine] = PriorityQueue()
self.desired_caps: Set[ICapability] = set([]) self.desired_caps: Set[ICapability] = set([])
self._read_queue: Deque[Tuple[Line, Optional[Emit]]] = deque() self._read_queue: Deque[Line] = deque()
self._wait_fors: List[WaitFor] = [] self._wait_for: Optional[Tuple[Awaitable, WaitFor]] = None
self._wait_for_fut: Dict[asyncio.Task, Future[bool]] = {} self._wait_for_fut: Optional[Future[WaitFor]] = None
self._pending_who: Deque[str] = deque() self._pending_who: Deque[str] = deque()
@ -202,73 +202,93 @@ class Server(IServer):
line = await self.wait_for(end) line = await self.wait_for(end)
async def _next_lines(self) -> AsyncIterable[Line]: async def _read_line(self, timeout: float) -> Optional[Line]:
ping_sent = False
while True: while True:
if self._read_queue:
return self._read_queue.popleft()
try: try:
async with timeout_(PING_TIMEOUT): async with timeout_(timeout):
data = await self._reader.read(1024) data = await self._reader.read(1024)
except asyncio.TimeoutError: except asyncio.TimeoutError:
if ping_sent: return None
data = b"" # empty data means the socket disconnected
else:
ping_sent = True
await self.send(build("PING", ["hello"]))
continue
self.last_read = monotonic() self.last_read = monotonic()
ping_sent = False lines = self.recv(data)
try:
lines = self.recv(data)
except ServerDisconnectedException:
self.disconnected = True
raise
for line in lines: for line in lines:
yield line self._read_queue.append(line)
async def _line_or_wait(self, async def _line_or_wait(self,
line_aw: asyncio.Task line_aw: asyncio.Task
) -> Optional[Awaitable]: ) -> Optional[Tuple[Awaitable, WaitFor]]:
wait_for_fut: Future[bool] = Future() wait_for_fut: Future[WaitFor] = Future()
self._wait_for_fut[line_aw] = wait_for_fut self._wait_for_fut = wait_for_fut
done, pend = await asyncio.wait([line_aw, wait_for_fut], done, pend = await asyncio.wait([line_aw, wait_for_fut],
return_when=asyncio.FIRST_COMPLETED) return_when=asyncio.FIRST_COMPLETED)
del self._wait_for_fut[line_aw] self._wait_for_fut = None
if wait_for_fut.done(): if wait_for_fut.done():
new_line_aw = list(pend)[0] new_line_aw = list(pend)[0]
return new_line_aw return (new_line_aw, wait_for_fut.result())
else: else:
return None return None
async def _read_lines(self) -> AsyncIterable[Tuple[Line, Optional[Emit]]]: async def _read_lines(self):
async with anyio.create_task_group() as tg: waiting_lines: List[Tuple[Line, Optional[Emit]]] = []
async for line in self._next_lines(): sent_ping = True
while True:
now = monotonic()
timeouts: List[float] = []
timeouts.append((self.last_read+PING_TIMEOUT)-now)
if self._wait_for is not None:
_, wait_for = self._wait_for
timeouts.append(wait_for.deadline-now)
line = await self._read_line(max([0.1, min(timeouts)]))
if line is None:
now = monotonic()
since = now-self.last_read
if self._wait_for is not None:
aw, wait_for = self._wait_for
if wait_for.deadline <= now:
self._wait_for = None
await aw
if since >= PING_TIMEOUT:
if since >= (PING_TIMEOUT*2):
raise ServerDisconnectedException()
elif not sent_ping:
await self.send(build("PING", ["hello"]))
continue
else:
emit = self.parse_tokens(line) emit = self.parse_tokens(line)
waiting_lines.append((line, emit))
self.line_preread(line) self.line_preread(line)
for i, wait_for in enumerate(self._wait_fors): if self._wait_for is not None:
aw, wait_for = self._wait_for
if wait_for.match(self, line): if wait_for.match(self, line):
wait_for.resolve(line) wait_for.resolve(line)
self._wait_fors.pop(i) self._wait_for = await self._line_or_wait(aw)
if self._wait_for is not None:
continue
else:
continue
for i in range(len(waiting_lines)):
line, emit = waiting_lines.pop(0)
line_aw = self._on_read(line, emit)
self._wait_for = await self._line_or_wait(line_aw)
if self._wait_for is not None:
break break
line_aw = asyncio.create_task(self._on_read(line, emit))
new_wait = await self._line_or_wait(line_aw)
if new_wait is not None:
async def _aw():
await new_wait
await tg.spawn(_aw)
yield (line, emit)
async def wait_for(self, async def wait_for(self,
response: Union[IMatchResponse, Set[IMatchResponse]], response: Union[IMatchResponse, Set[IMatchResponse]],
sent_aw: Optional[Awaitable[SentLine]]=None, sent_aw: Optional[Awaitable[SentLine]]=None,
timeout: float=WAIT_TIMEOUT timeout: float=WAIT_TIMEOUT
) -> Line: ) -> Line:
response_obj: IMatchResponse response_obj: IMatchResponse
@ -277,13 +297,14 @@ class Server(IServer):
else: else:
response_obj = response response_obj = response
our_wait_for = WaitFor(response_obj) deadline = monotonic()+timeout
self._wait_fors.append(our_wait_for) our_wait_for = WaitFor(response_obj, deadline)
if self._wait_for_fut is not None:
cur_task = asyncio.current_task() self._wait_for_fut.set_result(our_wait_for)
if cur_task is not None and cur_task in self._wait_for_fut: else:
wait_for_fut = self._wait_for_fut[cur_task] cur_task = asyncio.current_task()
wait_for_fut.set_result(True) if cur_task is not None:
self._wait_for = (cur_task, our_wait_for)
if sent_aw is not None: if sent_aw is not None:
sent_line = await sent_aw sent_line = await sent_aw
@ -297,8 +318,7 @@ class Server(IServer):
if (line.command in ["PRIVMSG", "NOTICE", "TAGMSG"] and if (line.command in ["PRIVMSG", "NOTICE", "TAGMSG"] and
not self.cap_agreed(CAP_ECHO)): not self.cap_agreed(CAP_ECHO)):
new_line = line.with_source(self.hostmask()) new_line = line.with_source(self.hostmask())
emit = self.parse_tokens(new_line) self._read_queue.append(new_line)
self._read_queue.append((new_line, emit))
async def _send_lines(self): async def _send_lines(self):
while True: while True: