use Server.registerd to detect handshake CAP LS; handle non-handshake CAP LS
This commit is contained in:
parent
2a9d658207
commit
1c7caf9039
4 changed files with 25 additions and 30 deletions
|
@ -55,7 +55,6 @@ class Bot(object):
|
|||
|
||||
await tg.spawn(_write)
|
||||
await tg.spawn(_read)
|
||||
await server.handshake()
|
||||
await tg.spawn(_read_query)
|
||||
|
||||
del self.servers[server.name]
|
||||
|
|
|
@ -83,20 +83,6 @@ class CAPContext(ServerContext):
|
|||
not self.server.params.sasl is None):
|
||||
await self.server.sasl_auth(self.server.params.sasl)
|
||||
|
||||
async def handshake(self) -> bool:
|
||||
# improve this by being able to wait_for Emit objects
|
||||
line = await self.server.wait_for(ResponseOr(
|
||||
Response(
|
||||
"CAP",
|
||||
[ParamAny(), ParamLiteral("LS"), ParamNot(ParamLiteral("*"))]
|
||||
),
|
||||
Numerics(["RPL_WELCOME"])
|
||||
))
|
||||
|
||||
if line.command == "CAP":
|
||||
await self.on_ls(self.server.available_caps)
|
||||
await self.server.send(build("CAP", ["END"]))
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
async def handshake(self):
|
||||
await self.on_ls(self.server.available_caps)
|
||||
await self.server.send(build("CAP", ["END"]))
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import asyncio
|
||||
from ssl import SSLContext
|
||||
from asyncio import Future, PriorityQueue, Queue
|
||||
from typing import List, Optional, Set, Tuple
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
from asyncio_throttle import Throttler
|
||||
from ircstates import Emit
|
||||
|
@ -59,7 +59,9 @@ class Server(IServer):
|
|||
|
||||
self._reader = reader
|
||||
self._writer = writer
|
||||
|
||||
self.params = params
|
||||
await self.handshake()
|
||||
|
||||
async def handshake(self):
|
||||
nickname = self.params.nickname
|
||||
|
@ -70,14 +72,20 @@ class Server(IServer):
|
|||
await self.send(build("NICK", [nickname]))
|
||||
await self.send(build("USER", [username, "0", "*", realname]))
|
||||
|
||||
await CAPContext(self).handshake()
|
||||
|
||||
async def _on_read_emit(self, line: Line, emit: Emit):
|
||||
if emit.command == "001":
|
||||
if emit.command == "001":
|
||||
self.set_throttle(THROTTLE_RATE, THROTTLE_TIME)
|
||||
elif emit.command == "CAP":
|
||||
if emit.subcommand == "NEW":
|
||||
await self._cap_new(emit)
|
||||
|
||||
elif emit.command == "CAP":
|
||||
if emit.subcommand == "NEW":
|
||||
await self._cap_ls(emit)
|
||||
elif (emit.subcommand == "LS" and
|
||||
emit.finished):
|
||||
if not self.registered:
|
||||
await CAPContext(self).handshake()
|
||||
else:
|
||||
await self._cap_ls(emit)
|
||||
|
||||
elif emit.command == "JOIN":
|
||||
if emit.self and not emit.channel is None:
|
||||
await self.send(build("MODE", [emit.channel.name]))
|
||||
|
@ -145,11 +153,13 @@ class Server(IServer):
|
|||
def cap_available(self, capability: ICapability) -> Optional[str]:
|
||||
return capability.available(self.agreed_caps)
|
||||
|
||||
async def _cap_new(self, emit: Emit):
|
||||
async def _cap_ls(self, emit: Emit):
|
||||
if not emit.tokens is None:
|
||||
tokens = [t.split("=", 1)[0] for t in emit.tokens]
|
||||
if CAP_SASL.available(tokens) and not self.params.sasl is None:
|
||||
await self.sasl_auth(self.params.sasl)
|
||||
tokens: Dict[str, str] = {}
|
||||
for token in emit.tokens:
|
||||
key, _, value = token.partition("=")
|
||||
tokens[key] = value
|
||||
await CAPContext(self).on_ls(tokens)
|
||||
|
||||
async def sasl_auth(self, params: SASLParams) -> bool:
|
||||
if (self.sasl_state == SASLResult.NONE and
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
ircstates ==0.8.2
|
||||
ircstates ==0.8.3
|
||||
asyncio-throttle ==1.0.1
|
||||
dataclasses ==0.6
|
||||
|
|
Loading…
Reference in a new issue