fix PriorityQueue, use both Line and Emit, add CAP REQ

This commit is contained in:
jesopo 2020-04-01 23:06:41 +01:00
parent fb892c584e
commit fd934b1101
4 changed files with 115 additions and 16 deletions

View file

@ -19,8 +19,12 @@ class Bot(object):
async def disconnected(self, server: Server): async def disconnected(self, server: Server):
await asyncio.sleep(RECONNECT_DELAY) await asyncio.sleep(RECONNECT_DELAY)
await self.add_server(server.name, server.params) await self.add_server(server.name, server.params)
async def line_read(self, server: Server, line: Line): async def line_read(self, server: Server, line: Line):
pass pass
async def emit_read(self, server: Server, line: Line):
pass
async def line_send(self, server: Server, line: Line): async def line_send(self, server: Server, line: Line):
pass pass
@ -36,7 +40,9 @@ class Bot(object):
async def _read(): async def _read():
while not tg.cancel_scope.cancel_called: while not tg.cancel_scope.cancel_called:
lines = await server._read_lines() lines = await server._read_lines()
for line in lines: for line, emits in lines:
for emit in emits:
await self.emit_read(server, emit)
await self.line_read(server, line) await self.line_read(server, line)
await tg.cancel_scope.cancel() await tg.cancel_scope.cancel()

49
ircrobots/ircv3.py Normal file
View file

@ -0,0 +1,49 @@
from typing import Callable, Iterable, List, Optional
class Capability(object):
def __init__(self,
ratified_name: Optional[str],
draft_name: Optional[str]=None,
alias: Optional[str]=None,
depends_on: List[str]=[]):
self.name = ratified_name
self.draft = draft_name
self.alias = alias or ratified_name
self.depends_on = depends_on.copy()
self._caps = set((ratified_name, draft_name))
def available(self, capabilities: Iterable[str]
) -> Optional[str]:
match = list(set(capabilities)&self._caps)
return match[0] if match else None
def match(self, capability: str) -> Optional[str]:
cap = list(set([capability])&self._caps)
return cap[0] if cap else None
def copy(self):
return Capability(
self.name,
self.draft,
alias=self.alias,
depends_on=self.depends_on[:])
CAPS = [
Capability("multi-prefix"),
Capability("chghost"),
Capability("away-notify"),
Capability("userhost-in-names"),
Capability("invite-notify"),
Capability("account-tag"),
Capability("account-notify"),
Capability("extended-join"),
Capability("message-tags", "draft/message-tags-0.2"),
Capability("cap-notify"),
Capability("batch"),
Capability(None, "draft/rename", alias="rename"),
Capability("setname", "draft/setname")
]

View file

@ -1,13 +1,16 @@
import asyncio, ssl import asyncio, ssl
from asyncio import PriorityQueue
from queue import Queue from queue import Queue
from typing import Callable, Dict, List, Optional, Tuple from typing import Callable, Dict, List, Optional, Set, Tuple
from enum import Enum from enum import IntEnum
from dataclasses import dataclass from dataclasses import dataclass
from asyncio_throttle import Throttler from asyncio_throttle import Throttler
from ircstates import Server as BaseServer from ircstates import Server as BaseServer
from ircstates import Emit
from irctokens import build, Line, tokenise from irctokens import build, Line, tokenise
from .ircv3 import Capability, CAPS
sc = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) sc = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
THROTTLE_RATE = 4 # lines THROTTLE_RATE = 4 # lines
@ -24,13 +27,19 @@ class ConnectionParams(object):
realname: Optional[str] = None realname: Optional[str] = None
bindhost: Optional[str] = None bindhost: Optional[str] = None
class SendPriority(Enum): class SendPriority(IntEnum):
HIGH = 0 HIGH = 0
MEDIUM = 10 MEDIUM = 10
LOW = 20 LOW = 20
DEFAULT = MEDIUM DEFAULT = MEDIUM
class PriorityLine(object):
def __init__(self, priority: int, line: Line):
self.priority = priority
self.line = line
def __lt__(self, other: "PriorityLine") -> bool:
return self.priority < other.priority
class Server(BaseServer): class Server(BaseServer):
_reader: asyncio.StreamReader _reader: asyncio.StreamReader
_writer: asyncio.StreamWriter _writer: asyncio.StreamWriter
@ -38,14 +47,19 @@ class Server(BaseServer):
def __init__(self, name: str): def __init__(self, name: str):
super().__init__(name) super().__init__(name)
self.throttle = Throttler( self.throttle = Throttler(
rate_limit=THROTTLE_RATE, period=THROTTLE_TIME) rate_limit=THROTTLE_RATE, period=THROTTLE_TIME)
self._write_queue: asyncio.PriorityQueue[Tuple[int, Line]] = asyncio.PriorityQueue()
self._write_queue: PriorityQueue[PriorityLine] = PriorityQueue()
self._cap_queue: Set[Capability] = set([])
self._requested_caps: List[str] = []
async def send_raw(self, line: str, priority=SendPriority.DEFAULT): async def send_raw(self, line: str, priority=SendPriority.DEFAULT):
await self.send(tokenise(line), priority) await self.send(tokenise(line), priority)
async def send(self, line: Line, priority=SendPriority.DEFAULT): async def send(self, line: Line, priority=SendPriority.DEFAULT):
await self._write_queue.put((priority, line)) await self._write_queue.put(PriorityLine(priority, line))
def set_throttle(self, rate: int, time: float): def set_throttle(self, rate: int, time: float):
self.throttle.rate_limit = rate self.throttle.rate_limit = rate
@ -62,19 +76,49 @@ class Server(BaseServer):
username = params.username or nickname username = params.username or nickname
realname = params.realname or nickname realname = params.realname or nickname
await self.send(build("CAP", ["LS"]))
await self.send(build("NICK", [nickname])) await self.send(build("NICK", [nickname]))
await self.send(build("USER", [username, "0", "*", realname])) await self.send(build("USER", [username, "0", "*", realname]))
self.params = params self.params = params
async def line_received(self, line: Line): async def queue_capability(self, cap: Capability):
self._cap_queue.add(cap)
async def _cap_ls_done(self):
caps = CAPS+list(self._cap_queue)
self._cap_queue.clear()
matches = list(filter(bool,
(c.available(self.available_caps) for c in caps)))
if matches:
self._requested_caps = matches
await self.send(build("CAP", ["REQ", " ".join(matches)]))
async def _cap_ack(self, line: Line):
caps = line.params[2].split(" ")
for cap in caps:
if cap in self._requested_caps:
self._requested_caps.remove(cap)
if not self._requested_caps:
await self.send(build("CAP", ["END"]))
async def _on_read_emit(self, line: Line, emit: Emit):
if emit.command == "CAP":
if emit.subcommand == "LS" and emit.finished:
await self._cap_ls_done()
elif emit.subcommand in ["ACK", "NAK"]:
await self._cap_ack(line)
async def _on_read_line(self, line: Line):
pass pass
async def _read_lines(self) -> List[Line]:
async def _read_lines(self) -> List[Tuple[Line, List[Emit]]]:
data = await self._reader.read(1024) data = await self._reader.read(1024)
lines = self.recv(data) lines = self.recv(data)
for line in lines:
print(f"{self.name}< {line.format()}") for line, emits in lines:
await self.line_received(line) for emit in emits:
await self._on_read_emit(line, emit)
await self._on_read_line(line)
return lines return lines
async def line_written(self, line: Line): async def line_written(self, line: Line):
@ -84,8 +128,8 @@ class Server(BaseServer):
while (not lines or while (not lines or
(len(lines) < 5 and self._write_queue.qsize() > 0)): (len(lines) < 5 and self._write_queue.qsize() > 0)):
prio, line = await self._write_queue.get() prio_line = await self._write_queue.get()
lines.append(line) lines.append(prio_line.line)
for line in lines: for line in lines:
async with self.throttle: async with self.throttle:

View file

@ -1,2 +1,2 @@
ircstates ==0.7.0 ircstates ==0.8.0
asyncio-throttle ==0.1.1 asyncio-throttle ==1.0.1