use Server.registerd to detect handshake CAP LS; handle non-handshake CAP LS
This commit is contained in:
parent
2a9d658207
commit
1c7caf9039
4 changed files with 25 additions and 30 deletions
|
@ -55,7 +55,6 @@ class Bot(object):
|
||||||
|
|
||||||
await tg.spawn(_write)
|
await tg.spawn(_write)
|
||||||
await tg.spawn(_read)
|
await tg.spawn(_read)
|
||||||
await server.handshake()
|
|
||||||
await tg.spawn(_read_query)
|
await tg.spawn(_read_query)
|
||||||
|
|
||||||
del self.servers[server.name]
|
del self.servers[server.name]
|
||||||
|
|
|
@ -83,20 +83,6 @@ class CAPContext(ServerContext):
|
||||||
not self.server.params.sasl is None):
|
not self.server.params.sasl is None):
|
||||||
await self.server.sasl_auth(self.server.params.sasl)
|
await self.server.sasl_auth(self.server.params.sasl)
|
||||||
|
|
||||||
async def handshake(self) -> bool:
|
async def handshake(self):
|
||||||
# improve this by being able to wait_for Emit objects
|
await self.on_ls(self.server.available_caps)
|
||||||
line = await self.server.wait_for(ResponseOr(
|
await self.server.send(build("CAP", ["END"]))
|
||||||
Response(
|
|
||||||
"CAP",
|
|
||||||
[ParamAny(), ParamLiteral("LS"), ParamNot(ParamLiteral("*"))]
|
|
||||||
),
|
|
||||||
Numerics(["RPL_WELCOME"])
|
|
||||||
))
|
|
||||||
|
|
||||||
if line.command == "CAP":
|
|
||||||
await self.on_ls(self.server.available_caps)
|
|
||||||
await self.server.send(build("CAP", ["END"]))
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from ssl import SSLContext
|
from ssl import SSLContext
|
||||||
from asyncio import Future, PriorityQueue, Queue
|
from asyncio import Future, PriorityQueue, Queue
|
||||||
from typing import List, Optional, Set, Tuple
|
from typing import Dict, List, Optional, Set, Tuple
|
||||||
|
|
||||||
from asyncio_throttle import Throttler
|
from asyncio_throttle import Throttler
|
||||||
from ircstates import Emit
|
from ircstates import Emit
|
||||||
|
@ -59,7 +59,9 @@ class Server(IServer):
|
||||||
|
|
||||||
self._reader = reader
|
self._reader = reader
|
||||||
self._writer = writer
|
self._writer = writer
|
||||||
|
|
||||||
self.params = params
|
self.params = params
|
||||||
|
await self.handshake()
|
||||||
|
|
||||||
async def handshake(self):
|
async def handshake(self):
|
||||||
nickname = self.params.nickname
|
nickname = self.params.nickname
|
||||||
|
@ -70,14 +72,20 @@ class Server(IServer):
|
||||||
await self.send(build("NICK", [nickname]))
|
await self.send(build("NICK", [nickname]))
|
||||||
await self.send(build("USER", [username, "0", "*", realname]))
|
await self.send(build("USER", [username, "0", "*", realname]))
|
||||||
|
|
||||||
await CAPContext(self).handshake()
|
|
||||||
|
|
||||||
async def _on_read_emit(self, line: Line, emit: Emit):
|
async def _on_read_emit(self, line: Line, emit: Emit):
|
||||||
if emit.command == "001":
|
if emit.command == "001":
|
||||||
self.set_throttle(THROTTLE_RATE, THROTTLE_TIME)
|
self.set_throttle(THROTTLE_RATE, THROTTLE_TIME)
|
||||||
elif emit.command == "CAP":
|
|
||||||
if emit.subcommand == "NEW":
|
elif emit.command == "CAP":
|
||||||
await self._cap_new(emit)
|
if emit.subcommand == "NEW":
|
||||||
|
await self._cap_ls(emit)
|
||||||
|
elif (emit.subcommand == "LS" and
|
||||||
|
emit.finished):
|
||||||
|
if not self.registered:
|
||||||
|
await CAPContext(self).handshake()
|
||||||
|
else:
|
||||||
|
await self._cap_ls(emit)
|
||||||
|
|
||||||
elif emit.command == "JOIN":
|
elif emit.command == "JOIN":
|
||||||
if emit.self and not emit.channel is None:
|
if emit.self and not emit.channel is None:
|
||||||
await self.send(build("MODE", [emit.channel.name]))
|
await self.send(build("MODE", [emit.channel.name]))
|
||||||
|
@ -145,11 +153,13 @@ class Server(IServer):
|
||||||
def cap_available(self, capability: ICapability) -> Optional[str]:
|
def cap_available(self, capability: ICapability) -> Optional[str]:
|
||||||
return capability.available(self.agreed_caps)
|
return capability.available(self.agreed_caps)
|
||||||
|
|
||||||
async def _cap_new(self, emit: Emit):
|
async def _cap_ls(self, emit: Emit):
|
||||||
if not emit.tokens is None:
|
if not emit.tokens is None:
|
||||||
tokens = [t.split("=", 1)[0] for t in emit.tokens]
|
tokens: Dict[str, str] = {}
|
||||||
if CAP_SASL.available(tokens) and not self.params.sasl is None:
|
for token in emit.tokens:
|
||||||
await self.sasl_auth(self.params.sasl)
|
key, _, value = token.partition("=")
|
||||||
|
tokens[key] = value
|
||||||
|
await CAPContext(self).on_ls(tokens)
|
||||||
|
|
||||||
async def sasl_auth(self, params: SASLParams) -> bool:
|
async def sasl_auth(self, params: SASLParams) -> bool:
|
||||||
if (self.sasl_state == SASLResult.NONE and
|
if (self.sasl_state == SASLResult.NONE and
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
ircstates ==0.8.2
|
ircstates ==0.8.3
|
||||||
asyncio-throttle ==1.0.1
|
asyncio-throttle ==1.0.1
|
||||||
dataclasses ==0.6
|
dataclasses ==0.6
|
||||||
|
|
Loading…
Reference in a new issue