use Server.registerd to detect handshake CAP LS; handle non-handshake CAP LS

This commit is contained in:
jesopo 2020-04-05 13:18:23 +01:00
parent 2a9d658207
commit 1c7caf9039
4 changed files with 25 additions and 30 deletions

View file

@ -55,7 +55,6 @@ class Bot(object):
await tg.spawn(_write) await tg.spawn(_write)
await tg.spawn(_read) await tg.spawn(_read)
await server.handshake()
await tg.spawn(_read_query) await tg.spawn(_read_query)
del self.servers[server.name] del self.servers[server.name]

View file

@ -83,20 +83,6 @@ class CAPContext(ServerContext):
not self.server.params.sasl is None): not self.server.params.sasl is None):
await self.server.sasl_auth(self.server.params.sasl) await self.server.sasl_auth(self.server.params.sasl)
async def handshake(self) -> bool: async def handshake(self):
# improve this by being able to wait_for Emit objects await self.on_ls(self.server.available_caps)
line = await self.server.wait_for(ResponseOr( await self.server.send(build("CAP", ["END"]))
Response(
"CAP",
[ParamAny(), ParamLiteral("LS"), ParamNot(ParamLiteral("*"))]
),
Numerics(["RPL_WELCOME"])
))
if line.command == "CAP":
await self.on_ls(self.server.available_caps)
await self.server.send(build("CAP", ["END"]))
return True
else:
return False

View file

@ -1,7 +1,7 @@
import asyncio import asyncio
from ssl import SSLContext from ssl import SSLContext
from asyncio import Future, PriorityQueue, Queue from asyncio import Future, PriorityQueue, Queue
from typing import List, Optional, Set, Tuple from typing import Dict, List, Optional, Set, Tuple
from asyncio_throttle import Throttler from asyncio_throttle import Throttler
from ircstates import Emit from ircstates import Emit
@ -59,7 +59,9 @@ class Server(IServer):
self._reader = reader self._reader = reader
self._writer = writer self._writer = writer
self.params = params self.params = params
await self.handshake()
async def handshake(self): async def handshake(self):
nickname = self.params.nickname nickname = self.params.nickname
@ -70,14 +72,20 @@ class Server(IServer):
await self.send(build("NICK", [nickname])) await self.send(build("NICK", [nickname]))
await self.send(build("USER", [username, "0", "*", realname])) await self.send(build("USER", [username, "0", "*", realname]))
await CAPContext(self).handshake()
async def _on_read_emit(self, line: Line, emit: Emit): async def _on_read_emit(self, line: Line, emit: Emit):
if emit.command == "001": if emit.command == "001":
self.set_throttle(THROTTLE_RATE, THROTTLE_TIME) self.set_throttle(THROTTLE_RATE, THROTTLE_TIME)
elif emit.command == "CAP":
if emit.subcommand == "NEW": elif emit.command == "CAP":
await self._cap_new(emit) if emit.subcommand == "NEW":
await self._cap_ls(emit)
elif (emit.subcommand == "LS" and
emit.finished):
if not self.registered:
await CAPContext(self).handshake()
else:
await self._cap_ls(emit)
elif emit.command == "JOIN": elif emit.command == "JOIN":
if emit.self and not emit.channel is None: if emit.self and not emit.channel is None:
await self.send(build("MODE", [emit.channel.name])) await self.send(build("MODE", [emit.channel.name]))
@ -145,11 +153,13 @@ class Server(IServer):
def cap_available(self, capability: ICapability) -> Optional[str]: def cap_available(self, capability: ICapability) -> Optional[str]:
return capability.available(self.agreed_caps) return capability.available(self.agreed_caps)
async def _cap_new(self, emit: Emit): async def _cap_ls(self, emit: Emit):
if not emit.tokens is None: if not emit.tokens is None:
tokens = [t.split("=", 1)[0] for t in emit.tokens] tokens: Dict[str, str] = {}
if CAP_SASL.available(tokens) and not self.params.sasl is None: for token in emit.tokens:
await self.sasl_auth(self.params.sasl) key, _, value = token.partition("=")
tokens[key] = value
await CAPContext(self).on_ls(tokens)
async def sasl_auth(self, params: SASLParams) -> bool: async def sasl_auth(self, params: SASLParams) -> bool:
if (self.sasl_state == SASLResult.NONE and if (self.sasl_state == SASLResult.NONE and

View file

@ -1,3 +1,3 @@
ircstates ==0.8.2 ircstates ==0.8.3
asyncio-throttle ==1.0.1 asyncio-throttle ==1.0.1
dataclasses ==0.6 dataclasses ==0.6