implement STS policies; we're not parsing them yet though

This commit is contained in:
jesopo 2020-04-19 21:51:33 +01:00
parent de4ba754e7
commit 1bfe06d2ea
5 changed files with 37 additions and 7 deletions

View file

@ -1,4 +1,5 @@
from .bot import Bot from .bot import Bot
from .server import Server from .server import Server
from .params import ConnectionParams, SASLUserPass, SASLExternal, SASLSCRAM from .params import (ConnectionParams, SASLUserPass, SASLExternal, SASLSCRAM,
STSPolicy)
from .ircv3 import Capability from .ircv3 import Capability

View file

@ -23,7 +23,7 @@ class Bot(object):
async def add_server(self, name: str, params: ConnectionParams) -> Server: async def add_server(self, name: str, params: ConnectionParams) -> Server:
server = self.create_server(name) server = self.create_server(name)
self.servers[name] = server self.servers[name] = server
await server.connect(TCPTransport() ,params) await server.connect(TCPTransport(), params)
await self._server_queue.put(server) await self._server_queue.put(server)
return server return server

View file

@ -1,9 +1,12 @@
from typing import Dict, Iterable, List, Optional from time import time
from irctokens import build from typing import Dict, Iterable, List, Optional, Tuple
from dataclasses import dataclass
from irctokens import build
from .contexts import ServerContext from .contexts import ServerContext
from .matching import Response, ResponseOr, ParamAny, ParamLiteral from .matching import Response, ResponseOr, ParamAny, ParamLiteral
from .interface import ICapability from .interface import ICapability
from .params import STSPolicy
class Capability(ICapability): class Capability(ICapability):
def __init__(self, def __init__(self,
@ -93,3 +96,16 @@ class CAPContext(ServerContext):
async def handshake(self): async def handshake(self):
await self.on_ls(self.server.available_caps) await self.on_ls(self.server.available_caps)
await self.server.send(build("CAP", ["END"])) await self.server.send(build("CAP", ["END"]))
class STSContext(ServerContext):
async def transmute(self,
port: int,
tls: bool,
sts: Optional[STSPolicy]) -> Tuple[int, bool]:
if not sts is None:
now = time()
since = (now-sts.created)
if since <= sts.duration:
return sts.port, True
return port, tls

View file

@ -20,6 +20,13 @@ class SASLExternal(SASLParams):
def __init__(self): def __init__(self):
super().__init__("EXTERNAL") super().__init__("EXTERNAL")
@dataclass
class STSPolicy(object):
created: int
port: int
duration: int
preload: bool
@dataclass @dataclass
class ConnectionParams(object): class ConnectionParams(object):
nickname: str nickname: str
@ -34,3 +41,5 @@ class ConnectionParams(object):
password: Optional[str] = None password: Optional[str] = None
tls_verify: bool = True tls_verify: bool = True
sasl: Optional[SASLParams] = None sasl: Optional[SASLParams] = None
sts: Optional[STSPolicy] = None

View file

@ -7,7 +7,8 @@ from ircstates import Emit, Channel
from ircstates.numerics import * from ircstates.numerics import *
from irctokens import build, Line, tokenise from irctokens import build, Line, tokenise
from .ircv3 import CAPContext, CAP_ECHO, CAP_SASL, CAP_LABEL, LABEL_TAG from .ircv3 import (CAPContext, STSContext, CAP_ECHO, CAP_SASL, CAP_LABEL,
LABEL_TAG)
from .sasl import SASLContext, SASLResult from .sasl import SASLContext, SASLResult
from .join_info import WHOContext from .join_info import WHOContext
from .matching import ResponseOr, Responses, Response, ParamAny, ParamFolded from .matching import ResponseOr, Responses, Response, ParamAny, ParamFolded
@ -80,10 +81,13 @@ class Server(IServer):
async def connect(self, async def connect(self,
transport: ITCPTransport, transport: ITCPTransport,
params: ConnectionParams): params: ConnectionParams):
port, tls = await STSContext(self).transmute(
params.port, params.tls, params.sts)
reader, writer = await transport.connect( reader, writer = await transport.connect(
params.host, params.host,
params.port, port,
tls =params.tls, tls =tls,
tls_verify=params.tls_verify, tls_verify=params.tls_verify,
bindhost =params.bindhost) bindhost =params.bindhost)