refactor and simplify normal reading vs wait_for
This commit is contained in:
parent
b8ddc6883d
commit
0829fd9499
3 changed files with 32 additions and 43 deletions
|
@ -20,12 +20,6 @@ class Bot(object):
|
|||
await asyncio.sleep(RECONNECT_DELAY)
|
||||
await self.add_server(server.name, server.params)
|
||||
|
||||
async def line_read(self, server: Server, line: Line):
|
||||
pass
|
||||
|
||||
async def line_send(self, server: Server, line: Line):
|
||||
pass
|
||||
|
||||
async def add_server(self, name: str, params: ConnectionParams) -> Server:
|
||||
server = self.create_server(name)
|
||||
self.servers[name] = server
|
||||
|
@ -35,27 +29,18 @@ class Bot(object):
|
|||
|
||||
async def _run_server(self, server: Server):
|
||||
async with anyio.create_task_group() as tg:
|
||||
async def _read_query():
|
||||
while not tg.cancel_scope.cancel_called:
|
||||
await server._read_lines()
|
||||
await tg.cancel_scope.cancel()
|
||||
|
||||
async def _read():
|
||||
while not tg.cancel_scope.cancel_called:
|
||||
line = await server.next_line()
|
||||
await self.line_read(server, line)
|
||||
line, emits = await server.next_line()
|
||||
await tg.cancel_scope.cancel()
|
||||
|
||||
async def _write():
|
||||
while not tg.cancel_scope.cancel_called:
|
||||
lines = await server._write_lines()
|
||||
for line in lines:
|
||||
await self.line_send(server, line)
|
||||
await tg.cancel_scope.cancel()
|
||||
|
||||
await tg.spawn(_write)
|
||||
await tg.spawn(_read)
|
||||
await tg.spawn(_read_query)
|
||||
|
||||
del self.servers[server.name]
|
||||
await self.disconnected(server)
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
from asyncio import Future
|
||||
from typing import Awaitable, Iterable, Set, Optional
|
||||
from typing import Awaitable, Iterable, List, Optional, Set, Tuple
|
||||
from enum import IntEnum
|
||||
|
||||
from ircstates import Server
|
||||
from ircstates import Server, Emit
|
||||
from irctokens import Line
|
||||
|
||||
from .params import ConnectionParams, SASLParams
|
||||
|
@ -61,7 +61,7 @@ class IServer(Server):
|
|||
async def line_send(self, line: Line):
|
||||
pass
|
||||
|
||||
async def next_line(self) -> Line:
|
||||
async def next_line(self) -> Tuple[Line, List[Emit]]:
|
||||
pass
|
||||
|
||||
def cap_agreed(self, capability: ICapability) -> bool:
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
import asyncio
|
||||
from ssl import SSLContext
|
||||
from asyncio import Future, PriorityQueue, Queue
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
from typing import Deque, Dict, List, Optional, Set, Tuple
|
||||
from collections import deque
|
||||
|
||||
from asyncio_throttle import Throttler
|
||||
from ircstates import Emit
|
||||
|
@ -29,11 +30,13 @@ class Server(IServer):
|
|||
|
||||
self.sasl_state = SASLResult.NONE
|
||||
|
||||
|
||||
self._wait_for_cache: List[Tuple[Line, List[Emit]]] = []
|
||||
self._write_queue: PriorityQueue[SentLine] = PriorityQueue()
|
||||
self._read_queue: Queue[Tuple[Line, List[Emit]]] = Queue()
|
||||
self.desired_caps: Set[ICapability] = set([])
|
||||
|
||||
self._read_queue: Deque[Tuple[Line, List[Emit]]] = deque()
|
||||
|
||||
async def send_raw(self, line: str, priority=SendPriority.DEFAULT
|
||||
) -> Future:
|
||||
return await self.send(tokenise(line), priority)
|
||||
|
@ -96,34 +99,34 @@ class Server(IServer):
|
|||
|
||||
async def line_read(self, line: Line):
|
||||
pass
|
||||
async def _read_lines(self) -> List[Tuple[Line, List[Emit]]]:
|
||||
data = await self._reader.read(1024)
|
||||
lines = self.recv(data)
|
||||
for line, emits in lines:
|
||||
for emit in emits:
|
||||
await self._on_read_emit(line, emit)
|
||||
|
||||
await self._on_read_line(line)
|
||||
await self.line_read(line)
|
||||
async def next_line(self, wait_for: bool = False
|
||||
) -> Tuple[Line, List[Emit]]:
|
||||
if self._read_queue:
|
||||
both = self._read_queue.popleft()
|
||||
else:
|
||||
data = await self._reader.read(1024)
|
||||
lines = self.recv(data)
|
||||
|
||||
self._read_queue.extend(lines[1:])
|
||||
both = lines[0]
|
||||
|
||||
line, emits = both
|
||||
for emit in emits:
|
||||
await self._on_read_emit(line, emit)
|
||||
await self._on_read_line(line)
|
||||
await self.line_read(line)
|
||||
|
||||
return both
|
||||
|
||||
await self._read_queue.put((line, emits))
|
||||
return lines
|
||||
async def next_line(self) -> Line:
|
||||
line, emits = await self._read_queue.get()
|
||||
return line
|
||||
|
||||
async def wait_for(self, response: IMatchResponse) -> Line:
|
||||
while True:
|
||||
lines = self._wait_for_cache.copy()
|
||||
self._wait_for_cache.clear()
|
||||
both = await self.next_line(wait_for=True)
|
||||
line, emits = both
|
||||
|
||||
if not lines:
|
||||
lines += await self._read_lines()
|
||||
|
||||
for i, (line, emits) in enumerate(lines):
|
||||
if response.match(self, line):
|
||||
self._wait_for_cache = lines[i+1:]
|
||||
return line
|
||||
if response.match(self, line):
|
||||
return line
|
||||
|
||||
async def line_send(self, line: Line):
|
||||
pass
|
||||
|
@ -144,6 +147,7 @@ class Server(IServer):
|
|||
|
||||
for line in lines:
|
||||
line.future.set_result(None)
|
||||
await self.line_send(line.line)
|
||||
|
||||
return [l.line for l in lines]
|
||||
|
||||
|
|
Loading…
Reference in a new issue