combine params.tls and .tls_verify, support pinned certs
This commit is contained in:
parent
0a5c774965
commit
5b347f95c9
9 changed files with 85 additions and 34 deletions
|
@ -154,8 +154,8 @@ async def main(hostname: str, channel: str, nickname: str):
|
||||||
params = ConnectionParams(
|
params = ConnectionParams(
|
||||||
nickname,
|
nickname,
|
||||||
hostname,
|
hostname,
|
||||||
6697,
|
6697
|
||||||
tls=True)
|
)
|
||||||
await bot.add_server("freenode", params)
|
await bot.add_server("freenode", params)
|
||||||
await bot.run()
|
await bot.run()
|
||||||
|
|
||||||
|
|
|
@ -23,7 +23,6 @@ async def main():
|
||||||
"MyNickname",
|
"MyNickname",
|
||||||
host = "chat.freenode.invalid",
|
host = "chat.freenode.invalid",
|
||||||
port = 6697,
|
port = 6697,
|
||||||
tls = True,
|
|
||||||
sasl = sasl_params)
|
sasl = sasl_params)
|
||||||
|
|
||||||
await bot.add_server("freenode", params)
|
await bot.add_server("freenode", params)
|
||||||
|
|
|
@ -25,7 +25,7 @@ class Bot(BaseBot):
|
||||||
async def main():
|
async def main():
|
||||||
bot = Bot()
|
bot = Bot()
|
||||||
for name, host in SERVERS:
|
for name, host in SERVERS:
|
||||||
params = ConnectionParams("BitBotNewTest", host, 6697, True)
|
params = ConnectionParams("BitBotNewTest", host, 6697)
|
||||||
await bot.add_server(name, params)
|
await bot.add_server(name, params)
|
||||||
|
|
||||||
await bot.run()
|
await bot.run()
|
||||||
|
|
|
@ -6,6 +6,7 @@ from ircstates import Server, Emit
|
||||||
from irctokens import Line, Hostmask
|
from irctokens import Line, Hostmask
|
||||||
|
|
||||||
from .params import ConnectionParams, SASLParams, STSPolicy, ResumePolicy
|
from .params import ConnectionParams, SASLParams, STSPolicy, ResumePolicy
|
||||||
|
from .security import TLS
|
||||||
|
|
||||||
class ITCPReader(object):
|
class ITCPReader(object):
|
||||||
async def read(self, byte_count: int):
|
async def read(self, byte_count: int):
|
||||||
|
@ -24,11 +25,10 @@ class ITCPWriter(object):
|
||||||
|
|
||||||
class ITCPTransport(object):
|
class ITCPTransport(object):
|
||||||
async def connect(self,
|
async def connect(self,
|
||||||
hostname: str,
|
hostname: str,
|
||||||
port: int,
|
port: int,
|
||||||
tls: bool,
|
tls: Optional[TLS],
|
||||||
tls_verify: bool=True,
|
bindhost: Optional[str]=None
|
||||||
bindhost: Optional[str]=None
|
|
||||||
) -> Tuple[ITCPReader, ITCPWriter]:
|
) -> Tuple[ITCPReader, ITCPWriter]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
@ -8,6 +8,7 @@ from .contexts import ServerContext
|
||||||
from .matching import Response, ANY
|
from .matching import Response, ANY
|
||||||
from .interface import ICapability
|
from .interface import ICapability
|
||||||
from .params import ConnectionParams, STSPolicy, ResumePolicy
|
from .params import ConnectionParams, STSPolicy, ResumePolicy
|
||||||
|
from .security import TLS_VERIFYCHAIN
|
||||||
|
|
||||||
class Capability(ICapability):
|
class Capability(ICapability):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -101,12 +102,12 @@ def _cap_dict(s: str) -> Dict[str, str]:
|
||||||
return d
|
return d
|
||||||
|
|
||||||
async def sts_transmute(params: ConnectionParams):
|
async def sts_transmute(params: ConnectionParams):
|
||||||
if not params.sts is None and not params.tls:
|
if not params.sts is None and params.tls is None:
|
||||||
now = time()
|
now = time()
|
||||||
since = (now-params.sts.created)
|
since = (now-params.sts.created)
|
||||||
if since <= params.sts.duration:
|
if since <= params.sts.duration:
|
||||||
params.port = params.sts.port
|
params.port = params.sts.port
|
||||||
params.tls = True
|
params.tls = TLS_VERIFYCHAIN
|
||||||
async def resume_transmute(params: ConnectionParams):
|
async def resume_transmute(params: ConnectionParams):
|
||||||
if params.resume is not None:
|
if params.resume is not None:
|
||||||
params.host = params.resume.address
|
params.host = params.resume.address
|
||||||
|
@ -182,7 +183,7 @@ class CAPContext(ServerContext):
|
||||||
if not params.tls:
|
if not params.tls:
|
||||||
if "port" in sts_dict:
|
if "port" in sts_dict:
|
||||||
params.port = int(sts_dict["port"])
|
params.port = int(sts_dict["port"])
|
||||||
params.tls = True
|
params.tls = TLS_VERIFYCHAIN
|
||||||
|
|
||||||
await self.server.bot.disconnect(self.server)
|
await self.server.bot.disconnect(self.server)
|
||||||
await self.server.bot.add_server(self.server.name, params)
|
await self.server.bot.add_server(self.server.name, params)
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
|
from re import compile as re_compile
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from .security import TLS, TLS_NOVERIFY, TLS_VERIFYCHAIN
|
||||||
|
|
||||||
class SASLParams(object):
|
class SASLParams(object):
|
||||||
mechanism: str
|
mechanism: str
|
||||||
|
|
||||||
|
@ -28,19 +31,24 @@ class ResumePolicy(object):
|
||||||
address: str
|
address: str
|
||||||
token: str
|
token: str
|
||||||
|
|
||||||
|
RE_IPV6HOST = re_compile("\[([a-fA-F0-9:]+)\]")
|
||||||
|
|
||||||
|
_TLS_TYPES = {
|
||||||
|
"+": TLS_VERIFYCHAIN,
|
||||||
|
"~": TLS_NOVERIFY
|
||||||
|
}
|
||||||
@dataclass
|
@dataclass
|
||||||
class ConnectionParams(object):
|
class ConnectionParams(object):
|
||||||
nickname: str
|
nickname: str
|
||||||
host: str
|
host: str
|
||||||
port: int
|
port: int
|
||||||
tls: bool
|
tls: Optional[TLS] = TLS_VERIFYCHAIN
|
||||||
|
|
||||||
username: Optional[str] = None
|
username: Optional[str] = None
|
||||||
realname: Optional[str] = None
|
realname: Optional[str] = None
|
||||||
bindhost: Optional[str] = None
|
bindhost: Optional[str] = None
|
||||||
|
|
||||||
password: Optional[str] = None
|
password: Optional[str] = None
|
||||||
tls_verify: bool = True
|
|
||||||
sasl: Optional[SASLParams] = None
|
sasl: Optional[SASLParams] = None
|
||||||
|
|
||||||
sts: Optional[STSPolicy] = None
|
sts: Optional[STSPolicy] = None
|
||||||
|
@ -57,15 +65,19 @@ class ConnectionParams(object):
|
||||||
hoststring: str
|
hoststring: str
|
||||||
) -> "ConnectionParams":
|
) -> "ConnectionParams":
|
||||||
|
|
||||||
host, _, port_s = hoststring.strip().partition(":")
|
ipv6host = RE_IPV6HOST.search(hoststring)
|
||||||
|
if ipv6host is not None and ipv6host.start() == 0:
|
||||||
|
host = ipv6host.group(1)
|
||||||
|
port_s = hoststring[ipv6host.end()+1:]
|
||||||
|
else:
|
||||||
|
host, _, port_s = hoststring.strip().partition(":")
|
||||||
|
|
||||||
if port_s.startswith("+"):
|
tls_type: Optional[TLS] = None
|
||||||
tls = True
|
if not port_s:
|
||||||
port_s = port_s.lstrip("+") or "6697"
|
|
||||||
elif not port_s:
|
|
||||||
tls = False
|
|
||||||
port_s = "6667"
|
port_s = "6667"
|
||||||
else:
|
else:
|
||||||
tls = False
|
tls_type = _TLS_TYPES.get(port_s[0], None)
|
||||||
|
if tls_type is not None:
|
||||||
|
port_s = port_s[1:] or "6697"
|
||||||
|
|
||||||
return ConnectionParams(nickname, host, int(port_s), tls)
|
return ConnectionParams(nickname, host, int(port_s), tls_type)
|
||||||
|
|
|
@ -1,4 +1,28 @@
|
||||||
import ssl
|
import ssl
|
||||||
|
|
||||||
|
class TLS:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# tls without verification
|
||||||
|
class TLSNoVerify(TLS):
|
||||||
|
pass
|
||||||
|
TLS_NOVERIFY = TLSNoVerify()
|
||||||
|
|
||||||
|
# verify via CAs
|
||||||
|
class TLSVerifyChain(TLS):
|
||||||
|
pass
|
||||||
|
TLS_VERIFYCHAIN = TLSVerifyChain()
|
||||||
|
|
||||||
|
# verify by a pinned hash
|
||||||
|
class TLSVerifyHash(TLSNoVerify):
|
||||||
|
def __init__(self, sum: str):
|
||||||
|
self.sum = sum.lower()
|
||||||
|
class TLSVerifySHA512(TLSVerifyHash):
|
||||||
|
pass
|
||||||
|
|
||||||
def tls_context(verify: bool=True) -> ssl.SSLContext:
|
def tls_context(verify: bool=True) -> ssl.SSLContext:
|
||||||
return ssl.create_default_context()
|
ctx = ssl.create_default_context()
|
||||||
|
if not verify:
|
||||||
|
ctx.check_hostname = False
|
||||||
|
ctx.verify_mode = ssl.CERT_NONE
|
||||||
|
return ctx
|
||||||
|
|
|
@ -124,9 +124,8 @@ class Server(IServer):
|
||||||
reader, writer = await transport.connect(
|
reader, writer = await transport.connect(
|
||||||
params.host,
|
params.host,
|
||||||
params.port,
|
params.port,
|
||||||
tls =params.tls,
|
tls =params.tls,
|
||||||
tls_verify=params.tls_verify,
|
bindhost =params.bindhost)
|
||||||
bindhost =params.bindhost)
|
|
||||||
|
|
||||||
self._reader = reader
|
self._reader = reader
|
||||||
self._writer = writer
|
self._writer = writer
|
||||||
|
|
|
@ -1,10 +1,12 @@
|
||||||
|
from hashlib import sha512
|
||||||
from ssl import SSLContext
|
from ssl import SSLContext
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
from asyncio import StreamReader, StreamWriter
|
from asyncio import StreamReader, StreamWriter
|
||||||
from async_stagger import open_connection
|
from async_stagger import open_connection
|
||||||
|
|
||||||
from .interface import ITCPTransport, ITCPReader, ITCPWriter
|
from .interface import ITCPTransport, ITCPReader, ITCPWriter
|
||||||
from .security import tls_context
|
from .security import (tls_context, TLS, TLSNoVerify, TLSVerifyHash,
|
||||||
|
TLSVerifySHA512)
|
||||||
|
|
||||||
class TCPReader(ITCPReader):
|
class TCPReader(ITCPReader):
|
||||||
def __init__(self, reader: StreamReader):
|
def __init__(self, reader: StreamReader):
|
||||||
|
@ -32,16 +34,15 @@ class TCPWriter(ITCPWriter):
|
||||||
|
|
||||||
class TCPTransport(ITCPTransport):
|
class TCPTransport(ITCPTransport):
|
||||||
async def connect(self,
|
async def connect(self,
|
||||||
hostname: str,
|
hostname: str,
|
||||||
port: int,
|
port: int,
|
||||||
tls: bool,
|
tls: Optional[TLS],
|
||||||
tls_verify: bool=True,
|
bindhost: Optional[str]=None
|
||||||
bindhost: Optional[str]=None
|
|
||||||
) -> Tuple[ITCPReader, ITCPWriter]:
|
) -> Tuple[ITCPReader, ITCPWriter]:
|
||||||
|
|
||||||
cur_ssl: Optional[SSLContext] = None
|
cur_ssl: Optional[SSLContext] = None
|
||||||
if tls:
|
if tls is not None:
|
||||||
cur_ssl = tls_context(tls_verify)
|
cur_ssl = tls_context(not isinstance(tls, TLSNoVerify))
|
||||||
|
|
||||||
local_addr: Optional[Tuple[str, int]] = None
|
local_addr: Optional[Tuple[str, int]] = None
|
||||||
if not bindhost is None:
|
if not bindhost is None:
|
||||||
|
@ -55,5 +56,20 @@ class TCPTransport(ITCPTransport):
|
||||||
server_hostname=server_hostname,
|
server_hostname=server_hostname,
|
||||||
ssl =cur_ssl,
|
ssl =cur_ssl,
|
||||||
local_addr =local_addr)
|
local_addr =local_addr)
|
||||||
|
|
||||||
|
if isinstance(tls, TLSVerifyHash):
|
||||||
|
cert: bytes = writer.transport.get_extra_info(
|
||||||
|
"ssl_object"
|
||||||
|
).getpeercert(True)
|
||||||
|
if isinstance(tls, TLSVerifySHA512):
|
||||||
|
sum = sha512(cert).hexdigest()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"unknown hash pinning {type(tls)}")
|
||||||
|
|
||||||
|
if not sum == tls.sum:
|
||||||
|
raise ValueError(
|
||||||
|
f"pinned hash for {hostname} does not match ({sum})"
|
||||||
|
)
|
||||||
|
|
||||||
return (TCPReader(reader), TCPWriter(writer))
|
return (TCPReader(reader), TCPWriter(writer))
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue