diff --git a/ircrobots/bot.py b/ircrobots/bot.py index bef2e11..e3c5bfe 100644 --- a/ircrobots/bot.py +++ b/ircrobots/bot.py @@ -19,8 +19,12 @@ class Bot(object): async def disconnected(self, server: Server): await asyncio.sleep(RECONNECT_DELAY) await self.add_server(server.name, server.params) + async def line_read(self, server: Server, line: Line): pass + async def emit_read(self, server: Server, line: Line): + pass + async def line_send(self, server: Server, line: Line): pass @@ -36,7 +40,9 @@ class Bot(object): async def _read(): while not tg.cancel_scope.cancel_called: lines = await server._read_lines() - for line in lines: + for line, emits in lines: + for emit in emits: + await self.emit_read(server, emit) await self.line_read(server, line) await tg.cancel_scope.cancel() diff --git a/ircrobots/ircv3.py b/ircrobots/ircv3.py new file mode 100644 index 0000000..7c4ec3d --- /dev/null +++ b/ircrobots/ircv3.py @@ -0,0 +1,49 @@ +from typing import Callable, Iterable, List, Optional + +class Capability(object): + def __init__(self, + ratified_name: Optional[str], + draft_name: Optional[str]=None, + alias: Optional[str]=None, + depends_on: List[str]=[]): + self.name = ratified_name + self.draft = draft_name + self.alias = alias or ratified_name + self.depends_on = depends_on.copy() + + self._caps = set((ratified_name, draft_name)) + + def available(self, capabilities: Iterable[str] + ) -> Optional[str]: + match = list(set(capabilities)&self._caps) + return match[0] if match else None + + def match(self, capability: str) -> Optional[str]: + cap = list(set([capability])&self._caps) + return cap[0] if cap else None + + def copy(self): + return Capability( + self.name, + self.draft, + alias=self.alias, + depends_on=self.depends_on[:]) + +CAPS = [ + Capability("multi-prefix"), + Capability("chghost"), + Capability("away-notify"), + Capability("userhost-in-names"), + + Capability("invite-notify"), + Capability("account-tag"), + Capability("account-notify"), + Capability("extended-join"), + + Capability("message-tags", "draft/message-tags-0.2"), + Capability("cap-notify"), + Capability("batch"), + + Capability(None, "draft/rename", alias="rename"), + Capability("setname", "draft/setname") +] diff --git a/ircrobots/server.py b/ircrobots/server.py index 5018a37..0540fc0 100644 --- a/ircrobots/server.py +++ b/ircrobots/server.py @@ -1,13 +1,16 @@ import asyncio, ssl +from asyncio import PriorityQueue from queue import Queue -from typing import Callable, Dict, List, Optional, Tuple -from enum import Enum +from typing import Callable, Dict, List, Optional, Set, Tuple +from enum import IntEnum from dataclasses import dataclass from asyncio_throttle import Throttler from ircstates import Server as BaseServer +from ircstates import Emit from irctokens import build, Line, tokenise +from .ircv3 import Capability, CAPS sc = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) THROTTLE_RATE = 4 # lines @@ -24,13 +27,19 @@ class ConnectionParams(object): realname: Optional[str] = None bindhost: Optional[str] = None -class SendPriority(Enum): +class SendPriority(IntEnum): HIGH = 0 MEDIUM = 10 LOW = 20 - DEFAULT = MEDIUM +class PriorityLine(object): + def __init__(self, priority: int, line: Line): + self.priority = priority + self.line = line + def __lt__(self, other: "PriorityLine") -> bool: + return self.priority < other.priority + class Server(BaseServer): _reader: asyncio.StreamReader _writer: asyncio.StreamWriter @@ -38,14 +47,19 @@ class Server(BaseServer): def __init__(self, name: str): super().__init__(name) + self.throttle = Throttler( rate_limit=THROTTLE_RATE, period=THROTTLE_TIME) - self._write_queue: asyncio.PriorityQueue[Tuple[int, Line]] = asyncio.PriorityQueue() + + self._write_queue: PriorityQueue[PriorityLine] = PriorityQueue() + + self._cap_queue: Set[Capability] = set([]) + self._requested_caps: List[str] = [] async def send_raw(self, line: str, priority=SendPriority.DEFAULT): await self.send(tokenise(line), priority) async def send(self, line: Line, priority=SendPriority.DEFAULT): - await self._write_queue.put((priority, line)) + await self._write_queue.put(PriorityLine(priority, line)) def set_throttle(self, rate: int, time: float): self.throttle.rate_limit = rate @@ -62,19 +76,49 @@ class Server(BaseServer): username = params.username or nickname realname = params.realname or nickname + await self.send(build("CAP", ["LS"])) await self.send(build("NICK", [nickname])) await self.send(build("USER", [username, "0", "*", realname])) self.params = params - async def line_received(self, line: Line): + async def queue_capability(self, cap: Capability): + self._cap_queue.add(cap) + async def _cap_ls_done(self): + caps = CAPS+list(self._cap_queue) + self._cap_queue.clear() + + matches = list(filter(bool, + (c.available(self.available_caps) for c in caps))) + if matches: + self._requested_caps = matches + await self.send(build("CAP", ["REQ", " ".join(matches)])) + async def _cap_ack(self, line: Line): + caps = line.params[2].split(" ") + for cap in caps: + if cap in self._requested_caps: + self._requested_caps.remove(cap) + if not self._requested_caps: + await self.send(build("CAP", ["END"])) + + async def _on_read_emit(self, line: Line, emit: Emit): + if emit.command == "CAP": + if emit.subcommand == "LS" and emit.finished: + await self._cap_ls_done() + elif emit.subcommand in ["ACK", "NAK"]: + await self._cap_ack(line) + + async def _on_read_line(self, line: Line): pass - async def _read_lines(self) -> List[Line]: + + async def _read_lines(self) -> List[Tuple[Line, List[Emit]]]: data = await self._reader.read(1024) lines = self.recv(data) - for line in lines: - print(f"{self.name}< {line.format()}") - await self.line_received(line) + + for line, emits in lines: + for emit in emits: + await self._on_read_emit(line, emit) + await self._on_read_line(line) return lines async def line_written(self, line: Line): @@ -84,8 +128,8 @@ class Server(BaseServer): while (not lines or (len(lines) < 5 and self._write_queue.qsize() > 0)): - prio, line = await self._write_queue.get() - lines.append(line) + prio_line = await self._write_queue.get() + lines.append(prio_line.line) for line in lines: async with self.throttle: diff --git a/requirements.txt b/requirements.txt index ecb74bc..18f582f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ -ircstates ==0.7.0 -asyncio-throttle ==0.1.1 +ircstates ==0.8.0 +asyncio-throttle ==1.0.1