diff --git a/ircrobots/bot.py b/ircrobots/bot.py index 9a2634b..85efd2a 100644 --- a/ircrobots/bot.py +++ b/ircrobots/bot.py @@ -5,20 +5,28 @@ from typing import Dict from .server import ConnectionParams, Server from .transport import TCPTransport +from .interface import IBot, IServer RECONNECT_DELAY = 10.0 # ten seconds reconnect -class Bot(object): +class Bot(IBot): def __init__(self): self.servers: Dict[str, Server] = {} self._server_queue: asyncio.Queue[Server] = asyncio.Queue() # methods designed to be overridden def create_server(self, name: str): - return Server(name) - async def disconnected(self, server: Server): - await asyncio.sleep(RECONNECT_DELAY) - await self.add_server(server.name, server.params) + return Server(self, name) + async def disconnected(self, server: IServer): + if (server.name in self.servers and + server.disconnected): + await asyncio.sleep(RECONNECT_DELAY) + await self.add_server(server.name, server.params) + # /methods designed to be overridden + + async def disconnect(self, server: IServer): + await server.disconnect() + del self.servers[server.name] async def add_server(self, name: str, params: ConnectionParams) -> Server: server = self.create_server(name) @@ -31,7 +39,9 @@ class Bot(object): async with anyio.create_task_group() as tg: async def _read(): while not tg.cancel_scope.cancel_called: - line, emits = await server.next_line() + both = await server.next_line() + if both is None: + break await tg.cancel_scope.cancel() async def _write(): @@ -42,7 +52,6 @@ class Bot(object): await tg.spawn(_write) await tg.spawn(_read) - del self.servers[server.name] await self.disconnected(server) async def run(self): diff --git a/ircrobots/interface.py b/ircrobots/interface.py index 586f798..7bf4e43 100644 --- a/ircrobots/interface.py +++ b/ircrobots/interface.py @@ -5,7 +5,7 @@ from enum import IntEnum from ircstates import Server, Emit from irctokens import Line -from .params import ConnectionParams, SASLParams +from .params import ConnectionParams, SASLParams, STSPolicy class ITCPReader(object): async def read(self, byte_count: int): @@ -49,7 +49,7 @@ class ICapability(object): def available(self, capabilities: Iterable[str]) -> Optional[str]: pass - def match(self, capability: str) -> Optional[str]: + def match(self, capability: str) -> bool: pass def copy(self) -> "ICapability": @@ -63,6 +63,7 @@ class IMatchResponseParam(object): pass class IServer(Server): + disconnected: bool params: ConnectionParams desired_caps: Set[ICapability] @@ -83,13 +84,17 @@ class IServer(Server): transport: ITCPTransport, params: ConnectionParams): pass + async def disconnect(self): + pass async def line_read(self, line: Line): pass async def line_send(self, line: Line): pass + def sts_policy(self, sts: STSPolicy): + pass - async def next_line(self) -> Tuple[Line, List[Emit]]: + async def next_line(self) -> Optional[Tuple[Line, List[Emit]]]: pass def cap_agreed(self, capability: ICapability) -> bool: @@ -99,3 +104,18 @@ class IServer(Server): async def sasl_auth(self, sasl: SASLParams) -> bool: pass + +class IBot(object): + def create_server(self, name: str) -> IServer: + pass + async def disconnected(self, server: IServer): + pass + + async def disconnect(self, server: IServer): + pass + + async def add_server(self, name: str, params: ConnectionParams) -> IServer: + pass + + async def run(self): + pass diff --git a/ircrobots/ircv3.py b/ircrobots/ircv3.py index de1ea57..833b46b 100644 --- a/ircrobots/ircv3.py +++ b/ircrobots/ircv3.py @@ -6,7 +6,7 @@ from irctokens import build from .contexts import ServerContext from .matching import Response, ResponseOr, ParamAny, ParamLiteral from .interface import ICapability -from .params import STSPolicy +from .params import ConnectionParams, STSPolicy class Capability(ICapability): def __init__(self, @@ -19,16 +19,18 @@ class Capability(ICapability): self.alias = alias or ratified_name self.depends_on = depends_on.copy() - self._caps = set((ratified_name, draft_name)) + self._caps = [ratified_name, draft_name] + + def match(self, capability: str) -> bool: + return capability in self._caps def available(self, capabilities: Iterable[str] ) -> Optional[str]: - match = list(set(capabilities)&self._caps) - return match[0] if match else None - - def match(self, capability: str) -> Optional[str]: - cap = list(set([capability])&self._caps) - return cap[0] if cap else None + for cap in self._caps: + if not cap is None and cap in capabilities: + return cap + else: + return None def copy(self): return Capability( @@ -40,6 +42,7 @@ class Capability(ICapability): CAP_SASL = Capability("sasl") CAP_ECHO = Capability("echo-message") CAP_LABEL = Capability("labeled-response", "draft/labeled-response-0.2") +CAP_STS = Capability("sts", "draft/sts") LABEL_TAG = { "draft/labeled-response-0.2": "draft/label", @@ -65,6 +68,13 @@ CAPS: List[ICapability] = [ Capability("setname", "draft/setname") ] +def _cap_dict(s: str) -> Dict[str, str]: + d: Dict[str, str] = {} + for token in s.split(","): + key, _, value = token.partition("=") + d[key] = value + return d + class CAPContext(ServerContext): async def on_ls(self, tokens: Dict[str, str]): caps = list(self.server.desired_caps)+CAPS @@ -94,18 +104,34 @@ class CAPContext(ServerContext): await self.server.sasl_auth(self.server.params.sasl) async def handshake(self): + cap_sts = CAP_STS.available(self.server.available_caps) + if not cap_sts is None: + sts_dict = _cap_dict(self.server.available_caps[cap_sts]) + params = self.server.params + if not params.tls: + if "port" in sts_dict: + params.port = int(sts_dict["port"]) + params.tls = True + + await self.server.bot.disconnect(self.server) + await self.server.bot.add_server(self.server.name, params) + return + elif "duration" in sts_dict: + policy = STSPolicy( + int(time()), + params.port, + int(sts_dict["duration"]), + "preload" in sts_dict) + self.server.sts_policy(policy) + await self.on_ls(self.server.available_caps) await self.server.send(build("CAP", ["END"])) class STSContext(ServerContext): - async def transmute(self, - port: int, - tls: bool, - sts: Optional[STSPolicy]) -> Tuple[int, bool]: - if not sts is None: + async def transmute(self, params: ConnectionParams): + if not params.sts is None and not params.tls: now = time() - since = (now-sts.created) - if since <= sts.duration: - return sts.port, True - - return port, tls + since = (now-params.sts.created) + if since <= params.sts.duration: + params.port = params.sts.port + params.tls = True diff --git a/ircrobots/server.py b/ircrobots/server.py index 7cf5c97..7e2fb86 100644 --- a/ircrobots/server.py +++ b/ircrobots/server.py @@ -5,6 +5,7 @@ from collections import deque from asyncio_throttle import Throttler from ircstates import Emit, Channel from ircstates.numerics import * +from ircstates.server import ServerDisconnectedException from irctokens import build, Line, tokenise from .ircv3 import (CAPContext, STSContext, CAP_ECHO, CAP_SASL, CAP_LABEL, @@ -14,9 +15,9 @@ from .join_info import WHOContext from .matching import ResponseOr, Responses, Response, ParamAny, ParamFolded from .asyncs import MaybeAwait from .struct import Whois - -from .interface import (ConnectionParams, ICapability, IServer, SentLine, - SendPriority, SASLParams, IMatchResponse) +from .params import ConnectionParams, SASLParams, STSPolicy +from .interface import (IBot, ICapability, IServer, SentLine, SendPriority, + IMatchResponse) from .interface import ITCPTransport, ITCPReader, ITCPWriter THROTTLE_RATE = 4 # lines @@ -27,15 +28,17 @@ class Server(IServer): _writer: ITCPWriter params: ConnectionParams - def __init__(self, name: str): + def __init__(self, bot: IBot, name: str): super().__init__(name) + self.bot = bot + + self.disconnected = False self.throttle = Throttler( rate_limit=100, period=THROTTLE_TIME) self.sasl_state = SASLResult.NONE - self._sent_count: int = 0 self._wait_for: List[Tuple["Future[Line]", IMatchResponse]] = [] self._write_queue: PriorityQueue[SentLine] = PriorityQueue() @@ -81,13 +84,12 @@ class Server(IServer): async def connect(self, transport: ITCPTransport, params: ConnectionParams): - port, tls = await STSContext(self).transmute( - params.port, params.tls, params.sts) + await STSContext(self).transmute(params) reader, writer = await transport.connect( params.host, - port, - tls =tls, + params.port, + tls =params.tls, tls_verify=params.tls_verify, bindhost =params.bindhost) @@ -96,6 +98,11 @@ class Server(IServer): self.params = params await self.handshake() + async def disconnect(self): + if not self._writer is None: + await self._writer.close() + self._writer = None + self._read_queue.clear() async def handshake(self): nickname = self.params.nickname @@ -133,16 +140,18 @@ class Server(IServer): if line.command == "PING": await self.send(build("PONG", line.params)) - async def line_read(self, line: Line): - pass - - async def next_line(self) -> Tuple[Line, List[Emit]]: + async def next_line(self) -> Optional[Tuple[Line, List[Emit]]]: if self._read_queue: both = self._read_queue.popleft() else: - data = await self._reader.read(1024) while True: - lines = self.recv(data) + data = await self._reader.read(1024) + try: + lines = self.recv(data) + except ServerDisconnectedException: + self.disconnected = True + return None + if lines: self._read_queue.extend(lines[1:]) both = lines[0] @@ -161,18 +170,20 @@ class Server(IServer): self._wait_for.append((our_fut, response)) while self._wait_for: both = await self.next_line() - line, emits = both - for i, (fut, waiting) in enumerate(self._wait_for): - if waiting.match(self, line): - fut.set_result(line) - self._wait_for.pop(i) - break + if not both is None: + line, emits = both + + for i, (fut, waiting) in enumerate(self._wait_for): + if waiting.match(self, line): + fut.set_result(line) + self._wait_for.pop(i) + break + else: + fut.set_result(build("")) return await our_fut - async def line_send(self, line: Line): - pass async def _on_write_line(self, line: Line): if (line.command == "PRIVMSG" and not self.cap_agreed(CAP_ECHO)): diff --git a/ircrobots/transport.py b/ircrobots/transport.py index 107cc72..2df636a 100644 --- a/ircrobots/transport.py +++ b/ircrobots/transport.py @@ -22,6 +22,10 @@ class TCPWriter(ITCPWriter): async def drain(self): await self._writer.drain() + async def close(self): + self._writer.close() + await self._writer.wait_closed() + class TCPTransport(ITCPTransport): async def connect(self, hostname: str,