await server.send() should block until it hits the wire

This commit is contained in:
jesopo 2020-04-02 23:53:32 +01:00
parent 8dde7b7216
commit b34e4fcc22
2 changed files with 23 additions and 13 deletions

View file

@ -1,5 +1,6 @@
from typing import Awaitable, Iterable, List, Optional from asyncio import Future
from enum import IntEnum from typing import Awaitable, Iterable, List, Optional
from enum import IntEnum
from ircstates import Server from ircstates import Server
from irctokens import Line from irctokens import Line
@ -13,11 +14,12 @@ class SendPriority(IntEnum):
LOW = 20 LOW = 20
DEFAULT = MEDIUM DEFAULT = MEDIUM
class PriorityLine(object): class SentLine(object):
def __init__(self, priority: int, line: Line): def __init__(self, priority: int, line: Line):
self.priority = priority self.priority = priority
self.line = line self.line = line
def __lt__(self, other: "PriorityLine") -> bool: self.future: Future = Future()
def __lt__(self, other: "SentLine") -> bool:
return self.priority < other.priority return self.priority < other.priority
class ICapability(object): class ICapability(object):

View file

@ -8,7 +8,7 @@ from ircstates import Emit
from irctokens import build, Line, tokenise from irctokens import build, Line, tokenise
from .ircv3 import CAPContext, CAPS, CAP_SASL from .ircv3 import CAPContext, CAPS, CAP_SASL
from .interface import (ConnectionParams, ICapability, IServer, PriorityLine, from .interface import (ConnectionParams, ICapability, IServer, SentLine,
SendPriority) SendPriority)
from .matching import BaseResponse from .matching import BaseResponse
from .sasl import SASLContext, SASLResult from .sasl import SASLContext, SASLResult
@ -29,7 +29,7 @@ class Server(IServer):
rate_limit=THROTTLE_RATE, period=THROTTLE_TIME) rate_limit=THROTTLE_RATE, period=THROTTLE_TIME)
self.sasl_state = SASLResult.NONE self.sasl_state = SASLResult.NONE
self._write_queue: PriorityQueue[PriorityLine] = PriorityQueue() self._write_queue: PriorityQueue[SentLine] = PriorityQueue()
self._cap_queue: Set[ICapability] = set([]) self._cap_queue: Set[ICapability] = set([])
self._wait_for: List[Tuple[BaseResponse, Future]] = [] self._wait_for: List[Tuple[BaseResponse, Future]] = []
@ -37,7 +37,9 @@ class Server(IServer):
async def send_raw(self, line: str, priority=SendPriority.DEFAULT): async def send_raw(self, line: str, priority=SendPriority.DEFAULT):
await self.send(tokenise(line), priority) await self.send(tokenise(line), priority)
async def send(self, line: Line, priority=SendPriority.DEFAULT): async def send(self, line: Line, priority=SendPriority.DEFAULT):
await self._write_queue.put(PriorityLine(priority, line)) prio_line = SentLine(priority, line)
await self._write_queue.put(prio_line)
await prio_line.future
def set_throttle(self, rate: int, time: float): def set_throttle(self, rate: int, time: float):
self.throttle.rate_limit = rate self.throttle.rate_limit = rate
@ -100,18 +102,24 @@ class Server(IServer):
async def line_written(self, line: Line): async def line_written(self, line: Line):
pass pass
async def _write_lines(self) -> List[Line]: async def _write_lines(self) -> List[Line]:
lines: List[Line] = [] lines: List[SentLine] = []
while (not lines or while (not lines or
(len(lines) < 5 and self._write_queue.qsize() > 0)): (len(lines) < 5 and self._write_queue.qsize() > 0)):
prio_line = await self._write_queue.get() prio_line = await self._write_queue.get()
lines.append(prio_line.line) lines.append(prio_line)
for line in lines: for line in lines:
async with self.throttle: async with self.throttle:
self._writer.write(f"{line.format()}\r\n".encode("utf8")) self._writer.write(
f"{line.line.format()}\r\n".encode("utf8"))
await self._writer.drain() await self._writer.drain()
return lines
for line in lines:
line.future.set_result(None)
return [l.line for l in lines]
# CAP-related # CAP-related
async def queue_capability(self, cap: ICapability): async def queue_capability(self, cap: ICapability):