move more CAP related stuff to CAPContext
This commit is contained in:
parent
afe9ec359d
commit
f70932ac44
4 changed files with 36 additions and 46 deletions
|
@ -1,5 +1,5 @@
|
|||
from asyncio import Future
|
||||
from typing import Awaitable, Iterable, List, Optional
|
||||
from typing import Awaitable, Iterable, Set, Optional
|
||||
from enum import IntEnum
|
||||
|
||||
from ircstates import Server
|
||||
|
@ -33,7 +33,8 @@ class ICapability(object):
|
|||
pass
|
||||
|
||||
class IServer(Server):
|
||||
params: ConnectionParams
|
||||
params: ConnectionParams
|
||||
desired_caps: Set[ICapability]
|
||||
|
||||
async def send_raw(self, line: str, priority=SendPriority.DEFAULT):
|
||||
pass
|
||||
|
@ -49,9 +50,6 @@ class IServer(Server):
|
|||
async def connect(self, params: ConnectionParams):
|
||||
pass
|
||||
|
||||
async def queue_capability(self, cap: ICapability):
|
||||
pass
|
||||
|
||||
async def line_read(self, line: Line):
|
||||
pass
|
||||
async def line_send(self, line: Line):
|
||||
|
@ -65,8 +63,5 @@ class IServer(Server):
|
|||
def cap_available(self, capability: ICapability) -> Optional[str]:
|
||||
pass
|
||||
|
||||
def collect_caps(self) -> List[str]:
|
||||
pass
|
||||
|
||||
async def sasl_auth(self, sasl: SASLParams) -> bool:
|
||||
pass
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Iterable, List, Optional
|
||||
from typing import Dict, Iterable, List, Optional
|
||||
from irctokens import build
|
||||
|
||||
from .contexts import ServerContext
|
||||
|
@ -56,6 +56,33 @@ CAPS: List[ICapability] = [
|
|||
]
|
||||
|
||||
class CAPContext(ServerContext):
|
||||
async def on_ls(self, tokens: Dict[str, str]):
|
||||
caps = list(self.server.desired_caps)+CAPS
|
||||
|
||||
if (not self.server.params.sasl is None and
|
||||
not CAP_SASL in caps):
|
||||
caps.append(CAP_SASL)
|
||||
|
||||
matched = (c.available(tokens) for c in caps)
|
||||
cap_names = [name for name in matched if not name is None]
|
||||
|
||||
if cap_names:
|
||||
await self.server.send(build("CAP", ["REQ", " ".join(cap_names)]))
|
||||
|
||||
while cap_names:
|
||||
line = await self.server.wait_for(ResponseOr(
|
||||
Response("CAP", [ParamAny(), ParamLiteral("ACK")]),
|
||||
Response("CAP", [ParamAny(), ParamLiteral("NAK")])
|
||||
))
|
||||
|
||||
current_caps = line.params[2].split(" ")
|
||||
for cap in current_caps:
|
||||
if cap in cap_names:
|
||||
cap_names.remove(cap)
|
||||
if (self.server.cap_agreed(CAP_SASL) and
|
||||
not self.server.params.sasl is None):
|
||||
await self.server.sasl_auth(self.server.params.sasl)
|
||||
|
||||
async def handshake(self) -> bool:
|
||||
# improve this by being able to wait_for Emit objects
|
||||
line = await self.server.wait_for(ResponseOr(
|
||||
|
@ -67,26 +94,7 @@ class CAPContext(ServerContext):
|
|||
))
|
||||
|
||||
if line.command == "CAP":
|
||||
caps = self.server.collect_caps()
|
||||
if caps:
|
||||
await self.server.send(
|
||||
build("CAP", ["REQ", " ".join(caps)]))
|
||||
|
||||
while caps:
|
||||
line = await self.server.wait_for(ResponseOr(
|
||||
Response("CAP", [ParamAny(), ParamLiteral("ACK")]),
|
||||
Response("CAP", [ParamAny(), ParamLiteral("NAK")])
|
||||
))
|
||||
|
||||
current_caps = line.params[2].split(" ")
|
||||
for cap in current_caps:
|
||||
if cap in caps:
|
||||
caps.remove(cap)
|
||||
|
||||
if (self.server.cap_agreed(CAP_SASL) and
|
||||
not self.server.params.sasl is None):
|
||||
await self.server.sasl_auth(self.server.params.sasl)
|
||||
|
||||
await self.on_ls(self.server.available_caps)
|
||||
await self.server.send(build("CAP", ["END"]))
|
||||
return True
|
||||
else:
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import List, Optional
|
||||
from typing import List
|
||||
from enum import Enum
|
||||
from base64 import b64decode, b64encode
|
||||
from irctokens import build
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
import asyncio
|
||||
from ssl import SSLContext
|
||||
from asyncio import Future, PriorityQueue, Queue
|
||||
from typing import Awaitable, List, Optional, Set, Tuple
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
from asyncio_throttle import Throttler
|
||||
from ircstates import Emit
|
||||
from irctokens import build, Line, tokenise
|
||||
|
||||
from .ircv3 import CAPContext, CAPS, CAP_SASL
|
||||
from .ircv3 import CAPContext, CAP_SASL
|
||||
from .interface import (ConnectionParams, ICapability, IServer, SentLine,
|
||||
SendPriority, SASLParams)
|
||||
from .matching import BaseResponse
|
||||
|
@ -33,7 +33,7 @@ class Server(IServer):
|
|||
self._wait_for_cache: List[Tuple[Line, List[Emit]]] = []
|
||||
self._write_queue: PriorityQueue[SentLine] = PriorityQueue()
|
||||
self._read_queue: Queue[Tuple[Line, List[Emit]]] = Queue()
|
||||
self._cap_queue: Set[ICapability] = set([])
|
||||
self.desired_caps: Set[ICapability] = set([])
|
||||
|
||||
async def send_raw(self, line: str, priority=SendPriority.DEFAULT):
|
||||
await self.send(tokenise(line), priority)
|
||||
|
@ -140,24 +140,11 @@ class Server(IServer):
|
|||
return [l.line for l in lines]
|
||||
|
||||
# CAP-related
|
||||
async def queue_capability(self, cap: ICapability):
|
||||
self._cap_queue.add(cap)
|
||||
|
||||
def cap_agreed(self, capability: ICapability) -> bool:
|
||||
return bool(self.cap_available(capability))
|
||||
def cap_available(self, capability: ICapability) -> Optional[str]:
|
||||
return capability.available(self.agreed_caps)
|
||||
|
||||
def collect_caps(self) -> List[str]:
|
||||
caps = CAPS+list(self._cap_queue)
|
||||
self._cap_queue.clear()
|
||||
|
||||
if not self.params.sasl is None:
|
||||
caps.append(CAP_SASL)
|
||||
|
||||
matched = [c.available(self.available_caps) for c in caps]
|
||||
return [name for name in matched if not name is None]
|
||||
|
||||
async def _cap_new(self, emit: Emit):
|
||||
if not emit.tokens is None:
|
||||
tokens = [t.split("=", 1)[0] for t in emit.tokens]
|
||||
|
|
Loading…
Add table
Reference in a new issue