SCRAM should take mechanisms as Enum values

This commit is contained in:
jesopo 2020-04-26 17:08:03 +01:00
parent e26190c283
commit c25f6d2a00
2 changed files with 28 additions and 16 deletions

View file

@ -7,7 +7,7 @@ from ircstates.numerics import *
from .matching import ResponseOr, Responses, Response, ANY from .matching import ResponseOr, Responses, Response, ANY
from .contexts import ServerContext from .contexts import ServerContext
from .params import SASLParams, SASLUserPass, SASLSCRAM, SASLExternal from .params import SASLParams, SASLUserPass, SASLSCRAM, SASLExternal
from .scram import SCRAMContext from .scram import SCRAMContext, SCRAMAlgorithm
SASL_SCRAM_MECHANISMS = [ SASL_SCRAM_MECHANISMS = [
"SCRAM-SHA-512", "SCRAM-SHA-512",
@ -153,8 +153,15 @@ class SASLContext(ServerContext):
return SASLResult.FAILURE return SASLResult.FAILURE
async def _scram(self, algo: str, username: str, password: str) -> str: async def _scram(self, algo_str: str,
algo = algo.replace("SCRAM-", "", 1) username: str,
password: str) -> str:
algo_str_prep = algo_str.replace("SCRAM-", "", 1
).replace("-", "").upper()
try:
algo = SCRAMAlgorithm(algo_str_prep)
except ValueError:
raise ValueError("Unknown SCRAM algorithm '%s'" % algo)
scram = SCRAMContext(algo, username, password) scram = SCRAMContext(algo, username, password)
client_first = _b64eb(scram.client_first()) client_first = _b64eb(scram.client_first())

View file

@ -1,12 +1,18 @@
import base64, enum, hashlib, hmac, os import base64, hashlib, hmac, os
from enum import Enum
from typing import Dict from typing import Dict
# IANA Hash Function Textual Names # IANA Hash Function Textual Names
# https://tools.ietf.org/html/rfc5802#section-4 # https://tools.ietf.org/html/rfc5802#section-4
# https://www.iana.org/assignments/hash-function-text-names/ # https://www.iana.org/assignments/hash-function-text-names/
# MD2 has been removed as it's unacceptably weak # MD2 has been removed as it's unacceptably weak
ALGORITHMS = [ class SCRAMAlgorithm(Enum):
"MD5", "SHA-1", "SHA-224", "SHA-256", "SHA-384", "SHA-512"] MD5 = "MD5"
SHA_1 = "SHA-1"
SHA_224 = "SHA224"
SHA_256 = "SHA256"
SHA_384 = "SHA384"
SHA_512 = "SHA512"
SCRAM_ERRORS = [ SCRAM_ERRORS = [
"invalid-encoding", "invalid-encoding",
@ -30,7 +36,7 @@ def _scram_unescape(s: bytes) -> bytes:
def _scram_xor(s1: bytes, s2: bytes) -> bytes: def _scram_xor(s1: bytes, s2: bytes) -> bytes:
return bytes(a ^ b for a, b in zip(s1, s2)) return bytes(a ^ b for a, b in zip(s1, s2))
class SCRAMState(enum.Enum): class SCRAMState(Enum):
NONE = 0 NONE = 0
CLIENT_FIRST = 1 CLIENT_FIRST = 1
CLIENT_FINAL = 2 CLIENT_FINAL = 2
@ -42,11 +48,10 @@ class SCRAMError(Exception):
pass pass
class SCRAMContext(object): class SCRAMContext(object):
def __init__(self, algo: str, username: str, password: str): def __init__(self, algo: SCRAMAlgorithm,
if not algo in ALGORITHMS: username: str,
raise ValueError("Unknown SCRAM algorithm '%s'" % algo) password: str):
self._algo = algo
self._algo = algo.replace("-", "") # SHA-1 -> SHA1
self._username = username.encode("utf8") self._username = username.encode("utf8")
self._password = password.encode("utf8") self._password = password.encode("utf8")
@ -65,9 +70,9 @@ class SCRAMContext(object):
return dict((piece[0], piece[1]) for piece in pieces) return dict((piece[0], piece[1]) for piece in pieces)
def _hmac(self, key: bytes, msg: bytes) -> bytes: def _hmac(self, key: bytes, msg: bytes) -> bytes:
return hmac.new(key, msg, self._algo).digest() return hmac.new(key, msg, self._algo.value).digest()
def _hash(self, msg: bytes) -> bytes: def _hash(self, msg: bytes) -> bytes:
return hashlib.new(self._algo, msg).digest() return hashlib.new(self._algo.value, msg).digest()
def _constant_time_compare(self, b1: bytes, b2: bytes): def _constant_time_compare(self, b1: bytes, b2: bytes):
return hmac.compare_digest(b1, b2) return hmac.compare_digest(b1, b2)
@ -113,8 +118,8 @@ class SCRAMContext(object):
salt = base64.b64decode(pieces[b"s"]) # salt is b64encoded salt = base64.b64decode(pieces[b"s"]) # salt is b64encoded
iterations = int(pieces[b"i"]) iterations = int(pieces[b"i"])
salted_password = hashlib.pbkdf2_hmac(self._algo, self._password, salted_password = hashlib.pbkdf2_hmac(self._algo.value,
salt, iterations, dklen=None) self._password, salt, iterations, dklen=None)
self._salted_password = salted_password self._salted_password = salted_password
client_key = self._hmac(salted_password, b"Client Key") client_key = self._hmac(salted_password, b"Client Key")