rewrite how wait_for works - only one at a time now

This commit is contained in:
jesopo 2020-04-29 12:13:06 +01:00
parent a3abae811a
commit 51cfd0f36b
4 changed files with 95 additions and 85 deletions

View file

@ -2,6 +2,7 @@ from asyncio import Future
from irctokens import Line from irctokens import Line
from typing import Any, Awaitable, Callable, Generator, Generic, TypeVar from typing import Any, Awaitable, Callable, Generator, Generic, TypeVar
from .matching import IMatchResponse from .matching import IMatchResponse
from .interface import IServer
TEvent = TypeVar("TEvent") TEvent = TypeVar("TEvent")
class MaybeAwait(Generic[TEvent]): class MaybeAwait(Generic[TEvent]):
@ -20,5 +21,8 @@ class WaitFor(object):
def __await__(self) -> Generator[Any, None, Line]: def __await__(self) -> Generator[Any, None, Line]:
return self._fut.__await__() return self._fut.__await__()
def match(self, server: IServer, line: Line):
return self.response.match(server, line)
def resolve(self, line: Line): def resolve(self, line: Line):
self._fut.set_result(line) self._fut.set_result(line)

View file

@ -37,21 +37,8 @@ class Bot(IBot):
async def _run_server(self, server: Server): async def _run_server(self, server: Server):
async with anyio.create_task_group() as tg: async with anyio.create_task_group() as tg:
async def _read(): await tg.spawn(server._read_lines)
while not tg.cancel_scope.cancel_called: await tg.spawn(server._send_lines)
try:
both = await server.next_line()
except ServerDisconnectedException:
break
await tg.cancel_scope.cancel()
async def _write():
while not tg.cancel_scope.cancel_called:
lines = await server._write_lines()
await tg.cancel_scope.cancel()
await tg.spawn(_write)
await tg.spawn(_read)
await self.disconnected(server) await self.disconnected(server)

View file

@ -103,6 +103,10 @@ class IServer(Server):
async def disconnect(self): async def disconnect(self):
pass pass
def line_preread(self, line: Line):
pass
def line_presend(self, line: Line):
pass
async def line_read(self, line: Line): async def line_read(self, line: Line):
pass pass
async def line_send(self, line: Line): async def line_send(self, line: Line):
@ -112,9 +116,6 @@ class IServer(Server):
async def resume_policy(self, resume: ResumePolicy): async def resume_policy(self, resume: ResumePolicy):
pass pass
async def next_line(self) -> Optional[Tuple[Line, Optional[Emit]]]:
pass
def cap_agreed(self, capability: ICapability) -> bool: def cap_agreed(self, capability: ICapability) -> bool:
pass pass
def cap_available(self, capability: ICapability) -> Optional[str]: def cap_available(self, capability: ICapability) -> Optional[str]:

View file

@ -59,7 +59,7 @@ class Server(IServer):
self.last_read = -1.0 self.last_read = -1.0
self._sent_count: int = 0 self._sent_count: int = 0
self._write_queue: PriorityQueue[SentLine] = PriorityQueue() self._send_queue: PriorityQueue[SentLine] = PriorityQueue()
self.desired_caps: Set[ICapability] = set([]) self.desired_caps: Set[ICapability] = set([])
self._read_queue: Deque[Tuple[Line, Optional[Emit]]] = deque() self._read_queue: Deque[Tuple[Line, Optional[Emit]]] = deque()
@ -80,6 +80,7 @@ class Server(IServer):
return self.send(tokenise(line), priority) return self.send(tokenise(line), priority)
def send(self, line: Line, priority=SendPriority.DEFAULT def send(self, line: Line, priority=SendPriority.DEFAULT
) -> Awaitable[SentLine]: ) -> Awaitable[SentLine]:
self.line_presend(line)
sent_line = SentLine(self._sent_count, priority, line) sent_line = SentLine(self._sent_count, priority, line)
self._sent_count += 1 self._sent_count += 1
@ -91,7 +92,7 @@ class Server(IServer):
line.tags = {} line.tags = {}
line.tags[tag] = str(sent_line.id) line.tags[tag] = str(sent_line.id)
self._write_queue.put_nowait(sent_line) self._send_queue.put_nowait(sent_line)
async def _assure() -> SentLine: async def _assure() -> SentLine:
await sent_line.future await sent_line.future
@ -142,6 +143,10 @@ class Server(IServer):
self.send(build("USER", [username, "0", "*", realname])) self.send(build("USER", [username, "0", "*", realname]))
# to be overridden # to be overridden
def line_preread(self, line: Line):
pass
def line_presend(self, line: Line):
pass
async def line_read(self, line: Line): async def line_read(self, line: Line):
pass pass
async def line_send(self, line: Line): async def line_send(self, line: Line):
@ -152,42 +157,33 @@ class Server(IServer):
pass pass
# /to be overriden # /to be overriden
async def _on_read_emit(self, line: Line, emit: Emit): async def _on_read(self, line: Line, emit: Optional[Emit]):
if emit.command == "001":
await self.send(build("WHO", [self.nickname]))
self.set_throttle(THROTTLE_RATE, THROTTLE_TIME)
elif emit.command == "CAP":
if emit.subcommand == "NEW":
await self._cap_ls(emit)
elif (emit.subcommand == "LS" and
emit.finished):
if not self.registered:
await CAPContext(self).handshake()
else:
await self._cap_ls(emit)
elif emit.command == "JOIN":
if emit.self and not emit.channel is None:
await self.send(build("MODE", [emit.channel.name]))
await WHOContext(self).ensure(emit.channel.name)
async def _on_read_line(self, line: Line):
if line.command == "PING": if line.command == "PING":
await self.send(build("PONG", line.params)) await self.send(build("PONG", line.params))
async def _line_or_wait(self, line_aw: Awaitable): elif emit is not None:
wait_for_fut: Future[WaitFor] = Future() if emit.command == "001":
self._wait_for_fut = wait_for_fut await self.send(build("WHO", [self.nickname]))
self.set_throttle(THROTTLE_RATE, THROTTLE_TIME)
done, pend = await asyncio.wait([line_aw, wait_for_fut], elif emit.command == "CAP":
return_when=asyncio.FIRST_COMPLETED) if emit.subcommand == "NEW":
await self._cap_ls(emit)
elif (emit.subcommand == "LS" and
emit.finished):
if not self.registered:
await CAPContext(self).handshake()
else:
await self._cap_ls(emit)
if wait_for_fut.done(): elif emit.command == "JOIN":
new_line_aw = list(pend)[0] if emit.self and not emit.channel is None:
self._wait_for.append((new_line_aw, await wait_for_fut)) await self.send(build("MODE", [emit.channel.name]))
await WHOContext(self).ensure(emit.channel.name)
async def next_line(self) -> Tuple[Line, Optional[Emit]]: await self.line_read(line)
async def _next_line(self) -> Tuple[Line, Optional[Emit]]:
if self._read_queue: if self._read_queue:
both = self._read_queue.popleft() both = self._read_queue.popleft()
else: else:
@ -217,25 +213,48 @@ class Server(IServer):
self._read_queue.extend(lines[1:]) self._read_queue.extend(lines[1:])
both = lines[0] both = lines[0]
break break
line, emit = both
async def _line():
if emit is not None:
await self._on_read_emit(line, emit)
await self._on_read_line(line)
await self.line_read(line)
for i, (aw, wait_for) in enumerate(self._wait_for):
if wait_for.response.match(self, line):
wait_for.resolve(line)
self._wait_for.pop(i)
await self._line_or_wait(aw)
break
await self._line_or_wait(_line())
return both return both
async def _line_or_wait(self, line_aw: Awaitable):
wait_for_fut: Future[WaitFor] = Future()
self._wait_for_fut = wait_for_fut
done, pend = await asyncio.wait([line_aw, wait_for_fut],
return_when=asyncio.FIRST_COMPLETED)
if wait_for_fut.done():
new_line_aw = list(pend)[0]
return (await wait_for_fut), new_line_aw
else:
return None, None
async def _read_lines(self):
waited_reads: List[Tuple[Line, Optional[Emit]]] = []
wait_for: Optional[WaitFor] = None
wait_for_aw: Optional[Awaitable] = None
while True:
if not waited_reads or wait_for is not None:
both = await self._next_line()
waited_reads.append(both)
line, emit = both
self.line_preread(line)
if wait_for is not None:
line, emit = waited_reads[-1]
if wait_for.response.match(self, line):
wait_for.resolve(line)
wait_for, wait_for_aw = await self._line_or_wait(
wait_for_aw)
else:
while waited_reads:
new_line, new_emit = waited_reads.pop(0)
line_aw = self._on_read(new_line, new_emit)
wait_for, wait_for_aw = await self._line_or_wait(line_aw)
if wait_for is not None:
break
async def wait_for(self, async def wait_for(self,
response: Union[IMatchResponse, Set[IMatchResponse]] response: Union[IMatchResponse, Set[IMatchResponse]]
) -> Line: ) -> Line:
@ -254,34 +273,33 @@ class Server(IServer):
return await our_wait_for return await our_wait_for
raise Exception() raise Exception()
async def _on_write_line(self, line: Line): async def _on_send_line(self, line: Line):
if (line.command == "PRIVMSG" and if (line.command == "PRIVMSG" and
not self.cap_agreed(CAP_ECHO)): not self.cap_agreed(CAP_ECHO)):
new_line = line.with_source(self.hostmask()) new_line = line.with_source(self.hostmask())
emit = self.parse_tokens(new_line) emit = self.parse_tokens(new_line)
self._read_queue.append((new_line, emit)) self._read_queue.append((new_line, emit))
async def _write_lines(self) -> List[Line]: async def _send_lines(self):
lines: List[SentLine] = [] while True:
lines: List[SentLine] = []
while (not lines or while (not lines or
(len(lines) < 5 and self._write_queue.qsize() > 0)): (len(lines) < 5 and self._send_queue.qsize() > 0)):
prio_line = await self._write_queue.get() prio_line = await self._send_queue.get()
lines.append(prio_line) lines.append(prio_line)
for line in lines: for line in lines:
async with self.throttle: async with self.throttle:
self._writer.write( self._writer.write(
f"{line.line.format()}\r\n".encode("utf8")) f"{line.line.format()}\r\n".encode("utf8"))
await self._writer.drain() await self._writer.drain()
for line in lines: for line in lines:
await self._on_write_line(line.line) await self._on_send_line(line.line)
await self.line_send(line.line) await self.line_send(line.line)
line.future.set_result(line) line.future.set_result(line)
return [l.line for l in lines]
# CAP-related # CAP-related
def cap_agreed(self, capability: ICapability) -> bool: def cap_agreed(self, capability: ICapability) -> bool: