support requested disconnects, support STS connection transmutations
This commit is contained in:
parent
1bfe06d2ea
commit
ae01201d39
5 changed files with 121 additions and 51 deletions
|
@ -5,20 +5,28 @@ from typing import Dict
|
||||||
|
|
||||||
from .server import ConnectionParams, Server
|
from .server import ConnectionParams, Server
|
||||||
from .transport import TCPTransport
|
from .transport import TCPTransport
|
||||||
|
from .interface import IBot, IServer
|
||||||
|
|
||||||
RECONNECT_DELAY = 10.0 # ten seconds reconnect
|
RECONNECT_DELAY = 10.0 # ten seconds reconnect
|
||||||
|
|
||||||
class Bot(object):
|
class Bot(IBot):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.servers: Dict[str, Server] = {}
|
self.servers: Dict[str, Server] = {}
|
||||||
self._server_queue: asyncio.Queue[Server] = asyncio.Queue()
|
self._server_queue: asyncio.Queue[Server] = asyncio.Queue()
|
||||||
|
|
||||||
# methods designed to be overridden
|
# methods designed to be overridden
|
||||||
def create_server(self, name: str):
|
def create_server(self, name: str):
|
||||||
return Server(name)
|
return Server(self, name)
|
||||||
async def disconnected(self, server: Server):
|
async def disconnected(self, server: IServer):
|
||||||
|
if (server.name in self.servers and
|
||||||
|
server.disconnected):
|
||||||
await asyncio.sleep(RECONNECT_DELAY)
|
await asyncio.sleep(RECONNECT_DELAY)
|
||||||
await self.add_server(server.name, server.params)
|
await self.add_server(server.name, server.params)
|
||||||
|
# /methods designed to be overridden
|
||||||
|
|
||||||
|
async def disconnect(self, server: IServer):
|
||||||
|
await server.disconnect()
|
||||||
|
del self.servers[server.name]
|
||||||
|
|
||||||
async def add_server(self, name: str, params: ConnectionParams) -> Server:
|
async def add_server(self, name: str, params: ConnectionParams) -> Server:
|
||||||
server = self.create_server(name)
|
server = self.create_server(name)
|
||||||
|
@ -31,7 +39,9 @@ class Bot(object):
|
||||||
async with anyio.create_task_group() as tg:
|
async with anyio.create_task_group() as tg:
|
||||||
async def _read():
|
async def _read():
|
||||||
while not tg.cancel_scope.cancel_called:
|
while not tg.cancel_scope.cancel_called:
|
||||||
line, emits = await server.next_line()
|
both = await server.next_line()
|
||||||
|
if both is None:
|
||||||
|
break
|
||||||
await tg.cancel_scope.cancel()
|
await tg.cancel_scope.cancel()
|
||||||
|
|
||||||
async def _write():
|
async def _write():
|
||||||
|
@ -42,7 +52,6 @@ class Bot(object):
|
||||||
await tg.spawn(_write)
|
await tg.spawn(_write)
|
||||||
await tg.spawn(_read)
|
await tg.spawn(_read)
|
||||||
|
|
||||||
del self.servers[server.name]
|
|
||||||
await self.disconnected(server)
|
await self.disconnected(server)
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
|
|
|
@ -5,7 +5,7 @@ from enum import IntEnum
|
||||||
from ircstates import Server, Emit
|
from ircstates import Server, Emit
|
||||||
from irctokens import Line
|
from irctokens import Line
|
||||||
|
|
||||||
from .params import ConnectionParams, SASLParams
|
from .params import ConnectionParams, SASLParams, STSPolicy
|
||||||
|
|
||||||
class ITCPReader(object):
|
class ITCPReader(object):
|
||||||
async def read(self, byte_count: int):
|
async def read(self, byte_count: int):
|
||||||
|
@ -49,7 +49,7 @@ class ICapability(object):
|
||||||
def available(self, capabilities: Iterable[str]) -> Optional[str]:
|
def available(self, capabilities: Iterable[str]) -> Optional[str]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def match(self, capability: str) -> Optional[str]:
|
def match(self, capability: str) -> bool:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def copy(self) -> "ICapability":
|
def copy(self) -> "ICapability":
|
||||||
|
@ -63,6 +63,7 @@ class IMatchResponseParam(object):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class IServer(Server):
|
class IServer(Server):
|
||||||
|
disconnected: bool
|
||||||
params: ConnectionParams
|
params: ConnectionParams
|
||||||
desired_caps: Set[ICapability]
|
desired_caps: Set[ICapability]
|
||||||
|
|
||||||
|
@ -83,13 +84,17 @@ class IServer(Server):
|
||||||
transport: ITCPTransport,
|
transport: ITCPTransport,
|
||||||
params: ConnectionParams):
|
params: ConnectionParams):
|
||||||
pass
|
pass
|
||||||
|
async def disconnect(self):
|
||||||
|
pass
|
||||||
|
|
||||||
async def line_read(self, line: Line):
|
async def line_read(self, line: Line):
|
||||||
pass
|
pass
|
||||||
async def line_send(self, line: Line):
|
async def line_send(self, line: Line):
|
||||||
pass
|
pass
|
||||||
|
def sts_policy(self, sts: STSPolicy):
|
||||||
|
pass
|
||||||
|
|
||||||
async def next_line(self) -> Tuple[Line, List[Emit]]:
|
async def next_line(self) -> Optional[Tuple[Line, List[Emit]]]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def cap_agreed(self, capability: ICapability) -> bool:
|
def cap_agreed(self, capability: ICapability) -> bool:
|
||||||
|
@ -99,3 +104,18 @@ class IServer(Server):
|
||||||
|
|
||||||
async def sasl_auth(self, sasl: SASLParams) -> bool:
|
async def sasl_auth(self, sasl: SASLParams) -> bool:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
class IBot(object):
|
||||||
|
def create_server(self, name: str) -> IServer:
|
||||||
|
pass
|
||||||
|
async def disconnected(self, server: IServer):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def disconnect(self, server: IServer):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def add_server(self, name: str, params: ConnectionParams) -> IServer:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
pass
|
||||||
|
|
|
@ -6,7 +6,7 @@ from irctokens import build
|
||||||
from .contexts import ServerContext
|
from .contexts import ServerContext
|
||||||
from .matching import Response, ResponseOr, ParamAny, ParamLiteral
|
from .matching import Response, ResponseOr, ParamAny, ParamLiteral
|
||||||
from .interface import ICapability
|
from .interface import ICapability
|
||||||
from .params import STSPolicy
|
from .params import ConnectionParams, STSPolicy
|
||||||
|
|
||||||
class Capability(ICapability):
|
class Capability(ICapability):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -19,16 +19,18 @@ class Capability(ICapability):
|
||||||
self.alias = alias or ratified_name
|
self.alias = alias or ratified_name
|
||||||
self.depends_on = depends_on.copy()
|
self.depends_on = depends_on.copy()
|
||||||
|
|
||||||
self._caps = set((ratified_name, draft_name))
|
self._caps = [ratified_name, draft_name]
|
||||||
|
|
||||||
|
def match(self, capability: str) -> bool:
|
||||||
|
return capability in self._caps
|
||||||
|
|
||||||
def available(self, capabilities: Iterable[str]
|
def available(self, capabilities: Iterable[str]
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
match = list(set(capabilities)&self._caps)
|
for cap in self._caps:
|
||||||
return match[0] if match else None
|
if not cap is None and cap in capabilities:
|
||||||
|
return cap
|
||||||
def match(self, capability: str) -> Optional[str]:
|
else:
|
||||||
cap = list(set([capability])&self._caps)
|
return None
|
||||||
return cap[0] if cap else None
|
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
return Capability(
|
return Capability(
|
||||||
|
@ -40,6 +42,7 @@ class Capability(ICapability):
|
||||||
CAP_SASL = Capability("sasl")
|
CAP_SASL = Capability("sasl")
|
||||||
CAP_ECHO = Capability("echo-message")
|
CAP_ECHO = Capability("echo-message")
|
||||||
CAP_LABEL = Capability("labeled-response", "draft/labeled-response-0.2")
|
CAP_LABEL = Capability("labeled-response", "draft/labeled-response-0.2")
|
||||||
|
CAP_STS = Capability("sts", "draft/sts")
|
||||||
|
|
||||||
LABEL_TAG = {
|
LABEL_TAG = {
|
||||||
"draft/labeled-response-0.2": "draft/label",
|
"draft/labeled-response-0.2": "draft/label",
|
||||||
|
@ -65,6 +68,13 @@ CAPS: List[ICapability] = [
|
||||||
Capability("setname", "draft/setname")
|
Capability("setname", "draft/setname")
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def _cap_dict(s: str) -> Dict[str, str]:
|
||||||
|
d: Dict[str, str] = {}
|
||||||
|
for token in s.split(","):
|
||||||
|
key, _, value = token.partition("=")
|
||||||
|
d[key] = value
|
||||||
|
return d
|
||||||
|
|
||||||
class CAPContext(ServerContext):
|
class CAPContext(ServerContext):
|
||||||
async def on_ls(self, tokens: Dict[str, str]):
|
async def on_ls(self, tokens: Dict[str, str]):
|
||||||
caps = list(self.server.desired_caps)+CAPS
|
caps = list(self.server.desired_caps)+CAPS
|
||||||
|
@ -94,18 +104,34 @@ class CAPContext(ServerContext):
|
||||||
await self.server.sasl_auth(self.server.params.sasl)
|
await self.server.sasl_auth(self.server.params.sasl)
|
||||||
|
|
||||||
async def handshake(self):
|
async def handshake(self):
|
||||||
|
cap_sts = CAP_STS.available(self.server.available_caps)
|
||||||
|
if not cap_sts is None:
|
||||||
|
sts_dict = _cap_dict(self.server.available_caps[cap_sts])
|
||||||
|
params = self.server.params
|
||||||
|
if not params.tls:
|
||||||
|
if "port" in sts_dict:
|
||||||
|
params.port = int(sts_dict["port"])
|
||||||
|
params.tls = True
|
||||||
|
|
||||||
|
await self.server.bot.disconnect(self.server)
|
||||||
|
await self.server.bot.add_server(self.server.name, params)
|
||||||
|
return
|
||||||
|
elif "duration" in sts_dict:
|
||||||
|
policy = STSPolicy(
|
||||||
|
int(time()),
|
||||||
|
params.port,
|
||||||
|
int(sts_dict["duration"]),
|
||||||
|
"preload" in sts_dict)
|
||||||
|
self.server.sts_policy(policy)
|
||||||
|
|
||||||
await self.on_ls(self.server.available_caps)
|
await self.on_ls(self.server.available_caps)
|
||||||
await self.server.send(build("CAP", ["END"]))
|
await self.server.send(build("CAP", ["END"]))
|
||||||
|
|
||||||
class STSContext(ServerContext):
|
class STSContext(ServerContext):
|
||||||
async def transmute(self,
|
async def transmute(self, params: ConnectionParams):
|
||||||
port: int,
|
if not params.sts is None and not params.tls:
|
||||||
tls: bool,
|
|
||||||
sts: Optional[STSPolicy]) -> Tuple[int, bool]:
|
|
||||||
if not sts is None:
|
|
||||||
now = time()
|
now = time()
|
||||||
since = (now-sts.created)
|
since = (now-params.sts.created)
|
||||||
if since <= sts.duration:
|
if since <= params.sts.duration:
|
||||||
return sts.port, True
|
params.port = params.sts.port
|
||||||
|
params.tls = True
|
||||||
return port, tls
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ from collections import deque
|
||||||
from asyncio_throttle import Throttler
|
from asyncio_throttle import Throttler
|
||||||
from ircstates import Emit, Channel
|
from ircstates import Emit, Channel
|
||||||
from ircstates.numerics import *
|
from ircstates.numerics import *
|
||||||
|
from ircstates.server import ServerDisconnectedException
|
||||||
from irctokens import build, Line, tokenise
|
from irctokens import build, Line, tokenise
|
||||||
|
|
||||||
from .ircv3 import (CAPContext, STSContext, CAP_ECHO, CAP_SASL, CAP_LABEL,
|
from .ircv3 import (CAPContext, STSContext, CAP_ECHO, CAP_SASL, CAP_LABEL,
|
||||||
|
@ -14,9 +15,9 @@ from .join_info import WHOContext
|
||||||
from .matching import ResponseOr, Responses, Response, ParamAny, ParamFolded
|
from .matching import ResponseOr, Responses, Response, ParamAny, ParamFolded
|
||||||
from .asyncs import MaybeAwait
|
from .asyncs import MaybeAwait
|
||||||
from .struct import Whois
|
from .struct import Whois
|
||||||
|
from .params import ConnectionParams, SASLParams, STSPolicy
|
||||||
from .interface import (ConnectionParams, ICapability, IServer, SentLine,
|
from .interface import (IBot, ICapability, IServer, SentLine, SendPriority,
|
||||||
SendPriority, SASLParams, IMatchResponse)
|
IMatchResponse)
|
||||||
from .interface import ITCPTransport, ITCPReader, ITCPWriter
|
from .interface import ITCPTransport, ITCPReader, ITCPWriter
|
||||||
|
|
||||||
THROTTLE_RATE = 4 # lines
|
THROTTLE_RATE = 4 # lines
|
||||||
|
@ -27,15 +28,17 @@ class Server(IServer):
|
||||||
_writer: ITCPWriter
|
_writer: ITCPWriter
|
||||||
params: ConnectionParams
|
params: ConnectionParams
|
||||||
|
|
||||||
def __init__(self, name: str):
|
def __init__(self, bot: IBot, name: str):
|
||||||
super().__init__(name)
|
super().__init__(name)
|
||||||
|
self.bot = bot
|
||||||
|
|
||||||
|
self.disconnected = False
|
||||||
|
|
||||||
self.throttle = Throttler(
|
self.throttle = Throttler(
|
||||||
rate_limit=100, period=THROTTLE_TIME)
|
rate_limit=100, period=THROTTLE_TIME)
|
||||||
|
|
||||||
self.sasl_state = SASLResult.NONE
|
self.sasl_state = SASLResult.NONE
|
||||||
|
|
||||||
|
|
||||||
self._sent_count: int = 0
|
self._sent_count: int = 0
|
||||||
self._wait_for: List[Tuple["Future[Line]", IMatchResponse]] = []
|
self._wait_for: List[Tuple["Future[Line]", IMatchResponse]] = []
|
||||||
self._write_queue: PriorityQueue[SentLine] = PriorityQueue()
|
self._write_queue: PriorityQueue[SentLine] = PriorityQueue()
|
||||||
|
@ -81,13 +84,12 @@ class Server(IServer):
|
||||||
async def connect(self,
|
async def connect(self,
|
||||||
transport: ITCPTransport,
|
transport: ITCPTransport,
|
||||||
params: ConnectionParams):
|
params: ConnectionParams):
|
||||||
port, tls = await STSContext(self).transmute(
|
await STSContext(self).transmute(params)
|
||||||
params.port, params.tls, params.sts)
|
|
||||||
|
|
||||||
reader, writer = await transport.connect(
|
reader, writer = await transport.connect(
|
||||||
params.host,
|
params.host,
|
||||||
port,
|
params.port,
|
||||||
tls =tls,
|
tls =params.tls,
|
||||||
tls_verify=params.tls_verify,
|
tls_verify=params.tls_verify,
|
||||||
bindhost =params.bindhost)
|
bindhost =params.bindhost)
|
||||||
|
|
||||||
|
@ -96,6 +98,11 @@ class Server(IServer):
|
||||||
|
|
||||||
self.params = params
|
self.params = params
|
||||||
await self.handshake()
|
await self.handshake()
|
||||||
|
async def disconnect(self):
|
||||||
|
if not self._writer is None:
|
||||||
|
await self._writer.close()
|
||||||
|
self._writer = None
|
||||||
|
self._read_queue.clear()
|
||||||
|
|
||||||
async def handshake(self):
|
async def handshake(self):
|
||||||
nickname = self.params.nickname
|
nickname = self.params.nickname
|
||||||
|
@ -133,16 +140,18 @@ class Server(IServer):
|
||||||
if line.command == "PING":
|
if line.command == "PING":
|
||||||
await self.send(build("PONG", line.params))
|
await self.send(build("PONG", line.params))
|
||||||
|
|
||||||
async def line_read(self, line: Line):
|
async def next_line(self) -> Optional[Tuple[Line, List[Emit]]]:
|
||||||
pass
|
|
||||||
|
|
||||||
async def next_line(self) -> Tuple[Line, List[Emit]]:
|
|
||||||
if self._read_queue:
|
if self._read_queue:
|
||||||
both = self._read_queue.popleft()
|
both = self._read_queue.popleft()
|
||||||
else:
|
else:
|
||||||
data = await self._reader.read(1024)
|
|
||||||
while True:
|
while True:
|
||||||
|
data = await self._reader.read(1024)
|
||||||
|
try:
|
||||||
lines = self.recv(data)
|
lines = self.recv(data)
|
||||||
|
except ServerDisconnectedException:
|
||||||
|
self.disconnected = True
|
||||||
|
return None
|
||||||
|
|
||||||
if lines:
|
if lines:
|
||||||
self._read_queue.extend(lines[1:])
|
self._read_queue.extend(lines[1:])
|
||||||
both = lines[0]
|
both = lines[0]
|
||||||
|
@ -161,6 +170,8 @@ class Server(IServer):
|
||||||
self._wait_for.append((our_fut, response))
|
self._wait_for.append((our_fut, response))
|
||||||
while self._wait_for:
|
while self._wait_for:
|
||||||
both = await self.next_line()
|
both = await self.next_line()
|
||||||
|
|
||||||
|
if not both is None:
|
||||||
line, emits = both
|
line, emits = both
|
||||||
|
|
||||||
for i, (fut, waiting) in enumerate(self._wait_for):
|
for i, (fut, waiting) in enumerate(self._wait_for):
|
||||||
|
@ -168,11 +179,11 @@ class Server(IServer):
|
||||||
fut.set_result(line)
|
fut.set_result(line)
|
||||||
self._wait_for.pop(i)
|
self._wait_for.pop(i)
|
||||||
break
|
break
|
||||||
|
else:
|
||||||
|
fut.set_result(build(""))
|
||||||
|
|
||||||
return await our_fut
|
return await our_fut
|
||||||
|
|
||||||
async def line_send(self, line: Line):
|
|
||||||
pass
|
|
||||||
async def _on_write_line(self, line: Line):
|
async def _on_write_line(self, line: Line):
|
||||||
if (line.command == "PRIVMSG" and
|
if (line.command == "PRIVMSG" and
|
||||||
not self.cap_agreed(CAP_ECHO)):
|
not self.cap_agreed(CAP_ECHO)):
|
||||||
|
|
|
@ -22,6 +22,10 @@ class TCPWriter(ITCPWriter):
|
||||||
async def drain(self):
|
async def drain(self):
|
||||||
await self._writer.drain()
|
await self._writer.drain()
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
self._writer.close()
|
||||||
|
await self._writer.wait_closed()
|
||||||
|
|
||||||
class TCPTransport(ITCPTransport):
|
class TCPTransport(ITCPTransport):
|
||||||
async def connect(self,
|
async def connect(self,
|
||||||
hostname: str,
|
hostname: str,
|
||||||
|
|
Loading…
Add table
Reference in a new issue