support requested disconnects, support STS connection transmutations
This commit is contained in:
parent
1bfe06d2ea
commit
ae01201d39
5 changed files with 121 additions and 51 deletions
|
@ -5,20 +5,28 @@ from typing import Dict
|
|||
|
||||
from .server import ConnectionParams, Server
|
||||
from .transport import TCPTransport
|
||||
from .interface import IBot, IServer
|
||||
|
||||
RECONNECT_DELAY = 10.0 # ten seconds reconnect
|
||||
|
||||
class Bot(object):
|
||||
class Bot(IBot):
|
||||
def __init__(self):
|
||||
self.servers: Dict[str, Server] = {}
|
||||
self._server_queue: asyncio.Queue[Server] = asyncio.Queue()
|
||||
|
||||
# methods designed to be overridden
|
||||
def create_server(self, name: str):
|
||||
return Server(name)
|
||||
async def disconnected(self, server: Server):
|
||||
return Server(self, name)
|
||||
async def disconnected(self, server: IServer):
|
||||
if (server.name in self.servers and
|
||||
server.disconnected):
|
||||
await asyncio.sleep(RECONNECT_DELAY)
|
||||
await self.add_server(server.name, server.params)
|
||||
# /methods designed to be overridden
|
||||
|
||||
async def disconnect(self, server: IServer):
|
||||
await server.disconnect()
|
||||
del self.servers[server.name]
|
||||
|
||||
async def add_server(self, name: str, params: ConnectionParams) -> Server:
|
||||
server = self.create_server(name)
|
||||
|
@ -31,7 +39,9 @@ class Bot(object):
|
|||
async with anyio.create_task_group() as tg:
|
||||
async def _read():
|
||||
while not tg.cancel_scope.cancel_called:
|
||||
line, emits = await server.next_line()
|
||||
both = await server.next_line()
|
||||
if both is None:
|
||||
break
|
||||
await tg.cancel_scope.cancel()
|
||||
|
||||
async def _write():
|
||||
|
@ -42,7 +52,6 @@ class Bot(object):
|
|||
await tg.spawn(_write)
|
||||
await tg.spawn(_read)
|
||||
|
||||
del self.servers[server.name]
|
||||
await self.disconnected(server)
|
||||
|
||||
async def run(self):
|
||||
|
|
|
@ -5,7 +5,7 @@ from enum import IntEnum
|
|||
from ircstates import Server, Emit
|
||||
from irctokens import Line
|
||||
|
||||
from .params import ConnectionParams, SASLParams
|
||||
from .params import ConnectionParams, SASLParams, STSPolicy
|
||||
|
||||
class ITCPReader(object):
|
||||
async def read(self, byte_count: int):
|
||||
|
@ -49,7 +49,7 @@ class ICapability(object):
|
|||
def available(self, capabilities: Iterable[str]) -> Optional[str]:
|
||||
pass
|
||||
|
||||
def match(self, capability: str) -> Optional[str]:
|
||||
def match(self, capability: str) -> bool:
|
||||
pass
|
||||
|
||||
def copy(self) -> "ICapability":
|
||||
|
@ -63,6 +63,7 @@ class IMatchResponseParam(object):
|
|||
pass
|
||||
|
||||
class IServer(Server):
|
||||
disconnected: bool
|
||||
params: ConnectionParams
|
||||
desired_caps: Set[ICapability]
|
||||
|
||||
|
@ -83,13 +84,17 @@ class IServer(Server):
|
|||
transport: ITCPTransport,
|
||||
params: ConnectionParams):
|
||||
pass
|
||||
async def disconnect(self):
|
||||
pass
|
||||
|
||||
async def line_read(self, line: Line):
|
||||
pass
|
||||
async def line_send(self, line: Line):
|
||||
pass
|
||||
def sts_policy(self, sts: STSPolicy):
|
||||
pass
|
||||
|
||||
async def next_line(self) -> Tuple[Line, List[Emit]]:
|
||||
async def next_line(self) -> Optional[Tuple[Line, List[Emit]]]:
|
||||
pass
|
||||
|
||||
def cap_agreed(self, capability: ICapability) -> bool:
|
||||
|
@ -99,3 +104,18 @@ class IServer(Server):
|
|||
|
||||
async def sasl_auth(self, sasl: SASLParams) -> bool:
|
||||
pass
|
||||
|
||||
class IBot(object):
|
||||
def create_server(self, name: str) -> IServer:
|
||||
pass
|
||||
async def disconnected(self, server: IServer):
|
||||
pass
|
||||
|
||||
async def disconnect(self, server: IServer):
|
||||
pass
|
||||
|
||||
async def add_server(self, name: str, params: ConnectionParams) -> IServer:
|
||||
pass
|
||||
|
||||
async def run(self):
|
||||
pass
|
||||
|
|
|
@ -6,7 +6,7 @@ from irctokens import build
|
|||
from .contexts import ServerContext
|
||||
from .matching import Response, ResponseOr, ParamAny, ParamLiteral
|
||||
from .interface import ICapability
|
||||
from .params import STSPolicy
|
||||
from .params import ConnectionParams, STSPolicy
|
||||
|
||||
class Capability(ICapability):
|
||||
def __init__(self,
|
||||
|
@ -19,16 +19,18 @@ class Capability(ICapability):
|
|||
self.alias = alias or ratified_name
|
||||
self.depends_on = depends_on.copy()
|
||||
|
||||
self._caps = set((ratified_name, draft_name))
|
||||
self._caps = [ratified_name, draft_name]
|
||||
|
||||
def match(self, capability: str) -> bool:
|
||||
return capability in self._caps
|
||||
|
||||
def available(self, capabilities: Iterable[str]
|
||||
) -> Optional[str]:
|
||||
match = list(set(capabilities)&self._caps)
|
||||
return match[0] if match else None
|
||||
|
||||
def match(self, capability: str) -> Optional[str]:
|
||||
cap = list(set([capability])&self._caps)
|
||||
return cap[0] if cap else None
|
||||
for cap in self._caps:
|
||||
if not cap is None and cap in capabilities:
|
||||
return cap
|
||||
else:
|
||||
return None
|
||||
|
||||
def copy(self):
|
||||
return Capability(
|
||||
|
@ -40,6 +42,7 @@ class Capability(ICapability):
|
|||
CAP_SASL = Capability("sasl")
|
||||
CAP_ECHO = Capability("echo-message")
|
||||
CAP_LABEL = Capability("labeled-response", "draft/labeled-response-0.2")
|
||||
CAP_STS = Capability("sts", "draft/sts")
|
||||
|
||||
LABEL_TAG = {
|
||||
"draft/labeled-response-0.2": "draft/label",
|
||||
|
@ -65,6 +68,13 @@ CAPS: List[ICapability] = [
|
|||
Capability("setname", "draft/setname")
|
||||
]
|
||||
|
||||
def _cap_dict(s: str) -> Dict[str, str]:
|
||||
d: Dict[str, str] = {}
|
||||
for token in s.split(","):
|
||||
key, _, value = token.partition("=")
|
||||
d[key] = value
|
||||
return d
|
||||
|
||||
class CAPContext(ServerContext):
|
||||
async def on_ls(self, tokens: Dict[str, str]):
|
||||
caps = list(self.server.desired_caps)+CAPS
|
||||
|
@ -94,18 +104,34 @@ class CAPContext(ServerContext):
|
|||
await self.server.sasl_auth(self.server.params.sasl)
|
||||
|
||||
async def handshake(self):
|
||||
cap_sts = CAP_STS.available(self.server.available_caps)
|
||||
if not cap_sts is None:
|
||||
sts_dict = _cap_dict(self.server.available_caps[cap_sts])
|
||||
params = self.server.params
|
||||
if not params.tls:
|
||||
if "port" in sts_dict:
|
||||
params.port = int(sts_dict["port"])
|
||||
params.tls = True
|
||||
|
||||
await self.server.bot.disconnect(self.server)
|
||||
await self.server.bot.add_server(self.server.name, params)
|
||||
return
|
||||
elif "duration" in sts_dict:
|
||||
policy = STSPolicy(
|
||||
int(time()),
|
||||
params.port,
|
||||
int(sts_dict["duration"]),
|
||||
"preload" in sts_dict)
|
||||
self.server.sts_policy(policy)
|
||||
|
||||
await self.on_ls(self.server.available_caps)
|
||||
await self.server.send(build("CAP", ["END"]))
|
||||
|
||||
class STSContext(ServerContext):
|
||||
async def transmute(self,
|
||||
port: int,
|
||||
tls: bool,
|
||||
sts: Optional[STSPolicy]) -> Tuple[int, bool]:
|
||||
if not sts is None:
|
||||
async def transmute(self, params: ConnectionParams):
|
||||
if not params.sts is None and not params.tls:
|
||||
now = time()
|
||||
since = (now-sts.created)
|
||||
if since <= sts.duration:
|
||||
return sts.port, True
|
||||
|
||||
return port, tls
|
||||
since = (now-params.sts.created)
|
||||
if since <= params.sts.duration:
|
||||
params.port = params.sts.port
|
||||
params.tls = True
|
||||
|
|
|
@ -5,6 +5,7 @@ from collections import deque
|
|||
from asyncio_throttle import Throttler
|
||||
from ircstates import Emit, Channel
|
||||
from ircstates.numerics import *
|
||||
from ircstates.server import ServerDisconnectedException
|
||||
from irctokens import build, Line, tokenise
|
||||
|
||||
from .ircv3 import (CAPContext, STSContext, CAP_ECHO, CAP_SASL, CAP_LABEL,
|
||||
|
@ -14,9 +15,9 @@ from .join_info import WHOContext
|
|||
from .matching import ResponseOr, Responses, Response, ParamAny, ParamFolded
|
||||
from .asyncs import MaybeAwait
|
||||
from .struct import Whois
|
||||
|
||||
from .interface import (ConnectionParams, ICapability, IServer, SentLine,
|
||||
SendPriority, SASLParams, IMatchResponse)
|
||||
from .params import ConnectionParams, SASLParams, STSPolicy
|
||||
from .interface import (IBot, ICapability, IServer, SentLine, SendPriority,
|
||||
IMatchResponse)
|
||||
from .interface import ITCPTransport, ITCPReader, ITCPWriter
|
||||
|
||||
THROTTLE_RATE = 4 # lines
|
||||
|
@ -27,15 +28,17 @@ class Server(IServer):
|
|||
_writer: ITCPWriter
|
||||
params: ConnectionParams
|
||||
|
||||
def __init__(self, name: str):
|
||||
def __init__(self, bot: IBot, name: str):
|
||||
super().__init__(name)
|
||||
self.bot = bot
|
||||
|
||||
self.disconnected = False
|
||||
|
||||
self.throttle = Throttler(
|
||||
rate_limit=100, period=THROTTLE_TIME)
|
||||
|
||||
self.sasl_state = SASLResult.NONE
|
||||
|
||||
|
||||
self._sent_count: int = 0
|
||||
self._wait_for: List[Tuple["Future[Line]", IMatchResponse]] = []
|
||||
self._write_queue: PriorityQueue[SentLine] = PriorityQueue()
|
||||
|
@ -81,13 +84,12 @@ class Server(IServer):
|
|||
async def connect(self,
|
||||
transport: ITCPTransport,
|
||||
params: ConnectionParams):
|
||||
port, tls = await STSContext(self).transmute(
|
||||
params.port, params.tls, params.sts)
|
||||
await STSContext(self).transmute(params)
|
||||
|
||||
reader, writer = await transport.connect(
|
||||
params.host,
|
||||
port,
|
||||
tls =tls,
|
||||
params.port,
|
||||
tls =params.tls,
|
||||
tls_verify=params.tls_verify,
|
||||
bindhost =params.bindhost)
|
||||
|
||||
|
@ -96,6 +98,11 @@ class Server(IServer):
|
|||
|
||||
self.params = params
|
||||
await self.handshake()
|
||||
async def disconnect(self):
|
||||
if not self._writer is None:
|
||||
await self._writer.close()
|
||||
self._writer = None
|
||||
self._read_queue.clear()
|
||||
|
||||
async def handshake(self):
|
||||
nickname = self.params.nickname
|
||||
|
@ -133,16 +140,18 @@ class Server(IServer):
|
|||
if line.command == "PING":
|
||||
await self.send(build("PONG", line.params))
|
||||
|
||||
async def line_read(self, line: Line):
|
||||
pass
|
||||
|
||||
async def next_line(self) -> Tuple[Line, List[Emit]]:
|
||||
async def next_line(self) -> Optional[Tuple[Line, List[Emit]]]:
|
||||
if self._read_queue:
|
||||
both = self._read_queue.popleft()
|
||||
else:
|
||||
data = await self._reader.read(1024)
|
||||
while True:
|
||||
data = await self._reader.read(1024)
|
||||
try:
|
||||
lines = self.recv(data)
|
||||
except ServerDisconnectedException:
|
||||
self.disconnected = True
|
||||
return None
|
||||
|
||||
if lines:
|
||||
self._read_queue.extend(lines[1:])
|
||||
both = lines[0]
|
||||
|
@ -161,6 +170,8 @@ class Server(IServer):
|
|||
self._wait_for.append((our_fut, response))
|
||||
while self._wait_for:
|
||||
both = await self.next_line()
|
||||
|
||||
if not both is None:
|
||||
line, emits = both
|
||||
|
||||
for i, (fut, waiting) in enumerate(self._wait_for):
|
||||
|
@ -168,11 +179,11 @@ class Server(IServer):
|
|||
fut.set_result(line)
|
||||
self._wait_for.pop(i)
|
||||
break
|
||||
else:
|
||||
fut.set_result(build(""))
|
||||
|
||||
return await our_fut
|
||||
|
||||
async def line_send(self, line: Line):
|
||||
pass
|
||||
async def _on_write_line(self, line: Line):
|
||||
if (line.command == "PRIVMSG" and
|
||||
not self.cap_agreed(CAP_ECHO)):
|
||||
|
|
|
@ -22,6 +22,10 @@ class TCPWriter(ITCPWriter):
|
|||
async def drain(self):
|
||||
await self._writer.drain()
|
||||
|
||||
async def close(self):
|
||||
self._writer.close()
|
||||
await self._writer.wait_closed()
|
||||
|
||||
class TCPTransport(ITCPTransport):
|
||||
async def connect(self,
|
||||
hostname: str,
|
||||
|
|
Loading…
Reference in a new issue