support requested disconnects, support STS connection transmutations

This commit is contained in:
jesopo 2020-04-19 23:04:56 +01:00
parent 1bfe06d2ea
commit ae01201d39
5 changed files with 121 additions and 51 deletions

View file

@ -5,20 +5,28 @@ from typing import Dict
from .server import ConnectionParams, Server
from .transport import TCPTransport
from .interface import IBot, IServer
RECONNECT_DELAY = 10.0 # ten seconds reconnect
class Bot(object):
class Bot(IBot):
def __init__(self):
self.servers: Dict[str, Server] = {}
self._server_queue: asyncio.Queue[Server] = asyncio.Queue()
# methods designed to be overridden
def create_server(self, name: str):
return Server(name)
async def disconnected(self, server: Server):
return Server(self, name)
async def disconnected(self, server: IServer):
if (server.name in self.servers and
server.disconnected):
await asyncio.sleep(RECONNECT_DELAY)
await self.add_server(server.name, server.params)
# /methods designed to be overridden
async def disconnect(self, server: IServer):
await server.disconnect()
del self.servers[server.name]
async def add_server(self, name: str, params: ConnectionParams) -> Server:
server = self.create_server(name)
@ -31,7 +39,9 @@ class Bot(object):
async with anyio.create_task_group() as tg:
async def _read():
while not tg.cancel_scope.cancel_called:
line, emits = await server.next_line()
both = await server.next_line()
if both is None:
break
await tg.cancel_scope.cancel()
async def _write():
@ -42,7 +52,6 @@ class Bot(object):
await tg.spawn(_write)
await tg.spawn(_read)
del self.servers[server.name]
await self.disconnected(server)
async def run(self):

View file

@ -5,7 +5,7 @@ from enum import IntEnum
from ircstates import Server, Emit
from irctokens import Line
from .params import ConnectionParams, SASLParams
from .params import ConnectionParams, SASLParams, STSPolicy
class ITCPReader(object):
async def read(self, byte_count: int):
@ -49,7 +49,7 @@ class ICapability(object):
def available(self, capabilities: Iterable[str]) -> Optional[str]:
pass
def match(self, capability: str) -> Optional[str]:
def match(self, capability: str) -> bool:
pass
def copy(self) -> "ICapability":
@ -63,6 +63,7 @@ class IMatchResponseParam(object):
pass
class IServer(Server):
disconnected: bool
params: ConnectionParams
desired_caps: Set[ICapability]
@ -83,13 +84,17 @@ class IServer(Server):
transport: ITCPTransport,
params: ConnectionParams):
pass
async def disconnect(self):
pass
async def line_read(self, line: Line):
pass
async def line_send(self, line: Line):
pass
def sts_policy(self, sts: STSPolicy):
pass
async def next_line(self) -> Tuple[Line, List[Emit]]:
async def next_line(self) -> Optional[Tuple[Line, List[Emit]]]:
pass
def cap_agreed(self, capability: ICapability) -> bool:
@ -99,3 +104,18 @@ class IServer(Server):
async def sasl_auth(self, sasl: SASLParams) -> bool:
pass
class IBot(object):
def create_server(self, name: str) -> IServer:
pass
async def disconnected(self, server: IServer):
pass
async def disconnect(self, server: IServer):
pass
async def add_server(self, name: str, params: ConnectionParams) -> IServer:
pass
async def run(self):
pass

View file

