diff --git a/ircrobots/asyncs.py b/ircrobots/asyncs.py index 3a85239..ce216c8 100644 --- a/ircrobots/asyncs.py +++ b/ircrobots/asyncs.py @@ -17,21 +17,16 @@ class MaybeAwait(Generic[TEvent]): class WaitFor(object): def __init__(self, - wait_fut: "Future[WaitFor]", - response: IMatchResponse, - label: Optional[str]): - self._wait_fut = wait_fut + response: IMatchResponse): self.response = response - self._label = label - self.deferred = False + self._label: Optional[str] = None self._our_fut: "Future[Line]" = Future() def __await__(self) -> Generator[Any, None, Line]: - self._wait_fut.set_result(self) return self._our_fut.__await__() - async def defer(self): - self.deferred = True - return await self + + def with_label(self, label: str): + self._label = label def match(self, server: IServer, line: Line): if (self._label is not None and diff --git a/ircrobots/bot.py b/ircrobots/bot.py index 5ab7829..a65cbd5 100644 --- a/ircrobots/bot.py +++ b/ircrobots/bot.py @@ -37,7 +37,10 @@ class Bot(IBot): async def _run_server(self, server: Server): async with anyio.create_task_group() as tg: - await tg.spawn(server._read_lines) + async def _read(): + while True: + await server._read_lines() + await tg.spawn(_read) await tg.spawn(server._send_lines) await self.disconnected(server) diff --git a/ircrobots/interface.py b/ircrobots/interface.py index 317223a..f680f5f 100644 --- a/ircrobots/interface.py +++ b/ircrobots/interface.py @@ -46,7 +46,7 @@ class SentLine(object): self.id = id self.priority = priority self.line = line - self.future: Future = Future() + self.future: "Future[SentLine]" = Future() def __lt__(self, other: "SentLine") -> bool: return self.priority < other.priority diff --git a/ircrobots/server.py b/ircrobots/server.py index e661607..24eaaed 100644 --- a/ircrobots/server.py +++ b/ircrobots/server.py @@ -64,6 +64,7 @@ class Server(IServer): self._read_queue: Deque[Tuple[Line, Optional[Emit]]] = deque() + self._wait_fors: List[Tuple[WaitFor, Optional[Awaitable]]] = [] self._wait_for_fut: Optional[Future[WaitFor]] = None def hostmask(self) -> str: @@ -77,8 +78,11 @@ class Server(IServer): def send_raw(self, line: str, priority=SendPriority.DEFAULT ) -> Awaitable[SentLine]: return self.send(tokenise(line), priority) - def send(self, line: Line, priority=SendPriority.DEFAULT + def send(self, + line: Line, + priority=SendPriority.DEFAULT ) -> Awaitable[SentLine]: + self.line_presend(line) sent_line = SentLine(self._sent_count, priority, line) self._sent_count += 1 @@ -93,10 +97,7 @@ class Server(IServer): self._send_queue.put_nowait(sent_line) - async def _assure() -> SentLine: - await sent_line.future - return sent_line - return MaybeAwait(_assure) + return sent_line.future def set_throttle(self, rate: int, time: float): self.throttle.rate_limit = rate @@ -182,39 +183,34 @@ class Server(IServer): await self.line_read(line) - async def _next_line(self) -> Tuple[Line, Optional[Emit]]: - if self._read_queue: - both = self._read_queue.popleft() - else: - ping_sent = False - while True: - try: - async with timeout(PING_TIMEOUT): - data = await self._reader.read(1024) - except asyncio.TimeoutError: - if ping_sent: - data = b"" # empty data means the socket disconnected - else: - ping_sent = True - await self.send(build("PING", ["hello"])) - continue + async def _next_lines(self) -> List[Tuple[Line, Optional[Emit]]]: + ping_sent = False + while True: + try: + async with timeout(PING_TIMEOUT): + data = await self._reader.read(1024) + except asyncio.TimeoutError: + if ping_sent: + data = b"" # empty data means the socket disconnected + else: + ping_sent = True + await self.send(build("PING", ["hello"])) + continue - self.last_read = monotonic() - ping_sent = False + self.last_read = monotonic() + ping_sent = False - try: - lines = self.recv(data) - except ServerDisconnectedException: - self.disconnected = True - raise + try: + lines = self.recv(data) + except ServerDisconnectedException: + self.disconnected = True + raise - if lines: - self._read_queue.extend(lines[1:]) - both = lines[0] - break - return both + return lines - async def _line_or_wait(self, line_aw: Awaitable): + async def _line_or_wait(self, + line_aw: Awaitable + ) -> Optional[Tuple[WaitFor, Awaitable]]: wait_for_fut: Future[WaitFor] = Future() self._wait_for_fut = wait_for_fut @@ -225,43 +221,34 @@ class Server(IServer): new_line_aw = list(pend)[0] return (await wait_for_fut), new_line_aw else: - return None, None + return None - async def _read_lines(self): - waited_reads: Deque[Tuple[Line, Optional[Emit]]] = deque() - wait_for: Optional[WaitFor] = None - wait_for_aw: Optional[Awaitable] = None - - async def _line() -> Tuple[Line, Optional[Emit]]: - both = await self._next_line() - waited_reads.append(both) - line, emit = both + async def _read_lines(self) -> List[Tuple[Line, Optional[Emit]]]: + lines = await self._next_lines() + for line, emit in lines: self.line_preread(line) - return both - while True: - if wait_for is not None: - line, emit = await _line() - if wait_for.response.match(self, line): + for i, (wait_for, aw) in enumerate(self._wait_fors): + if wait_for.match(self, line): wait_for.resolve(line) - wait_for, wait_for_aw = await self._line_or_wait( - wait_for_aw) - else: - if not waited_reads: - await _line() + if aw is not None: + new_wait_for = await self._line_or_wait(aw) + if new_wait_for is not None: + self._wait_fors.append(new_wait_for) + self._wait_fors.pop(i) + break - while waited_reads: - new_line, new_emit = waited_reads.popleft() - line_aw = self._on_read(new_line, new_emit) - wait_for, wait_for_aw = await self._line_or_wait(line_aw) - if wait_for is not None: - break + line_aw = self._on_read(line, emit) + new_wait_for = await self._line_or_wait(line_aw) + if new_wait_for is not None: + self._wait_fors.append(new_wait_for) + return lines - def wait_for(self, - response: Union[IMatchResponse, Set[IMatchResponse]], - sent_line: Optional[SentLine]=None - ) -> Awaitable[Line]: + async def wait_for(self, + response: Union[IMatchResponse, Set[IMatchResponse]], + sent_aw: Optional[Awaitable[SentLine]]=None + ) -> Line: response_obj: IMatchResponse if isinstance(response, set): @@ -269,17 +256,21 @@ class Server(IServer): else: response_obj = response + our_wait_for = WaitFor(response_obj) + wait_for_fut = self._wait_for_fut if wait_for_fut is not None: self._wait_for_fut = None + wait_for_fut.set_result(our_wait_for) + else: + self._wait_fors.append((our_wait_for, None)) - label: Optional[str] = None - if sent_line is not None: - label = str(sent_line.id) + if sent_aw is not None: + sent_line = await sent_aw + label = str(sent_line.id) + our_wait_for.with_label(label) - our_wait_for = WaitFor(wait_for_fut, response_obj, label) - return our_wait_for - raise Exception() + return (await our_wait_for) async def _on_send_line(self, line: Line): if (line.command == "PRIVMSG" and @@ -337,7 +328,6 @@ class Server(IServer): def send_nick(self, new_nick: str) -> Awaitable[bool]: fut = self.send(build("NICK", [new_nick])) async def _assure() -> bool: - await fut line = await self.wait_for({ Response("NICK", [Folded(new_nick)], source=MASK_SELF), Responses([ @@ -350,7 +340,7 @@ class Server(IServer): ERR_ERRONEUSNICKNAME, ERR_UNAVAILRESOURCE ], [ANY, Folded(new_nick)]) - }) + }, fut) return line.command == "NICK" return MaybeAwait(_assure) @@ -368,9 +358,9 @@ class Server(IServer): fut = self.send(build("PART", [name])) async def _assure(): - await fut line = await self.wait_for( - Response("PART", [Folded(name)], source=MASK_SELF) + Response("PART", [Folded(name)], source=MASK_SELF), + fut ) return return MaybeAwait(_assure) @@ -388,8 +378,6 @@ class Server(IServer): fut = self.send(build("JOIN", [",".join(names)]+keys)) async def _assure(): - await fut - channels: List[Channel] = [] while folded_names: @@ -398,7 +386,7 @@ class Server(IServer): Responses(JOIN_ERR_FIRST, [ANY, ANY]), Response(ERR_USERONCHANNEL, [ANY, SELF, ANY]), Response(ERR_LINKCHANNEL, [ANY, ANY, ANY]) - }) + }, fut) chan: Optional[str] = None if line.command == RPL_CHANNELMODEIS: @@ -429,9 +417,9 @@ class Server(IServer): ) -> Awaitable[Optional[str]]: fut = self.send(build("PRIVMSG", [target, message])) async def _assure(): - await fut line = await self.wait_for( - Response("PRIVMSG", [Folded(target), ANY], source=MASK_SELF) + Response("PRIVMSG", [Folded(target), ANY], source=MASK_SELF), + fut ) if line.command == "PRIVMSG": return line.params[1] @@ -449,7 +437,6 @@ class Server(IServer): fut = self.send(build("WHOIS", args)) async def _assure() -> Optional[Whois]: - await fut params = [ANY, Folded(self.casefold(target))] obj = Whois() while True: @@ -465,7 +452,7 @@ class Server(IServer): RPL_WHOISACCOUNT, RPL_WHOISSECURE, RPL_ENDOFWHOIS - ], params)) + ], params), fut) if line.command in [ERR_NOSUCHNICK, ERR_NOSUCHSERVER]: return None elif line.command == RPL_WHOISUSER: