rewrite how wait_for works - only one at a time now

This commit is contained in:
jesopo 2020-04-29 12:13:06 +01:00
parent a3abae811a
commit 51cfd0f36b
4 changed files with 95 additions and 85 deletions

View file

@ -2,6 +2,7 @@ from asyncio import Future
from irctokens import Line
from typing import Any, Awaitable, Callable, Generator, Generic, TypeVar
from .matching import IMatchResponse
from .interface import IServer
TEvent = TypeVar("TEvent")
class MaybeAwait(Generic[TEvent]):
@ -20,5 +21,8 @@ class WaitFor(object):
def __await__(self) -> Generator[Any, None, Line]:
return self._fut.__await__()
def match(self, server: IServer, line: Line):
return self.response.match(server, line)
def resolve(self, line: Line):
self._fut.set_result(line)

View file

@ -37,21 +37,8 @@ class Bot(IBot):
async def _run_server(self, server: Server):
async with anyio.create_task_group() as tg:
async def _read():
while not tg.cancel_scope.cancel_called:
try:
both = await server.next_line()
except ServerDisconnectedException:
break
await tg.cancel_scope.cancel()
async def _write():
while not tg.cancel_scope.cancel_called:
lines = await server._write_lines()
await tg.cancel_scope.cancel()
await tg.spawn(_write)
await tg.spawn(_read)
await tg.spawn(server._read_lines)
await tg.spawn(server._send_lines)
await self.disconnected(server)

View file

@ -103,6 +103,10 @@ class IServer(Server):
async def disconnect(self):
pass
def line_preread(self, line: Line):
pass
def line_presend(self, line: Line):
pass
async def line_read(self, line: Line):
pass
async def line_send(self, line: Line):
@ -112,9 +116,6 @@ class IServer(Server):
async def resume_policy(self, resume: ResumePolicy):
pass
async def next_line(self) -> Optional[Tuple[Line, Optional[Emit]]]:
pass
def cap_agreed(self, capability: ICapability) -> bool:
pass
def cap_available(self, capability: ICapability) -> Optional[str]:

View file

@ -59,7 +59,7 @@ class Server(IServer):
self.last_read = -1.0
self._sent_count: int = 0
self._write_queue: PriorityQueue[SentLine] = PriorityQueue()
self._send_queue: PriorityQueue[SentLine] = PriorityQueue()
self.desired_caps: Set[ICapability] = set([])
self._read_queue: Deque[Tuple[Line, Optional[Emit]]] = deque()
@ -80,6 +80,7 @@ class Server(IServer):
return self.send(tokenise(line), priority)
def send(self, line: Line, priority=SendPriority.DEFAULT
) -> Awaitable[SentLine]:
self.line_presend(line)
sent_line = SentLine(self._sent_count, priority, line)
self._sent_count += 1
@ -91,7 +92,7 @@ class Server(IServer):
line.tags = {}
line.tags[tag] = str(sent_line.id)
self._write_queue.put_nowait(sent_line)
self._send_queue.put_nowait(sent_line)
async def _assure() -> SentLine:
await sent_line.future
@ -142,6 +143,10 @@ class Server(IServer):
self.send(build("USER", [username, "0", "*", realname]))
# to be overridden
def line_preread(self, line: Line):
pass
def line_presend(self, line: Line):
pass
async def line_read(self, line: Line):
pass
async def line_send(self, line: Line):
@ -152,42 +157,33 @@ class Server(IServer):
pass
# /to be overriden
async def _on_read_emit(self, line: Line, emit: Emit):
if emit.command == "001":
await self.send(build("WHO", [self.nickname]))
self.set_throttle(THROTTLE_RATE, THROTTLE_TIME)
elif emit.command == "CAP":
if emit.subcommand == "NEW":
await self._cap_ls(emit)
elif (emit.subcommand == "LS" and
emit.finished):
if not self.registered:
await CAPContext(self).handshake()
else:
await self._cap_ls(emit)
elif emit.command == "JOIN":
if emit.self and not emit.channel is None:
await self.send(build("MODE", [emit.channel.name]))
await WHOContext(self).ensure(emit.channel.name)
async def _on_read_line(self, line: Line):
async def _on_read(self, line: Line, emit: Optional[Emit]):
if line.command == "PING":
await self.send(build("PONG", line.params))
async def _line_or_wait(self, line_aw: Awaitable):
wait_for_fut: Future[WaitFor] = Future()
self._wait_for_fut = wait_for_fut
elif emit is not None:
if emit.command == "001":
await self.send(build("WHO", [self.nickname]))
self.set_throttle(THROTTLE_RATE, THROTTLE_TIME)
done, pend = await asyncio.wait([line_aw, wait_for_fut],
return_when=asyncio.FIRST_COMPLETED)
elif emit.command == "CAP":
if emit.subcommand == "NEW":
await self._cap_ls(emit)
elif (emit.subcommand == "LS" and
emit.finished):
if not self.registered:
await CAPContext(self).handshake()
else:
await self._cap_ls(emit)
if wait_for_fut.done():
new_line_aw = list(pend)[0]
self._wait_for.append((new_line_aw, await wait_for_fut))
elif emit.command == "JOIN":
if emit.self and not emit.channel is None:
await self.send(build("MODE", [emit.channel.name]))
await WHOContext(self).ensure(emit.channel.name)
async def next_line(self) -> Tuple[Line, Optional[Emit]]:
await self.line_read(line)
async def _next_line(self) -> Tuple[Line, Optional[Emit]]:
if self._read_queue:
both = self._read_queue.popleft()
else:
@ -217,25 +213,48 @@ class Server(IServer):
self._read_queue.extend(lines[1:])
both = lines[0]
break
line, emit = both
async def _line():
if emit is not None:
await self._on_read_emit(line, emit)
await self._on_read_line(line)
await self.line_read(line)
for i, (aw, wait_for) in enumerate(self._wait_for):
if wait_for.response.match(self, line):
wait_for.resolve(line)
self._wait_for.pop(i)
await self._line_or_wait(aw)
break
await self._line_or_wait(_line())
return both
async def _line_or_wait(self, line_aw: Awaitable):
wait_for_fut: Future[WaitFor] = Future()
self._wait_for_fut = wait_for_fut
done, pend = await asyncio.wait([line_aw, wait_for_fut],
return_when=asyncio.FIRST_COMPLETED)
if wait_for_fut.done():
new_line_aw = list(pend)[0]
return (await wait_for_fut), new_line_aw
else:
return None, None
async def _read_lines(self):
waited_reads: List[Tuple[Line, Optional[Emit]]] = []
wait_for: Optional[WaitFor] = None
wait_for_aw: Optional[Awaitable] = None
while True:
if not waited_reads or wait_for is not None:
both = await self._next_line()
waited_reads.append(both)
line, emit = both
self.line_preread(line)
if wait_for is not None:
line, emit = waited_reads[-1]
if wait_for.response.match(self, line):
wait_for.resolve(line)
wait_for, wait_for_aw = await self._line_or_wait(
wait_for_aw)
else:
while waited_reads:
new_line, new_emit = waited_reads.pop(0)
line_aw = self._on_read(new_line, new_emit)
wait_for, wait_for_aw = await self._line_or_wait(line_aw)
if wait_for is not None:
break
async def wait_for(self,
response: Union[IMatchResponse, Set[IMatchResponse]]
) -> Line:
@ -254,34 +273,33 @@ class Server(IServer):
return await our_wait_for
raise Exception()
async def _on_write_line(self, line: Line):
async def _on_send_line(self, line: Line):
if (line.command == "PRIVMSG" and
not self.cap_agreed(CAP_ECHO)):
new_line = line.with_source(self.hostmask())
emit = self.parse_tokens(new_line)
self._read_queue.append((new_line, emit))
async def _write_lines(self) -> List[Line]:
lines: List[SentLine] = []
async def _send_lines(self):
while True:
lines: List[SentLine] = []
while (not lines or
(len(lines) < 5 and self._write_queue.qsize() > 0)):
prio_line = await self._write_queue.get()
lines.append(prio_line)
while (not lines or
(len(lines) < 5 and self._send_queue.qsize() > 0)):
prio_line = await self._send_queue.get()
lines.append(prio_line)
for line in lines:
async with self.throttle:
self._writer.write(
f"{line.line.format()}\r\n".encode("utf8"))
for line in lines:
async with self.throttle:
self._writer.write(
f"{line.line.format()}\r\n".encode("utf8"))
await self._writer.drain()
await self._writer.drain()
for line in lines:
await self._on_write_line(line.line)
await self.line_send(line.line)
line.future.set_result(line)
return [l.line for l in lines]
for line in lines:
await self._on_send_line(line.line)
await self.line_send(line.line)
line.future.set_result(line)
# CAP-related
def cap_agreed(self, capability: ICapability) -> bool: