fix PriorityQueue, use both Line and Emit, add CAP REQ
This commit is contained in:
parent
fb892c584e
commit
fd934b1101
4 changed files with 115 additions and 16 deletions
|
@ -19,8 +19,12 @@ class Bot(object):
|
||||||
async def disconnected(self, server: Server):
|
async def disconnected(self, server: Server):
|
||||||
await asyncio.sleep(RECONNECT_DELAY)
|
await asyncio.sleep(RECONNECT_DELAY)
|
||||||
await self.add_server(server.name, server.params)
|
await self.add_server(server.name, server.params)
|
||||||
|
|
||||||
async def line_read(self, server: Server, line: Line):
|
async def line_read(self, server: Server, line: Line):
|
||||||
pass
|
pass
|
||||||
|
async def emit_read(self, server: Server, line: Line):
|
||||||
|
pass
|
||||||
|
|
||||||
async def line_send(self, server: Server, line: Line):
|
async def line_send(self, server: Server, line: Line):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -36,7 +40,9 @@ class Bot(object):
|
||||||
async def _read():
|
async def _read():
|
||||||
while not tg.cancel_scope.cancel_called:
|
while not tg.cancel_scope.cancel_called:
|
||||||
lines = await server._read_lines()
|
lines = await server._read_lines()
|
||||||
for line in lines:
|
for line, emits in lines:
|
||||||
|
for emit in emits:
|
||||||
|
await self.emit_read(server, emit)
|
||||||
await self.line_read(server, line)
|
await self.line_read(server, line)
|
||||||
await tg.cancel_scope.cancel()
|
await tg.cancel_scope.cancel()
|
||||||
|
|
||||||
|
|
49
ircrobots/ircv3.py
Normal file
49
ircrobots/ircv3.py
Normal file
|
@ -0,0 +1,49 @@
|
||||||
|
from typing import Callable, Iterable, List, Optional
|
||||||
|
|
||||||
|
class Capability(object):
|
||||||
|
def __init__(self,
|
||||||
|
ratified_name: Optional[str],
|
||||||
|
draft_name: Optional[str]=None,
|
||||||
|
alias: Optional[str]=None,
|
||||||
|
depends_on: List[str]=[]):
|
||||||
|
self.name = ratified_name
|
||||||
|
self.draft = draft_name
|
||||||
|
self.alias = alias or ratified_name
|
||||||
|
self.depends_on = depends_on.copy()
|
||||||
|
|
||||||
|
self._caps = set((ratified_name, draft_name))
|
||||||
|
|
||||||
|
def available(self, capabilities: Iterable[str]
|
||||||
|
) -> Optional[str]:
|
||||||
|
match = list(set(capabilities)&self._caps)
|
||||||
|
return match[0] if match else None
|
||||||
|
|
||||||
|
def match(self, capability: str) -> Optional[str]:
|
||||||
|
cap = list(set([capability])&self._caps)
|
||||||
|
return cap[0] if cap else None
|
||||||
|
|
||||||
|
def copy(self):
|
||||||
|
return Capability(
|
||||||
|
self.name,
|
||||||
|
self.draft,
|
||||||
|
alias=self.alias,
|
||||||
|
depends_on=self.depends_on[:])
|
||||||
|
|
||||||
|
CAPS = [
|
||||||
|
Capability("multi-prefix"),
|
||||||
|
Capability("chghost"),
|
||||||
|
Capability("away-notify"),
|
||||||
|
Capability("userhost-in-names"),
|
||||||
|
|
||||||
|
Capability("invite-notify"),
|
||||||
|
Capability("account-tag"),
|
||||||
|
Capability("account-notify"),
|
||||||
|
Capability("extended-join"),
|
||||||
|
|
||||||
|
Capability("message-tags", "draft/message-tags-0.2"),
|
||||||
|
Capability("cap-notify"),
|
||||||
|
Capability("batch"),
|
||||||
|
|
||||||
|
Capability(None, "draft/rename", alias="rename"),
|
||||||
|
Capability("setname", "draft/setname")
|
||||||
|
]
|
|
@ -1,13 +1,16 @@
|
||||||
import asyncio, ssl
|
import asyncio, ssl
|
||||||
|
from asyncio import PriorityQueue
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from typing import Callable, Dict, List, Optional, Tuple
|
from typing import Callable, Dict, List, Optional, Set, Tuple
|
||||||
from enum import Enum
|
from enum import IntEnum
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from asyncio_throttle import Throttler
|
from asyncio_throttle import Throttler
|
||||||
from ircstates import Server as BaseServer
|
from ircstates import Server as BaseServer
|
||||||
|
from ircstates import Emit
|
||||||
from irctokens import build, Line, tokenise
|
from irctokens import build, Line, tokenise
|
||||||
|
|
||||||
|
from .ircv3 import Capability, CAPS
|
||||||
sc = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
|
sc = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
|
||||||
|
|
||||||
THROTTLE_RATE = 4 # lines
|
THROTTLE_RATE = 4 # lines
|
||||||
|
@ -24,13 +27,19 @@ class ConnectionParams(object):
|
||||||
realname: Optional[str] = None
|
realname: Optional[str] = None
|
||||||
bindhost: Optional[str] = None
|
bindhost: Optional[str] = None
|
||||||
|
|
||||||
class SendPriority(Enum):
|
class SendPriority(IntEnum):
|
||||||
HIGH = 0
|
HIGH = 0
|
||||||
MEDIUM = 10
|
MEDIUM = 10
|
||||||
LOW = 20
|
LOW = 20
|
||||||
|
|
||||||
DEFAULT = MEDIUM
|
DEFAULT = MEDIUM
|
||||||
|
|
||||||
|
class PriorityLine(object):
|
||||||
|
def __init__(self, priority: int, line: Line):
|
||||||
|
self.priority = priority
|
||||||
|
self.line = line
|
||||||
|
def __lt__(self, other: "PriorityLine") -> bool:
|
||||||
|
return self.priority < other.priority
|
||||||
|
|
||||||
class Server(BaseServer):
|
class Server(BaseServer):
|
||||||
_reader: asyncio.StreamReader
|
_reader: asyncio.StreamReader
|
||||||
_writer: asyncio.StreamWriter
|
_writer: asyncio.StreamWriter
|
||||||
|
@ -38,14 +47,19 @@ class Server(BaseServer):
|
||||||
|
|
||||||
def __init__(self, name: str):
|
def __init__(self, name: str):
|
||||||
super().__init__(name)
|
super().__init__(name)
|
||||||
|
|
||||||
self.throttle = Throttler(
|
self.throttle = Throttler(
|
||||||
rate_limit=THROTTLE_RATE, period=THROTTLE_TIME)
|
rate_limit=THROTTLE_RATE, period=THROTTLE_TIME)
|
||||||
self._write_queue: asyncio.PriorityQueue[Tuple[int, Line]] = asyncio.PriorityQueue()
|
|
||||||
|
self._write_queue: PriorityQueue[PriorityLine] = PriorityQueue()
|
||||||
|
|
||||||
|
self._cap_queue: Set[Capability] = set([])
|
||||||
|
self._requested_caps: List[str] = []
|
||||||
|
|
||||||
async def send_raw(self, line: str, priority=SendPriority.DEFAULT):
|
async def send_raw(self, line: str, priority=SendPriority.DEFAULT):
|
||||||
await self.send(tokenise(line), priority)
|
await self.send(tokenise(line), priority)
|
||||||
async def send(self, line: Line, priority=SendPriority.DEFAULT):
|
async def send(self, line: Line, priority=SendPriority.DEFAULT):
|
||||||
await self._write_queue.put((priority, line))
|
await self._write_queue.put(PriorityLine(priority, line))
|
||||||
|
|
||||||
def set_throttle(self, rate: int, time: float):
|
def set_throttle(self, rate: int, time: float):
|
||||||
self.throttle.rate_limit = rate
|
self.throttle.rate_limit = rate
|
||||||
|
@ -62,19 +76,49 @@ class Server(BaseServer):
|
||||||
username = params.username or nickname
|
username = params.username or nickname
|
||||||
realname = params.realname or nickname
|
realname = params.realname or nickname
|
||||||
|
|
||||||
|
await self.send(build("CAP", ["LS"]))
|
||||||
await self.send(build("NICK", [nickname]))
|
await self.send(build("NICK", [nickname]))
|
||||||
await self.send(build("USER", [username, "0", "*", realname]))
|
await self.send(build("USER", [username, "0", "*", realname]))
|
||||||
|
|
||||||
self.params = params
|
self.params = params
|
||||||
|
|
||||||
async def line_received(self, line: Line):
|
async def queue_capability(self, cap: Capability):
|
||||||
|
self._cap_queue.add(cap)
|
||||||
|
async def _cap_ls_done(self):
|
||||||
|
caps = CAPS+list(self._cap_queue)
|
||||||
|
self._cap_queue.clear()
|
||||||
|
|
||||||
|
matches = list(filter(bool,
|
||||||
|
(c.available(self.available_caps) for c in caps)))
|
||||||
|
if matches:
|
||||||
|
self._requested_caps = matches
|
||||||
|
await self.send(build("CAP", ["REQ", " ".join(matches)]))
|
||||||
|
async def _cap_ack(self, line: Line):
|
||||||
|
caps = line.params[2].split(" ")
|
||||||
|
for cap in caps:
|
||||||
|
if cap in self._requested_caps:
|
||||||
|
self._requested_caps.remove(cap)
|
||||||
|
if not self._requested_caps:
|
||||||
|
await self.send(build("CAP", ["END"]))
|
||||||
|
|
||||||
|
async def _on_read_emit(self, line: Line, emit: Emit):
|
||||||
|
if emit.command == "CAP":
|
||||||
|
if emit.subcommand == "LS" and emit.finished:
|
||||||
|
await self._cap_ls_done()
|
||||||
|
elif emit.subcommand in ["ACK", "NAK"]:
|
||||||
|
await self._cap_ack(line)
|
||||||
|
|
||||||
|
async def _on_read_line(self, line: Line):
|
||||||
pass
|
pass
|
||||||
async def _read_lines(self) -> List[Line]:
|
|
||||||
|
async def _read_lines(self) -> List[Tuple[Line, List[Emit]]]:
|
||||||
data = await self._reader.read(1024)
|
data = await self._reader.read(1024)
|
||||||
lines = self.recv(data)
|
lines = self.recv(data)
|
||||||
for line in lines:
|
|
||||||
print(f"{self.name}< {line.format()}")
|
for line, emits in lines:
|
||||||
await self.line_received(line)
|
for emit in emits:
|
||||||
|
await self._on_read_emit(line, emit)
|
||||||
|
await self._on_read_line(line)
|
||||||
return lines
|
return lines
|
||||||
|
|
||||||
async def line_written(self, line: Line):
|
async def line_written(self, line: Line):
|
||||||
|
@ -84,8 +128,8 @@ class Server(BaseServer):
|
||||||
|
|
||||||
while (not lines or
|
while (not lines or
|
||||||
(len(lines) < 5 and self._write_queue.qsize() > 0)):
|
(len(lines) < 5 and self._write_queue.qsize() > 0)):
|
||||||
prio, line = await self._write_queue.get()
|
prio_line = await self._write_queue.get()
|
||||||
lines.append(line)
|
lines.append(prio_line.line)
|
||||||
|
|
||||||
for line in lines:
|
for line in lines:
|
||||||
async with self.throttle:
|
async with self.throttle:
|
||||||
|
|
|
@ -1,2 +1,2 @@
|
||||||
ircstates ==0.7.0
|
ircstates ==0.8.0
|
||||||
asyncio-throttle ==0.1.1
|
asyncio-throttle ==1.0.1
|
||||||
|
|
Loading…
Reference in a new issue