fix PriorityQueue, use both Line and Emit, add CAP REQ
This commit is contained in:
parent
fb892c584e
commit
fd934b1101
4 changed files with 115 additions and 16 deletions
|
@ -19,8 +19,12 @@ class Bot(object):
|
|||
async def disconnected(self, server: Server):
|
||||
await asyncio.sleep(RECONNECT_DELAY)
|
||||
await self.add_server(server.name, server.params)
|
||||
|
||||
async def line_read(self, server: Server, line: Line):
|
||||
pass
|
||||
async def emit_read(self, server: Server, line: Line):
|
||||
pass
|
||||
|
||||
async def line_send(self, server: Server, line: Line):
|
||||
pass
|
||||
|
||||
|
@ -36,7 +40,9 @@ class Bot(object):
|
|||
async def _read():
|
||||
while not tg.cancel_scope.cancel_called:
|
||||
lines = await server._read_lines()
|
||||
for line in lines:
|
||||
for line, emits in lines:
|
||||
for emit in emits:
|
||||
await self.emit_read(server, emit)
|
||||
await self.line_read(server, line)
|
||||
await tg.cancel_scope.cancel()
|
||||
|
||||
|
|
49
ircrobots/ircv3.py
Normal file
49
ircrobots/ircv3.py
Normal file
|
@ -0,0 +1,49 @@
|
|||
from typing import Callable, Iterable, List, Optional
|
||||
|
||||
class Capability(object):
|
||||
def __init__(self,
|
||||
ratified_name: Optional[str],
|
||||
draft_name: Optional[str]=None,
|
||||
alias: Optional[str]=None,
|
||||
depends_on: List[str]=[]):
|
||||
self.name = ratified_name
|
||||
self.draft = draft_name
|
||||
self.alias = alias or ratified_name
|
||||
self.depends_on = depends_on.copy()
|
||||
|
||||
self._caps = set((ratified_name, draft_name))
|
||||
|
||||
def available(self, capabilities: Iterable[str]
|
||||
) -> Optional[str]:
|
||||
match = list(set(capabilities)&self._caps)
|
||||
return match[0] if match else None
|
||||
|
||||
def match(self, capability: str) -> Optional[str]:
|
||||
cap = list(set([capability])&self._caps)
|
||||
return cap[0] if cap else None
|
||||
|
||||
def copy(self):
|
||||
return Capability(
|
||||
self.name,
|
||||
self.draft,
|
||||
alias=self.alias,
|
||||
depends_on=self.depends_on[:])
|
||||
|
||||
CAPS = [
|
||||
Capability("multi-prefix"),
|
||||
Capability("chghost"),
|
||||
Capability("away-notify"),
|
||||
Capability("userhost-in-names"),
|
||||
|
||||
Capability("invite-notify"),
|
||||
Capability("account-tag"),
|
||||
Capability("account-notify"),
|
||||
Capability("extended-join"),
|
||||
|
||||
Capability("message-tags", "draft/message-tags-0.2"),
|
||||
Capability("cap-notify"),
|
||||
Capability("batch"),
|
||||
|
||||
Capability(None, "draft/rename", alias="rename"),
|
||||
Capability("setname", "draft/setname")
|
||||
]
|
|
@ -1,13 +1,16 @@
|
|||
import asyncio, ssl
|
||||
from asyncio import PriorityQueue
|
||||
from queue import Queue
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
from enum import Enum
|
||||
from typing import Callable, Dict, List, Optional, Set, Tuple
|
||||
from enum import IntEnum
|
||||
from dataclasses import dataclass
|
||||
|
||||
from asyncio_throttle import Throttler
|
||||
from ircstates import Server as BaseServer
|
||||
from ircstates import Emit
|
||||
from irctokens import build, Line, tokenise
|
||||
|
||||
from .ircv3 import Capability, CAPS
|
||||
sc = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
|
||||
|
||||
THROTTLE_RATE = 4 # lines
|
||||
|
@ -24,13 +27,19 @@ class ConnectionParams(object):
|
|||
realname: Optional[str] = None
|
||||
bindhost: Optional[str] = None
|
||||
|
||||
class SendPriority(Enum):
|
||||
class SendPriority(IntEnum):
|
||||
HIGH = 0
|
||||
MEDIUM = 10
|
||||
LOW = 20
|
||||
|
||||
DEFAULT = MEDIUM
|
||||
|
||||
class PriorityLine(object):
|
||||
def __init__(self, priority: int, line: Line):
|
||||
self.priority = priority
|
||||
self.line = line
|
||||
def __lt__(self, other: "PriorityLine") -> bool:
|
||||
return self.priority < other.priority
|
||||
|
||||
class Server(BaseServer):
|
||||
_reader: asyncio.StreamReader
|
||||
_writer: asyncio.StreamWriter
|
||||
|
@ -38,14 +47,19 @@ class Server(BaseServer):
|
|||
|
||||
def __init__(self, name: str):
|
||||
super().__init__(name)
|
||||
|
||||
self.throttle = Throttler(
|
||||
rate_limit=THROTTLE_RATE, period=THROTTLE_TIME)
|
||||
self._write_queue: asyncio.PriorityQueue[Tuple[int, Line]] = asyncio.PriorityQueue()
|
||||
|
||||
self._write_queue: PriorityQueue[PriorityLine] = PriorityQueue()
|
||||
|
||||
self._cap_queue: Set[Capability] = set([])
|
||||
self._requested_caps: List[str] = []
|
||||
|
||||
async def send_raw(self, line: str, priority=SendPriority.DEFAULT):
|
||||
await self.send(tokenise(line), priority)
|
||||
async def send(self, line: Line, priority=SendPriority.DEFAULT):
|
||||
await self._write_queue.put((priority, line))
|
||||
await self._write_queue.put(PriorityLine(priority, line))
|
||||
|
||||
def set_throttle(self, rate: int, time: float):
|
||||
self.throttle.rate_limit = rate
|
||||
|
@ -62,19 +76,49 @@ class Server(BaseServer):
|
|||
username = params.username or nickname
|
||||
realname = params.realname or nickname
|
||||
|
||||
await self.send(build("CAP", ["LS"]))
|
||||
await self.send(build("NICK", [nickname]))
|
||||
await self.send(build("USER", [username, "0", "*", realname]))
|
||||
|
||||
self.params = params
|
||||
|
||||
async def line_received(self, line: Line):
|
||||
async def queue_capability(self, cap: Capability):
|
||||
self._cap_queue.add(cap)
|
||||
async def _cap_ls_done(self):
|
||||
caps = CAPS+list(self._cap_queue)
|
||||
self._cap_queue.clear()
|
||||
|
||||
matches = list(filter(bool,
|
||||
(c.available(self.available_caps) for c in caps)))
|
||||
if matches:
|
||||
self._requested_caps = matches
|
||||
await self.send(build("CAP", ["REQ", " ".join(matches)]))
|
||||
async def _cap_ack(self, line: Line):
|
||||
caps = line.params[2].split(" ")
|
||||
for cap in caps:
|
||||
if cap in self._requested_caps:
|
||||
self._requested_caps.remove(cap)
|
||||
if not self._requested_caps:
|
||||
await self.send(build("CAP", ["END"]))
|
||||
|
||||
async def _on_read_emit(self, line: Line, emit: Emit):
|
||||
if emit.command == "CAP":
|
||||
if emit.subcommand == "LS" and emit.finished:
|
||||
await self._cap_ls_done()
|
||||
elif emit.subcommand in ["ACK", "NAK"]:
|
||||
await self._cap_ack(line)
|
||||
|
||||
async def _on_read_line(self, line: Line):
|
||||
pass
|
||||
async def _read_lines(self) -> List[Line]:
|
||||
|
||||
async def _read_lines(self) -> List[Tuple[Line, List[Emit]]]:
|
||||
data = await self._reader.read(1024)
|
||||
lines = self.recv(data)
|
||||
for line in lines:
|
||||
print(f"{self.name}< {line.format()}")
|
||||
await self.line_received(line)
|
||||
|
||||
for line, emits in lines:
|
||||
for emit in emits:
|
||||
await self._on_read_emit(line, emit)
|
||||
await self._on_read_line(line)
|
||||
return lines
|
||||
|
||||
async def line_written(self, line: Line):
|
||||
|
@ -84,8 +128,8 @@ class Server(BaseServer):
|
|||
|
||||
while (not lines or
|
||||
(len(lines) < 5 and self._write_queue.qsize() > 0)):
|
||||
prio, line = await self._write_queue.get()
|
||||
lines.append(line)
|
||||
prio_line = await self._write_queue.get()
|
||||
lines.append(prio_line.line)
|
||||
|
||||
for line in lines:
|
||||
async with self.throttle:
|
||||
|
|
|
@ -1,2 +1,2 @@
|
|||
ircstates ==0.7.0
|
||||
asyncio-throttle ==0.1.1
|
||||
ircstates ==0.8.0
|
||||
asyncio-throttle ==1.0.1
|
||||
|
|
Loading…
Reference in a new issue