make handshake CAP dance happen in one async task. move to ircv3.py
This commit is contained in:
parent
1b3c537e0a
commit
06a4d20fc8
6 changed files with 98 additions and 42 deletions
|
@ -1,3 +1,4 @@
|
|||
from .bot import Bot
|
||||
from .server import Server
|
||||
from .params import ConnectionParams, SASLUserPass, SASLExternal
|
||||
from .ircv3 import Capability
|
||||
|
|
|
@ -56,6 +56,7 @@ class Bot(object):
|
|||
print(e)
|
||||
await tg.cancel_scope.cancel()
|
||||
|
||||
await tg.spawn(server.handshake)
|
||||
await tg.spawn(_read)
|
||||
await tg.spawn(_write)
|
||||
del self.servers[server.name]
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
from typing import Awaitable
|
||||
from typing import Awaitable, Iterable, List, Optional
|
||||
from enum import IntEnum
|
||||
|
||||
from ircstates import Server
|
||||
from irctokens import Line
|
||||
|
||||
from .ircv3 import Capability
|
||||
from .matching import BaseResponse
|
||||
from .params import ConnectionParams
|
||||
|
||||
|
@ -21,6 +20,16 @@ class PriorityLine(object):
|
|||
def __lt__(self, other: "PriorityLine") -> bool:
|
||||
return self.priority < other.priority
|
||||
|
||||
class ICapability(object):
|
||||
def available(self, capabilities: Iterable[str]) -> Optional[str]:
|
||||
pass
|
||||
|
||||
def match(self, capability: str) -> Optional[str]:
|
||||
pass
|
||||
|
||||
def copy(self) -> "ICapability":
|
||||
pass
|
||||
|
||||
class IServer(Server):
|
||||
params: ConnectionParams
|
||||
|
||||
|
@ -38,8 +47,19 @@ class IServer(Server):
|
|||
async def connect(self, params: ConnectionParams):
|
||||
pass
|
||||
|
||||
async def queue_capability(self, cap: Capability):
|
||||
async def queue_capability(self, cap: ICapability):
|
||||
pass
|
||||
|
||||
async def line_written(self, line: Line):
|
||||
pass
|
||||
|
||||
def cap_agreed(self, capability: ICapability) -> bool:
|
||||
pass
|
||||
def cap_available(self, capability: ICapability) -> Optional[str]:
|
||||
pass
|
||||
|
||||
def collect_caps(self) -> List[str]:
|
||||
pass
|
||||
|
||||
async def maybe_sasl(self) -> bool:
|
||||
pass
|
||||
|
|
|
@ -1,6 +1,11 @@
|
|||
from typing import Iterable, List, Optional
|
||||
from typing import Iterable, List, Optional
|
||||
from irctokens import build
|
||||
|
||||
class Capability(object):
|
||||
from .contexts import ServerContext
|
||||
from .matching import Response, ResponseOr, ParamAny, ParamNot, ParamLiteral
|
||||
from .interface import ICapability
|
||||
|
||||
class Capability(ICapability):
|
||||
def __init__(self,
|
||||
ratified_name: Optional[str],
|
||||
draft_name: Optional[str]=None,
|
||||
|
@ -30,7 +35,7 @@ class Capability(object):
|
|||
depends_on=self.depends_on[:])
|
||||
|
||||
CAP_SASL = Capability("sasl")
|
||||
CAPS = [
|
||||
CAPS: List[ICapability] = [
|
||||
Capability("multi-prefix"),
|
||||
Capability("chghost"),
|
||||
Capability("away-notify"),
|
||||
|
@ -48,3 +53,37 @@ CAPS = [
|
|||
Capability(None, "draft/rename", alias="rename"),
|
||||
Capability("setname", "draft/setname")
|
||||
]
|
||||
|
||||
class CAPContext(ServerContext):
|
||||
async def handshake(self) -> bool:
|
||||
# improve this by being able to wait_for Emit objects
|
||||
line = await self.server.wait_for(Response(
|
||||
"CAP",
|
||||
[ParamAny(), ParamLiteral("LS"), ParamNot(ParamLiteral("*"))],
|
||||
errors=["001"]))
|
||||
|
||||
if line.command == "CAP":
|
||||
caps = self.server.collect_caps()
|
||||
if caps:
|
||||
await self.server.send(build("CAP",
|
||||
["REQ", " ".join(caps)]))
|
||||
|
||||
while caps:
|
||||
line = await self.server.wait_for(ResponseOr(
|
||||
Response("CAP", [ParamAny(), ParamLiteral("ACK")]),
|
||||
Response("CAP", [ParamAny(), ParamLiteral("NAK")])
|
||||
))
|
||||
|
||||
current_caps = line.params[2].split(" ")
|
||||
for cap in current_caps:
|
||||
if cap in caps:
|
||||
caps.remove(cap)
|
||||
|
||||
if self.server.cap_agreed(CAP_SASL):
|
||||
await self.server.maybe_sasl()
|
||||
|
||||
await self.server.send(build("CAP", ["END"]))
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
|
|
@ -39,15 +39,25 @@ class Response(BaseResponse):
|
|||
else:
|
||||
return False
|
||||
|
||||
class ResponseOr(BaseResponse):
|
||||
def __init__(self, *responses: BaseResponse):
|
||||
self._responses = responses
|
||||
def match(self, line: Line) -> bool:
|
||||
for response in self._responses:
|
||||
if response.match(line):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
class ParamAny(ResponseParam):
|
||||
def match(self, arg: str) -> bool:
|
||||
return True
|
||||
class Literal(ResponseParam):
|
||||
class ParamLiteral(ResponseParam):
|
||||
def __init__(self, value: str):
|
||||
self._value = value
|
||||
def match(self, arg: str) -> bool:
|
||||
return self._value == arg
|
||||
class Not(ResponseParam):
|
||||
class ParamNot(ResponseParam):
|
||||
def __init__(self, param: ResponseParam):
|
||||
self._param = param
|
||||
def match(self, arg: str) -> bool:
|
||||
|
|
|
@ -6,8 +6,9 @@ from asyncio_throttle import Throttler
|
|||
from ircstates import Emit
|
||||
from irctokens import build, Line, tokenise
|
||||
|
||||
from .ircv3 import Capability, CAPS, CAP_SASL
|
||||
from .interface import ConnectionParams, IServer, PriorityLine, SendPriority
|
||||
from .ircv3 import CAPContext, CAPS, CAP_SASL
|
||||
from .interface import (ConnectionParams, ICapability, IServer, PriorityLine,
|
||||
SendPriority)
|
||||
from .matching import BaseResponse
|
||||
from .sasl import SASLContext, SASLResult
|
||||
|
||||
|
@ -29,9 +30,7 @@ class Server(IServer):
|
|||
self.sasl_state = SASLResult.NONE
|
||||
|
||||
self._write_queue: PriorityQueue[PriorityLine] = PriorityQueue()
|
||||
|
||||
self._cap_queue: Set[Capability] = set([])
|
||||
self._requested_caps: List[str] = []
|
||||
self._cap_queue: Set[ICapability] = set([])
|
||||
|
||||
self._wait_for: List[Tuple[BaseResponse, Future]] = []
|
||||
|
||||
|
@ -50,24 +49,22 @@ class Server(IServer):
|
|||
params.host, params.port, ssl=cur_ssl)
|
||||
self._reader = reader
|
||||
self._writer = writer
|
||||
self.params = params
|
||||
|
||||
nickname = params.nickname
|
||||
username = params.username or nickname
|
||||
realname = params.realname or nickname
|
||||
async def handshake(self):
|
||||
nickname = self.params.nickname
|
||||
username = self.params.username or nickname
|
||||
realname = self.params.realname or nickname
|
||||
|
||||
await self.send(build("CAP", ["LS"]))
|
||||
await self.send(build("NICK", [nickname]))
|
||||
await self.send(build("USER", [username, "0", "*", realname]))
|
||||
|
||||
self.params = params
|
||||
await CAPContext(self).handshake()
|
||||
|
||||
async def _on_read_emit(self, line: Line, emit: Emit):
|
||||
if emit.command == "CAP":
|
||||
if emit.subcommand == "LS" and emit.finished:
|
||||
await self._cap_ls_done()
|
||||
elif emit.subcommand in ["ACK", "NAK"]:
|
||||
await self._cap_ack(emit)
|
||||
elif emit.subcommand == "NEW":
|
||||
if emit.subcommand == "NEW":
|
||||
await self._cap_new(emit)
|
||||
|
||||
async def _on_read_line(self, line: Line):
|
||||
|
@ -107,43 +104,31 @@ class Server(IServer):
|
|||
return lines
|
||||
|
||||
# CAP-related
|
||||
async def queue_capability(self, cap: Capability):
|
||||
async def queue_capability(self, cap: ICapability):
|
||||
self._cap_queue.add(cap)
|
||||
|
||||
def cap_agreed(self, capability: Capability) -> bool:
|
||||
def cap_agreed(self, capability: ICapability) -> bool:
|
||||
return bool(self.cap_available(capability))
|
||||
def cap_available(self, capability: Capability) -> Optional[str]:
|
||||
def cap_available(self, capability: ICapability) -> Optional[str]:
|
||||
return capability.available(self.agreed_caps)
|
||||
|
||||
async def _cap_ls_done(self):
|
||||
def collect_caps(self) -> List[str]:
|
||||
caps = CAPS+list(self._cap_queue)
|
||||
self._cap_queue.clear()
|
||||
|
||||
if not self.params.sasl is None:
|
||||
caps.append(CAP_SASL)
|
||||
|
||||
matches = list(filter(bool,
|
||||
(c.available(self.available_caps) for c in caps)))
|
||||
if matches:
|
||||
self._requested_caps = matches
|
||||
await self.send(build("CAP", ["REQ", " ".join(matches)]))
|
||||
|
||||
async def _cap_ack(self, emit: Emit):
|
||||
await self._maybe_sasl()
|
||||
|
||||
for cap in (emit.tokens or []):
|
||||
if cap in self._requested_caps:
|
||||
self._requested_caps.remove(cap)
|
||||
if not self._requested_caps:
|
||||
await self.send(build("CAP", ["END"]))
|
||||
matched = [c.available(self.available_caps) for c in caps]
|
||||
return [name for name in matched if not name is None]
|
||||
|
||||
async def _cap_new(self, emit: Emit):
|
||||
if not emit.tokens is None:
|
||||
tokens = [t.split("=", 1)[0] for t in emit.tokens]
|
||||
if CAP_SASL.available(tokens):
|
||||
await self._maybe_sasl()
|
||||
await self.maybe_sasl()
|
||||
|
||||
async def _maybe_sasl(self) -> bool:
|
||||
async def maybe_sasl(self) -> bool:
|
||||
if (self.sasl_state == SASLResult.NONE and
|
||||
not self.params.sasl is None and
|
||||
self.cap_agreed(CAP_SASL)):
|
||||
|
|
Loading…
Reference in a new issue