make handshake CAP dance happen in one async task. move to ircv3.py
This commit is contained in:
parent
1b3c537e0a
commit
06a4d20fc8
6 changed files with 98 additions and 42 deletions
|
@ -1,3 +1,4 @@
|
||||||
from .bot import Bot
|
from .bot import Bot
|
||||||
from .server import Server
|
from .server import Server
|
||||||
from .params import ConnectionParams, SASLUserPass, SASLExternal
|
from .params import ConnectionParams, SASLUserPass, SASLExternal
|
||||||
|
from .ircv3 import Capability
|
||||||
|
|
|
@ -56,6 +56,7 @@ class Bot(object):
|
||||||
print(e)
|
print(e)
|
||||||
await tg.cancel_scope.cancel()
|
await tg.cancel_scope.cancel()
|
||||||
|
|
||||||
|
await tg.spawn(server.handshake)
|
||||||
await tg.spawn(_read)
|
await tg.spawn(_read)
|
||||||
await tg.spawn(_write)
|
await tg.spawn(_write)
|
||||||
del self.servers[server.name]
|
del self.servers[server.name]
|
||||||
|
|
|
@ -1,10 +1,9 @@
|
||||||
from typing import Awaitable
|
from typing import Awaitable, Iterable, List, Optional
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
|
|
||||||
from ircstates import Server
|
from ircstates import Server
|
||||||
from irctokens import Line
|
from irctokens import Line
|
||||||
|
|
||||||
from .ircv3 import Capability
|
|
||||||
from .matching import BaseResponse
|
from .matching import BaseResponse
|
||||||
from .params import ConnectionParams
|
from .params import ConnectionParams
|
||||||
|
|
||||||
|
@ -21,6 +20,16 @@ class PriorityLine(object):
|
||||||
def __lt__(self, other: "PriorityLine") -> bool:
|
def __lt__(self, other: "PriorityLine") -> bool:
|
||||||
return self.priority < other.priority
|
return self.priority < other.priority
|
||||||
|
|
||||||
|
class ICapability(object):
|
||||||
|
def available(self, capabilities: Iterable[str]) -> Optional[str]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def match(self, capability: str) -> Optional[str]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def copy(self) -> "ICapability":
|
||||||
|
pass
|
||||||
|
|
||||||
class IServer(Server):
|
class IServer(Server):
|
||||||
params: ConnectionParams
|
params: ConnectionParams
|
||||||
|
|
||||||
|
@ -38,8 +47,19 @@ class IServer(Server):
|
||||||
async def connect(self, params: ConnectionParams):
|
async def connect(self, params: ConnectionParams):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def queue_capability(self, cap: Capability):
|
async def queue_capability(self, cap: ICapability):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def line_written(self, line: Line):
|
async def line_written(self, line: Line):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def cap_agreed(self, capability: ICapability) -> bool:
|
||||||
|
pass
|
||||||
|
def cap_available(self, capability: ICapability) -> Optional[str]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def collect_caps(self) -> List[str]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def maybe_sasl(self) -> bool:
|
||||||
|
pass
|
||||||
|
|
|
@ -1,6 +1,11 @@
|
||||||
from typing import Iterable, List, Optional
|
from typing import Iterable, List, Optional
|
||||||
|
from irctokens import build
|
||||||
|
|
||||||
class Capability(object):
|
from .contexts import ServerContext
|
||||||
|
from .matching import Response, ResponseOr, ParamAny, ParamNot, ParamLiteral
|
||||||
|
from .interface import ICapability
|
||||||
|
|
||||||
|
class Capability(ICapability):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
ratified_name: Optional[str],
|
ratified_name: Optional[str],
|
||||||
draft_name: Optional[str]=None,
|
draft_name: Optional[str]=None,
|
||||||
|
@ -30,7 +35,7 @@ class Capability(object):
|
||||||
depends_on=self.depends_on[:])
|
depends_on=self.depends_on[:])
|
||||||
|
|
||||||
CAP_SASL = Capability("sasl")
|
CAP_SASL = Capability("sasl")
|
||||||
CAPS = [
|
CAPS: List[ICapability] = [
|
||||||
Capability("multi-prefix"),
|
Capability("multi-prefix"),
|
||||||
Capability("chghost"),
|
Capability("chghost"),
|
||||||
Capability("away-notify"),
|
Capability("away-notify"),
|
||||||
|
@ -48,3 +53,37 @@ CAPS = [
|
||||||
Capability(None, "draft/rename", alias="rename"),
|
Capability(None, "draft/rename", alias="rename"),
|
||||||
Capability("setname", "draft/setname")
|
Capability("setname", "draft/setname")
|
||||||
]
|
]
|
||||||
|
|
||||||
|
class CAPContext(ServerContext):
|
||||||
|
async def handshake(self) -> bool:
|
||||||
|
# improve this by being able to wait_for Emit objects
|
||||||
|
line = await self.server.wait_for(Response(
|
||||||
|
"CAP",
|
||||||
|
[ParamAny(), ParamLiteral("LS"), ParamNot(ParamLiteral("*"))],
|
||||||
|
errors=["001"]))
|
||||||
|
|
||||||
|
if line.command == "CAP":
|
||||||
|
caps = self.server.collect_caps()
|
||||||
|
if caps:
|
||||||
|
await self.server.send(build("CAP",
|
||||||
|
["REQ", " ".join(caps)]))
|
||||||
|
|
||||||
|
while caps:
|
||||||
|
line = await self.server.wait_for(ResponseOr(
|
||||||
|
Response("CAP", [ParamAny(), ParamLiteral("ACK")]),
|
||||||
|
Response("CAP", [ParamAny(), ParamLiteral("NAK")])
|
||||||
|
))
|
||||||
|
|
||||||
|
current_caps = line.params[2].split(" ")
|
||||||
|
for cap in current_caps:
|
||||||
|
if cap in caps:
|
||||||
|
caps.remove(cap)
|
||||||
|
|
||||||
|
if self.server.cap_agreed(CAP_SASL):
|
||||||
|
await self.server.maybe_sasl()
|
||||||
|
|
||||||
|
await self.server.send(build("CAP", ["END"]))
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
|
@ -39,15 +39,25 @@ class Response(BaseResponse):
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
class ResponseOr(BaseResponse):
|
||||||
|
def __init__(self, *responses: BaseResponse):
|
||||||
|
self._responses = responses
|
||||||
|
def match(self, line: Line) -> bool:
|
||||||
|
for response in self._responses:
|
||||||
|
if response.match(line):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
class ParamAny(ResponseParam):
|
class ParamAny(ResponseParam):
|
||||||
def match(self, arg: str) -> bool:
|
def match(self, arg: str) -> bool:
|
||||||
return True
|
return True
|
||||||
class Literal(ResponseParam):
|
class ParamLiteral(ResponseParam):
|
||||||
def __init__(self, value: str):
|
def __init__(self, value: str):
|
||||||
self._value = value
|
self._value = value
|
||||||
def match(self, arg: str) -> bool:
|
def match(self, arg: str) -> bool:
|
||||||
return self._value == arg
|
return self._value == arg
|
||||||
class Not(ResponseParam):
|
class ParamNot(ResponseParam):
|
||||||
def __init__(self, param: ResponseParam):
|
def __init__(self, param: ResponseParam):
|
||||||
self._param = param
|
self._param = param
|
||||||
def match(self, arg: str) -> bool:
|
def match(self, arg: str) -> bool:
|
||||||
|
|
|
@ -6,8 +6,9 @@ from asyncio_throttle import Throttler
|
||||||
from ircstates import Emit
|
from ircstates import Emit
|
||||||
from irctokens import build, Line, tokenise
|
from irctokens import build, Line, tokenise
|
||||||
|
|
||||||
from .ircv3 import Capability, CAPS, CAP_SASL
|
from .ircv3 import CAPContext, CAPS, CAP_SASL
|
||||||
from .interface import ConnectionParams, IServer, PriorityLine, SendPriority
|
from .interface import (ConnectionParams, ICapability, IServer, PriorityLine,
|
||||||
|
SendPriority)
|
||||||
from .matching import BaseResponse
|
from .matching import BaseResponse
|
||||||
from .sasl import SASLContext, SASLResult
|
from .sasl import SASLContext, SASLResult
|
||||||
|
|
||||||
|
@ -29,9 +30,7 @@ class Server(IServer):
|
||||||
self.sasl_state = SASLResult.NONE
|
self.sasl_state = SASLResult.NONE
|
||||||
|
|
||||||
self._write_queue: PriorityQueue[PriorityLine] = PriorityQueue()
|
self._write_queue: PriorityQueue[PriorityLine] = PriorityQueue()
|
||||||
|
self._cap_queue: Set[ICapability] = set([])
|
||||||
self._cap_queue: Set[Capability] = set([])
|
|
||||||
self._requested_caps: List[str] = []
|
|
||||||
|
|
||||||
self._wait_for: List[Tuple[BaseResponse, Future]] = []
|
self._wait_for: List[Tuple[BaseResponse, Future]] = []
|
||||||
|
|
||||||
|
@ -50,24 +49,22 @@ class Server(IServer):
|
||||||
params.host, params.port, ssl=cur_ssl)
|
params.host, params.port, ssl=cur_ssl)
|
||||||
self._reader = reader
|
self._reader = reader
|
||||||
self._writer = writer
|
self._writer = writer
|
||||||
|
self.params = params
|
||||||
|
|
||||||
nickname = params.nickname
|
async def handshake(self):
|
||||||
username = params.username or nickname
|
nickname = self.params.nickname
|
||||||
realname = params.realname or nickname
|
username = self.params.username or nickname
|
||||||
|
realname = self.params.realname or nickname
|
||||||
|
|
||||||
await self.send(build("CAP", ["LS"]))
|
await self.send(build("CAP", ["LS"]))
|
||||||
await self.send(build("NICK", [nickname]))
|
await self.send(build("NICK", [nickname]))
|
||||||
await self.send(build("USER", [username, "0", "*", realname]))
|
await self.send(build("USER", [username, "0", "*", realname]))
|
||||||
|
|
||||||
self.params = params
|
await CAPContext(self).handshake()
|
||||||
|
|
||||||
async def _on_read_emit(self, line: Line, emit: Emit):
|
async def _on_read_emit(self, line: Line, emit: Emit):
|
||||||
if emit.command == "CAP":
|
if emit.command == "CAP":
|
||||||
if emit.subcommand == "LS" and emit.finished:
|
if emit.subcommand == "NEW":
|
||||||
await self._cap_ls_done()
|
|
||||||
elif emit.subcommand in ["ACK", "NAK"]:
|
|
||||||
await self._cap_ack(emit)
|
|
||||||
elif emit.subcommand == "NEW":
|
|
||||||
await self._cap_new(emit)
|
await self._cap_new(emit)
|
||||||
|
|
||||||
async def _on_read_line(self, line: Line):
|
async def _on_read_line(self, line: Line):
|
||||||
|
@ -107,43 +104,31 @@ class Server(IServer):
|
||||||
return lines
|
return lines
|
||||||
|
|
||||||
# CAP-related
|
# CAP-related
|
||||||
async def queue_capability(self, cap: Capability):
|
async def queue_capability(self, cap: ICapability):
|
||||||
self._cap_queue.add(cap)
|
self._cap_queue.add(cap)
|
||||||
|
|
||||||
def cap_agreed(self, capability: Capability) -> bool:
|
def cap_agreed(self, capability: ICapability) -> bool:
|
||||||
return bool(self.cap_available(capability))
|
return bool(self.cap_available(capability))
|
||||||
def cap_available(self, capability: Capability) -> Optional[str]:
|
def cap_available(self, capability: ICapability) -> Optional[str]:
|
||||||
return capability.available(self.agreed_caps)
|
return capability.available(self.agreed_caps)
|
||||||
|
|
||||||
async def _cap_ls_done(self):
|
def collect_caps(self) -> List[str]:
|
||||||
caps = CAPS+list(self._cap_queue)
|
caps = CAPS+list(self._cap_queue)
|
||||||
self._cap_queue.clear()
|
self._cap_queue.clear()
|
||||||
|
|
||||||
if not self.params.sasl is None:
|
if not self.params.sasl is None:
|
||||||
caps.append(CAP_SASL)
|
caps.append(CAP_SASL)
|
||||||
|
|
||||||
matches = list(filter(bool,
|
matched = [c.available(self.available_caps) for c in caps]
|
||||||
(c.available(self.available_caps) for c in caps)))
|
return [name for name in matched if not name is None]
|
||||||
if matches:
|
|
||||||
self._requested_caps = matches
|
|
||||||
await self.send(build("CAP", ["REQ", " ".join(matches)]))
|
|
||||||
|
|
||||||
async def _cap_ack(self, emit: Emit):
|
|
||||||
await self._maybe_sasl()
|
|
||||||
|
|
||||||
for cap in (emit.tokens or []):
|
|
||||||
if cap in self._requested_caps:
|
|
||||||
self._requested_caps.remove(cap)
|
|
||||||
if not self._requested_caps:
|
|
||||||
await self.send(build("CAP", ["END"]))
|
|
||||||
|
|
||||||
async def _cap_new(self, emit: Emit):
|
async def _cap_new(self, emit: Emit):
|
||||||
if not emit.tokens is None:
|
if not emit.tokens is None:
|
||||||
tokens = [t.split("=", 1)[0] for t in emit.tokens]
|
tokens = [t.split("=", 1)[0] for t in emit.tokens]
|
||||||
if CAP_SASL.available(tokens):
|
if CAP_SASL.available(tokens):
|
||||||
await self._maybe_sasl()
|
await self.maybe_sasl()
|
||||||
|
|
||||||
async def _maybe_sasl(self) -> bool:
|
async def maybe_sasl(self) -> bool:
|
||||||
if (self.sasl_state == SASLResult.NONE and
|
if (self.sasl_state == SASLResult.NONE and
|
||||||
not self.params.sasl is None and
|
not self.params.sasl is None and
|
||||||
self.cap_agreed(CAP_SASL)):
|
self.cap_agreed(CAP_SASL)):
|
||||||
|
|
Loading…
Reference in a new issue