diff --git a/ircrobots/__init__.py b/ircrobots/__init__.py index d96b355..034c9d1 100644 --- a/ircrobots/__init__.py +++ b/ircrobots/__init__.py @@ -1,3 +1,4 @@ from .bot import Bot from .server import Server from .params import ConnectionParams, SASLUserPass, SASLExternal +from .ircv3 import Capability diff --git a/ircrobots/bot.py b/ircrobots/bot.py index 29c700b..a4bdc07 100644 --- a/ircrobots/bot.py +++ b/ircrobots/bot.py @@ -56,6 +56,7 @@ class Bot(object): print(e) await tg.cancel_scope.cancel() + await tg.spawn(server.handshake) await tg.spawn(_read) await tg.spawn(_write) del self.servers[server.name] diff --git a/ircrobots/interface.py b/ircrobots/interface.py index 4d2a734..6798518 100644 --- a/ircrobots/interface.py +++ b/ircrobots/interface.py @@ -1,10 +1,9 @@ -from typing import Awaitable +from typing import Awaitable, Iterable, List, Optional from enum import IntEnum from ircstates import Server from irctokens import Line -from .ircv3 import Capability from .matching import BaseResponse from .params import ConnectionParams @@ -21,6 +20,16 @@ class PriorityLine(object): def __lt__(self, other: "PriorityLine") -> bool: return self.priority < other.priority +class ICapability(object): + def available(self, capabilities: Iterable[str]) -> Optional[str]: + pass + + def match(self, capability: str) -> Optional[str]: + pass + + def copy(self) -> "ICapability": + pass + class IServer(Server): params: ConnectionParams @@ -38,8 +47,19 @@ class IServer(Server): async def connect(self, params: ConnectionParams): pass - async def queue_capability(self, cap: Capability): + async def queue_capability(self, cap: ICapability): pass async def line_written(self, line: Line): pass + + def cap_agreed(self, capability: ICapability) -> bool: + pass + def cap_available(self, capability: ICapability) -> Optional[str]: + pass + + def collect_caps(self) -> List[str]: + pass + + async def maybe_sasl(self) -> bool: + pass diff --git a/ircrobots/ircv3.py b/ircrobots/ircv3.py index 40abb0d..d35dd9e 100644 --- a/ircrobots/ircv3.py +++ b/ircrobots/ircv3.py @@ -1,6 +1,11 @@ -from typing import Iterable, List, Optional +from typing import Iterable, List, Optional +from irctokens import build -class Capability(object): +from .contexts import ServerContext +from .matching import Response, ResponseOr, ParamAny, ParamNot, ParamLiteral +from .interface import ICapability + +class Capability(ICapability): def __init__(self, ratified_name: Optional[str], draft_name: Optional[str]=None, @@ -30,7 +35,7 @@ class Capability(object): depends_on=self.depends_on[:]) CAP_SASL = Capability("sasl") -CAPS = [ +CAPS: List[ICapability] = [ Capability("multi-prefix"), Capability("chghost"), Capability("away-notify"), @@ -48,3 +53,37 @@ CAPS = [ Capability(None, "draft/rename", alias="rename"), Capability("setname", "draft/setname") ] + +class CAPContext(ServerContext): + async def handshake(self) -> bool: + # improve this by being able to wait_for Emit objects + line = await self.server.wait_for(Response( + "CAP", + [ParamAny(), ParamLiteral("LS"), ParamNot(ParamLiteral("*"))], + errors=["001"])) + + if line.command == "CAP": + caps = self.server.collect_caps() + if caps: + await self.server.send(build("CAP", + ["REQ", " ".join(caps)])) + + while caps: + line = await self.server.wait_for(ResponseOr( + Response("CAP", [ParamAny(), ParamLiteral("ACK")]), + Response("CAP", [ParamAny(), ParamLiteral("NAK")]) + )) + + current_caps = line.params[2].split(" ") + for cap in current_caps: + if cap in caps: + caps.remove(cap) + + if self.server.cap_agreed(CAP_SASL): + await self.server.maybe_sasl() + + await self.server.send(build("CAP", ["END"])) + return True + else: + return False + diff --git a/ircrobots/matching.py b/ircrobots/matching.py index 52fed7c..3fe43cd 100644 --- a/ircrobots/matching.py +++ b/ircrobots/matching.py @@ -39,15 +39,25 @@ class Response(BaseResponse): else: return False +class ResponseOr(BaseResponse): + def __init__(self, *responses: BaseResponse): + self._responses = responses + def match(self, line: Line) -> bool: + for response in self._responses: + if response.match(line): + return True + else: + return False + class ParamAny(ResponseParam): def match(self, arg: str) -> bool: return True -class Literal(ResponseParam): +class ParamLiteral(ResponseParam): def __init__(self, value: str): self._value = value def match(self, arg: str) -> bool: return self._value == arg -class Not(ResponseParam): +class ParamNot(ResponseParam): def __init__(self, param: ResponseParam): self._param = param def match(self, arg: str) -> bool: diff --git a/ircrobots/server.py b/ircrobots/server.py index f8a9bc9..e824b06 100644 --- a/ircrobots/server.py +++ b/ircrobots/server.py @@ -6,8 +6,9 @@ from asyncio_throttle import Throttler from ircstates import Emit from irctokens import build, Line, tokenise -from .ircv3 import Capability, CAPS, CAP_SASL -from .interface import ConnectionParams, IServer, PriorityLine, SendPriority +from .ircv3 import CAPContext, CAPS, CAP_SASL +from .interface import (ConnectionParams, ICapability, IServer, PriorityLine, + SendPriority) from .matching import BaseResponse from .sasl import SASLContext, SASLResult @@ -29,9 +30,7 @@ class Server(IServer): self.sasl_state = SASLResult.NONE self._write_queue: PriorityQueue[PriorityLine] = PriorityQueue() - - self._cap_queue: Set[Capability] = set([]) - self._requested_caps: List[str] = [] + self._cap_queue: Set[ICapability] = set([]) self._wait_for: List[Tuple[BaseResponse, Future]] = [] @@ -50,24 +49,22 @@ class Server(IServer): params.host, params.port, ssl=cur_ssl) self._reader = reader self._writer = writer + self.params = params - nickname = params.nickname - username = params.username or nickname - realname = params.realname or nickname + async def handshake(self): + nickname = self.params.nickname + username = self.params.username or nickname + realname = self.params.realname or nickname await self.send(build("CAP", ["LS"])) await self.send(build("NICK", [nickname])) await self.send(build("USER", [username, "0", "*", realname])) - self.params = params + await CAPContext(self).handshake() async def _on_read_emit(self, line: Line, emit: Emit): if emit.command == "CAP": - if emit.subcommand == "LS" and emit.finished: - await self._cap_ls_done() - elif emit.subcommand in ["ACK", "NAK"]: - await self._cap_ack(emit) - elif emit.subcommand == "NEW": + if emit.subcommand == "NEW": await self._cap_new(emit) async def _on_read_line(self, line: Line): @@ -107,43 +104,31 @@ class Server(IServer): return lines # CAP-related - async def queue_capability(self, cap: Capability): + async def queue_capability(self, cap: ICapability): self._cap_queue.add(cap) - def cap_agreed(self, capability: Capability) -> bool: + def cap_agreed(self, capability: ICapability) -> bool: return bool(self.cap_available(capability)) - def cap_available(self, capability: Capability) -> Optional[str]: + def cap_available(self, capability: ICapability) -> Optional[str]: return capability.available(self.agreed_caps) - async def _cap_ls_done(self): + def collect_caps(self) -> List[str]: caps = CAPS+list(self._cap_queue) self._cap_queue.clear() if not self.params.sasl is None: caps.append(CAP_SASL) - matches = list(filter(bool, - (c.available(self.available_caps) for c in caps))) - if matches: - self._requested_caps = matches - await self.send(build("CAP", ["REQ", " ".join(matches)])) - - async def _cap_ack(self, emit: Emit): - await self._maybe_sasl() - - for cap in (emit.tokens or []): - if cap in self._requested_caps: - self._requested_caps.remove(cap) - if not self._requested_caps: - await self.send(build("CAP", ["END"])) + matched = [c.available(self.available_caps) for c in caps] + return [name for name in matched if not name is None] async def _cap_new(self, emit: Emit): if not emit.tokens is None: tokens = [t.split("=", 1)[0] for t in emit.tokens] if CAP_SASL.available(tokens): - await self._maybe_sasl() + await self.maybe_sasl() - async def _maybe_sasl(self) -> bool: + async def maybe_sasl(self) -> bool: if (self.sasl_state == SASLResult.NONE and not self.params.sasl is None and self.cap_agreed(CAP_SASL)):