support requested disconnects, support STS connection transmutations

This commit is contained in:
jesopo 2020-04-19 23:04:56 +01:00
parent 1bfe06d2ea
commit ae01201d39
5 changed files with 121 additions and 51 deletions

View file

@ -5,20 +5,28 @@ from typing import Dict
from .server import ConnectionParams, Server from .server import ConnectionParams, Server
from .transport import TCPTransport from .transport import TCPTransport
from .interface import IBot, IServer
RECONNECT_DELAY = 10.0 # ten seconds reconnect RECONNECT_DELAY = 10.0 # ten seconds reconnect
class Bot(object): class Bot(IBot):
def __init__(self): def __init__(self):
self.servers: Dict[str, Server] = {} self.servers: Dict[str, Server] = {}
self._server_queue: asyncio.Queue[Server] = asyncio.Queue() self._server_queue: asyncio.Queue[Server] = asyncio.Queue()
# methods designed to be overridden # methods designed to be overridden
def create_server(self, name: str): def create_server(self, name: str):
return Server(name) return Server(self, name)
async def disconnected(self, server: Server): async def disconnected(self, server: IServer):
await asyncio.sleep(RECONNECT_DELAY) if (server.name in self.servers and
await self.add_server(server.name, server.params) server.disconnected):
await asyncio.sleep(RECONNECT_DELAY)
await self.add_server(server.name, server.params)
# /methods designed to be overridden
async def disconnect(self, server: IServer):
await server.disconnect()
del self.servers[server.name]
async def add_server(self, name: str, params: ConnectionParams) -> Server: async def add_server(self, name: str, params: ConnectionParams) -> Server:
server = self.create_server(name) server = self.create_server(name)
@ -31,7 +39,9 @@ class Bot(object):
async with anyio.create_task_group() as tg: async with anyio.create_task_group() as tg:
async def _read(): async def _read():
while not tg.cancel_scope.cancel_called: while not tg.cancel_scope.cancel_called:
line, emits = await server.next_line() both = await server.next_line()
if both is None:
break
await tg.cancel_scope.cancel() await tg.cancel_scope.cancel()
async def _write(): async def _write():
@ -42,7 +52,6 @@ class Bot(object):
await tg.spawn(_write) await tg.spawn(_write)
await tg.spawn(_read) await tg.spawn(_read)
del self.servers[server.name]
await self.disconnected(server) await self.disconnected(server)
async def run(self): async def run(self):

View file

@ -5,7 +5,7 @@ from enum import IntEnum
from ircstates import Server, Emit from ircstates import Server, Emit
from irctokens import Line from irctokens import Line
from .params import ConnectionParams, SASLParams from .params import ConnectionParams, SASLParams, STSPolicy
class ITCPReader(object): class ITCPReader(object):
async def read(self, byte_count: int): async def read(self, byte_count: int):
@ -49,7 +49,7 @@ class ICapability(object):
def available(self, capabilities: Iterable[str]) -> Optional[str]: def available(self, capabilities: Iterable[str]) -> Optional[str]:
pass pass
def match(self, capability: str) -> Optional[str]: def match(self, capability: str) -> bool:
pass pass
def copy(self) -> "ICapability": def copy(self) -> "ICapability":
@ -63,6 +63,7 @@ class IMatchResponseParam(object):
pass pass
class IServer(Server): class IServer(Server):
disconnected: bool
params: ConnectionParams params: ConnectionParams
desired_caps: Set[ICapability] desired_caps: Set[ICapability]
@ -83,13 +84,17 @@ class IServer(Server):
transport: ITCPTransport, transport: ITCPTransport,
params: ConnectionParams): params: ConnectionParams):
pass pass
async def disconnect(self):
pass
async def line_read(self, line: Line): async def line_read(self, line: Line):
pass pass
async def line_send(self, line: Line): async def line_send(self, line: Line):
pass pass
def sts_policy(self, sts: STSPolicy):
pass
async def next_line(self) -> Tuple[Line, List[Emit]]: async def next_line(self) -> Optional[Tuple[Line, List[Emit]]]:
pass pass
def cap_agreed(self, capability: ICapability) -> bool: def cap_agreed(self, capability: ICapability) -> bool:
@ -99,3 +104,18 @@ class IServer(Server):
async def sasl_auth(self, sasl: SASLParams) -> bool: async def sasl_auth(self, sasl: SASLParams) -> bool:
pass pass
class IBot(object):
def create_server(self, name: str) -> IServer:
pass
async def disconnected(self, server: IServer):
pass
async def disconnect(self, server: IServer):
pass
async def add_server(self, name: str, params: ConnectionParams) -> IServer:
pass
async def run(self):
pass

View file

