make handshake CAP dance happen in one async task. move to ircv3.py

This commit is contained in:
jesopo 2020-04-02 20:16:07 +01:00
parent 1b3c537e0a
commit 06a4d20fc8
6 changed files with 98 additions and 42 deletions

View file

@ -1,3 +1,4 @@
from .bot import Bot
from .server import Server
from .params import ConnectionParams, SASLUserPass, SASLExternal
from .ircv3 import Capability

View file

@ -56,6 +56,7 @@ class Bot(object):
print(e)
await tg.cancel_scope.cancel()
await tg.spawn(server.handshake)
await tg.spawn(_read)
await tg.spawn(_write)
del self.servers[server.name]

View file

@ -1,10 +1,9 @@
from typing import Awaitable
from typing import Awaitable, Iterable, List, Optional
from enum import IntEnum
from ircstates import Server
from irctokens import Line
from .ircv3 import Capability
from .matching import BaseResponse
from .params import ConnectionParams
@ -21,6 +20,16 @@ class PriorityLine(object):
def __lt__(self, other: "PriorityLine") -> bool:
return self.priority < other.priority
class ICapability(object):
def available(self, capabilities: Iterable[str]) -> Optional[str]:
pass
def match(self, capability: str) -> Optional[str]:
pass
def copy(self) -> "ICapability":
pass
class IServer(Server):
params: ConnectionParams
@ -38,8 +47,19 @@ class IServer(Server):
async def connect(self, params: ConnectionParams):
pass
async def queue_capability(self, cap: Capability):
async def queue_capability(self, cap: ICapability):
pass
async def line_written(self, line: Line):
pass
def cap_agreed(self, capability: ICapability) -> bool:
pass
def cap_available(self, capability: ICapability) -> Optional[str]:
pass
def collect_caps(self) -> List[str]:
pass
async def maybe_sasl(self) -> bool:
pass

View file

