implement SASL SCRAM

This commit is contained in:
jesopo 2020-04-02 22:37:51 +01:00
parent a4f5d8045f
commit 60e601aa81
4 changed files with 205 additions and 19 deletions

View file

@ -1,4 +1,4 @@
from .bot import Bot
from .server import Server
from .params import ConnectionParams, SASLUserPass, SASLExternal
from .params import ConnectionParams, SASLUserPass, SASLExternal, SASLSCRAM
from .ircv3 import Capability

View file

@ -13,6 +13,9 @@ class SASLParams(object):
class SASLUserPass(SASLParams):
def __init__(self, username: str, password: str):
super().__init__("USERPASS", username, password)
class SASLSCRAM(SASLParams):
def __init__(self, username: str, password: str):
super().__init__("SCRAM", username, password)
class SASLExternal(SASLParams):
def __init__(self):
super().__init__("EXTERNAL")

View file

@ -1,18 +1,19 @@
from typing import Optional
from typing import List, Optional
from enum import Enum
from base64 import b64encode
from base64 import b64decode, b64encode
from irctokens import build
from .matching import Response, ResponseOr, Numerics, ParamAny
from .contexts import ServerContext
from .params import SASLParams
from .scram import SCRAMContext
SASL_USERPASS_MECHANISMS = [
SASL_SCRAM_MECHANISMS = [
"SCRAM-SHA-512",
"SCRAM-SHA-256",
"SCRAM-SHA-1",
"PLAIN"
]
SASL_USERPASS_MECHANISMS = SASL_SCRAM_MECHANISMS+["PLAIN"]
class SASLResult(Enum):
NONE = 0
@ -25,14 +26,18 @@ class SASLError(Exception):
class SASLUnknownMechanismError(SASLError):
pass
AUTHENTICATE_ANY = Response("AUTHENTICATE", [ParamAny()])
NUMERICS_INITIAL = Numerics(
["ERR_SASLFAIL", "ERR_SASLALREADY", "RPL_SASLMECHS"])
NUMERICS_LAST = Numerics(["RPL_SASLSUCCESS", "ERR_SASLFAIL"])
class SASLContext(ServerContext):
async def from_params(self, params: SASLParams) -> SASLResult:
if params.mechanism == "USERPASS":
if params.mechanism == "USERPASS":
return await self.userpass(params.username, params.password)
elif params.mechanism == "SCRAM":
return await self.scram(params.username, params.password)
elif params.mechanism == "EXTERNAL":
return await self.external()
else:
@ -43,7 +48,7 @@ class SASLContext(ServerContext):
async def external(self) -> SASLResult:
await self.server.send(build("AUTHENTICATE", ["EXTERNAL"]))
line = await self.server.wait_for(ResponseOr(
Response("AUTHENTICATE", [ParamAny()]),
AUTHENTICATE_ANY,
NUMERICS_INITIAL
))
@ -63,17 +68,27 @@ class SASLContext(ServerContext):
return SASLResult.SUCCESS
return SASLResult.FAILURE
async def userpass(self, username: str, password: str) -> SASLResult:
async def plain(self, username: str, password: str) -> SASLResult:
return await self.userpass(username, password, ["PLAIN"])
async def scram(self, username: str, password: str) -> SASLResult:
return await self.userpass(username, password, SASL_SCRAM_MECHANISMS)
async def userpass(self,
username: str,
password: str,
mechanisms: List[str]=SASL_USERPASS_MECHANISMS
) -> SASLResult:
# this will, in the future, offer SCRAM support
def _common(server_mechs) -> str:
for our_mech in SASL_USERPASS_MECHANISMS:
for our_mech in mechanisms:
if our_mech in server_mechs:
return our_mech
else:
raise SASLUnknownMechanismError(
"No matching SASL mechanims. "
f"(we have: {SASL_USERPASS_MECHANISMS} "
f"(we want: {mechanisms} "
f"server has: {server_mechs})")
if not self.server.available_caps["sasl"] is None:
@ -83,11 +98,11 @@ class SASLContext(ServerContext):
else:
# CAP v3.1 does not. pick the pick and wait for 907 to inform us of
# what mechanisms are supported
match = SASL_USERPASS_MECHANISMS[0]
match = mechanisms[0]
await self.server.send(build("AUTHENTICATE", [match]))
line = await self.server.wait_for(ResponseOr(
Response("AUTHENTICATE", [ParamAny()]),
AUTHENTICATE_ANY,
NUMERICS_INITIAL
))
@ -100,19 +115,47 @@ class SASLContext(ServerContext):
match = _common(available)
await self.server.send(build("AUTHENTICATE", [match]))
line = await self.server.wait_for(
Response("AUTHENTICATE", [ParamAny()]))
line = await self.server.wait_for(AUTHENTICATE_ANY)
if line.command == "AUTHENTICATE" and line.params[0] == "+":
auth_text: Optional[str] = None
def _b64e(s: str):
return b64encode(s.encode("utf8")).decode("ascii")
def _b64eb(s: bytes) -> str:
# encode-from-bytes
return b64encode(s).decode("ascii")
def _b64db(s: str) -> bytes:
# decode-to-bytes
return b64decode(s)
auth_text = ""
if match == "PLAIN":
auth_text = f"{username}\0{username}\0{password}"
elif match.startswith("SCRAM-SHA-"):
algo = match.replace("SCRAM-", "", 1)
scram = SCRAMContext(algo, username, password)
if not auth_text is None:
auth_b64 = b64encode(
auth_text.encode("utf8")).decode("ascii")
await self.server.send(build("AUTHENTICATE", [auth_b64]))
client_first = _b64eb(scram.client_first())
await self.server.send(build("AUTHENTICATE", [client_first]))
line = await self.server.wait_for(AUTHENTICATE_ANY)
server_first = _b64db(line.params[0])
client_final = _b64eb(scram.server_first(server_first))
if not client_final == "":
await self.server.send(build("AUTHENTICATE", [client_final]))
line = await self.server.wait_for(AUTHENTICATE_ANY)
server_final = _b64db(line.params[0])
verified = scram.server_final(server_final)
#TODO PANIC if verified is false!
auth_text = "+"
else:
auth_text = ""
if not auth_text == "+":
auth_text = _b64e(auth_text)
if auth_text:
await self.server.send(build("AUTHENTICATE", [auth_text]))
line = await self.server.wait_for(NUMERICS_LAST)
if line.command == "903":
return SASLResult.SUCCESS

140
ircrobots/scram.py Normal file
View file

@ -0,0 +1,140 @@
import base64, enum, hashlib, hmac, os
from typing import Dict
# IANA Hash Function Textual Names
# https://tools.ietf.org/html/rfc5802#section-4
# https://www.iana.org/assignments/hash-function-text-names/
# MD2 has been removed as it's unacceptably weak
ALGORITHMS = [
"MD5", "SHA-1", "SHA-224", "SHA-256", "SHA-384", "SHA-512"]
SCRAM_ERRORS = [
"invalid-encoding",
"extensions-not-supported", # unrecognized 'm' value
"invalid-proof",
"channel-bindings-dont-match",
"server-does-support-channel-binding",
"channel-binding-not-supported",
"unsupported-channel-binding-type",
"unknown-user",
"invalid-username-encoding", # invalid utf8 or bad SASLprep
"no-resources"
]
def _scram_nonce() -> bytes:
return base64.b64encode(os.urandom(32))
def _scram_escape(s: bytes) -> bytes:
return s.replace(b"=", b"=3D").replace(b",", b"=2C")
def _scram_unescape(s: bytes) -> bytes:
return s.replace(b"=3D", b"=").replace(b"=2C", b",")
def _scram_xor(s1: bytes, s2: bytes) -> bytes:
return bytes(a ^ b for a, b in zip(s1, s2))
class SCRAMState(enum.Enum):
Uninitialised = 0
ClientFirst = 1
ClientFinal = 2
Success = 3
Failed = 4
VerifyFailed = 5
class SCRAMError(Exception):
pass
class SCRAMContext(object):
def __init__(self, algo: str, username: str, password: str):
if not algo in ALGORITHMS:
raise ValueError("Unknown SCRAM algorithm '%s'" % algo)
self._algo = algo.replace("-", "") # SHA-1 -> SHA1
self._username = username.encode("utf8")
self._password = password.encode("utf8")
self.state = SCRAMState.Uninitialised
self.error = ""
self.raw_error = ""
self._client_first = b""
self._salted_password = b""
self._auth_message = b""
def _get_pieces(self, data: bytes) -> Dict[bytes, bytes]:
pieces = (piece.split(b"=", 1) for piece in data.split(b","))
return dict((piece[0], piece[1]) for piece in pieces)
def _hmac(self, key: bytes, msg: bytes) -> bytes:
return hmac.new(key, msg, self._algo).digest()
def _hash(self, msg: bytes) -> bytes:
return hashlib.new(self._algo, msg).digest()
def _constant_time_compare(self, b1: bytes, b2: bytes):
return hmac.compare_digest(b1, b2)
def client_first(self) -> bytes:
self.state = SCRAMState.ClientFirst
self._client_first = b"n=%s,r=%s" % (
_scram_escape(self._username), _scram_nonce())
# n,,n=<username>,r=<nonce>
return b"n,,%s" % self._client_first
def _assert_error(self, pieces: Dict[bytes, bytes]) -> bool:
if b"e" in pieces:
error = pieces[b"e"].decode("utf8")
self.raw_error = error
if error in SCRAM_ERRORS:
self.error = error
else:
self.error = "other-error"
self.state = SCRAMState.Failed
return True
else:
return False
def server_first(self, data: bytes) -> bytes:
self.state = SCRAMState.ClientFinal
pieces = self._get_pieces(data)
if self._assert_error(pieces):
return b""
nonce = pieces[b"r"] # server combines your nonce with it's own
salt = base64.b64decode(pieces[b"s"]) # salt is b64encoded
iterations = int(pieces[b"i"])
salted_password = hashlib.pbkdf2_hmac(self._algo, self._password,
salt, iterations, dklen=None)
self._salted_password = salted_password
client_key = self._hmac(salted_password, b"Client Key")
stored_key = self._hash(client_key)
channel = base64.b64encode(b"n,,")
auth_noproof = b"c=%s,r=%s" % (channel, nonce)
auth_message = b"%s,%s,%s" % (self._client_first, data, auth_noproof)
self._auth_message = auth_message
client_signature = self._hmac(stored_key, auth_message)
client_proof_xor = _scram_xor(client_key, client_signature)
client_proof = base64.b64encode(client_proof_xor)
# c=<b64encode("n,,")>,r=<nonce>,p=<proof>
return b"%s,p=%s" % (auth_noproof, client_proof)
def server_final(self, data: bytes) -> bool:
pieces = self._get_pieces(data)
if self._assert_error(pieces):
return False
verifier = base64.b64decode(pieces[b"v"])
server_key = self._hmac(self._salted_password, b"Server Key")
server_signature = self._hmac(server_key, self._auth_message)
if server_signature == verifier:
self.state = SCRAMState.Success
return True
else:
self.state = SCRAMState.VerifyFailed
return False