SCRAM should take mechanisms as Enum values
This commit is contained in:
parent
e26190c283
commit
c25f6d2a00
2 changed files with 28 additions and 16 deletions
|
@ -7,7 +7,7 @@ from ircstates.numerics import *
|
||||||
from .matching import ResponseOr, Responses, Response, ANY
|
from .matching import ResponseOr, Responses, Response, ANY
|
||||||
from .contexts import ServerContext
|
from .contexts import ServerContext
|
||||||
from .params import SASLParams, SASLUserPass, SASLSCRAM, SASLExternal
|
from .params import SASLParams, SASLUserPass, SASLSCRAM, SASLExternal
|
||||||
from .scram import SCRAMContext
|
from .scram import SCRAMContext, SCRAMAlgorithm
|
||||||
|
|
||||||
SASL_SCRAM_MECHANISMS = [
|
SASL_SCRAM_MECHANISMS = [
|
||||||
"SCRAM-SHA-512",
|
"SCRAM-SHA-512",
|
||||||
|
@ -153,8 +153,15 @@ class SASLContext(ServerContext):
|
||||||
|
|
||||||
return SASLResult.FAILURE
|
return SASLResult.FAILURE
|
||||||
|
|
||||||
async def _scram(self, algo: str, username: str, password: str) -> str:
|
async def _scram(self, algo_str: str,
|
||||||
algo = algo.replace("SCRAM-", "", 1)
|
username: str,
|
||||||
|
password: str) -> str:
|
||||||
|
algo_str_prep = algo_str.replace("SCRAM-", "", 1
|
||||||
|
).replace("-", "").upper()
|
||||||
|
try:
|
||||||
|
algo = SCRAMAlgorithm(algo_str_prep)
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError("Unknown SCRAM algorithm '%s'" % algo)
|
||||||
scram = SCRAMContext(algo, username, password)
|
scram = SCRAMContext(algo, username, password)
|
||||||
|
|
||||||
client_first = _b64eb(scram.client_first())
|
client_first = _b64eb(scram.client_first())
|
||||||
|
|
|
@ -1,12 +1,18 @@
|
||||||
import base64, enum, hashlib, hmac, os
|
import base64, hashlib, hmac, os
|
||||||
|
from enum import Enum
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
# IANA Hash Function Textual Names
|
# IANA Hash Function Textual Names
|
||||||
# https://tools.ietf.org/html/rfc5802#section-4
|
# https://tools.ietf.org/html/rfc5802#section-4
|
||||||
# https://www.iana.org/assignments/hash-function-text-names/
|
# https://www.iana.org/assignments/hash-function-text-names/
|
||||||
# MD2 has been removed as it's unacceptably weak
|
# MD2 has been removed as it's unacceptably weak
|
||||||
ALGORITHMS = [
|
class SCRAMAlgorithm(Enum):
|
||||||
"MD5", "SHA-1", "SHA-224", "SHA-256", "SHA-384", "SHA-512"]
|
MD5 = "MD5"
|
||||||
|
SHA_1 = "SHA-1"
|
||||||
|
SHA_224 = "SHA224"
|
||||||
|
SHA_256 = "SHA256"
|
||||||
|
SHA_384 = "SHA384"
|
||||||
|
SHA_512 = "SHA512"
|
||||||
|
|
||||||
SCRAM_ERRORS = [
|
SCRAM_ERRORS = [
|
||||||
"invalid-encoding",
|
"invalid-encoding",
|
||||||
|
@ -30,7 +36,7 @@ def _scram_unescape(s: bytes) -> bytes:
|
||||||
def _scram_xor(s1: bytes, s2: bytes) -> bytes:
|
def _scram_xor(s1: bytes, s2: bytes) -> bytes:
|
||||||
return bytes(a ^ b for a, b in zip(s1, s2))
|
return bytes(a ^ b for a, b in zip(s1, s2))
|
||||||
|
|
||||||
class SCRAMState(enum.Enum):
|
class SCRAMState(Enum):
|
||||||
NONE = 0
|
NONE = 0
|
||||||
CLIENT_FIRST = 1
|
CLIENT_FIRST = 1
|
||||||
CLIENT_FINAL = 2
|
CLIENT_FINAL = 2
|
||||||
|
@ -42,11 +48,10 @@ class SCRAMError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class SCRAMContext(object):
|
class SCRAMContext(object):
|
||||||
def __init__(self, algo: str, username: str, password: str):
|
def __init__(self, algo: SCRAMAlgorithm,
|
||||||
if not algo in ALGORITHMS:
|
username: str,
|
||||||
raise ValueError("Unknown SCRAM algorithm '%s'" % algo)
|
password: str):
|
||||||
|
self._algo = algo
|
||||||
self._algo = algo.replace("-", "") # SHA-1 -> SHA1
|
|
||||||
self._username = username.encode("utf8")
|
self._username = username.encode("utf8")
|
||||||
self._password = password.encode("utf8")
|
self._password = password.encode("utf8")
|
||||||
|
|
||||||
|
@ -65,9 +70,9 @@ class SCRAMContext(object):
|
||||||
return dict((piece[0], piece[1]) for piece in pieces)
|
return dict((piece[0], piece[1]) for piece in pieces)
|
||||||
|
|
||||||
def _hmac(self, key: bytes, msg: bytes) -> bytes:
|
def _hmac(self, key: bytes, msg: bytes) -> bytes:
|
||||||
return hmac.new(key, msg, self._algo).digest()
|
return hmac.new(key, msg, self._algo.value).digest()
|
||||||
def _hash(self, msg: bytes) -> bytes:
|
def _hash(self, msg: bytes) -> bytes:
|
||||||
return hashlib.new(self._algo, msg).digest()
|
return hashlib.new(self._algo.value, msg).digest()
|
||||||
|
|
||||||
def _constant_time_compare(self, b1: bytes, b2: bytes):
|
def _constant_time_compare(self, b1: bytes, b2: bytes):
|
||||||
return hmac.compare_digest(b1, b2)
|
return hmac.compare_digest(b1, b2)
|
||||||
|
@ -113,8 +118,8 @@ class SCRAMContext(object):
|
||||||
salt = base64.b64decode(pieces[b"s"]) # salt is b64encoded
|
salt = base64.b64decode(pieces[b"s"]) # salt is b64encoded
|
||||||
iterations = int(pieces[b"i"])
|
iterations = int(pieces[b"i"])
|
||||||
|
|
||||||
salted_password = hashlib.pbkdf2_hmac(self._algo, self._password,
|
salted_password = hashlib.pbkdf2_hmac(self._algo.value,
|
||||||
salt, iterations, dklen=None)
|
self._password, salt, iterations, dklen=None)
|
||||||
self._salted_password = salted_password
|
self._salted_password = salted_password
|
||||||
|
|
||||||
client_key = self._hmac(salted_password, b"Client Key")
|
client_key = self._hmac(salted_password, b"Client Key")
|
||||||
|
|
Loading…
Reference in a new issue