ensure server nonce .startswith() our nonce

This commit is contained in:
jesopo 2020-04-03 00:07:48 +01:00
parent 99d55de170
commit 8cc3db5e58

View file

@ -31,12 +31,12 @@ def _scram_xor(s1: bytes, s2: bytes) -> bytes:
return bytes(a ^ b for a, b in zip(s1, s2)) return bytes(a ^ b for a, b in zip(s1, s2))
class SCRAMState(enum.Enum): class SCRAMState(enum.Enum):
Uninitialised = 0 NONE = 0
ClientFirst = 1 CLIENT_FIRST = 1
ClientFinal = 2 CLIENT_FINAL = 2
Success = 3 SUCCESS = 3
Failed = 4 FAILURE = 4
VerifyFailed = 5 VERIFY_FAILURE = 5
class SCRAMError(Exception): class SCRAMError(Exception):
pass pass
@ -50,11 +50,13 @@ class SCRAMContext(object):
self._username = username.encode("utf8") self._username = username.encode("utf8")
self._password = password.encode("utf8") self._password = password.encode("utf8")
self.state = SCRAMState.Uninitialised self.state = SCRAMState.NONE
self.error = "" self.error = ""
self.raw_error = "" self.raw_error = ""
self._client_first = b"" self._client_first = b""
self._client_nonce = b""
self._salted_password = b"" self._salted_password = b""
self._auth_message = b"" self._auth_message = b""
@ -70,10 +72,19 @@ class SCRAMContext(object):
def _constant_time_compare(self, b1: bytes, b2: bytes): def _constant_time_compare(self, b1: bytes, b2: bytes):
return hmac.compare_digest(b1, b2) return hmac.compare_digest(b1, b2)
def _fail(self, error: str):
self.raw_error = error
if error in SCRAM_ERRORS:
self.error = error
else:
self.error = "other-error"
self.state = SCRAMState.FAILURE
def client_first(self) -> bytes: def client_first(self) -> bytes:
self.state = SCRAMState.ClientFirst self.state = SCRAMState.CLIENT_FIRST
self._client_nonce = _scram_nonce()
self._client_first = b"n=%s,r=%s" % ( self._client_first = b"n=%s,r=%s" % (
_scram_escape(self._username), _scram_nonce()) _scram_escape(self._username), self._client_nonce)
# n,,n=<username>,r=<nonce> # n,,n=<username>,r=<nonce>
return b"n,,%s" % self._client_first return b"n,,%s" % self._client_first
@ -81,25 +92,24 @@ class SCRAMContext(object):
def _assert_error(self, pieces: Dict[bytes, bytes]) -> bool: def _assert_error(self, pieces: Dict[bytes, bytes]) -> bool:
if b"e" in pieces: if b"e" in pieces:
error = pieces[b"e"].decode("utf8") error = pieces[b"e"].decode("utf8")
self.raw_error = error self._fail(error)
if error in SCRAM_ERRORS:
self.error = error
else:
self.error = "other-error"
self.state = SCRAMState.Failed
return True return True
else: else:
return False return False
def server_first(self, data: bytes) -> bytes: def server_first(self, data: bytes) -> bytes:
self.state = SCRAMState.ClientFinal self.state = SCRAMState.CLIENT_FINAL
pieces = self._get_pieces(data) pieces = self._get_pieces(data)
if self._assert_error(pieces): if self._assert_error(pieces):
return b"" return b""
nonce = pieces[b"r"] # server combines your nonce with it's own nonce = pieces[b"r"] # server combines your nonce with it's own
if (not nonce.startswith(self._client_nonce) or
nonce == self._client_nonce):
self._fail("nonce-unacceptable")
return b""
salt = base64.b64decode(pieces[b"s"]) # salt is b64encoded salt = base64.b64decode(pieces[b"s"]) # salt is b64encoded
iterations = int(pieces[b"i"]) iterations = int(pieces[b"i"])
@ -133,8 +143,8 @@ class SCRAMContext(object):
server_signature = self._hmac(server_key, self._auth_message) server_signature = self._hmac(server_key, self._auth_message)
if server_signature == verifier: if server_signature == verifier:
self.state = SCRAMState.Success self.state = SCRAMState.SUCCESS
return True return True
else: else:
self.state = SCRAMState.VerifyFailed self.state = SCRAMState.VERIFY_FAILURE
return False return False