From 8cc3db5e5809e74d91b0db1f83a5cad1f8bf6069 Mon Sep 17 00:00:00 2001 From: jesopo Date: Fri, 3 Apr 2020 00:07:48 +0100 Subject: [PATCH] ensure server nonce .startswith() our nonce --- ircrobots/scram.py | 52 +++++++++++++++++++++++++++------------------- 1 file changed, 31 insertions(+), 21 deletions(-) diff --git a/ircrobots/scram.py b/ircrobots/scram.py index d28ee53..699b694 100644 --- a/ircrobots/scram.py +++ b/ircrobots/scram.py @@ -31,12 +31,12 @@ def _scram_xor(s1: bytes, s2: bytes) -> bytes: return bytes(a ^ b for a, b in zip(s1, s2)) class SCRAMState(enum.Enum): - Uninitialised = 0 - ClientFirst = 1 - ClientFinal = 2 - Success = 3 - Failed = 4 - VerifyFailed = 5 + NONE = 0 + CLIENT_FIRST = 1 + CLIENT_FINAL = 2 + SUCCESS = 3 + FAILURE = 4 + VERIFY_FAILURE = 5 class SCRAMError(Exception): pass @@ -50,13 +50,15 @@ class SCRAMContext(object): self._username = username.encode("utf8") self._password = password.encode("utf8") - self.state = SCRAMState.Uninitialised + self.state = SCRAMState.NONE self.error = "" self.raw_error = "" - self._client_first = b"" + self._client_first = b"" + self._client_nonce = b"" + self._salted_password = b"" - self._auth_message = b"" + self._auth_message = b"" def _get_pieces(self, data: bytes) -> Dict[bytes, bytes]: pieces = (piece.split(b"=", 1) for piece in data.split(b",")) @@ -70,10 +72,19 @@ class SCRAMContext(object): def _constant_time_compare(self, b1: bytes, b2: bytes): return hmac.compare_digest(b1, b2) + def _fail(self, error: str): + self.raw_error = error + if error in SCRAM_ERRORS: + self.error = error + else: + self.error = "other-error" + self.state = SCRAMState.FAILURE + def client_first(self) -> bytes: - self.state = SCRAMState.ClientFirst + self.state = SCRAMState.CLIENT_FIRST + self._client_nonce = _scram_nonce() self._client_first = b"n=%s,r=%s" % ( - _scram_escape(self._username), _scram_nonce()) + _scram_escape(self._username), self._client_nonce) # n,,n=,r= return b"n,,%s" % self._client_first @@ -81,25 +92,24 @@ class SCRAMContext(object): def _assert_error(self, pieces: Dict[bytes, bytes]) -> bool: if b"e" in pieces: error = pieces[b"e"].decode("utf8") - self.raw_error = error - if error in SCRAM_ERRORS: - self.error = error - else: - self.error = "other-error" - - self.state = SCRAMState.Failed + self._fail(error) return True else: return False def server_first(self, data: bytes) -> bytes: - self.state = SCRAMState.ClientFinal + self.state = SCRAMState.CLIENT_FINAL pieces = self._get_pieces(data) if self._assert_error(pieces): return b"" nonce = pieces[b"r"] # server combines your nonce with it's own + if (not nonce.startswith(self._client_nonce) or + nonce == self._client_nonce): + self._fail("nonce-unacceptable") + return b"" + salt = base64.b64decode(pieces[b"s"]) # salt is b64encoded iterations = int(pieces[b"i"]) @@ -133,8 +143,8 @@ class SCRAMContext(object): server_signature = self._hmac(server_key, self._auth_message) if server_signature == verifier: - self.state = SCRAMState.Success + self.state = SCRAMState.SUCCESS return True else: - self.state = SCRAMState.VerifyFailed + self.state = SCRAMState.VERIFY_FAILURE return False