173 lines
5.7 KiB
173 lines
5.7 KiB
import asyncio
from ssl import SSLContext
from asyncio import Future, PriorityQueue, Queue
from typing import Dict, List, Optional, Set, Tuple
from asyncio_throttle import Throttler
from ircstates import Emit
from irctokens import build, Line, tokenise
from .ircv3 import CAPContext, CAP_SASL
from .interface import (ConnectionParams, ICapability, IServer, SentLine,
SendPriority, SASLParams, IMatchResponse)
from .sasl import SASLContext, SASLResult
from .security import ssl_context
THROTTLE_RATE = 4 # lines
THROTTLE_TIME = 2 # seconds
class Server(IServer):
_reader: asyncio.StreamReader
_writer: asyncio.StreamWriter
params: ConnectionParams
def __init__(self, name: str):
self.throttle = Throttler(
rate_limit=100, period=THROTTLE_TIME)
self.sasl_state = SASLResult.NONE
self._wait_for_cache: List[Tuple[Line, List[Emit]]] = []
self._write_queue: PriorityQueue[SentLine] = PriorityQueue()
self._read_queue: Queue[Tuple[Line, List[Emit]]] = Queue()
self.desired_caps: Set[ICapability] = set([])
async def send_raw(self, line: str, priority=SendPriority.DEFAULT
) -> Future:
return await self.send(tokenise(line), priority)
async def send(self, line: Line, priority=SendPriority.DEFAULT) -> Future:
prio_line = SentLine(priority, line)
await self._write_queue.put(prio_line)
return prio_line.future
def set_throttle(self, rate: int, time: float):
self.throttle.rate_limit = rate
self.throttle.period = time
async def connect(self, params: ConnectionParams):
cur_ssl: Optional[SSLContext] = None
if params.tls:
cur_ssl = ssl_context(params.tls_verify)
reader, writer = await asyncio.open_connection(
local_addr=(params.bindhost, 0))
self._reader = reader
self._writer = writer
self.params = params
await self.handshake()
async def handshake(self):
nickname = self.params.nickname
username = self.params.username or nickname
realname = self.params.realname or nickname
await self.send(build("CAP", ["LS", "302"]))
await self.send(build("NICK", [nickname]))
await self.send(build("USER", [username, "0", "*", realname]))
async def _on_read_emit(self, line: Line, emit: Emit):
if emit.command == "001":
elif emit.command == "CAP":
if emit.subcommand == "NEW":
await self._cap_ls(emit)
elif (emit.subcommand == "LS" and
if not self.registered:
await CAPContext(self).handshake()
await self._cap_ls(emit)
elif emit.command == "JOIN":
if emit.self and not emit.channel is None:
await self.send(build("MODE", [emit.channel.name]))
async def _on_read_line(self, line: Line):
if line.command == "PING":
await self.send(build("PONG", line.params))
async def line_read(self, line: Line):
async def _read_lines(self) -> List[Tuple[Line, List[Emit]]]:
data = await self._reader.read(1024)
lines = self.recv(data)
for line, emits in lines:
for emit in emits:
await self._on_read_emit(line, emit)
await self._on_read_line(line)
await self.line_read(line)
await self._read_queue.put((line, emits))
return lines
async def next_line(self) -> Line:
line, emits = await self._read_queue.get()
return line
async def wait_for(self, response: IMatchResponse) -> Line:
while True:
lines = self._wait_for_cache.copy()
if not lines:
lines += await self._read_lines()
for i, (line, emits) in enumerate(lines):
if response.match(self, line):
self._wait_for_cache = lines[i+1:]
return line
async def line_send(self, line: Line):
async def _write_lines(self) -> List[Line]:
lines: List[SentLine] = []
while (not lines or
(len(lines) < 5 and self._write_queue.qsize() > 0)):
prio_line = await self._write_queue.get()
for line in lines:
async with self.throttle:
await self._writer.drain()
for line in lines:
return [l.line for l in lines]
# CAP-related
def cap_agreed(self, capability: ICapability) -> bool:
return bool(self.cap_available(capability))
def cap_available(self, capability: ICapability) -> Optional[str]:
return capability.available(self.agreed_caps)
async def _cap_ls(self, emit: Emit):
if not emit.tokens is None:
tokens: Dict[str, str] = {}
for token in emit.tokens:
key, _, value = token.partition("=")
tokens[key] = value
await CAPContext(self).on_ls(tokens)
async def sasl_auth(self, params: SASLParams) -> bool:
if (self.sasl_state == SASLResult.NONE and
res = await SASLContext(self).from_params(params)
self.sasl_state = res
return True
return False
# /CAP-related