rewrite how wait_for works - only one at a time now
This commit is contained in:
parent
a3abae811a
commit
51cfd0f36b
4 changed files with 95 additions and 85 deletions
|
@ -2,6 +2,7 @@ from asyncio import Future
|
|||
from irctokens import Line
|
||||
from typing import Any, Awaitable, Callable, Generator, Generic, TypeVar
|
||||
from .matching import IMatchResponse
|
||||
from .interface import IServer
|
||||
|
||||
TEvent = TypeVar("TEvent")
|
||||
class MaybeAwait(Generic[TEvent]):
|
||||
|
@ -20,5 +21,8 @@ class WaitFor(object):
|
|||
def __await__(self) -> Generator[Any, None, Line]:
|
||||
return self._fut.__await__()
|
||||
|
||||
def match(self, server: IServer, line: Line):
|
||||
return self.response.match(server, line)
|
||||
|
||||
def resolve(self, line: Line):
|
||||
self._fut.set_result(line)
|
||||
|
|
|
@ -37,21 +37,8 @@ class Bot(IBot):
|
|||
|
||||
async def _run_server(self, server: Server):
|
||||
async with anyio.create_task_group() as tg:
|
||||
async def _read():
|
||||
while not tg.cancel_scope.cancel_called:
|
||||
try:
|
||||
both = await server.next_line()
|
||||
except ServerDisconnectedException:
|
||||
break
|
||||
await tg.cancel_scope.cancel()
|
||||
|
||||
async def _write():
|
||||
while not tg.cancel_scope.cancel_called:
|
||||
lines = await server._write_lines()
|
||||
await tg.cancel_scope.cancel()
|
||||
|
||||
await tg.spawn(_write)
|
||||
await tg.spawn(_read)
|
||||
await tg.spawn(server._read_lines)
|
||||
await tg.spawn(server._send_lines)
|
||||
|
||||
await self.disconnected(server)
|
||||
|
||||
|
|
|
@ -103,6 +103,10 @@ class IServer(Server):
|
|||
async def disconnect(self):
|
||||
pass
|
||||
|
||||
def line_preread(self, line: Line):
|
||||
pass
|
||||
def line_presend(self, line: Line):
|
||||
pass
|
||||
async def line_read(self, line: Line):
|
||||
pass
|
||||
async def line_send(self, line: Line):
|
||||
|
@ -112,9 +116,6 @@ class IServer(Server):
|
|||
async def resume_policy(self, resume: ResumePolicy):
|
||||
pass
|
||||
|
||||
async def next_line(self) -> Optional[Tuple[Line, Optional[Emit]]]:
|
||||
pass
|
||||
|
||||
def cap_agreed(self, capability: ICapability) -> bool:
|
||||
pass
|
||||
def cap_available(self, capability: ICapability) -> Optional[str]:
|
||||
|
|
|
@ -59,7 +59,7 @@ class Server(IServer):
|
|||
self.last_read = -1.0
|
||||
|
||||
self._sent_count: int = 0
|
||||
self._write_queue: PriorityQueue[SentLine] = PriorityQueue()
|
||||
self._send_queue: PriorityQueue[SentLine] = PriorityQueue()
|
||||
self.desired_caps: Set[ICapability] = set([])
|
||||
|
||||
self._read_queue: Deque[Tuple[Line, Optional[Emit]]] = deque()
|
||||
|
@ -80,6 +80,7 @@ class Server(IServer):
|
|||
return self.send(tokenise(line), priority)
|
||||
def send(self, line: Line, priority=SendPriority.DEFAULT
|
||||
) -> Awaitable[SentLine]:
|
||||
self.line_presend(line)
|
||||
sent_line = SentLine(self._sent_count, priority, line)
|
||||
self._sent_count += 1
|
||||
|
||||
|
@ -91,7 +92,7 @@ class Server(IServer):
|
|||
line.tags = {}
|
||||
line.tags[tag] = str(sent_line.id)
|
||||
|
||||
self._write_queue.put_nowait(sent_line)
|
||||
self._send_queue.put_nowait(sent_line)
|
||||
|
||||
async def _assure() -> SentLine:
|
||||
await sent_line.future
|
||||
|
@ -142,6 +143,10 @@ class Server(IServer):
|
|||
self.send(build("USER", [username, "0", "*", realname]))
|
||||
|
||||
# to be overridden
|
||||
def line_preread(self, line: Line):
|
||||
pass
|
||||
def line_presend(self, line: Line):
|
||||
pass
|
||||
async def line_read(self, line: Line):
|
||||
pass
|
||||
async def line_send(self, line: Line):
|
||||
|
@ -152,42 +157,33 @@ class Server(IServer):
|
|||
pass
|
||||
# /to be overriden
|
||||
|
||||
async def _on_read_emit(self, line: Line, emit: Emit):
|
||||
if emit.command == "001":
|
||||
await self.send(build("WHO", [self.nickname]))
|
||||
self.set_throttle(THROTTLE_RATE, THROTTLE_TIME)
|
||||
|
||||
elif emit.command == "CAP":
|
||||
if emit.subcommand == "NEW":
|
||||
await self._cap_ls(emit)
|
||||
elif (emit.subcommand == "LS" and
|
||||
emit.finished):
|
||||
if not self.registered:
|
||||
await CAPContext(self).handshake()
|
||||
else:
|
||||
await self._cap_ls(emit)
|
||||
|
||||
elif emit.command == "JOIN":
|
||||
if emit.self and not emit.channel is None:
|
||||
await self.send(build("MODE", [emit.channel.name]))
|
||||
await WHOContext(self).ensure(emit.channel.name)
|
||||
|
||||
async def _on_read_line(self, line: Line):
|
||||
async def _on_read(self, line: Line, emit: Optional[Emit]):
|
||||
if line.command == "PING":
|
||||
await self.send(build("PONG", line.params))
|
||||
|
||||
async def _line_or_wait(self, line_aw: Awaitable):
|
||||
wait_for_fut: Future[WaitFor] = Future()
|
||||
self._wait_for_fut = wait_for_fut
|
||||
elif emit is not None:
|
||||
if emit.command == "001":
|
||||
await self.send(build("WHO", [self.nickname]))
|
||||
self.set_throttle(THROTTLE_RATE, THROTTLE_TIME)
|
||||
|
||||
done, pend = await asyncio.wait([line_aw, wait_for_fut],
|
||||
return_when=asyncio.FIRST_COMPLETED)
|
||||
elif emit.command == "CAP":
|
||||
if emit.subcommand == "NEW":
|
||||
await self._cap_ls(emit)
|
||||
elif (emit.subcommand == "LS" and
|
||||
emit.finished):
|
||||
if not self.registered:
|
||||
await CAPContext(self).handshake()
|
||||
else:
|
||||
await self._cap_ls(emit)
|
||||
|
||||
if wait_for_fut.done():
|
||||
new_line_aw = list(pend)[0]
|
||||
self._wait_for.append((new_line_aw, await wait_for_fut))
|
||||
elif emit.command == "JOIN":
|
||||
if emit.self and not emit.channel is None:
|
||||
await self.send(build("MODE", [emit.channel.name]))
|
||||
await WHOContext(self).ensure(emit.channel.name)
|
||||
|
||||
async def next_line(self) -> Tuple[Line, Optional[Emit]]:
|
||||
await self.line_read(line)
|
||||
|
||||
async def _next_line(self) -> Tuple[Line, Optional[Emit]]:
|
||||
if self._read_queue:
|
||||
both = self._read_queue.popleft()
|
||||
else:
|
||||
|
@ -217,25 +213,48 @@ class Server(IServer):
|
|||
self._read_queue.extend(lines[1:])
|
||||
both = lines[0]
|
||||
break
|
||||
|
||||
line, emit = both
|
||||
async def _line():
|
||||
if emit is not None:
|
||||
await self._on_read_emit(line, emit)
|
||||
await self._on_read_line(line)
|
||||
await self.line_read(line)
|
||||
|
||||
for i, (aw, wait_for) in enumerate(self._wait_for):
|
||||
if wait_for.response.match(self, line):
|
||||
wait_for.resolve(line)
|
||||
self._wait_for.pop(i)
|
||||
await self._line_or_wait(aw)
|
||||
break
|
||||
|
||||
await self._line_or_wait(_line())
|
||||
|
||||
return both
|
||||
|
||||
async def _line_or_wait(self, line_aw: Awaitable):
|
||||
wait_for_fut: Future[WaitFor] = Future()
|
||||
self._wait_for_fut = wait_for_fut
|
||||
|
||||
done, pend = await asyncio.wait([line_aw, wait_for_fut],
|
||||
return_when=asyncio.FIRST_COMPLETED)
|
||||
|
||||
if wait_for_fut.done():
|
||||
new_line_aw = list(pend)[0]
|
||||
return (await wait_for_fut), new_line_aw
|
||||
else:
|
||||
return None, None
|
||||
|
||||
async def _read_lines(self):
|
||||
waited_reads: List[Tuple[Line, Optional[Emit]]] = []
|
||||
wait_for: Optional[WaitFor] = None
|
||||
wait_for_aw: Optional[Awaitable] = None
|
||||
|
||||
while True:
|
||||
if not waited_reads or wait_for is not None:
|
||||
both = await self._next_line()
|
||||
waited_reads.append(both)
|
||||
line, emit = both
|
||||
self.line_preread(line)
|
||||
|
||||
if wait_for is not None:
|
||||
line, emit = waited_reads[-1]
|
||||
if wait_for.response.match(self, line):
|
||||
wait_for.resolve(line)
|
||||
|
||||
wait_for, wait_for_aw = await self._line_or_wait(
|
||||
wait_for_aw)
|
||||
else:
|
||||
while waited_reads:
|
||||
new_line, new_emit = waited_reads.pop(0)
|
||||
line_aw = self._on_read(new_line, new_emit)
|
||||
wait_for, wait_for_aw = await self._line_or_wait(line_aw)
|
||||
if wait_for is not None:
|
||||
break
|
||||
|
||||
async def wait_for(self,
|
||||
response: Union[IMatchResponse, Set[IMatchResponse]]
|
||||
) -> Line:
|
||||
|
@ -254,34 +273,33 @@ class Server(IServer):
|
|||
return await our_wait_for
|
||||
raise Exception()
|
||||
|
||||
async def _on_write_line(self, line: Line):
|
||||
async def _on_send_line(self, line: Line):
|
||||
if (line.command == "PRIVMSG" and
|
||||
not self.cap_agreed(CAP_ECHO)):
|
||||
new_line = line.with_source(self.hostmask())
|
||||
emit = self.parse_tokens(new_line)
|
||||
self._read_queue.append((new_line, emit))
|
||||
|
||||
async def _write_lines(self) -> List[Line]:
|
||||
lines: List[SentLine] = []
|
||||
async def _send_lines(self):
|
||||
while True:
|
||||
lines: List[SentLine] = []
|
||||
|
||||
while (not lines or
|
||||
(len(lines) < 5 and self._write_queue.qsize() > 0)):
|
||||
prio_line = await self._write_queue.get()
|
||||
lines.append(prio_line)
|
||||
while (not lines or
|
||||
(len(lines) < 5 and self._send_queue.qsize() > 0)):
|
||||
prio_line = await self._send_queue.get()
|
||||
lines.append(prio_line)
|
||||
|
||||
for line in lines:
|
||||
async with self.throttle:
|
||||
self._writer.write(
|
||||
f"{line.line.format()}\r\n".encode("utf8"))
|
||||
for line in lines:
|
||||
async with self.throttle:
|
||||
self._writer.write(
|
||||
f"{line.line.format()}\r\n".encode("utf8"))
|
||||
|
||||
await self._writer.drain()
|
||||
await self._writer.drain()
|
||||
|
||||
for line in lines:
|
||||
await self._on_write_line(line.line)
|
||||
await self.line_send(line.line)
|
||||
line.future.set_result(line)
|
||||
|
||||
return [l.line for l in lines]
|
||||
for line in lines:
|
||||
await self._on_send_line(line.line)
|
||||
await self.line_send(line.line)
|
||||
line.future.set_result(line)
|
||||
|
||||
# CAP-related
|
||||
def cap_agreed(self, capability: ICapability) -> bool:
|
||||
|
|
Loading…
Reference in a new issue