implement STS policies; we're not parsing them yet though
This commit is contained in:
parent
de4ba754e7
commit
1bfe06d2ea
5 changed files with 37 additions and 7 deletions
|
@ -1,4 +1,5 @@
|
||||||
from .bot import Bot
|
from .bot import Bot
|
||||||
from .server import Server
|
from .server import Server
|
||||||
from .params import ConnectionParams, SASLUserPass, SASLExternal, SASLSCRAM
|
from .params import (ConnectionParams, SASLUserPass, SASLExternal, SASLSCRAM,
|
||||||
|
STSPolicy)
|
||||||
from .ircv3 import Capability
|
from .ircv3 import Capability
|
||||||
|
|
|
@ -23,7 +23,7 @@ class Bot(object):
|
||||||
async def add_server(self, name: str, params: ConnectionParams) -> Server:
|
async def add_server(self, name: str, params: ConnectionParams) -> Server:
|
||||||
server = self.create_server(name)
|
server = self.create_server(name)
|
||||||
self.servers[name] = server
|
self.servers[name] = server
|
||||||
await server.connect(TCPTransport() ,params)
|
await server.connect(TCPTransport(), params)
|
||||||
await self._server_queue.put(server)
|
await self._server_queue.put(server)
|
||||||
return server
|
return server
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,12 @@
|
||||||
from typing import Dict, Iterable, List, Optional
|
from time import time
|
||||||
from irctokens import build
|
from typing import Dict, Iterable, List, Optional, Tuple
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from irctokens import build
|
||||||
|
|
||||||
from .contexts import ServerContext
|
from .contexts import ServerContext
|
||||||
from .matching import Response, ResponseOr, ParamAny, ParamLiteral
|
from .matching import Response, ResponseOr, ParamAny, ParamLiteral
|
||||||
from .interface import ICapability
|
from .interface import ICapability
|
||||||
|
from .params import STSPolicy
|
||||||
|
|
||||||
class Capability(ICapability):
|
class Capability(ICapability):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -93,3 +96,16 @@ class CAPContext(ServerContext):
|
||||||
async def handshake(self):
|
async def handshake(self):
|
||||||
await self.on_ls(self.server.available_caps)
|
await self.on_ls(self.server.available_caps)
|
||||||
await self.server.send(build("CAP", ["END"]))
|
await self.server.send(build("CAP", ["END"]))
|
||||||
|
|
||||||
|
class STSContext(ServerContext):
|
||||||
|
async def transmute(self,
|
||||||
|
port: int,
|
||||||
|
tls: bool,
|
||||||
|
sts: Optional[STSPolicy]) -> Tuple[int, bool]:
|
||||||
|
if not sts is None:
|
||||||
|
now = time()
|
||||||
|
since = (now-sts.created)
|
||||||
|
if since <= sts.duration:
|
||||||
|
return sts.port, True
|
||||||
|
|
||||||
|
return port, tls
|
||||||
|
|
|
@ -20,6 +20,13 @@ class SASLExternal(SASLParams):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__("EXTERNAL")
|
super().__init__("EXTERNAL")
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class STSPolicy(object):
|
||||||
|
created: int
|
||||||
|
port: int
|
||||||
|
duration: int
|
||||||
|
preload: bool
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ConnectionParams(object):
|
class ConnectionParams(object):
|
||||||
nickname: str
|
nickname: str
|
||||||
|
@ -34,3 +41,5 @@ class ConnectionParams(object):
|
||||||
password: Optional[str] = None
|
password: Optional[str] = None
|
||||||
tls_verify: bool = True
|
tls_verify: bool = True
|
||||||
sasl: Optional[SASLParams] = None
|
sasl: Optional[SASLParams] = None
|
||||||
|
|
||||||
|
sts: Optional[STSPolicy] = None
|
||||||
|
|
|
@ -7,7 +7,8 @@ from ircstates import Emit, Channel
|
||||||
from ircstates.numerics import *
|
from ircstates.numerics import *
|
||||||
from irctokens import build, Line, tokenise
|
from irctokens import build, Line, tokenise
|
||||||
|
|
||||||
from .ircv3 import CAPContext, CAP_ECHO, CAP_SASL, CAP_LABEL, LABEL_TAG
|
from .ircv3 import (CAPContext, STSContext, CAP_ECHO, CAP_SASL, CAP_LABEL,
|
||||||
|
LABEL_TAG)
|
||||||
from .sasl import SASLContext, SASLResult
|
from .sasl import SASLContext, SASLResult
|
||||||
from .join_info import WHOContext
|
from .join_info import WHOContext
|
||||||
from .matching import ResponseOr, Responses, Response, ParamAny, ParamFolded
|
from .matching import ResponseOr, Responses, Response, ParamAny, ParamFolded
|
||||||
|
@ -80,10 +81,13 @@ class Server(IServer):
|
||||||
async def connect(self,
|
async def connect(self,
|
||||||
transport: ITCPTransport,
|
transport: ITCPTransport,
|
||||||
params: ConnectionParams):
|
params: ConnectionParams):
|
||||||
|
port, tls = await STSContext(self).transmute(
|
||||||
|
params.port, params.tls, params.sts)
|
||||||
|
|
||||||
reader, writer = await transport.connect(
|
reader, writer = await transport.connect(
|
||||||
params.host,
|
params.host,
|
||||||
params.port,
|
port,
|
||||||
tls =params.tls,
|
tls =tls,
|
||||||
tls_verify=params.tls_verify,
|
tls_verify=params.tls_verify,
|
||||||
bindhost =params.bindhost)
|
bindhost =params.bindhost)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue