make handshake CAP dance happen in one async task. move to ircv3.py

This commit is contained in:
jesopo 2020-04-02 20:16:07 +01:00
parent 1b3c537e0a
commit 06a4d20fc8
6 changed files with 98 additions and 42 deletions

View file

@ -1,3 +1,4 @@
from .bot import Bot from .bot import Bot
from .server import Server from .server import Server
from .params import ConnectionParams, SASLUserPass, SASLExternal from .params import ConnectionParams, SASLUserPass, SASLExternal
from .ircv3 import Capability

View file

@ -56,6 +56,7 @@ class Bot(object):
print(e) print(e)
await tg.cancel_scope.cancel() await tg.cancel_scope.cancel()
await tg.spawn(server.handshake)
await tg.spawn(_read) await tg.spawn(_read)
await tg.spawn(_write) await tg.spawn(_write)
del self.servers[server.name] del self.servers[server.name]

View file

@ -1,10 +1,9 @@
from typing import Awaitable from typing import Awaitable, Iterable, List, Optional
from enum import IntEnum from enum import IntEnum
from ircstates import Server from ircstates import Server
from irctokens import Line from irctokens import Line
from .ircv3 import Capability
from .matching import BaseResponse from .matching import BaseResponse
from .params import ConnectionParams from .params import ConnectionParams
@ -21,6 +20,16 @@ class PriorityLine(object):
def __lt__(self, other: "PriorityLine") -> bool: def __lt__(self, other: "PriorityLine") -> bool:
return self.priority < other.priority return self.priority < other.priority
class ICapability(object):
def available(self, capabilities: Iterable[str]) -> Optional[str]:
pass
def match(self, capability: str) -> Optional[str]:
pass
def copy(self) -> "ICapability":
pass
class IServer(Server): class IServer(Server):
params: ConnectionParams params: ConnectionParams
@ -38,8 +47,19 @@ class IServer(Server):
async def connect(self, params: ConnectionParams): async def connect(self, params: ConnectionParams):
pass pass
async def queue_capability(self, cap: Capability): async def queue_capability(self, cap: ICapability):
pass pass
async def line_written(self, line: Line): async def line_written(self, line: Line):
pass pass
def cap_agreed(self, capability: ICapability) -> bool:
pass
def cap_available(self, capability: ICapability) -> Optional[str]:
pass
def collect_caps(self) -> List[str]:
pass
async def maybe_sasl(self) -> bool:
pass

View file

@ -1,6 +1,11 @@
from typing import Iterable, List, Optional from typing import Iterable, List, Optional
from irctokens import build
class Capability(object): from .contexts import ServerContext
from .matching import Response, ResponseOr, ParamAny, ParamNot, ParamLiteral
from .interface import ICapability
class Capability(ICapability):
def __init__(self, def __init__(self,
ratified_name: Optional[str], ratified_name: Optional[str],
draft_name: Optional[str]=None, draft_name: Optional[str]=None,
@ -30,7 +35,7 @@ class Capability(object):
depends_on=self.depends_on[:]) depends_on=self.depends_on[:])
CAP_SASL = Capability("sasl") CAP_SASL = Capability("sasl")
CAPS = [ CAPS: List[ICapability] = [
Capability("multi-prefix"), Capability("multi-prefix"),
Capability("chghost"), Capability("chghost"),
Capability("away-notify"), Capability("away-notify"),
@ -48,3 +53,37 @@ CAPS = [
Capability(None, "draft/rename", alias="rename"), Capability(None, "draft/rename", alias="rename"),
Capability("setname", "draft/setname") Capability("setname", "draft/setname")
] ]
class CAPContext(ServerContext):
async def handshake(self) -> bool:
# improve this by being able to wait_for Emit objects
line = await self.server.wait_for(Response(
"CAP",
[ParamAny(), ParamLiteral("LS"), ParamNot(ParamLiteral("*"))],
errors=["001"]))
if line.command == "CAP":
caps = self.server.collect_caps()
if caps:
await self.server.send(build("CAP",
["REQ", " ".join(caps)]))
while caps:
line = await self.server.wait_for(ResponseOr(
Response("CAP", [ParamAny(), ParamLiteral("ACK")]),
Response("CAP", [ParamAny(), ParamLiteral("NAK")])
))
current_caps = line.params[2].split(" ")
for cap in current_caps:
if cap in caps:
caps.remove(cap)
if self.server.cap_agreed(CAP_SASL):
await self.server.maybe_sasl()
await self.server.send(build("CAP", ["END"]))
return True
else:
return False

View file

@ -39,15 +39,25 @@ class Response(BaseResponse):
else: else:
return False return False
class ResponseOr(BaseResponse):
def __init__(self, *responses: BaseResponse):
self._responses = responses
def match(self, line: Line) -> bool:
for response in self._responses:
if response.match(line):
return True
else:
return False
class ParamAny(ResponseParam): class ParamAny(ResponseParam):
def match(self, arg: str) -> bool: def match(self, arg: str) -> bool:
return True return True
class Literal(ResponseParam): class ParamLiteral(ResponseParam):
def __init__(self, value: str): def __init__(self, value: str):
self._value = value self._value = value
def match(self, arg: str) -> bool: def match(self, arg: str) -> bool:
return self._value == arg return self._value == arg
class Not(ResponseParam): class ParamNot(ResponseParam):
def __init__(self, param: ResponseParam): def __init__(self, param: ResponseParam):
self._param = param self._param = param
def match(self, arg: str) -> bool: def match(self, arg: str) -> bool:

View file

@ -6,8 +6,9 @@ from asyncio_throttle import Throttler
from ircstates import Emit from ircstates import Emit
from irctokens import build, Line, tokenise from irctokens import build, Line, tokenise
from .ircv3 import Capability, CAPS, CAP_SASL from .ircv3 import CAPContext, CAPS, CAP_SASL
from .interface import ConnectionParams, IServer, PriorityLine, SendPriority from .interface import (ConnectionParams, ICapability, IServer, PriorityLine,
SendPriority)
from .matching import BaseResponse from .matching import BaseResponse
from .sasl import SASLContext, SASLResult from .sasl import SASLContext, SASLResult
@ -29,9 +30,7 @@ class Server(IServer):
self.sasl_state = SASLResult.NONE self.sasl_state = SASLResult.NONE
self._write_queue: PriorityQueue[PriorityLine] = PriorityQueue() self._write_queue: PriorityQueue[PriorityLine] = PriorityQueue()
self._cap_queue: Set[ICapability] = set([])
self._cap_queue: Set[Capability] = set([])
self._requested_caps: List[str] = []
self._wait_for: List[Tuple[BaseResponse, Future]] = [] self._wait_for: List[Tuple[BaseResponse, Future]] = []
@ -50,24 +49,22 @@ class Server(IServer):
params.host, params.port, ssl=cur_ssl) params.host, params.port, ssl=cur_ssl)
self._reader = reader self._reader = reader
self._writer = writer self._writer = writer
self.params = params
nickname = params.nickname async def handshake(self):
username = params.username or nickname nickname = self.params.nickname
realname = params.realname or nickname username = self.params.username or nickname
realname = self.params.realname or nickname
await self.send(build("CAP", ["LS"])) await self.send(build("CAP", ["LS"]))
await self.send(build("NICK", [nickname])) await self.send(build("NICK", [nickname]))
await self.send(build("USER", [username, "0", "*", realname])) await self.send(build("USER", [username, "0", "*", realname]))
self.params = params await CAPContext(self).handshake()
async def _on_read_emit(self, line: Line, emit: Emit): async def _on_read_emit(self, line: Line, emit: Emit):
if emit.command == "CAP": if emit.command == "CAP":
if emit.subcommand == "LS" and emit.finished: if emit.subcommand == "NEW":
await self._cap_ls_done()
elif emit.subcommand in ["ACK", "NAK"]:
await self._cap_ack(emit)
elif emit.subcommand == "NEW":
await self._cap_new(emit) await self._cap_new(emit)
async def _on_read_line(self, line: Line): async def _on_read_line(self, line: Line):
@ -107,43 +104,31 @@ class Server(IServer):
return lines return lines
# CAP-related # CAP-related
async def queue_capability(self, cap: Capability): async def queue_capability(self, cap: ICapability):
self._cap_queue.add(cap) self._cap_queue.add(cap)
def cap_agreed(self, capability: Capability) -> bool: def cap_agreed(self, capability: ICapability) -> bool:
return bool(self.cap_available(capability)) return bool(self.cap_available(capability))
def cap_available(self, capability: Capability) -> Optional[str]: def cap_available(self, capability: ICapability) -> Optional[str]:
return capability.available(self.agreed_caps) return capability.available(self.agreed_caps)
async def _cap_ls_done(self): def collect_caps(self) -> List[str]:
caps = CAPS+list(self._cap_queue) caps = CAPS+list(self._cap_queue)
self._cap_queue.clear() self._cap_queue.clear()
if not self.params.sasl is None: if not self.params.sasl is None:
caps.append(CAP_SASL) caps.append(CAP_SASL)
matches = list(filter(bool, matched = [c.available(self.available_caps) for c in caps]
(c.available(self.available_caps) for c in caps))) return [name for name in matched if not name is None]
if matches:
self._requested_caps = matches
await self.send(build("CAP", ["REQ", " ".join(matches)]))
async def _cap_ack(self, emit: Emit):
await self._maybe_sasl()
for cap in (emit.tokens or []):
if cap in self._requested_caps:
self._requested_caps.remove(cap)
if not self._requested_caps:
await self.send(build("CAP", ["END"]))
async def _cap_new(self, emit: Emit): async def _cap_new(self, emit: Emit):
if not emit.tokens is None: if not emit.tokens is None:
tokens = [t.split("=", 1)[0] for t in emit.tokens] tokens = [t.split("=", 1)[0] for t in emit.tokens]
if CAP_SASL.available(tokens): if CAP_SASL.available(tokens):
await self._maybe_sasl() await self.maybe_sasl()
async def _maybe_sasl(self) -> bool: async def maybe_sasl(self) -> bool:
if (self.sasl_state == SASLResult.NONE and if (self.sasl_state == SASLResult.NONE and
not self.params.sasl is None and not self.params.sasl is None and
self.cap_agreed(CAP_SASL)): self.cap_agreed(CAP_SASL)):