From 0829fd9499e0702a6913f05128efe98d2ad077a8 Mon Sep 17 00:00:00 2001 From: jesopo Date: Sun, 5 Apr 2020 17:11:04 +0100 Subject: [PATCH] refactor and simplify normal reading vs wait_for --- ircrobots/bot.py | 17 +------------- ircrobots/interface.py | 6 ++--- ircrobots/server.py | 52 +++++++++++++++++++++++------------------- 3 files changed, 32 insertions(+), 43 deletions(-) diff --git a/ircrobots/bot.py b/ircrobots/bot.py index b177151..3bbb75a 100644 --- a/ircrobots/bot.py +++ b/ircrobots/bot.py @@ -20,12 +20,6 @@ class Bot(object): await asyncio.sleep(RECONNECT_DELAY) await self.add_server(server.name, server.params) - async def line_read(self, server: Server, line: Line): - pass - - async def line_send(self, server: Server, line: Line): - pass - async def add_server(self, name: str, params: ConnectionParams) -> Server: server = self.create_server(name) self.servers[name] = server @@ -35,27 +29,18 @@ class Bot(object): async def _run_server(self, server: Server): async with anyio.create_task_group() as tg: - async def _read_query(): - while not tg.cancel_scope.cancel_called: - await server._read_lines() - await tg.cancel_scope.cancel() - async def _read(): while not tg.cancel_scope.cancel_called: - line = await server.next_line() - await self.line_read(server, line) + line, emits = await server.next_line() await tg.cancel_scope.cancel() async def _write(): while not tg.cancel_scope.cancel_called: lines = await server._write_lines() - for line in lines: - await self.line_send(server, line) await tg.cancel_scope.cancel() await tg.spawn(_write) await tg.spawn(_read) - await tg.spawn(_read_query) del self.servers[server.name] await self.disconnected(server) diff --git a/ircrobots/interface.py b/ircrobots/interface.py index 201fd1d..88d75d1 100644 --- a/ircrobots/interface.py +++ b/ircrobots/interface.py @@ -1,8 +1,8 @@ from asyncio import Future -from typing import Awaitable, Iterable, Set, Optional +from typing import Awaitable, Iterable, List, Optional, Set, Tuple from enum import IntEnum -from ircstates import Server +from ircstates import Server, Emit from irctokens import Line from .params import ConnectionParams, SASLParams @@ -61,7 +61,7 @@ class IServer(Server): async def line_send(self, line: Line): pass - async def next_line(self) -> Line: + async def next_line(self) -> Tuple[Line, List[Emit]]: pass def cap_agreed(self, capability: ICapability) -> bool: diff --git a/ircrobots/server.py b/ircrobots/server.py index f11d459..9a69df9 100644 --- a/ircrobots/server.py +++ b/ircrobots/server.py @@ -1,7 +1,8 @@ import asyncio from ssl import SSLContext from asyncio import Future, PriorityQueue, Queue -from typing import Dict, List, Optional, Set, Tuple +from typing import Deque, Dict, List, Optional, Set, Tuple +from collections import deque from asyncio_throttle import Throttler from ircstates import Emit @@ -29,11 +30,13 @@ class Server(IServer): self.sasl_state = SASLResult.NONE + self._wait_for_cache: List[Tuple[Line, List[Emit]]] = [] self._write_queue: PriorityQueue[SentLine] = PriorityQueue() - self._read_queue: Queue[Tuple[Line, List[Emit]]] = Queue() self.desired_caps: Set[ICapability] = set([]) + self._read_queue: Deque[Tuple[Line, List[Emit]]] = deque() + async def send_raw(self, line: str, priority=SendPriority.DEFAULT ) -> Future: return await self.send(tokenise(line), priority) @@ -96,34 +99,34 @@ class Server(IServer): async def line_read(self, line: Line): pass - async def _read_lines(self) -> List[Tuple[Line, List[Emit]]]: - data = await self._reader.read(1024) - lines = self.recv(data) - for line, emits in lines: - for emit in emits: - await self._on_read_emit(line, emit) - await self._on_read_line(line) - await self.line_read(line) + async def next_line(self, wait_for: bool = False + ) -> Tuple[Line, List[Emit]]: + if self._read_queue: + both = self._read_queue.popleft() + else: + data = await self._reader.read(1024) + lines = self.recv(data) + + self._read_queue.extend(lines[1:]) + both = lines[0] + + line, emits = both + for emit in emits: + await self._on_read_emit(line, emit) + await self._on_read_line(line) + await self.line_read(line) + + return both - await self._read_queue.put((line, emits)) - return lines - async def next_line(self) -> Line: - line, emits = await self._read_queue.get() - return line async def wait_for(self, response: IMatchResponse) -> Line: while True: - lines = self._wait_for_cache.copy() - self._wait_for_cache.clear() + both = await self.next_line(wait_for=True) + line, emits = both - if not lines: - lines += await self._read_lines() - - for i, (line, emits) in enumerate(lines): - if response.match(self, line): - self._wait_for_cache = lines[i+1:] - return line + if response.match(self, line): + return line async def line_send(self, line: Line): pass @@ -144,6 +147,7 @@ class Server(IServer): for line in lines: line.future.set_result(None) + await self.line_send(line.line) return [l.line for l in lines]