combine params.tls and .tls_verify, support pinned certs

This commit is contained in:
jesopo 2022-01-23 16:52:27 +00:00
parent 0a5c774965
commit 5b347f95c9
9 changed files with 85 additions and 34 deletions

View file

@ -154,8 +154,8 @@ async def main(hostname: str, channel: str, nickname: str):
params = ConnectionParams( params = ConnectionParams(
nickname, nickname,
hostname, hostname,
6697, 6697
tls=True) )
await bot.add_server("freenode", params) await bot.add_server("freenode", params)
await bot.run() await bot.run()

View file

@ -23,7 +23,6 @@ async def main():
"MyNickname", "MyNickname",
host = "chat.freenode.invalid", host = "chat.freenode.invalid",
port = 6697, port = 6697,
tls = True,
sasl = sasl_params) sasl = sasl_params)
await bot.add_server("freenode", params) await bot.add_server("freenode", params)

View file

@ -25,7 +25,7 @@ class Bot(BaseBot):
async def main(): async def main():
bot = Bot() bot = Bot()
for name, host in SERVERS: for name, host in SERVERS:
params = ConnectionParams("BitBotNewTest", host, 6697, True) params = ConnectionParams("BitBotNewTest", host, 6697)
await bot.add_server(name, params) await bot.add_server(name, params)
await bot.run() await bot.run()

View file

@ -6,6 +6,7 @@ from ircstates import Server, Emit
from irctokens import Line, Hostmask from irctokens import Line, Hostmask
from .params import ConnectionParams, SASLParams, STSPolicy, ResumePolicy from .params import ConnectionParams, SASLParams, STSPolicy, ResumePolicy
from .security import TLS
class ITCPReader(object): class ITCPReader(object):
async def read(self, byte_count: int): async def read(self, byte_count: int):
@ -26,8 +27,7 @@ class ITCPTransport(object):
async def connect(self, async def connect(self,
hostname: str, hostname: str,
port: int, port: int,
tls: bool, tls: Optional[TLS],
tls_verify: bool=True,
bindhost: Optional[str]=None bindhost: Optional[str]=None
) -> Tuple[ITCPReader, ITCPWriter]: ) -> Tuple[ITCPReader, ITCPWriter]:
pass pass

View file

@ -8,6 +8,7 @@ from .contexts import ServerContext
from .matching import Response, ANY from .matching import Response, ANY
from .interface import ICapability from .interface import ICapability
from .params import ConnectionParams, STSPolicy, ResumePolicy from .params import ConnectionParams, STSPolicy, ResumePolicy
from .security import TLS_VERIFYCHAIN
class Capability(ICapability): class Capability(ICapability):
def __init__(self, def __init__(self,
@ -101,12 +102,12 @@ def _cap_dict(s: str) -> Dict[str, str]:
return d return d
async def sts_transmute(params: ConnectionParams): async def sts_transmute(params: ConnectionParams):
if not params.sts is None and not params.tls: if not params.sts is None and params.tls is None:
now = time() now = time()
since = (now-params.sts.created) since = (now-params.sts.created)
if since <= params.sts.duration: if since <= params.sts.duration:
params.port = params.sts.port params.port = params.sts.port
params.tls = True params.tls = TLS_VERIFYCHAIN
async def resume_transmute(params: ConnectionParams): async def resume_transmute(params: ConnectionParams):
if params.resume is not None: if params.resume is not None:
params.host = params.resume.address params.host = params.resume.address
@ -182,7 +183,7 @@ class CAPContext(ServerContext):
if not params.tls: if not params.tls:
if "port" in sts_dict: if "port" in sts_dict:
params.port = int(sts_dict["port"]) params.port = int(sts_dict["port"])
params.tls = True params.tls = TLS_VERIFYCHAIN
await self.server.bot.disconnect(self.server) await self.server.bot.disconnect(self.server)
await self.server.bot.add_server(self.server.name, params) await self.server.bot.add_server(self.server.name, params)

View file

@ -1,6 +1,9 @@
from re import compile as re_compile
from typing import List, Optional from typing import List, Optional
from dataclasses import dataclass, field from dataclasses import dataclass, field
from .security import TLS, TLS_NOVERIFY, TLS_VERIFYCHAIN
class SASLParams(object): class SASLParams(object):
mechanism: str mechanism: str
@ -28,19 +31,24 @@ class ResumePolicy(object):
address: str address: str
token: str token: str
RE_IPV6HOST = re_compile("\[([a-fA-F0-9:]+)\]")
_TLS_TYPES = {
"+": TLS_VERIFYCHAIN,
"~": TLS_NOVERIFY
}
@dataclass @dataclass
class ConnectionParams(object): class ConnectionParams(object):
nickname: str nickname: str
host: str host: str
port: int port: int
tls: bool tls: Optional[TLS] = TLS_VERIFYCHAIN
username: Optional[str] = None username: Optional[str] = None
realname: Optional[str] = None realname: Optional[str] = None
bindhost: Optional[str] = None bindhost: Optional[str] = None
password: Optional[str] = None password: Optional[str] = None
tls_verify: bool = True
sasl: Optional[SASLParams] = None sasl: Optional[SASLParams] = None
sts: Optional[STSPolicy] = None sts: Optional[STSPolicy] = None
@ -57,15 +65,19 @@ class ConnectionParams(object):
hoststring: str hoststring: str
) -> "ConnectionParams": ) -> "ConnectionParams":
ipv6host = RE_IPV6HOST.search(hoststring)
if ipv6host is not None and ipv6host.start() == 0:
host = ipv6host.group(1)
port_s = hoststring[ipv6host.end()+1:]
else:
host, _, port_s = hoststring.strip().partition(":") host, _, port_s = hoststring.strip().partition(":")
if port_s.startswith("+"): tls_type: Optional[TLS] = None
tls = True if not port_s:
port_s = port_s.lstrip("+") or "6697"
elif not port_s:
tls = False
port_s = "6667" port_s = "6667"
else: else:
tls = False tls_type = _TLS_TYPES.get(port_s[0], None)
if tls_type is not None:
port_s = port_s[1:] or "6697"
return ConnectionParams(nickname, host, int(port_s), tls) return ConnectionParams(nickname, host, int(port_s), tls_type)

