diff --git a/ircrobots/interface.py b/ircrobots/interface.py index 6798518..751a0e5 100644 --- a/ircrobots/interface.py +++ b/ircrobots/interface.py @@ -1,5 +1,6 @@ -from typing import Awaitable, Iterable, List, Optional -from enum import IntEnum +from asyncio import Future +from typing import Awaitable, Iterable, List, Optional +from enum import IntEnum from ircstates import Server from irctokens import Line @@ -13,11 +14,12 @@ class SendPriority(IntEnum): LOW = 20 DEFAULT = MEDIUM -class PriorityLine(object): +class SentLine(object): def __init__(self, priority: int, line: Line): - self.priority = priority - self.line = line - def __lt__(self, other: "PriorityLine") -> bool: + self.priority = priority + self.line = line + self.future: Future = Future() + def __lt__(self, other: "SentLine") -> bool: return self.priority < other.priority class ICapability(object): diff --git a/ircrobots/server.py b/ircrobots/server.py index 9b2018a..09da06d 100644 --- a/ircrobots/server.py +++ b/ircrobots/server.py @@ -8,7 +8,7 @@ from ircstates import Emit from irctokens import build, Line, tokenise from .ircv3 import CAPContext, CAPS, CAP_SASL -from .interface import (ConnectionParams, ICapability, IServer, PriorityLine, +from .interface import (ConnectionParams, ICapability, IServer, SentLine, SendPriority) from .matching import BaseResponse from .sasl import SASLContext, SASLResult @@ -29,7 +29,7 @@ class Server(IServer): rate_limit=THROTTLE_RATE, period=THROTTLE_TIME) self.sasl_state = SASLResult.NONE - self._write_queue: PriorityQueue[PriorityLine] = PriorityQueue() + self._write_queue: PriorityQueue[SentLine] = PriorityQueue() self._cap_queue: Set[ICapability] = set([]) self._wait_for: List[Tuple[BaseResponse, Future]] = [] @@ -37,7 +37,9 @@ class Server(IServer): async def send_raw(self, line: str, priority=SendPriority.DEFAULT): await self.send(tokenise(line), priority) async def send(self, line: Line, priority=SendPriority.DEFAULT): - await self._write_queue.put(PriorityLine(priority, line)) + prio_line = SentLine(priority, line) + await self._write_queue.put(prio_line) + await prio_line.future def set_throttle(self, rate: int, time: float): self.throttle.rate_limit = rate @@ -100,18 +102,24 @@ class Server(IServer): async def line_written(self, line: Line): pass async def _write_lines(self) -> List[Line]: - lines: List[Line] = [] + lines: List[SentLine] = [] while (not lines or (len(lines) < 5 and self._write_queue.qsize() > 0)): prio_line = await self._write_queue.get() - lines.append(prio_line.line) + lines.append(prio_line) for line in lines: async with self.throttle: - self._writer.write(f"{line.format()}\r\n".encode("utf8")) + self._writer.write( + f"{line.line.format()}\r\n".encode("utf8")) + await self._writer.drain() - return lines + + for line in lines: + line.future.set_result(None) + + return [l.line for l in lines] # CAP-related async def queue_capability(self, cap: ICapability):