@ -6,7 +6,7 @@ from irctokens import build
from .contexts import ServerContext from .contexts import ServerContext
from .matching import Response, ResponseOr, ParamAny, ParamLiteral from .matching import Response, ResponseOr, ParamAny, ParamLiteral
from .interface import ICapability from .interface import ICapability
from .params import STSPolicy from .params import ConnectionParams, STSPolicy
class Capability(ICapability): class Capability(ICapability):
def __init__(self, def __init__(self,
@ -19,16 +19,18 @@ class Capability(ICapability):
self.alias = alias or ratified_name self.alias = alias or ratified_name
self.depends_on = depends_on.copy() self.depends_on = depends_on.copy()
self._caps = set((ratified_name, draft_name)) self._caps = [ratified_name, draft_name]
def match(self, capability: str) -> bool:
return capability in self._caps
def available(self, capabilities: Iterable[str] def available(self, capabilities: Iterable[str]
) -> Optional[str]: ) -> Optional[str]:
match = list(set(capabilities)&self._caps) for cap in self._caps:
return match[0] if match else None if not cap is None and cap in capabilities:
return cap
def match(self, capability: str) -> Optional[str]: else:
cap = list(set([capability])&self._caps) return None
return cap[0] if cap else None
def copy(self): def copy(self):
return Capability( return Capability(
@ -40,6 +42,7 @@ class Capability(ICapability):
CAP_SASL = Capability("sasl") CAP_SASL = Capability("sasl")
CAP_ECHO = Capability("echo-message") CAP_ECHO = Capability("echo-message")
CAP_LABEL = Capability("labeled-response", "draft/labeled-response-0.2") CAP_LABEL = Capability("labeled-response", "draft/labeled-response-0.2")
CAP_STS = Capability("sts", "draft/sts")
LABEL_TAG = { LABEL_TAG = {
"draft/labeled-response-0.2": "draft/label", "draft/labeled-response-0.2": "draft/label",
@ -65,6 +68,13 @@ CAPS: List[ICapability] = [
Capability("setname", "draft/setname") Capability("setname", "draft/setname")
] ]
def _cap_dict(s: str) -> Dict[str, str]:
d: Dict[str, str] = {}
for token in s.split(","):
key, _, value = token.partition("=")
d[key] = value
return d
class CAPContext(ServerContext): class CAPContext(ServerContext):
async def on_ls(self, tokens: Dict[str, str]): async def on_ls(self, tokens: Dict[str, str]):
caps = list(self.server.desired_caps)+CAPS caps = list(self.server.desired_caps)+CAPS
@ -94,18 +104,34 @@ class CAPContext(ServerContext):
await self.server.sasl_auth(self.server.params.sasl) await self.server.sasl_auth(self.server.params.sasl)
async def handshake(self): async def handshake(self):
cap_sts = CAP_STS.available(self.server.available_caps)
if not cap_sts is None:
sts_dict = _cap_dict(self.server.available_caps[cap_sts])
params = self.server.params
if not params.tls:
if "port" in sts_dict:
params.port = int(sts_dict["port"])
params.tls = True
await self.server.bot.disconnect(self.server)
await self.server.bot.add_server(self.server.name, params)
return
elif "duration" in sts_dict:
policy = STSPolicy(
int(time()),
params.port,
int(sts_dict["duration"]),
"preload" in sts_dict)
self.server.sts_policy(policy)
await self.on_ls(self.server.available_caps) await self.on_ls(self.server.available_caps)
await self.server.send(build("CAP", ["END"])) await self.server.send(build("CAP", ["END"]))
class STSContext(ServerContext): class STSContext(ServerContext):
async def transmute(self, async def transmute(self, params: ConnectionParams):
port: int, if not params.sts is None and not params.tls:
tls: bool,
sts: Optional[STSPolicy]) -> Tuple[int, bool]:
if not sts is None:
now = time() now = time()
since = (now-sts.created) since = (now-params.sts.created)
if since <= sts.duration: if since <= params.sts.duration:
return sts.port, True params.port = params.sts.port
params.tls = True
return port, tls

View file

@ -5,6 +5,7 @@ from collections import deque
from asyncio_throttle import Throttler from asyncio_throttle import Throttler
from ircstates import Emit, Channel from ircstates import Emit, Channel
from ircstates.numerics import * from ircstates.numerics import *
from ircstates.server import ServerDisconnectedException
from irctokens import build, Line, tokenise from irctokens import build, Line, tokenise
from .ircv3 import (CAPContext, STSContext, CAP_ECHO, CAP_SASL, CAP_LABEL, from .ircv3 import (CAPContext, STSContext, CAP_ECHO, CAP_SASL, CAP_LABEL,
@ -14,9 +15,9 @@ from .join_info import WHOContext
from .matching import ResponseOr, Responses, Response, ParamAny, ParamFolded from .matching import ResponseOr, Responses, Response, ParamAny, ParamFolded
from .asyncs import MaybeAwait from .asyncs import MaybeAwait
from .struct import Whois from .struct import Whois
from .params import ConnectionParams, SASLParams, STSPolicy
from .interface import (ConnectionParams, ICapability, IServer, SentLine, from .interface import (IBot, ICapability, IServer, SentLine, SendPriority,
SendPriority, SASLParams, IMatchResponse) IMatchResponse)
from .interface import ITCPTransport, ITCPReader, ITCPWriter from .interface import ITCPTransport, ITCPReader, ITCPWriter
THROTTLE_RATE = 4 # lines THROTTLE_RATE = 4 # lines
@ -27,15 +28,17 @@ class Server(IServer):
_writer: ITCPWriter _writer: ITCPWriter
params: ConnectionParams params: ConnectionParams
def __init__(self, name: str): def __init__(self, bot: IBot, name: str):
super().__init__(name) super().__init__(name)
self.bot = bot
self.disconnected = False
self.throttle = Throttler( self.throttle = Throttler(
rate_limit=100, period=THROTTLE_TIME) rate_limit=100, period=THROTTLE_TIME)
self.sasl_state = SASLResult.NONE self.sasl_state = SASLResult.NONE
self._sent_count: int = 0 self._sent_count: int = 0
self._wait_for: List[Tuple["Future[Line]", IMatchResponse]] = [] self._wait_for: List[Tuple["Future[Line]", IMatchResponse]] = []
self._write_queue: PriorityQueue[SentLine] = PriorityQueue() self._write_queue: PriorityQueue[SentLine] = PriorityQueue()
@ -81,13 +84,12 @@ class Server(IServer):
async def connect(self, async def connect(self,
transport: ITCPTransport, transport: ITCPTransport,
params: ConnectionParams): params: ConnectionParams):
port, tls = await STSContext(self).transmute( await STSContext(self).transmute(params)
params.port, params.tls, params.sts)
reader, writer = await transport.connect( reader, writer = await transport.connect(
params.host, params.host,
port, params.port,
tls =tls, tls =params.tls,
tls_verify=params.tls_verify, tls_verify=params.tls_verify,
bindhost =params.bindhost) bindhost =params.bindhost)
@ -96,6 +98,11 @@ class Server(IServer):
self.params = params self.params = params
await self.handshake() await self.handshake()
async def disconnect(self):
if not self._writer is None:
await self._writer.close()
self._writer = None
self._read_queue.clear()
async def handshake(self): async def handshake(self):
nickname = self.params.nickname nickname = self.params.nickname
@ -133,16 +140,18 @@ class Server(IServer):
if line.command == "PING": if line.command == "PING":
await self.send(build("PONG", line.params)) await self.send(build("PONG", line.params))
async def line_read(self, line: Line): async def next_line(self) -> Optional[Tuple[Line, List[Emit]]]:
pass
async def next_line(self) -> Tuple[Line, List[Emit]]:
if self._read_queue: if self._read_queue:
both = self._read_queue.popleft() both = self._read_queue.popleft()
else: else:
data = await self._reader.read(1024)
while True: while True:
lines = self.recv(data) data = await self._reader.read(1024)
try:
lines = self.recv(data)
except ServerDisconnectedException:
self.disconnected = True
return None
if lines: if lines:
self._read_queue.extend(lines[1:]) self._read_queue.extend(lines[1:])
both = lines[0] both = lines[0]
@ -161,18 +170,20 @@ class Server(IServer):
self._wait_for.append((our_fut, response)) self._wait_for.append((our_fut, response))
while self._wait_for: while self._wait_for:
both = await self.next_line() both = await self.next_line()
line, emits = both
for i, (fut, waiting) in enumerate(self._wait_for): if not both is None:
if waiting.match(self, line): line, emits = both
fut.set_result(line)
self._wait_for.pop(i) for i, (fut, waiting) in enumerate(self._wait_for):
break if waiting.match(self, line):
fut.set_result(line)
self._wait_for.pop(i)
break
else:
fut.set_result(build(""))
return await our_fut return await our_fut
async def line_send(self, line: Line):
pass
async def _on_write_line(self, line: Line): async def _on_write_line(self, line: Line):
if (line.command == "PRIVMSG" and if (line.command == "PRIVMSG" and
not self.cap_agreed(CAP_ECHO)): not self.cap_agreed(CAP_ECHO)):

View file

@ -22,6 +22,10 @@ class TCPWriter(ITCPWriter):
async def drain(self): async def drain(self):
await self._writer.drain() await self._writer.drain()
async def close(self):
self._writer.close()
await self._writer.wait_closed()
class TCPTransport(ITCPTransport): class TCPTransport(ITCPTransport):
async def connect(self, async def connect(self,
hostname: str, hostname: str,