@ -1,6 +1,11 @@
from typing import Iterable, List, Optional
from typing import Iterable, List, Optional
from irctokens import build
class Capability(object):
from .contexts import ServerContext
from .matching import Response, ResponseOr, ParamAny, ParamNot, ParamLiteral
from .interface import ICapability
class Capability(ICapability):
def __init__(self,
ratified_name: Optional[str],
draft_name: Optional[str]=None,
@ -30,7 +35,7 @@ class Capability(object):
depends_on=self.depends_on[:])
CAP_SASL = Capability("sasl")
CAPS = [
CAPS: List[ICapability] = [
Capability("multi-prefix"),
Capability("chghost"),
Capability("away-notify"),
@ -48,3 +53,37 @@ CAPS = [
Capability(None, "draft/rename", alias="rename"),
Capability("setname", "draft/setname")
]
class CAPContext(ServerContext):
async def handshake(self) -> bool:
# improve this by being able to wait_for Emit objects
line = await self.server.wait_for(Response(
"CAP",
[ParamAny(), ParamLiteral("LS"), ParamNot(ParamLiteral("*"))],
errors=["001"]))
if line.command == "CAP":
caps = self.server.collect_caps()
if caps:
await self.server.send(build("CAP",
["REQ", " ".join(caps)]))
while caps:
line = await self.server.wait_for(ResponseOr(
Response("CAP", [ParamAny(), ParamLiteral("ACK")]),
Response("CAP", [ParamAny(), ParamLiteral("NAK")])
))
current_caps = line.params[2].split(" ")
for cap in current_caps:
if cap in caps:
caps.remove(cap)
if self.server.cap_agreed(CAP_SASL):
await self.server.maybe_sasl()
await self.server.send(build("CAP", ["END"]))
return True
else:
return False

View file

@ -39,15 +39,25 @@ class Response(BaseResponse):
else:
return False
class ResponseOr(BaseResponse):
def __init__(self, *responses: BaseResponse):
self._responses = responses
def match(self, line: Line) -> bool:
for response in self._responses:
if response.match(line):
return True
else:
return False
class ParamAny(ResponseParam):
def match(self, arg: str) -> bool:
return True
class Literal(ResponseParam):
class ParamLiteral(ResponseParam):
def __init__(self, value: str):
self._value = value
def match(self, arg: str) -> bool:
return self._value == arg
class Not(ResponseParam):
class ParamNot(ResponseParam):
def __init__(self, param: ResponseParam):
self._param = param
def match(self, arg: str) -> bool:

View file

@ -6,8 +6,9 @@ from asyncio_throttle import Throttler
from ircstates import Emit
from irctokens import build, Line, tokenise
from .ircv3 import Capability, CAPS, CAP_SASL
from .interface import ConnectionParams, IServer, PriorityLine, SendPriority
from .ircv3 import CAPContext, CAPS, CAP_SASL
from .interface import (ConnectionParams, ICapability, IServer, PriorityLine,
SendPriority)
from .matching import BaseResponse
from .sasl import SASLContext, SASLResult
@ -29,9 +30,7 @@ class Server(IServer):
self.sasl_state = SASLResult.NONE
self._write_queue: PriorityQueue[PriorityLine] = PriorityQueue()
self._cap_queue: Set[Capability] = set([])
self._requested_caps: List[str] = []
self._cap_queue: Set[ICapability] = set([])
self._wait_for: List[Tuple[BaseResponse, Future]] = []
@ -50,24 +49,22 @@ class Server(IServer):
params.host, params.port, ssl=cur_ssl)
self._reader = reader
self._writer = writer
self.params = params
nickname = params.nickname
username = params.username or nickname
realname = params.realname or nickname
async def handshake(self):
nickname = self.params.nickname
username = self.params.username or nickname
realname = self.params.realname or nickname
await self.send(build("CAP", ["LS"]))
await self.send(build("NICK", [nickname]))
await self.send(build("USER", [username, "0", "*", realname]))
self.params = params
await CAPContext(self).handshake()
async def _on_read_emit(self, line: Line, emit: Emit):
if emit.command == "CAP":
if emit.subcommand == "LS" and emit.finished:
await self._cap_ls_done()
elif emit.subcommand in ["ACK", "NAK"]:
await self._cap_ack(emit)
elif emit.subcommand == "NEW":
if emit.subcommand == "NEW":
await self._cap_new(emit)
async def _on_read_line(self, line: Line):
@ -107,43 +104,31 @@ class Server(IServer):
return lines
# CAP-related
async def queue_capability(self, cap: Capability):
async def queue_capability(self, cap: ICapability):
self._cap_queue.add(cap)
def cap_agreed(self, capability: Capability) -> bool:
def cap_agreed(self, capability: ICapability) -> bool:
return bool(self.cap_available(capability))
def cap_available(self, capability: Capability) -> Optional[str]:
def cap_available(self, capability: ICapability) -> Optional[str]:
return capability.available(self.agreed_caps)
async def _cap_ls_done(self):
def collect_caps(self) -> List[str]:
caps = CAPS+list(self._cap_queue)
self._cap_queue.clear()
if not self.params.sasl is None:
caps.append(CAP_SASL)
matches = list(filter(bool,
(c.available(self.available_caps) for c in caps)))
if matches:
self._requested_caps = matches
await self.send(build("CAP", ["REQ", " ".join(matches)]))
async def _cap_ack(self, emit: Emit):
await self._maybe_sasl()
for cap in (emit.tokens or []):
if cap in self._requested_caps:
self._requested_caps.remove(cap)
if not self._requested_caps:
await self.send(build("CAP", ["END"]))
matched = [c.available(self.available_caps) for c in caps]
return [name for name in matched if not name is None]
async def _cap_new(self, emit: Emit):
if not emit.tokens is None:
tokens = [t.split("=", 1)[0] for t in emit.tokens]
if CAP_SASL.available(tokens):
await self._maybe_sasl()
await self.maybe_sasl()
async def _maybe_sasl(self) -> bool:
async def maybe_sasl(self) -> bool:
if (self.sasl_state == SASLResult.NONE and
not self.params.sasl is None and
self.cap_agreed(CAP_SASL)):