diff --git a/ircrobots/join_info.py b/ircrobots/join_info.py new file mode 100644 index 0000000..ab4cf01 --- /dev/null +++ b/ircrobots/join_info.py @@ -0,0 +1,42 @@ +from typing import Dict, Iterable, List, Optional +from irctokens import build +from ircstates.numerics import * + +from .contexts import ServerContext +from .matching import Response, ResponseOr, ParamAny, ParamFolded + +""" +class JoinContext(ServerContext): + async def enlighten(self, channels: List[str]): + folded = [self.server.casefold(c) for c in channels] + waiting = len(folded) + while waiting: + line = await self.server.wait_for(ResponseOr( + Response("JOIN", [ParamAny()]) + )) + + if (line.command == "JOIN" and + self.server.casefold(line.params[0]) in folded): + waiting -= 1 + + for channel in folded: + await self.server.send(build("WHO", [channel])) + line = await self.wait_for( + Response(RPL_ENDOFWHO, [ParamAny(), ParamFolded(channel)]) + ) + + return [self.server.channels[c] for c in folded] +""" + +class WHOContext(ServerContext): + async def ensure(self, channel: str): + folded = self.server.casefold(channel) + + if self.server.isupport.whox: + await self.server.send(self.server.prepare_whox(channel)) + else: + await self.server.send(build("WHO", [channel])) + + line = await self.server.wait_for( + Response(RPL_ENDOFWHO, [ParamAny(), ParamFolded(folded)]) + ) diff --git a/ircrobots/server.py b/ircrobots/server.py index acf4093..f7da968 100644 --- a/ircrobots/server.py +++ b/ircrobots/server.py @@ -9,6 +9,7 @@ from irctokens import build, Line, tokenise from .ircv3 import CAPContext, CAP_ECHO, CAP_SASL, CAP_LABEL, LABEL_TAG from .sasl import SASLContext, SASLResult +from .join_info import WHOContext from .matching import ResponseOr, Responses, Response, ParamAny, ParamFolded from .asyncs import MaybeAwait from .struct import Whois @@ -122,6 +123,7 @@ class Server(IServer): elif emit.command == "JOIN": if emit.self and not emit.channel is None: await self.send(build("MODE", [emit.channel.name])) + await WHOContext(self).ensure(emit.channel.name) async def _on_read_line(self, line: Line): if line.command == "PING": diff --git a/requirements.txt b/requirements.txt index 740e364..26aca75 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ anyio ==1.3.0 asyncio-throttle ==1.0.1 dataclasses ==0.6 -ircstates ==0.9.1 +ircstates ==0.9.2 async_stagger ==0.3.0