use Server.registerd to detect handshake CAP LS; handle non-handshake CAP LS

This commit is contained in:
jesopo 2020-04-05 13:18:23 +01:00
parent 2a9d658207
commit 1c7caf9039
4 changed files with 25 additions and 30 deletions

View file

@ -55,7 +55,6 @@ class Bot(object):
await tg.spawn(_write)
await tg.spawn(_read)
await server.handshake()
await tg.spawn(_read_query)
del self.servers[server.name]

View file

@ -83,20 +83,6 @@ class CAPContext(ServerContext):
not self.server.params.sasl is None):
await self.server.sasl_auth(self.server.params.sasl)
async def handshake(self) -> bool:
# improve this by being able to wait_for Emit objects
line = await self.server.wait_for(ResponseOr(
Response(
"CAP",
[ParamAny(), ParamLiteral("LS"), ParamNot(ParamLiteral("*"))]
),
Numerics(["RPL_WELCOME"])
))
if line.command == "CAP":
await self.on_ls(self.server.available_caps)
await self.server.send(build("CAP", ["END"]))
return True
else:
return False
async def handshake(self):
await self.on_ls(self.server.available_caps)
await self.server.send(build("CAP", ["END"]))

View file

@ -1,7 +1,7 @@
import asyncio
from ssl import SSLContext
from asyncio import Future, PriorityQueue, Queue
from typing import List, Optional, Set, Tuple
from typing import Dict, List, Optional, Set, Tuple
from asyncio_throttle import Throttler
from ircstates import Emit
@ -59,7 +59,9 @@ class Server(IServer):
self._reader = reader
self._writer = writer
self.params = params
await self.handshake()
async def handshake(self):
nickname = self.params.nickname
@ -70,14 +72,20 @@ class Server(IServer):
await self.send(build("NICK", [nickname]))
await self.send(build("USER", [username, "0", "*", realname]))
await CAPContext(self).handshake()
async def _on_read_emit(self, line: Line, emit: Emit):
if emit.command == "001":
if emit.command == "001":
self.set_throttle(THROTTLE_RATE, THROTTLE_TIME)
elif emit.command == "CAP":
if emit.subcommand == "NEW":
await self._cap_new(emit)
elif emit.command == "CAP":
if emit.subcommand == "NEW":
await self._cap_ls(emit)
elif (emit.subcommand == "LS" and
emit.finished):
if not self.registered:
await CAPContext(self).handshake()
else:
await self._cap_ls(emit)
elif emit.command == "JOIN":
if emit.self and not emit.channel is None:
await self.send(build("MODE", [emit.channel.name]))
@ -145,11 +153,13 @@ class Server(IServer):
def cap_available(self, capability: ICapability) -> Optional[str]:
return capability.available(self.agreed_caps)
async def _cap_new(self, emit: Emit):
async def _cap_ls(self, emit: Emit):
if not emit.tokens is None:
tokens = [t.split("=", 1)[0] for t in emit.tokens]
if CAP_SASL.available(tokens) and not self.params.sasl is None:
await self.sasl_auth(self.params.sasl)
tokens: Dict[str, str] = {}
for token in emit.tokens:
key, _, value = token.partition("=")
tokens[key] = value
await CAPContext(self).on_ls(tokens)
async def sasl_auth(self, params: SASLParams) -> bool:
if (self.sasl_state == SASLResult.NONE and

View file

@ -1,3 +1,3 @@
ircstates ==0.8.2
ircstates ==0.8.3
asyncio-throttle ==1.0.1
dataclasses ==0.6