ensure server nonce .startswith() our nonce

This commit is contained in:
jesopo 2020-04-03 00:07:48 +01:00
parent 99d55de170
commit 8cc3db5e58

View file

@ -31,12 +31,12 @@ def _scram_xor(s1: bytes, s2: bytes) -> bytes:
return bytes(a ^ b for a, b in zip(s1, s2))
class SCRAMState(enum.Enum):
Uninitialised = 0
ClientFirst = 1
ClientFinal = 2
Success = 3
Failed = 4
VerifyFailed = 5
NONE = 0
CLIENT_FIRST = 1
CLIENT_FINAL = 2
SUCCESS = 3
FAILURE = 4
VERIFY_FAILURE = 5
class SCRAMError(Exception):
pass
@ -50,13 +50,15 @@ class SCRAMContext(object):
self._username = username.encode("utf8")
self._password = password.encode("utf8")
self.state = SCRAMState.Uninitialised
self.state = SCRAMState.NONE
self.error = ""
self.raw_error = ""
self._client_first = b""
self._client_first = b""
self._client_nonce = b""
self._salted_password = b""
self._auth_message = b""
self._auth_message = b""
def _get_pieces(self, data: bytes) -> Dict[bytes, bytes]:
pieces = (piece.split(b"=", 1) for piece in data.split(b","))
@ -70,10 +72,19 @@ class SCRAMContext(object):
def _constant_time_compare(self, b1: bytes, b2: bytes):
return hmac.compare_digest(b1, b2)
def _fail(self, error: str):
self.raw_error = error
if error in SCRAM_ERRORS:
self.error = error
else:
self.error = "other-error"
self.state = SCRAMState.FAILURE
def client_first(self) -> bytes:
self.state = SCRAMState.ClientFirst
self.state = SCRAMState.CLIENT_FIRST
self._client_nonce = _scram_nonce()
self._client_first = b"n=%s,r=%s" % (
_scram_escape(self._username), _scram_nonce())
_scram_escape(self._username), self._client_nonce)
# n,,n=<username>,r=<nonce>
return b"n,,%s" % self._client_first
@ -81,25 +92,24 @@ class SCRAMContext(object):
def _assert_error(self, pieces: Dict[bytes, bytes]) -> bool:
if b"e" in pieces:
error = pieces[b"e"].decode("utf8")
self.raw_error = error
if error in SCRAM_ERRORS:
self.error = error
else:
self.error = "other-error"
self.state = SCRAMState.Failed
self._fail(error)
return True
else:
return False
def server_first(self, data: bytes) -> bytes:
self.state = SCRAMState.ClientFinal
self.state = SCRAMState.CLIENT_FINAL
pieces = self._get_pieces(data)
if self._assert_error(pieces):
return b""
nonce = pieces[b"r"] # server combines your nonce with it's own
if (not nonce.startswith(self._client_nonce) or
nonce == self._client_nonce):
self._fail("nonce-unacceptable")
return b""
salt = base64.b64decode(pieces[b"s"]) # salt is b64encoded
iterations = int(pieces[b"i"])
@ -133,8 +143,8 @@ class SCRAMContext(object):
server_signature = self._hmac(server_key, self._auth_message)
if server_signature == verifier:
self.state = SCRAMState.Success
self.state = SCRAMState.SUCCESS
return True
else:
self.state = SCRAMState.VerifyFailed
self.state = SCRAMState.VERIFY_FAILURE
return False