@ -6,7 +6,7 @@ from irctokens import build
from .contexts import ServerContext
from .matching import Response, ResponseOr, ParamAny, ParamLiteral
from .interface import ICapability
from .params import STSPolicy
from .params import ConnectionParams, STSPolicy
class Capability(ICapability):
def __init__(self,
@ -19,16 +19,18 @@ class Capability(ICapability):
self.alias = alias or ratified_name
self.depends_on = depends_on.copy()
self._caps = set((ratified_name, draft_name))
self._caps = [ratified_name, draft_name]
def match(self, capability: str) -> bool:
return capability in self._caps
def available(self, capabilities: Iterable[str]
) -> Optional[str]:
match = list(set(capabilities)&self._caps)
return match[0] if match else None
def match(self, capability: str) -> Optional[str]:
cap = list(set([capability])&self._caps)
return cap[0] if cap else None
for cap in self._caps:
if not cap is None and cap in capabilities:
return cap
else:
return None
def copy(self):
return Capability(
@ -40,6 +42,7 @@ class Capability(ICapability):
CAP_SASL = Capability("sasl")
CAP_ECHO = Capability("echo-message")
CAP_LABEL = Capability("labeled-response", "draft/labeled-response-0.2")
CAP_STS = Capability("sts", "draft/sts")
LABEL_TAG = {
"draft/labeled-response-0.2": "draft/label",
@ -65,6 +68,13 @@ CAPS: List[ICapability] = [
Capability("setname", "draft/setname")
]
def _cap_dict(s: str) -> Dict[str, str]:
d: Dict[str, str] = {}
for token in s.split(","):
key, _, value = token.partition("=")
d[key] = value
return d
class CAPContext(ServerContext):
async def on_ls(self, tokens: Dict[str, str]):
caps = list(self.server.desired_caps)+CAPS
@ -94,18 +104,34 @@ class CAPContext(ServerContext):
await self.server.sasl_auth(self.server.params.sasl)
async def handshake(self):
cap_sts = CAP_STS.available(self.server.available_caps)
if not cap_sts is None:
sts_dict = _cap_dict(self.server.available_caps[cap_sts])
params = self.server.params
if not params.tls:
if "port" in sts_dict:
params.port = int(sts_dict["port"])
params.tls = True
await self.server.bot.disconnect(self.server)
await self.server.bot.add_server(self.server.name, params)
return
elif "duration" in sts_dict:
policy = STSPolicy(
int(time()),
params.port,
int(sts_dict["duration"]),
"preload" in sts_dict)
self.server.sts_policy(policy)
await self.on_ls(self.server.available_caps)
await self.server.send(build("CAP", ["END"]))
class STSContext(ServerContext):
async def transmute(self,
port: int,
tls: bool,
sts: Optional[STSPolicy]) -> Tuple[int, bool]:
if not sts is None:
async def transmute(self, params: ConnectionParams):
if not params.sts is None and not params.tls:
now = time()
since = (now-sts.created)
if since <= sts.duration:
return sts.port, True
return port, tls
since = (now-params.sts.created)
if since <= params.sts.duration:
params.port = params.sts.port
params.tls = True

View file

@ -5,6 +5,7 @@ from collections import deque
from asyncio_throttle import Throttler
from ircstates import Emit, Channel
from ircstates.numerics import *
from ircstates.server import ServerDisconnectedException
from irctokens import build, Line, tokenise
from .ircv3 import (CAPContext, STSContext, CAP_ECHO, CAP_SASL, CAP_LABEL,
@ -14,9 +15,9 @@ from .join_info import WHOContext
from .matching import ResponseOr, Responses, Response, ParamAny, ParamFolded
from .asyncs import MaybeAwait
from .struct import Whois
from .interface import (ConnectionParams, ICapability, IServer, SentLine,
SendPriority, SASLParams, IMatchResponse)
from .params import ConnectionParams, SASLParams, STSPolicy
from .interface import (IBot, ICapability, IServer, SentLine, SendPriority,
IMatchResponse)
from .interface import ITCPTransport, ITCPReader, ITCPWriter
THROTTLE_RATE = 4 # lines
@ -27,15 +28,17 @@ class Server(IServer):
_writer: ITCPWriter
params: ConnectionParams
def __init__(self, name: str):
def __init__(self, bot: IBot, name: str):
super().__init__(name)
self.bot = bot
self.disconnected = False
self.throttle = Throttler(
rate_limit=100, period=THROTTLE_TIME)
self.sasl_state = SASLResult.NONE
self._sent_count: int = 0
self._wait_for: List[Tuple["Future[Line]", IMatchResponse]] = []
self._write_queue: PriorityQueue[SentLine] = PriorityQueue()
@ -81,13 +84,12 @@ class Server(IServer):
async def connect(self,
transport: ITCPTransport,
params: ConnectionParams):
port, tls = await STSContext(self).transmute(
params.port, params.tls, params.sts)
await STSContext(self).transmute(params)
reader, writer = await transport.connect(
params.host,
port,
tls =tls,
params.port,
tls =params.tls,
tls_verify=params.tls_verify,
bindhost =params.bindhost)
@ -96,6 +98,11 @@ class Server(IServer):
self.params = params
await self.handshake()
async def disconnect(self):
if not self._writer is None:
await self._writer.close()
self._writer = None
self._read_queue.clear()
async def handshake(self):
nickname = self.params.nickname
@ -133,16 +140,18 @@ class Server(IServer):
if line.command == "PING":
await self.send(build("PONG", line.params))
async def line_read(self, line: Line):
pass
async def next_line(self) -> Tuple[Line, List[Emit]]:
async def next_line(self) -> Optional[Tuple[Line, List[Emit]]]:
if self._read_queue:
both = self._read_queue.popleft()
else:
data = await self._reader.read(1024)
while True:
data = await self._reader.read(1024)
try:
lines = self.recv(data)
except ServerDisconnectedException:
self.disconnected = True
return None
if lines:
self._read_queue.extend(lines[1:])
both = lines[0]
@ -161,6 +170,8 @@ class Server(IServer):
self._wait_for.append((our_fut, response))
while self._wait_for:
both = await self.next_line()
if not both is None:
line, emits = both
for i, (fut, waiting) in enumerate(self._wait_for):
@ -168,11 +179,11 @@ class Server(IServer):
fut.set_result(line)
self._wait_for.pop(i)
break
else:
fut.set_result(build(""))
return await our_fut
async def line_send(self, line: Line):
pass
async def _on_write_line(self, line: Line):
if (line.command == "PRIVMSG" and
not self.cap_agreed(CAP_ECHO)):

View file

@ -22,6 +22,10 @@ class TCPWriter(ITCPWriter):
async def drain(self):
await self._writer.drain()
async def close(self):
self._writer.close()
await self._writer.wait_closed()
class TCPTransport(ITCPTransport):
async def connect(self,
hostname: str,