grab WHO on JOIN but make sure we only do one at once
This commit is contained in:
parent
d15e9ee361
commit
ebef439e0a
3 changed files with 45 additions and 1 deletions
42
ircrobots/join_info.py
Normal file
42
ircrobots/join_info.py
Normal file
|
@ -0,0 +1,42 @@
|
||||||
|
from typing import Dict, Iterable, List, Optional
|
||||||
|
from irctokens import build
|
||||||
|
from ircstates.numerics import *
|
||||||
|
|
||||||
|
from .contexts import ServerContext
|
||||||
|
from .matching import Response, ResponseOr, ParamAny, ParamFolded
|
||||||
|
|
||||||
|
"""
|
||||||
|
class JoinContext(ServerContext):
|
||||||
|
async def enlighten(self, channels: List[str]):
|
||||||
|
folded = [self.server.casefold(c) for c in channels]
|
||||||
|
waiting = len(folded)
|
||||||
|
while waiting:
|
||||||
|
line = await self.server.wait_for(ResponseOr(
|
||||||
|
Response("JOIN", [ParamAny()])
|
||||||
|
))
|
||||||
|
|
||||||
|
if (line.command == "JOIN" and
|
||||||
|
self.server.casefold(line.params[0]) in folded):
|
||||||
|
waiting -= 1
|
||||||
|
|
||||||
|
for channel in folded:
|
||||||
|
await self.server.send(build("WHO", [channel]))
|
||||||
|
line = await self.wait_for(
|
||||||
|
Response(RPL_ENDOFWHO, [ParamAny(), ParamFolded(channel)])
|
||||||
|
)
|
||||||
|
|
||||||
|
return [self.server.channels[c] for c in folded]
|
||||||
|
"""
|
||||||
|
|
||||||
|
class WHOContext(ServerContext):
|
||||||
|
async def ensure(self, channel: str):
|
||||||
|
folded = self.server.casefold(channel)
|
||||||
|
|
||||||
|
if self.server.isupport.whox:
|
||||||
|
await self.server.send(self.server.prepare_whox(channel))
|
||||||
|
else:
|
||||||
|
await self.server.send(build("WHO", [channel]))
|
||||||
|
|
||||||
|
line = await self.server.wait_for(
|
||||||
|
Response(RPL_ENDOFWHO, [ParamAny(), ParamFolded(folded)])
|
||||||
|
)
|
|
@ -9,6 +9,7 @@ from irctokens import build, Line, tokenise
|
||||||
|
|
||||||
from .ircv3 import CAPContext, CAP_ECHO, CAP_SASL, CAP_LABEL, LABEL_TAG
|
from .ircv3 import CAPContext, CAP_ECHO, CAP_SASL, CAP_LABEL, LABEL_TAG
|
||||||
from .sasl import SASLContext, SASLResult
|
from .sasl import SASLContext, SASLResult
|
||||||
|
from .join_info import WHOContext
|
||||||
from .matching import ResponseOr, Responses, Response, ParamAny, ParamFolded
|
from .matching import ResponseOr, Responses, Response, ParamAny, ParamFolded
|
||||||
from .asyncs import MaybeAwait
|
from .asyncs import MaybeAwait
|
||||||
from .struct import Whois
|
from .struct import Whois
|
||||||
|
@ -122,6 +123,7 @@ class Server(IServer):
|
||||||
elif emit.command == "JOIN":
|
elif emit.command == "JOIN":
|
||||||
if emit.self and not emit.channel is None:
|
if emit.self and not emit.channel is None:
|
||||||
await self.send(build("MODE", [emit.channel.name]))
|
await self.send(build("MODE", [emit.channel.name]))
|
||||||
|
await WHOContext(self).ensure(emit.channel.name)
|
||||||
|
|
||||||
async def _on_read_line(self, line: Line):
|
async def _on_read_line(self, line: Line):
|
||||||
if line.command == "PING":
|
if line.command == "PING":
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
anyio ==1.3.0
|
anyio ==1.3.0
|
||||||
asyncio-throttle ==1.0.1
|
asyncio-throttle ==1.0.1
|
||||||
dataclasses ==0.6
|
dataclasses ==0.6
|
||||||
ircstates ==0.9.1
|
ircstates ==0.9.2
|
||||||
async_stagger ==0.3.0
|
async_stagger ==0.3.0
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue