refactor and simplify normal reading vs wait_for
This commit is contained in:
parent
b8ddc6883d
commit
0829fd9499
3 changed files with 32 additions and 43 deletions
|
@ -20,12 +20,6 @@ class Bot(object):
|
||||||
await asyncio.sleep(RECONNECT_DELAY)
|
await asyncio.sleep(RECONNECT_DELAY)
|
||||||
await self.add_server(server.name, server.params)
|
await self.add_server(server.name, server.params)
|
||||||
|
|
||||||
async def line_read(self, server: Server, line: Line):
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def line_send(self, server: Server, line: Line):
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def add_server(self, name: str, params: ConnectionParams) -> Server:
|
async def add_server(self, name: str, params: ConnectionParams) -> Server:
|
||||||
server = self.create_server(name)
|
server = self.create_server(name)
|
||||||
self.servers[name] = server
|
self.servers[name] = server
|
||||||
|
@ -35,27 +29,18 @@ class Bot(object):
|
||||||
|
|
||||||
async def _run_server(self, server: Server):
|
async def _run_server(self, server: Server):
|
||||||
async with anyio.create_task_group() as tg:
|
async with anyio.create_task_group() as tg:
|
||||||
async def _read_query():
|
|
||||||
while not tg.cancel_scope.cancel_called:
|
|
||||||
await server._read_lines()
|
|
||||||
await tg.cancel_scope.cancel()
|
|
||||||
|
|
||||||
async def _read():
|
async def _read():
|
||||||
while not tg.cancel_scope.cancel_called:
|
while not tg.cancel_scope.cancel_called:
|
||||||
line = await server.next_line()
|
line, emits = await server.next_line()
|
||||||
await self.line_read(server, line)
|
|
||||||
await tg.cancel_scope.cancel()
|
await tg.cancel_scope.cancel()
|
||||||
|
|
||||||
async def _write():
|
async def _write():
|
||||||
while not tg.cancel_scope.cancel_called:
|
while not tg.cancel_scope.cancel_called:
|
||||||
lines = await server._write_lines()
|
lines = await server._write_lines()
|
||||||
for line in lines:
|
|
||||||
await self.line_send(server, line)
|
|
||||||
await tg.cancel_scope.cancel()
|
await tg.cancel_scope.cancel()
|
||||||
|
|
||||||
await tg.spawn(_write)
|
await tg.spawn(_write)
|
||||||
await tg.spawn(_read)
|
await tg.spawn(_read)
|
||||||
await tg.spawn(_read_query)
|
|
||||||
|
|
||||||
del self.servers[server.name]
|
del self.servers[server.name]
|
||||||
await self.disconnected(server)
|
await self.disconnected(server)
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
from asyncio import Future
|
from asyncio import Future
|
||||||
from typing import Awaitable, Iterable, Set, Optional
|
from typing import Awaitable, Iterable, List, Optional, Set, Tuple
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
|
|
||||||
from ircstates import Server
|
from ircstates import Server, Emit
|
||||||
from irctokens import Line
|
from irctokens import Line
|
||||||
|
|
||||||
from .params import ConnectionParams, SASLParams
|
from .params import ConnectionParams, SASLParams
|
||||||
|
@ -61,7 +61,7 @@ class IServer(Server):
|
||||||
async def line_send(self, line: Line):
|
async def line_send(self, line: Line):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def next_line(self) -> Line:
|
async def next_line(self) -> Tuple[Line, List[Emit]]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def cap_agreed(self, capability: ICapability) -> bool:
|
def cap_agreed(self, capability: ICapability) -> bool:
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from ssl import SSLContext
|
from ssl import SSLContext
|
||||||
from asyncio import Future, PriorityQueue, Queue
|
from asyncio import Future, PriorityQueue, Queue
|
||||||
from typing import Dict, List, Optional, Set, Tuple
|
from typing import Deque, Dict, List, Optional, Set, Tuple
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
from asyncio_throttle import Throttler
|
from asyncio_throttle import Throttler
|
||||||
from ircstates import Emit
|
from ircstates import Emit
|
||||||
|
@ -29,11 +30,13 @@ class Server(IServer):
|
||||||
|
|
||||||
self.sasl_state = SASLResult.NONE
|
self.sasl_state = SASLResult.NONE
|
||||||
|
|
||||||
|
|
||||||
self._wait_for_cache: List[Tuple[Line, List[Emit]]] = []
|
self._wait_for_cache: List[Tuple[Line, List[Emit]]] = []
|
||||||
self._write_queue: PriorityQueue[SentLine] = PriorityQueue()
|
self._write_queue: PriorityQueue[SentLine] = PriorityQueue()
|
||||||
self._read_queue: Queue[Tuple[Line, List[Emit]]] = Queue()
|
|
||||||
self.desired_caps: Set[ICapability] = set([])
|
self.desired_caps: Set[ICapability] = set([])
|
||||||
|
|
||||||
|
self._read_queue: Deque[Tuple[Line, List[Emit]]] = deque()
|
||||||
|
|
||||||
async def send_raw(self, line: str, priority=SendPriority.DEFAULT
|
async def send_raw(self, line: str, priority=SendPriority.DEFAULT
|
||||||
) -> Future:
|
) -> Future:
|
||||||
return await self.send(tokenise(line), priority)
|
return await self.send(tokenise(line), priority)
|
||||||
|
@ -96,34 +99,34 @@ class Server(IServer):
|
||||||
|
|
||||||
async def line_read(self, line: Line):
|
async def line_read(self, line: Line):
|
||||||
pass
|
pass
|
||||||
async def _read_lines(self) -> List[Tuple[Line, List[Emit]]]:
|
|
||||||
data = await self._reader.read(1024)
|
|
||||||
lines = self.recv(data)
|
|
||||||
for line, emits in lines:
|
|
||||||
for emit in emits:
|
|
||||||
await self._on_read_emit(line, emit)
|
|
||||||
|
|
||||||
await self._on_read_line(line)
|
async def next_line(self, wait_for: bool = False
|
||||||
await self.line_read(line)
|
) -> Tuple[Line, List[Emit]]:
|
||||||
|
if self._read_queue:
|
||||||
|
both = self._read_queue.popleft()
|
||||||
|
else:
|
||||||
|
data = await self._reader.read(1024)
|
||||||
|
lines = self.recv(data)
|
||||||
|
|
||||||
|
self._read_queue.extend(lines[1:])
|
||||||
|
both = lines[0]
|
||||||
|
|
||||||
|
line, emits = both
|
||||||
|
for emit in emits:
|
||||||
|
await self._on_read_emit(line, emit)
|
||||||
|
await self._on_read_line(line)
|
||||||
|
await self.line_read(line)
|
||||||
|
|
||||||
|
return both
|
||||||
|
|
||||||
await self._read_queue.put((line, emits))
|
|
||||||
return lines
|
|
||||||
async def next_line(self) -> Line:
|
|
||||||
line, emits = await self._read_queue.get()
|
|
||||||
return line
|
|
||||||
|
|
||||||
async def wait_for(self, response: IMatchResponse) -> Line:
|
async def wait_for(self, response: IMatchResponse) -> Line:
|
||||||
while True:
|
while True:
|
||||||
lines = self._wait_for_cache.copy()
|
both = await self.next_line(wait_for=True)
|
||||||
self._wait_for_cache.clear()
|
line, emits = both
|
||||||
|
|
||||||
if not lines:
|
if response.match(self, line):
|
||||||
lines += await self._read_lines()
|
return line
|
||||||
|
|
||||||
for i, (line, emits) in enumerate(lines):
|
|
||||||
if response.match(self, line):
|
|
||||||
self._wait_for_cache = lines[i+1:]
|
|
||||||
return line
|
|
||||||
|
|
||||||
async def line_send(self, line: Line):
|
async def line_send(self, line: Line):
|
||||||
pass
|
pass
|
||||||
|
@ -144,6 +147,7 @@ class Server(IServer):
|
||||||
|
|
||||||
for line in lines:
|
for line in lines:
|
||||||
line.future.set_result(None)
|
line.future.set_result(None)
|
||||||
|
await self.line_send(line.line)
|
||||||
|
|
||||||
return [l.line for l in lines]
|
return [l.line for l in lines]
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue