scrap deferred wait_for, actually catch server disconnection
This commit is contained in:
parent
eb9888d0c4
commit
a264e4e347
3 changed files with 83 additions and 61 deletions
|
@ -1,7 +1,8 @@
|
|||
from asyncio import Future
|
||||
from irctokens import Line
|
||||
from typing import (Any, Awaitable, Callable, Generator, Generic, Optional,
|
||||
TypeVar)
|
||||
|
||||
from irctokens import Line
|
||||
from .matching import IMatchResponse
|
||||
from .interface import IServer
|
||||
from .ircv3 import TAG_LABEL
|
||||
|
@ -17,8 +18,10 @@ class MaybeAwait(Generic[TEvent]):
|
|||
|
||||
class WaitFor(object):
|
||||
def __init__(self,
|
||||
response: IMatchResponse):
|
||||
response: IMatchResponse,
|
||||
deadline: float):
|
||||
self.response = response
|
||||
self.deadline = deadline
|
||||
self._label: Optional[str] = None
|
||||
self._our_fut: "Future[Line]" = Future()
|
||||
|
||||
|
|
|
@ -36,13 +36,12 @@ class Bot(IBot):
|
|||
return server
|
||||
|
||||
async def _run_server(self, server: Server):
|
||||
try:
|
||||
async with anyio.create_task_group() as tg:
|
||||
async def _read():
|
||||
while True:
|
||||
async for line, emit in server._read_lines():
|
||||
pass
|
||||
await tg.spawn(_read)
|
||||
await tg.spawn(server._read_lines)
|
||||
await tg.spawn(server._send_lines)
|
||||
except ServerDisconnectedException:
|
||||
server.disconnected = True
|
||||
|
||||
await self.disconnected(server)
|
||||
|
||||
|
|
|
@ -58,16 +58,16 @@ class Server(IServer):
|
|||
rate_limit=100, period=THROTTLE_TIME)
|
||||
|
||||
self.sasl_state = SASLResult.NONE
|
||||
self.last_read = -1.0
|
||||
self.last_read = monotonic()
|
||||
|
||||
self._sent_count: int = 0
|
||||
self._send_queue: PriorityQueue[SentLine] = PriorityQueue()
|
||||
self.desired_caps: Set[ICapability] = set([])
|
||||
|
||||
self._read_queue: Deque[Tuple[Line, Optional[Emit]]] = deque()
|
||||
self._read_queue: Deque[Line] = deque()
|
||||
|
||||
self._wait_fors: List[WaitFor] = []
|
||||
self._wait_for_fut: Dict[asyncio.Task, Future[bool]] = {}
|
||||
self._wait_for: Optional[Tuple[Awaitable, WaitFor]] = None
|
||||
self._wait_for_fut: Optional[Future[WaitFor]] = None
|
||||
|
||||
self._pending_who: Deque[str] = deque()
|
||||
|
||||
|
@ -202,69 +202,89 @@ class Server(IServer):
|
|||
line = await self.wait_for(end)
|
||||
|
||||
|
||||
async def _next_lines(self) -> AsyncIterable[Line]:
|
||||
ping_sent = False
|
||||
async def _read_line(self, timeout: float) -> Optional[Line]:
|
||||
while True:
|
||||
if self._read_queue:
|
||||
return self._read_queue.popleft()
|
||||
|
||||
try:
|
||||
async with timeout_(PING_TIMEOUT):
|
||||
async with timeout_(timeout):
|
||||
data = await self._reader.read(1024)
|
||||
except asyncio.TimeoutError:
|
||||
if ping_sent:
|
||||
data = b"" # empty data means the socket disconnected
|
||||
else:
|
||||
ping_sent = True
|
||||
await self.send(build("PING", ["hello"]))
|
||||
continue
|
||||
return None
|
||||
|
||||
self.last_read = monotonic()
|
||||
ping_sent = False
|
||||
|
||||
try:
|
||||
lines = self.recv(data)
|
||||
except ServerDisconnectedException:
|
||||
self.disconnected = True
|
||||
raise
|
||||
|
||||
for line in lines:
|
||||
yield line
|
||||
self._read_queue.append(line)
|
||||
|
||||
async def _line_or_wait(self,
|
||||
line_aw: asyncio.Task
|
||||
) -> Optional[Awaitable]:
|
||||
wait_for_fut: Future[bool] = Future()
|
||||
self._wait_for_fut[line_aw] = wait_for_fut
|
||||
) -> Optional[Tuple[Awaitable, WaitFor]]:
|
||||
wait_for_fut: Future[WaitFor] = Future()
|
||||
self._wait_for_fut = wait_for_fut
|
||||
|
||||
done, pend = await asyncio.wait([line_aw, wait_for_fut],
|
||||
return_when=asyncio.FIRST_COMPLETED)
|
||||
del self._wait_for_fut[line_aw]
|
||||
self._wait_for_fut = None
|
||||
|
||||
if wait_for_fut.done():
|
||||
new_line_aw = list(pend)[0]
|
||||
return new_line_aw
|
||||
return (new_line_aw, wait_for_fut.result())
|
||||
else:
|
||||
return None
|
||||
|
||||
async def _read_lines(self) -> AsyncIterable[Tuple[Line, Optional[Emit]]]:
|
||||
async with anyio.create_task_group() as tg:
|
||||
async for line in self._next_lines():
|
||||
async def _read_lines(self):
|
||||
waiting_lines: List[Tuple[Line, Optional[Emit]]] = []
|
||||
sent_ping = True
|
||||
while True:
|
||||
now = monotonic()
|
||||
timeouts: List[float] = []
|
||||
timeouts.append((self.last_read+PING_TIMEOUT)-now)
|
||||
if self._wait_for is not None:
|
||||
_, wait_for = self._wait_for
|
||||
timeouts.append(wait_for.deadline-now)
|
||||
|
||||
line = await self._read_line(max([0.1, min(timeouts)]))
|
||||
if line is None:
|
||||
now = monotonic()
|
||||
since = now-self.last_read
|
||||
|
||||
if self._wait_for is not None:
|
||||
aw, wait_for = self._wait_for
|
||||
if wait_for.deadline <= now:
|
||||
self._wait_for = None
|
||||
await aw
|
||||
|
||||
if since >= PING_TIMEOUT:
|
||||
if since >= (PING_TIMEOUT*2):
|
||||
raise ServerDisconnectedException()
|
||||
elif not sent_ping:
|
||||
await self.send(build("PING", ["hello"]))
|
||||
continue
|
||||
else:
|
||||
emit = self.parse_tokens(line)
|
||||
|
||||
waiting_lines.append((line, emit))
|
||||
self.line_preread(line)
|
||||
|
||||
for i, wait_for in enumerate(self._wait_fors):
|
||||
if self._wait_for is not None:
|
||||
aw, wait_for = self._wait_for
|
||||
if wait_for.match(self, line):
|
||||
wait_for.resolve(line)
|
||||
self._wait_fors.pop(i)
|
||||
self._wait_for = await self._line_or_wait(aw)
|
||||
if self._wait_for is not None:
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
|
||||
for i in range(len(waiting_lines)):
|
||||
line, emit = waiting_lines.pop(0)
|
||||
line_aw = self._on_read(line, emit)
|
||||
self._wait_for = await self._line_or_wait(line_aw)
|
||||
if self._wait_for is not None:
|
||||
break
|
||||
|
||||
line_aw = asyncio.create_task(self._on_read(line, emit))
|
||||
new_wait = await self._line_or_wait(line_aw)
|
||||
if new_wait is not None:
|
||||
async def _aw():
|
||||
await new_wait
|
||||
await tg.spawn(_aw)
|
||||
|
||||
yield (line, emit)
|
||||
|
||||
async def wait_for(self,
|
||||
response: Union[IMatchResponse, Set[IMatchResponse]],
|
||||
sent_aw: Optional[Awaitable[SentLine]]=None,
|
||||
|
@ -277,13 +297,14 @@ class Server(IServer):
|
|||
else:
|
||||
response_obj = response
|
||||
|
||||
our_wait_for = WaitFor(response_obj)
|
||||
self._wait_fors.append(our_wait_for)
|
||||
|
||||
deadline = monotonic()+timeout
|
||||
our_wait_for = WaitFor(response_obj, deadline)
|
||||
if self._wait_for_fut is not None:
|
||||
self._wait_for_fut.set_result(our_wait_for)
|
||||
else:
|
||||
cur_task = asyncio.current_task()
|
||||
if cur_task is not None and cur_task in self._wait_for_fut:
|
||||
wait_for_fut = self._wait_for_fut[cur_task]
|
||||
wait_for_fut.set_result(True)
|
||||
if cur_task is not None:
|
||||
self._wait_for = (cur_task, our_wait_for)
|
||||
|
||||
if sent_aw is not None:
|
||||
sent_line = await sent_aw
|
||||
|
@ -297,8 +318,7 @@ class Server(IServer):
|
|||
if (line.command in ["PRIVMSG", "NOTICE", "TAGMSG"] and
|
||||
not self.cap_agreed(CAP_ECHO)):
|
||||
new_line = line.with_source(self.hostmask())
|
||||
emit = self.parse_tokens(new_line)
|
||||
self._read_queue.append((new_line, emit))
|
||||
self._read_queue.append(new_line)
|
||||
|
||||
async def _send_lines(self):
|
||||
while True:
|
||||
|
|
Loading…
Add table
Reference in a new issue