refactor and simplify normal reading vs wait_for

This commit is contained in:
jesopo 2020-04-05 17:11:04 +01:00
parent b8ddc6883d
commit 0829fd9499
3 changed files with 32 additions and 43 deletions

View file

@ -20,12 +20,6 @@ class Bot(object):
await asyncio.sleep(RECONNECT_DELAY)
await self.add_server(server.name, server.params)
async def line_read(self, server: Server, line: Line):
pass
async def line_send(self, server: Server, line: Line):
pass
async def add_server(self, name: str, params: ConnectionParams) -> Server:
server = self.create_server(name)
self.servers[name] = server
@ -35,27 +29,18 @@ class Bot(object):
async def _run_server(self, server: Server):
async with anyio.create_task_group() as tg:
async def _read_query():
while not tg.cancel_scope.cancel_called:
await server._read_lines()
await tg.cancel_scope.cancel()
async def _read():
while not tg.cancel_scope.cancel_called:
line = await server.next_line()
await self.line_read(server, line)
line, emits = await server.next_line()
await tg.cancel_scope.cancel()
async def _write():
while not tg.cancel_scope.cancel_called:
lines = await server._write_lines()
for line in lines:
await self.line_send(server, line)
await tg.cancel_scope.cancel()
await tg.spawn(_write)
await tg.spawn(_read)
await tg.spawn(_read_query)
del self.servers[server.name]
await self.disconnected(server)

View file

@ -1,8 +1,8 @@
from asyncio import Future
from typing import Awaitable, Iterable, Set, Optional
from typing import Awaitable, Iterable, List, Optional, Set, Tuple
from enum import IntEnum
from ircstates import Server
from ircstates import Server, Emit
from irctokens import Line
from .params import ConnectionParams, SASLParams
@ -61,7 +61,7 @@ class IServer(Server):
async def line_send(self, line: Line):
pass
async def next_line(self) -> Line:
async def next_line(self) -> Tuple[Line, List[Emit]]:
pass
def cap_agreed(self, capability: ICapability) -> bool:

View file

@ -1,7 +1,8 @@
import asyncio
from ssl import SSLContext
from asyncio import Future, PriorityQueue, Queue
from typing import Dict, List, Optional, Set, Tuple
from typing import Deque, Dict, List, Optional, Set, Tuple
from collections import deque
from asyncio_throttle import Throttler
from ircstates import Emit
@ -29,11 +30,13 @@ class Server(IServer):
self.sasl_state = SASLResult.NONE
self._wait_for_cache: List[Tuple[Line, List[Emit]]] = []
self._write_queue: PriorityQueue[SentLine] = PriorityQueue()
self._read_queue: Queue[Tuple[Line, List[Emit]]] = Queue()
self.desired_caps: Set[ICapability] = set([])
self._read_queue: Deque[Tuple[Line, List[Emit]]] = deque()
async def send_raw(self, line: str, priority=SendPriority.DEFAULT
) -> Future:
return await self.send(tokenise(line), priority)
@ -96,34 +99,34 @@ class Server(IServer):
async def line_read(self, line: Line):
pass
async def _read_lines(self) -> List[Tuple[Line, List[Emit]]]:
data = await self._reader.read(1024)
lines = self.recv(data)
for line, emits in lines:
for emit in emits:
await self._on_read_emit(line, emit)
await self._on_read_line(line)
await self.line_read(line)
async def next_line(self, wait_for: bool = False
) -> Tuple[Line, List[Emit]]:
if self._read_queue:
both = self._read_queue.popleft()
else:
data = await self._reader.read(1024)
lines = self.recv(data)
self._read_queue.extend(lines[1:])
both = lines[0]
line, emits = both
for emit in emits:
await self._on_read_emit(line, emit)
await self._on_read_line(line)
await self.line_read(line)
return both
await self._read_queue.put((line, emits))
return lines
async def next_line(self) -> Line:
line, emits = await self._read_queue.get()
return line
async def wait_for(self, response: IMatchResponse) -> Line:
while True:
lines = self._wait_for_cache.copy()
self._wait_for_cache.clear()
both = await self.next_line(wait_for=True)
line, emits = both
if not lines:
lines += await self._read_lines()
for i, (line, emits) in enumerate(lines):
if response.match(self, line):
self._wait_for_cache = lines[i+1:]
return line
if response.match(self, line):
return line
async def line_send(self, line: Line):
pass
@ -144,6 +147,7 @@ class Server(IServer):
for line in lines:
line.future.set_result(None)
await self.line_send(line.line)
return [l.line for l in lines]