implement STS policies; we're not parsing them yet though

This commit is contained in:
jesopo 2020-04-19 21:51:33 +01:00
parent de4ba754e7
commit 1bfe06d2ea
5 changed files with 37 additions and 7 deletions

View file

@ -1,4 +1,5 @@
from .bot import Bot
from .server import Server
from .params import ConnectionParams, SASLUserPass, SASLExternal, SASLSCRAM
from .params import (ConnectionParams, SASLUserPass, SASLExternal, SASLSCRAM,
STSPolicy)
from .ircv3 import Capability

View file

@ -1,9 +1,12 @@
from typing import Dict, Iterable, List, Optional
from time import time
from typing import Dict, Iterable, List, Optional, Tuple
from dataclasses import dataclass
from irctokens import build
from .contexts import ServerContext
from .matching import Response, ResponseOr, ParamAny, ParamLiteral
from .interface import ICapability
from .params import STSPolicy
class Capability(ICapability):
def __init__(self,
@ -93,3 +96,16 @@ class CAPContext(ServerContext):
async def handshake(self):
await self.on_ls(self.server.available_caps)
await self.server.send(build("CAP", ["END"]))
class STSContext(ServerContext):
async def transmute(self,
port: int,
tls: bool,
sts: Optional[STSPolicy]) -> Tuple[int, bool]:
if not sts is None:
now = time()
since = (now-sts.created)
if since <= sts.duration:
return sts.port, True
return port, tls

View file

@ -20,6 +20,13 @@ class SASLExternal(SASLParams):
def __init__(self):
super().__init__("EXTERNAL")
@dataclass
class STSPolicy(object):
created: int
port: int
duration: int
preload: bool
@dataclass
class ConnectionParams(object):
nickname: str
@ -34,3 +41,5 @@ class ConnectionParams(object):
password: Optional[str] = None
tls_verify: bool = True
sasl: Optional[SASLParams] = None
sts: Optional[STSPolicy] = None

View file

@ -7,7 +7,8 @@ from ircstates import Emit, Channel
from ircstates.numerics import *
from irctokens import build, Line, tokenise
from .ircv3 import CAPContext, CAP_ECHO, CAP_SASL, CAP_LABEL, LABEL_TAG
from .ircv3 import (CAPContext, STSContext, CAP_ECHO, CAP_SASL, CAP_LABEL,
LABEL_TAG)
from .sasl import SASLContext, SASLResult
from .join_info import WHOContext
from .matching import ResponseOr, Responses, Response, ParamAny, ParamFolded
@ -80,10 +81,13 @@ class Server(IServer):
async def connect(self,
transport: ITCPTransport,
params: ConnectionParams):
port, tls = await STSContext(self).transmute(
params.port, params.tls, params.sts)
reader, writer = await transport.connect(
params.host,
params.port,
tls =params.tls,
port,
tls =tls,
tls_verify=params.tls_verify,
bindhost =params.bindhost)