scrap deferred wait_for, actually catch server disconnection
This commit is contained in:
parent
eb9888d0c4
commit
a264e4e347
3 changed files with 83 additions and 61 deletions
|
@ -1,7 +1,8 @@
|
||||||
from asyncio import Future
|
from asyncio import Future
|
||||||
from irctokens import Line
|
|
||||||
from typing import (Any, Awaitable, Callable, Generator, Generic, Optional,
|
from typing import (Any, Awaitable, Callable, Generator, Generic, Optional,
|
||||||
TypeVar)
|
TypeVar)
|
||||||
|
|
||||||
|
from irctokens import Line
|
||||||
from .matching import IMatchResponse
|
from .matching import IMatchResponse
|
||||||
from .interface import IServer
|
from .interface import IServer
|
||||||
from .ircv3 import TAG_LABEL
|
from .ircv3 import TAG_LABEL
|
||||||
|
@ -17,8 +18,10 @@ class MaybeAwait(Generic[TEvent]):
|
||||||
|
|
||||||
class WaitFor(object):
|
class WaitFor(object):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
response: IMatchResponse):
|
response: IMatchResponse,
|
||||||
self.response = response
|
deadline: float):
|
||||||
|
self.response = response
|
||||||
|
self.deadline = deadline
|
||||||
self._label: Optional[str] = None
|
self._label: Optional[str] = None
|
||||||
self._our_fut: "Future[Line]" = Future()
|
self._our_fut: "Future[Line]" = Future()
|
||||||
|
|
||||||
|
|
|
@ -36,13 +36,12 @@ class Bot(IBot):
|
||||||
return server
|
return server
|
||||||
|
|
||||||
async def _run_server(self, server: Server):
|
async def _run_server(self, server: Server):
|
||||||
async with anyio.create_task_group() as tg:
|
try:
|
||||||
async def _read():
|
async with anyio.create_task_group() as tg:
|
||||||
while True:
|
await tg.spawn(server._read_lines)
|
||||||
async for line, emit in server._read_lines():
|
await tg.spawn(server._send_lines)
|
||||||
pass
|
except ServerDisconnectedException:
|
||||||
await tg.spawn(_read)
|
server.disconnected = True
|
||||||
await tg.spawn(server._send_lines)
|
|
||||||
|
|
||||||
await self.disconnected(server)
|
await self.disconnected(server)
|
||||||
|
|
||||||
|
|
|
@ -58,16 +58,16 @@ class Server(IServer):
|
||||||
rate_limit=100, period=THROTTLE_TIME)
|
rate_limit=100, period=THROTTLE_TIME)
|
||||||
|
|
||||||
self.sasl_state = SASLResult.NONE
|
self.sasl_state = SASLResult.NONE
|
||||||
self.last_read = -1.0
|
self.last_read = monotonic()
|
||||||
|
|
||||||
self._sent_count: int = 0
|
self._sent_count: int = 0
|
||||||
self._send_queue: PriorityQueue[SentLine] = PriorityQueue()
|
self._send_queue: PriorityQueue[SentLine] = PriorityQueue()
|
||||||
self.desired_caps: Set[ICapability] = set([])
|
self.desired_caps: Set[ICapability] = set([])
|
||||||
|
|
||||||
self._read_queue: Deque[Tuple[Line, Optional[Emit]]] = deque()
|
self._read_queue: Deque[Line] = deque()
|
||||||
|
|
||||||
self._wait_fors: List[WaitFor] = []
|
self._wait_for: Optional[Tuple[Awaitable, WaitFor]] = None
|
||||||
self._wait_for_fut: Dict[asyncio.Task, Future[bool]] = {}
|
self._wait_for_fut: Optional[Future[WaitFor]] = None
|
||||||
|
|
||||||
self._pending_who: Deque[str] = deque()
|
self._pending_who: Deque[str] = deque()
|
||||||
|
|
||||||
|
@ -202,73 +202,93 @@ class Server(IServer):
|
||||||
line = await self.wait_for(end)
|
line = await self.wait_for(end)
|
||||||
|
|
||||||
|
|
||||||
async def _next_lines(self) -> AsyncIterable[Line]:
|
async def _read_line(self, timeout: float) -> Optional[Line]:
|
||||||
ping_sent = False
|
|
||||||
while True:
|
while True:
|
||||||
|
if self._read_queue:
|
||||||
|
return self._read_queue.popleft()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with timeout_(PING_TIMEOUT):
|
async with timeout_(timeout):
|
||||||
data = await self._reader.read(1024)
|
data = await self._reader.read(1024)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
if ping_sent:
|
return None
|
||||||
data = b"" # empty data means the socket disconnected
|
|
||||||
else:
|
|
||||||
ping_sent = True
|
|
||||||
await self.send(build("PING", ["hello"]))
|
|
||||||
continue
|
|
||||||
|
|
||||||
self.last_read = monotonic()
|
self.last_read = monotonic()
|
||||||
ping_sent = False
|
lines = self.recv(data)
|
||||||
|
|
||||||
try:
|
|
||||||
lines = self.recv(data)
|
|
||||||
except ServerDisconnectedException:
|
|
||||||
self.disconnected = True
|
|
||||||
raise
|
|
||||||
|
|
||||||
for line in lines:
|
for line in lines:
|
||||||
yield line
|
self._read_queue.append(line)
|
||||||
|
|
||||||
async def _line_or_wait(self,
|
async def _line_or_wait(self,
|
||||||
line_aw: asyncio.Task
|
line_aw: asyncio.Task
|
||||||
) -> Optional[Awaitable]:
|
) -> Optional[Tuple[Awaitable, WaitFor]]:
|
||||||
wait_for_fut: Future[bool] = Future()
|
wait_for_fut: Future[WaitFor] = Future()
|
||||||
self._wait_for_fut[line_aw] = wait_for_fut
|
self._wait_for_fut = wait_for_fut
|
||||||
|
|
||||||
done, pend = await asyncio.wait([line_aw, wait_for_fut],
|
done, pend = await asyncio.wait([line_aw, wait_for_fut],
|
||||||
return_when=asyncio.FIRST_COMPLETED)
|
return_when=asyncio.FIRST_COMPLETED)
|
||||||
del self._wait_for_fut[line_aw]
|
self._wait_for_fut = None
|
||||||
|
|
||||||
if wait_for_fut.done():
|
if wait_for_fut.done():
|
||||||
new_line_aw = list(pend)[0]
|
new_line_aw = list(pend)[0]
|
||||||
return new_line_aw
|
return (new_line_aw, wait_for_fut.result())
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def _read_lines(self) -> AsyncIterable[Tuple[Line, Optional[Emit]]]:
|
async def _read_lines(self):
|
||||||
async with anyio.create_task_group() as tg:
|
waiting_lines: List[Tuple[Line, Optional[Emit]]] = []
|
||||||
async for line in self._next_lines():
|
sent_ping = True
|
||||||
|
while True:
|
||||||
|
now = monotonic()
|
||||||
|
timeouts: List[float] = []
|
||||||
|
timeouts.append((self.last_read+PING_TIMEOUT)-now)
|
||||||
|
if self._wait_for is not None:
|
||||||
|
_, wait_for = self._wait_for
|
||||||
|
timeouts.append(wait_for.deadline-now)
|
||||||
|
|
||||||
|
line = await self._read_line(max([0.1, min(timeouts)]))
|
||||||
|
if line is None:
|
||||||
|
now = monotonic()
|
||||||
|
since = now-self.last_read
|
||||||
|
|
||||||
|
if self._wait_for is not None:
|
||||||
|
aw, wait_for = self._wait_for
|
||||||
|
if wait_for.deadline <= now:
|
||||||
|
self._wait_for = None
|
||||||
|
await aw
|
||||||
|
|
||||||
|
if since >= PING_TIMEOUT:
|
||||||
|
if since >= (PING_TIMEOUT*2):
|
||||||
|
raise ServerDisconnectedException()
|
||||||
|
elif not sent_ping:
|
||||||
|
await self.send(build("PING", ["hello"]))
|
||||||
|
continue
|
||||||
|
else:
|
||||||
emit = self.parse_tokens(line)
|
emit = self.parse_tokens(line)
|
||||||
|
|
||||||
|
waiting_lines.append((line, emit))
|
||||||
self.line_preread(line)
|
self.line_preread(line)
|
||||||
|
|
||||||
for i, wait_for in enumerate(self._wait_fors):
|
if self._wait_for is not None:
|
||||||
|
aw, wait_for = self._wait_for
|
||||||
if wait_for.match(self, line):
|
if wait_for.match(self, line):
|
||||||
wait_for.resolve(line)
|
wait_for.resolve(line)
|
||||||
self._wait_fors.pop(i)
|
self._wait_for = await self._line_or_wait(aw)
|
||||||
|
if self._wait_for is not None:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for i in range(len(waiting_lines)):
|
||||||
|
line, emit = waiting_lines.pop(0)
|
||||||
|
line_aw = self._on_read(line, emit)
|
||||||
|
self._wait_for = await self._line_or_wait(line_aw)
|
||||||
|
if self._wait_for is not None:
|
||||||
break
|
break
|
||||||
|
|
||||||
line_aw = asyncio.create_task(self._on_read(line, emit))
|
|
||||||
new_wait = await self._line_or_wait(line_aw)
|
|
||||||
if new_wait is not None:
|
|
||||||
async def _aw():
|
|
||||||
await new_wait
|
|
||||||
await tg.spawn(_aw)
|
|
||||||
|
|
||||||
yield (line, emit)
|
|
||||||
|
|
||||||
async def wait_for(self,
|
async def wait_for(self,
|
||||||
response: Union[IMatchResponse, Set[IMatchResponse]],
|
response: Union[IMatchResponse, Set[IMatchResponse]],
|
||||||
sent_aw: Optional[Awaitable[SentLine]]=None,
|
sent_aw: Optional[Awaitable[SentLine]]=None,
|
||||||
timeout: float=WAIT_TIMEOUT
|
timeout: float=WAIT_TIMEOUT
|
||||||
) -> Line:
|
) -> Line:
|
||||||
|
|
||||||
response_obj: IMatchResponse
|
response_obj: IMatchResponse
|
||||||
|
@ -277,13 +297,14 @@ class Server(IServer):
|
||||||
else:
|
else:
|
||||||
response_obj = response
|
response_obj = response
|
||||||
|
|
||||||
our_wait_for = WaitFor(response_obj)
|
deadline = monotonic()+timeout
|
||||||
self._wait_fors.append(our_wait_for)
|
our_wait_for = WaitFor(response_obj, deadline)
|
||||||
|
if self._wait_for_fut is not None:
|
||||||
cur_task = asyncio.current_task()
|
self._wait_for_fut.set_result(our_wait_for)
|
||||||
if cur_task is not None and cur_task in self._wait_for_fut:
|
else:
|
||||||
wait_for_fut = self._wait_for_fut[cur_task]
|
cur_task = asyncio.current_task()
|
||||||
wait_for_fut.set_result(True)
|
if cur_task is not None:
|
||||||
|
self._wait_for = (cur_task, our_wait_for)
|
||||||
|
|
||||||
if sent_aw is not None:
|
if sent_aw is not None:
|
||||||
sent_line = await sent_aw
|
sent_line = await sent_aw
|
||||||
|
@ -297,8 +318,7 @@ class Server(IServer):
|
||||||
if (line.command in ["PRIVMSG", "NOTICE", "TAGMSG"] and
|
if (line.command in ["PRIVMSG", "NOTICE", "TAGMSG"] and
|
||||||
not self.cap_agreed(CAP_ECHO)):
|
not self.cap_agreed(CAP_ECHO)):
|
||||||
new_line = line.with_source(self.hostmask())
|
new_line = line.with_source(self.hostmask())
|
||||||
emit = self.parse_tokens(new_line)
|
self._read_queue.append(new_line)
|
||||||
self._read_queue.append((new_line, emit))
|
|
||||||
|
|
||||||
async def _send_lines(self):
|
async def _send_lines(self):
|
||||||
while True:
|
while True:
|
||||||
|
|
Loading…
Add table
Reference in a new issue