diff --git a/ircrobots/bot.py b/ircrobots/bot.py index eba51fc..b177151 100644 --- a/ircrobots/bot.py +++ b/ircrobots/bot.py @@ -55,7 +55,6 @@ class Bot(object): await tg.spawn(_write) await tg.spawn(_read) - await server.handshake() await tg.spawn(_read_query) del self.servers[server.name] diff --git a/ircrobots/ircv3.py b/ircrobots/ircv3.py index 3d81a32..c4aa495 100644 --- a/ircrobots/ircv3.py +++ b/ircrobots/ircv3.py @@ -83,20 +83,6 @@ class CAPContext(ServerContext): not self.server.params.sasl is None): await self.server.sasl_auth(self.server.params.sasl) - async def handshake(self) -> bool: - # improve this by being able to wait_for Emit objects - line = await self.server.wait_for(ResponseOr( - Response( - "CAP", - [ParamAny(), ParamLiteral("LS"), ParamNot(ParamLiteral("*"))] - ), - Numerics(["RPL_WELCOME"]) - )) - - if line.command == "CAP": - await self.on_ls(self.server.available_caps) - await self.server.send(build("CAP", ["END"])) - return True - else: - return False - + async def handshake(self): + await self.on_ls(self.server.available_caps) + await self.server.send(build("CAP", ["END"])) diff --git a/ircrobots/server.py b/ircrobots/server.py index b664e37..cfedbe6 100644 --- a/ircrobots/server.py +++ b/ircrobots/server.py @@ -1,7 +1,7 @@ import asyncio from ssl import SSLContext from asyncio import Future, PriorityQueue, Queue -from typing import List, Optional, Set, Tuple +from typing import Dict, List, Optional, Set, Tuple from asyncio_throttle import Throttler from ircstates import Emit @@ -59,7 +59,9 @@ class Server(IServer): self._reader = reader self._writer = writer + self.params = params + await self.handshake() async def handshake(self): nickname = self.params.nickname @@ -70,14 +72,20 @@ class Server(IServer): await self.send(build("NICK", [nickname])) await self.send(build("USER", [username, "0", "*", realname])) - await CAPContext(self).handshake() - async def _on_read_emit(self, line: Line, emit: Emit): - if emit.command == "001": + if emit.command == "001": self.set_throttle(THROTTLE_RATE, THROTTLE_TIME) - elif emit.command == "CAP": - if emit.subcommand == "NEW": - await self._cap_new(emit) + + elif emit.command == "CAP": + if emit.subcommand == "NEW": + await self._cap_ls(emit) + elif (emit.subcommand == "LS" and + emit.finished): + if not self.registered: + await CAPContext(self).handshake() + else: + await self._cap_ls(emit) + elif emit.command == "JOIN": if emit.self and not emit.channel is None: await self.send(build("MODE", [emit.channel.name])) @@ -145,11 +153,13 @@ class Server(IServer): def cap_available(self, capability: ICapability) -> Optional[str]: return capability.available(self.agreed_caps) - async def _cap_new(self, emit: Emit): + async def _cap_ls(self, emit: Emit): if not emit.tokens is None: - tokens = [t.split("=", 1)[0] for t in emit.tokens] - if CAP_SASL.available(tokens) and not self.params.sasl is None: - await self.sasl_auth(self.params.sasl) + tokens: Dict[str, str] = {} + for token in emit.tokens: + key, _, value = token.partition("=") + tokens[key] = value + await CAPContext(self).on_ls(tokens) async def sasl_auth(self, params: SASLParams) -> bool: if (self.sasl_state == SASLResult.NONE and diff --git a/requirements.txt b/requirements.txt index ada0bb5..84423cd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ -ircstates ==0.8.2 +ircstates ==0.8.3 asyncio-throttle ==1.0.1 dataclasses ==0.6