combine params.tls and .tls_verify, support pinned certs

This commit is contained in:
jesopo 2022-01-23 16:52:27 +00:00
parent 0a5c774965
commit 5b347f95c9
9 changed files with 85 additions and 34 deletions

View file

@ -154,8 +154,8 @@ async def main(hostname: str, channel: str, nickname: str):
params = ConnectionParams(
nickname,
hostname,
6697,
tls=True)
6697
)
await bot.add_server("freenode", params)
await bot.run()

View file

@ -23,7 +23,6 @@ async def main():
"MyNickname",
host = "chat.freenode.invalid",
port = 6697,
tls = True,
sasl = sasl_params)
await bot.add_server("freenode", params)

View file

@ -25,7 +25,7 @@ class Bot(BaseBot):
async def main():
bot = Bot()
for name, host in SERVERS:
params = ConnectionParams("BitBotNewTest", host, 6697, True)
params = ConnectionParams("BitBotNewTest", host, 6697)
await bot.add_server(name, params)
await bot.run()

View file

@ -6,6 +6,7 @@ from ircstates import Server, Emit
from irctokens import Line, Hostmask
from .params import ConnectionParams, SASLParams, STSPolicy, ResumePolicy
from .security import TLS
class ITCPReader(object):
async def read(self, byte_count: int):
@ -24,11 +25,10 @@ class ITCPWriter(object):
class ITCPTransport(object):
async def connect(self,
hostname: str,
port: int,
tls: bool,
tls_verify: bool=True,
bindhost: Optional[str]=None
hostname: str,
port: int,
tls: Optional[TLS],
bindhost: Optional[str]=None
) -> Tuple[ITCPReader, ITCPWriter]:
pass

View file

@ -8,6 +8,7 @@ from .contexts import ServerContext
from .matching import Response, ANY
from .interface import ICapability
from .params import ConnectionParams, STSPolicy, ResumePolicy
from .security import TLS_VERIFYCHAIN
class Capability(ICapability):
def __init__(self,
@ -101,12 +102,12 @@ def _cap_dict(s: str) -> Dict[str, str]:
return d
async def sts_transmute(params: ConnectionParams):
if not params.sts is None and not params.tls:
if not params.sts is None and params.tls is None:
now = time()
since = (now-params.sts.created)
if since <= params.sts.duration:
params.port = params.sts.port
params.tls = True
params.tls = TLS_VERIFYCHAIN
async def resume_transmute(params: ConnectionParams):
if params.resume is not None:
params.host = params.resume.address
@ -182,7 +183,7 @@ class CAPContext(ServerContext):
if not params.tls:
if "port" in sts_dict:
params.port = int(sts_dict["port"])
params.tls = True
params.tls = TLS_VERIFYCHAIN
await self.server.bot.disconnect(self.server)
await self.server.bot.add_server(self.server.name, params)

View file

@ -1,6 +1,9 @@
from re import compile as re_compile
from typing import List, Optional
from dataclasses import dataclass, field
from .security import TLS, TLS_NOVERIFY, TLS_VERIFYCHAIN
class SASLParams(object):
mechanism: str
@ -28,19 +31,24 @@ class ResumePolicy(object):
address: str
token: str
RE_IPV6HOST = re_compile("\[([a-fA-F0-9:]+)\]")
_TLS_TYPES = {
"+": TLS_VERIFYCHAIN,
"~": TLS_NOVERIFY
}
@dataclass
class ConnectionParams(object):
nickname: str
host: str
port: int
tls: bool
tls: Optional[TLS] = TLS_VERIFYCHAIN
username: Optional[str] = None
realname: Optional[str] = None
bindhost: Optional[str] = None
password: Optional[str] = None
tls_verify: bool = True
sasl: Optional[SASLParams] = None
sts: Optional[STSPolicy] = None
@ -57,15 +65,19 @@ class ConnectionParams(object):
hoststring: str
) -> "ConnectionParams":
host, _, port_s = hoststring.strip().partition(":")
ipv6host = RE_IPV6HOST.search(hoststring)
if ipv6host is not None and ipv6host.start() == 0:
host = ipv6host.group(1)
port_s = hoststring[ipv6host.end()+1:]
else:
host, _, port_s = hoststring.strip().partition(":")
if port_s.startswith("+"):
tls = True
port_s = port_s.lstrip("+") or "6697"
elif not port_s:
tls = False
tls_type: Optional[TLS] = None
if not port_s:
port_s = "6667"
else:
tls = False
tls_type = _TLS_TYPES.get(port_s[0], None)
if tls_type is not None:
port_s = port_s[1:] or "6697"
return ConnectionParams(nickname, host, int(port_s), tls)
return ConnectionParams(nickname, host, int(port_s), tls_type)

View file

@ -1,4 +1,28 @@
import ssl
class TLS:
pass
# tls without verification
class TLSNoVerify(TLS):
pass
TLS_NOVERIFY = TLSNoVerify()
# verify via CAs
class TLSVerifyChain(TLS):
pass
TLS_VERIFYCHAIN = TLSVerifyChain()
# verify by a pinned hash
class TLSVerifyHash(TLSNoVerify):
def __init__(self, sum: str):
self.sum = sum.lower()
class TLSVerifySHA512(TLSVerifyHash):
pass
def tls_context(verify: bool=True) -> ssl.SSLContext:
return ssl.create_default_context()
ctx = ssl.create_default_context()
if not verify:
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE
return ctx

View file

@ -124,9 +124,8 @@ class Server(IServer):
reader, writer = await transport.connect(
params.host,
params.port,
tls =params.tls,
tls_verify=params.tls_verify,
bindhost =params.bindhost)
tls =params.tls,
bindhost =params.bindhost)
self._reader = reader
self._writer = writer

View file

@ -1,10 +1,12 @@
from hashlib import sha512
from ssl import SSLContext
from typing import Optional, Tuple
from asyncio import StreamReader, StreamWriter
from async_stagger import open_connection
from .interface import ITCPTransport, ITCPReader, ITCPWriter
from .security import tls_context
from .security import (tls_context, TLS, TLSNoVerify, TLSVerifyHash,
TLSVerifySHA512)
class TCPReader(ITCPReader):
def __init__(self, reader: StreamReader):
@ -32,16 +34,15 @@ class TCPWriter(ITCPWriter):
class TCPTransport(ITCPTransport):
async def connect(self,
hostname: str,
port: int,
tls: bool,
tls_verify: bool=True,
bindhost: Optional[str]=None
hostname: str,
port: int,
tls: Optional[TLS],
bindhost: Optional[str]=None
) -> Tuple[ITCPReader, ITCPWriter]:
cur_ssl: Optional[SSLContext] = None
if tls:
cur_ssl = tls_context(tls_verify)
if tls is not None:
cur_ssl = tls_context(not isinstance(tls, TLSNoVerify))
local_addr: Optional[Tuple[str, int]] = None
if not bindhost is None:
@ -55,5 +56,20 @@ class TCPTransport(ITCPTransport):
server_hostname=server_hostname,
ssl =cur_ssl,
local_addr =local_addr)
if isinstance(tls, TLSVerifyHash):
cert: bytes = writer.transport.get_extra_info(
"ssl_object"
).getpeercert(True)
if isinstance(tls, TLSVerifySHA512):
sum = sha512(cert).hexdigest()
else:
raise ValueError(f"unknown hash pinning {type(tls)}")
if not sum == tls.sum:
raise ValueError(
f"pinned hash for {hostname} does not match ({sum})"
)
return (TCPReader(reader), TCPWriter(writer))