refactor and simplify normal reading vs wait_for

This commit is contained in:
jesopo 2020-04-05 17:11:04 +01:00
parent b8ddc6883d
commit 0829fd9499
3 changed files with 32 additions and 43 deletions

View file

@ -20,12 +20,6 @@ class Bot(object):
await asyncio.sleep(RECONNECT_DELAY) await asyncio.sleep(RECONNECT_DELAY)
await self.add_server(server.name, server.params) await self.add_server(server.name, server.params)
async def line_read(self, server: Server, line: Line):
pass
async def line_send(self, server: Server, line: Line):
pass
async def add_server(self, name: str, params: ConnectionParams) -> Server: async def add_server(self, name: str, params: ConnectionParams) -> Server:
server = self.create_server(name) server = self.create_server(name)
self.servers[name] = server self.servers[name] = server
@ -35,27 +29,18 @@ class Bot(object):
async def _run_server(self, server: Server): async def _run_server(self, server: Server):
async with anyio.create_task_group() as tg: async with anyio.create_task_group() as tg:
async def _read_query():
while not tg.cancel_scope.cancel_called:
await server._read_lines()
await tg.cancel_scope.cancel()
async def _read(): async def _read():
while not tg.cancel_scope.cancel_called: while not tg.cancel_scope.cancel_called:
line = await server.next_line() line, emits = await server.next_line()
await self.line_read(server, line)
await tg.cancel_scope.cancel() await tg.cancel_scope.cancel()
async def _write(): async def _write():
while not tg.cancel_scope.cancel_called: while not tg.cancel_scope.cancel_called:
lines = await server._write_lines() lines = await server._write_lines()
for line in lines:
await self.line_send(server, line)
await tg.cancel_scope.cancel() await tg.cancel_scope.cancel()
await tg.spawn(_write) await tg.spawn(_write)
await tg.spawn(_read) await tg.spawn(_read)
await tg.spawn(_read_query)
del self.servers[server.name] del self.servers[server.name]
await self.disconnected(server) await self.disconnected(server)

View file

@ -1,8 +1,8 @@
from asyncio import Future from asyncio import Future
from typing import Awaitable, Iterable, Set, Optional from typing import Awaitable, Iterable, List, Optional, Set, Tuple
from enum import IntEnum from enum import IntEnum
from ircstates import Server from ircstates import Server, Emit
from irctokens import Line from irctokens import Line
from .params import ConnectionParams, SASLParams from .params import ConnectionParams, SASLParams
@ -61,7 +61,7 @@ class IServer(Server):
async def line_send(self, line: Line): async def line_send(self, line: Line):
pass pass
async def next_line(self) -> Line: async def next_line(self) -> Tuple[Line, List[Emit]]:
pass pass
def cap_agreed(self, capability: ICapability) -> bool: def cap_agreed(self, capability: ICapability) -> bool:

View file

@ -1,7 +1,8 @@
import asyncio import asyncio
from ssl import SSLContext from ssl import SSLContext
from asyncio import Future, PriorityQueue, Queue from asyncio import Future, PriorityQueue, Queue
from typing import Dict, List, Optional, Set, Tuple from typing import Deque, Dict, List, Optional, Set, Tuple
from collections import deque
from asyncio_throttle import Throttler from asyncio_throttle import Throttler
from ircstates import Emit from ircstates import Emit
@ -29,11 +30,13 @@ class Server(IServer):
self.sasl_state = SASLResult.NONE self.sasl_state = SASLResult.NONE
self._wait_for_cache: List[Tuple[Line, List[Emit]]] = [] self._wait_for_cache: List[Tuple[Line, List[Emit]]] = []
self._write_queue: PriorityQueue[SentLine] = PriorityQueue() self._write_queue: PriorityQueue[SentLine] = PriorityQueue()
self._read_queue: Queue[Tuple[Line, List[Emit]]] = Queue()
self.desired_caps: Set[ICapability] = set([]) self.desired_caps: Set[ICapability] = set([])
self._read_queue: Deque[Tuple[Line, List[Emit]]] = deque()
async def send_raw(self, line: str, priority=SendPriority.DEFAULT async def send_raw(self, line: str, priority=SendPriority.DEFAULT
) -> Future: ) -> Future:
return await self.send(tokenise(line), priority) return await self.send(tokenise(line), priority)
@ -96,34 +99,34 @@ class Server(IServer):
async def line_read(self, line: Line): async def line_read(self, line: Line):
pass pass
async def _read_lines(self) -> List[Tuple[Line, List[Emit]]]:
data = await self._reader.read(1024)
lines = self.recv(data)
for line, emits in lines:
for emit in emits:
await self._on_read_emit(line, emit)
await self._on_read_line(line) async def next_line(self, wait_for: bool = False
await self.line_read(line) ) -> Tuple[Line, List[Emit]]:
if self._read_queue:
both = self._read_queue.popleft()
else:
data = await self._reader.read(1024)
lines = self.recv(data)
self._read_queue.extend(lines[1:])
both = lines[0]
line, emits = both
for emit in emits:
await self._on_read_emit(line, emit)
await self._on_read_line(line)
await self.line_read(line)
return both
await self._read_queue.put((line, emits))
return lines
async def next_line(self) -> Line:
line, emits = await self._read_queue.get()
return line
async def wait_for(self, response: IMatchResponse) -> Line: async def wait_for(self, response: IMatchResponse) -> Line:
while True: while True:
lines = self._wait_for_cache.copy() both = await self.next_line(wait_for=True)
self._wait_for_cache.clear() line, emits = both
if not lines: if response.match(self, line):
lines += await self._read_lines() return line
for i, (line, emits) in enumerate(lines):
if response.match(self, line):
self._wait_for_cache = lines[i+1:]
return line
async def line_send(self, line: Line): async def line_send(self, line: Line):
pass pass
@ -144,6 +147,7 @@ class Server(IServer):
for line in lines: for line in lines:
line.future.set_result(None) line.future.set_result(None)
await self.line_send(line.line)
return [l.line for l in lines] return [l.line for l in lines]