View file

@ -1,4 +1,28 @@
import ssl import ssl
class TLS:
pass
# tls without verification
class TLSNoVerify(TLS):
pass
TLS_NOVERIFY = TLSNoVerify()
# verify via CAs
class TLSVerifyChain(TLS):
pass
TLS_VERIFYCHAIN = TLSVerifyChain()
# verify by a pinned hash
class TLSVerifyHash(TLSNoVerify):
def __init__(self, sum: str):
self.sum = sum.lower()
class TLSVerifySHA512(TLSVerifyHash):
pass
def tls_context(verify: bool=True) -> ssl.SSLContext: def tls_context(verify: bool=True) -> ssl.SSLContext:
return ssl.create_default_context() ctx = ssl.create_default_context()
if not verify:
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE
return ctx

View file

@ -125,7 +125,6 @@ class Server(IServer):
params.host, params.host,
params.port, params.port,
tls =params.tls, tls =params.tls,
tls_verify=params.tls_verify,
bindhost =params.bindhost) bindhost =params.bindhost)
self._reader = reader self._reader = reader

View file

@ -1,10 +1,12 @@
from hashlib import sha512
from ssl import SSLContext from ssl import SSLContext
from typing import Optional, Tuple from typing import Optional, Tuple
from asyncio import StreamReader, StreamWriter from asyncio import StreamReader, StreamWriter
from async_stagger import open_connection from async_stagger import open_connection
from .interface import ITCPTransport, ITCPReader, ITCPWriter from .interface import ITCPTransport, ITCPReader, ITCPWriter
from .security import tls_context from .security import (tls_context, TLS, TLSNoVerify, TLSVerifyHash,
TLSVerifySHA512)
class TCPReader(ITCPReader): class TCPReader(ITCPReader):
def __init__(self, reader: StreamReader): def __init__(self, reader: StreamReader):
@ -34,14 +36,13 @@ class TCPTransport(ITCPTransport):
async def connect(self, async def connect(self,
hostname: str, hostname: str,
port: int, port: int,
tls: bool, tls: Optional[TLS],
tls_verify: bool=True,
bindhost: Optional[str]=None bindhost: Optional[str]=None
) -> Tuple[ITCPReader, ITCPWriter]: ) -> Tuple[ITCPReader, ITCPWriter]:
cur_ssl: Optional[SSLContext] = None cur_ssl: Optional[SSLContext] = None
if tls: if tls is not None:
cur_ssl = tls_context(tls_verify) cur_ssl = tls_context(not isinstance(tls, TLSNoVerify))
local_addr: Optional[Tuple[str, int]] = None local_addr: Optional[Tuple[str, int]] = None
if not bindhost is None: if not bindhost is None:
@ -55,5 +56,20 @@ class TCPTransport(ITCPTransport):
server_hostname=server_hostname, server_hostname=server_hostname,
ssl =cur_ssl, ssl =cur_ssl,
local_addr =local_addr) local_addr =local_addr)
if isinstance(tls, TLSVerifyHash):
cert: bytes = writer.transport.get_extra_info(
"ssl_object"
).getpeercert(True)
if isinstance(tls, TLSVerifySHA512):
sum = sha512(cert).hexdigest()
else:
raise ValueError(f"unknown hash pinning {type(tls)}")
if not sum == tls.sum:
raise ValueError(
f"pinned hash for {hostname} does not match ({sum})"
)
return (TCPReader(reader), TCPWriter(writer)) return (TCPReader(reader), TCPWriter(writer))