diff --git a/ircrobots/__init__.py b/ircrobots/__init__.py index 034c9d1..d32d6bc 100644 --- a/ircrobots/__init__.py +++ b/ircrobots/__init__.py @@ -1,4 +1,4 @@ from .bot import Bot from .server import Server -from .params import ConnectionParams, SASLUserPass, SASLExternal +from .params import ConnectionParams, SASLUserPass, SASLExternal, SASLSCRAM from .ircv3 import Capability diff --git a/ircrobots/params.py b/ircrobots/params.py index 3435f03..17c4fa0 100644 --- a/ircrobots/params.py +++ b/ircrobots/params.py @@ -13,6 +13,9 @@ class SASLParams(object): class SASLUserPass(SASLParams): def __init__(self, username: str, password: str): super().__init__("USERPASS", username, password) +class SASLSCRAM(SASLParams): + def __init__(self, username: str, password: str): + super().__init__("SCRAM", username, password) class SASLExternal(SASLParams): def __init__(self): super().__init__("EXTERNAL") diff --git a/ircrobots/sasl.py b/ircrobots/sasl.py index c1885ba..c46f3c9 100644 --- a/ircrobots/sasl.py +++ b/ircrobots/sasl.py @@ -1,18 +1,19 @@ -from typing import Optional +from typing import List, Optional from enum import Enum -from base64 import b64encode +from base64 import b64decode, b64encode from irctokens import build from .matching import Response, ResponseOr, Numerics, ParamAny from .contexts import ServerContext from .params import SASLParams +from .scram import SCRAMContext -SASL_USERPASS_MECHANISMS = [ +SASL_SCRAM_MECHANISMS = [ "SCRAM-SHA-512", "SCRAM-SHA-256", "SCRAM-SHA-1", - "PLAIN" ] +SASL_USERPASS_MECHANISMS = SASL_SCRAM_MECHANISMS+["PLAIN"] class SASLResult(Enum): NONE = 0 @@ -25,14 +26,18 @@ class SASLError(Exception): class SASLUnknownMechanismError(SASLError): pass + +AUTHENTICATE_ANY = Response("AUTHENTICATE", [ParamAny()]) NUMERICS_INITIAL = Numerics( ["ERR_SASLFAIL", "ERR_SASLALREADY", "RPL_SASLMECHS"]) NUMERICS_LAST = Numerics(["RPL_SASLSUCCESS", "ERR_SASLFAIL"]) class SASLContext(ServerContext): async def from_params(self, params: SASLParams) -> SASLResult: - if params.mechanism == "USERPASS": + if params.mechanism == "USERPASS": return await self.userpass(params.username, params.password) + elif params.mechanism == "SCRAM": + return await self.scram(params.username, params.password) elif params.mechanism == "EXTERNAL": return await self.external() else: @@ -43,7 +48,7 @@ class SASLContext(ServerContext): async def external(self) -> SASLResult: await self.server.send(build("AUTHENTICATE", ["EXTERNAL"])) line = await self.server.wait_for(ResponseOr( - Response("AUTHENTICATE", [ParamAny()]), + AUTHENTICATE_ANY, NUMERICS_INITIAL )) @@ -63,17 +68,27 @@ class SASLContext(ServerContext): return SASLResult.SUCCESS return SASLResult.FAILURE - async def userpass(self, username: str, password: str) -> SASLResult: + async def plain(self, username: str, password: str) -> SASLResult: + return await self.userpass(username, password, ["PLAIN"]) + + async def scram(self, username: str, password: str) -> SASLResult: + return await self.userpass(username, password, SASL_SCRAM_MECHANISMS) + + async def userpass(self, + username: str, + password: str, + mechanisms: List[str]=SASL_USERPASS_MECHANISMS + ) -> SASLResult: # this will, in the future, offer SCRAM support def _common(server_mechs) -> str: - for our_mech in SASL_USERPASS_MECHANISMS: + for our_mech in mechanisms: if our_mech in server_mechs: return our_mech else: raise SASLUnknownMechanismError( "No matching SASL mechanims. " - f"(we have: {SASL_USERPASS_MECHANISMS} " + f"(we want: {mechanisms} " f"server has: {server_mechs})") if not self.server.available_caps["sasl"] is None: @@ -83,11 +98,11 @@ class SASLContext(ServerContext): else: # CAP v3.1 does not. pick the pick and wait for 907 to inform us of # what mechanisms are supported - match = SASL_USERPASS_MECHANISMS[0] + match = mechanisms[0] await self.server.send(build("AUTHENTICATE", [match])) line = await self.server.wait_for(ResponseOr( - Response("AUTHENTICATE", [ParamAny()]), + AUTHENTICATE_ANY, NUMERICS_INITIAL )) @@ -100,19 +115,47 @@ class SASLContext(ServerContext): match = _common(available) await self.server.send(build("AUTHENTICATE", [match])) - line = await self.server.wait_for( - Response("AUTHENTICATE", [ParamAny()])) + line = await self.server.wait_for(AUTHENTICATE_ANY) if line.command == "AUTHENTICATE" and line.params[0] == "+": - auth_text: Optional[str] = None + def _b64e(s: str): + return b64encode(s.encode("utf8")).decode("ascii") + + def _b64eb(s: bytes) -> str: + # encode-from-bytes + return b64encode(s).decode("ascii") + def _b64db(s: str) -> bytes: + # decode-to-bytes + return b64decode(s) + + auth_text = "" if match == "PLAIN": auth_text = f"{username}\0{username}\0{password}" + elif match.startswith("SCRAM-SHA-"): + algo = match.replace("SCRAM-", "", 1) + scram = SCRAMContext(algo, username, password) - if not auth_text is None: - auth_b64 = b64encode( - auth_text.encode("utf8")).decode("ascii") - await self.server.send(build("AUTHENTICATE", [auth_b64])) + client_first = _b64eb(scram.client_first()) + await self.server.send(build("AUTHENTICATE", [client_first])) + line = await self.server.wait_for(AUTHENTICATE_ANY) + server_first = _b64db(line.params[0]) + client_final = _b64eb(scram.server_first(server_first)) + if not client_final == "": + await self.server.send(build("AUTHENTICATE", [client_final])) + line = await self.server.wait_for(AUTHENTICATE_ANY) + + server_final = _b64db(line.params[0]) + verified = scram.server_final(server_final) + #TODO PANIC if verified is false! + auth_text = "+" + else: + auth_text = "" + + if not auth_text == "+": + auth_text = _b64e(auth_text) + if auth_text: + await self.server.send(build("AUTHENTICATE", [auth_text])) line = await self.server.wait_for(NUMERICS_LAST) if line.command == "903": return SASLResult.SUCCESS diff --git a/ircrobots/scram.py b/ircrobots/scram.py new file mode 100644 index 0000000..d28ee53 --- /dev/null +++ b/ircrobots/scram.py @@ -0,0 +1,140 @@ +import base64, enum, hashlib, hmac, os +from typing import Dict + +# IANA Hash Function Textual Names +# https://tools.ietf.org/html/rfc5802#section-4 +# https://www.iana.org/assignments/hash-function-text-names/ +# MD2 has been removed as it's unacceptably weak +ALGORITHMS = [ + "MD5", "SHA-1", "SHA-224", "SHA-256", "SHA-384", "SHA-512"] + +SCRAM_ERRORS = [ + "invalid-encoding", + "extensions-not-supported", # unrecognized 'm' value + "invalid-proof", + "channel-bindings-dont-match", + "server-does-support-channel-binding", + "channel-binding-not-supported", + "unsupported-channel-binding-type", + "unknown-user", + "invalid-username-encoding", # invalid utf8 or bad SASLprep + "no-resources" +] + +def _scram_nonce() -> bytes: + return base64.b64encode(os.urandom(32)) +def _scram_escape(s: bytes) -> bytes: + return s.replace(b"=", b"=3D").replace(b",", b"=2C") +def _scram_unescape(s: bytes) -> bytes: + return s.replace(b"=3D", b"=").replace(b"=2C", b",") +def _scram_xor(s1: bytes, s2: bytes) -> bytes: + return bytes(a ^ b for a, b in zip(s1, s2)) + +class SCRAMState(enum.Enum): + Uninitialised = 0 + ClientFirst = 1 + ClientFinal = 2 + Success = 3 + Failed = 4 + VerifyFailed = 5 + +class SCRAMError(Exception): + pass + +class SCRAMContext(object): + def __init__(self, algo: str, username: str, password: str): + if not algo in ALGORITHMS: + raise ValueError("Unknown SCRAM algorithm '%s'" % algo) + + self._algo = algo.replace("-", "") # SHA-1 -> SHA1 + self._username = username.encode("utf8") + self._password = password.encode("utf8") + + self.state = SCRAMState.Uninitialised + self.error = "" + self.raw_error = "" + + self._client_first = b"" + self._salted_password = b"" + self._auth_message = b"" + + def _get_pieces(self, data: bytes) -> Dict[bytes, bytes]: + pieces = (piece.split(b"=", 1) for piece in data.split(b",")) + return dict((piece[0], piece[1]) for piece in pieces) + + def _hmac(self, key: bytes, msg: bytes) -> bytes: + return hmac.new(key, msg, self._algo).digest() + def _hash(self, msg: bytes) -> bytes: + return hashlib.new(self._algo, msg).digest() + + def _constant_time_compare(self, b1: bytes, b2: bytes): + return hmac.compare_digest(b1, b2) + + def client_first(self) -> bytes: + self.state = SCRAMState.ClientFirst + self._client_first = b"n=%s,r=%s" % ( + _scram_escape(self._username), _scram_nonce()) + + # n,,n=,r= + return b"n,,%s" % self._client_first + + def _assert_error(self, pieces: Dict[bytes, bytes]) -> bool: + if b"e" in pieces: + error = pieces[b"e"].decode("utf8") + self.raw_error = error + if error in SCRAM_ERRORS: + self.error = error + else: + self.error = "other-error" + + self.state = SCRAMState.Failed + return True + else: + return False + + def server_first(self, data: bytes) -> bytes: + self.state = SCRAMState.ClientFinal + + pieces = self._get_pieces(data) + if self._assert_error(pieces): + return b"" + + nonce = pieces[b"r"] # server combines your nonce with it's own + salt = base64.b64decode(pieces[b"s"]) # salt is b64encoded + iterations = int(pieces[b"i"]) + + salted_password = hashlib.pbkdf2_hmac(self._algo, self._password, + salt, iterations, dklen=None) + self._salted_password = salted_password + + client_key = self._hmac(salted_password, b"Client Key") + stored_key = self._hash(client_key) + + channel = base64.b64encode(b"n,,") + auth_noproof = b"c=%s,r=%s" % (channel, nonce) + auth_message = b"%s,%s,%s" % (self._client_first, data, auth_noproof) + self._auth_message = auth_message + + client_signature = self._hmac(stored_key, auth_message) + client_proof_xor = _scram_xor(client_key, client_signature) + client_proof = base64.b64encode(client_proof_xor) + + # c=,r=,p= + return b"%s,p=%s" % (auth_noproof, client_proof) + + def server_final(self, data: bytes) -> bool: + pieces = self._get_pieces(data) + if self._assert_error(pieces): + return False + + verifier = base64.b64decode(pieces[b"v"]) + + server_key = self._hmac(self._salted_password, b"Server Key") + server_signature = self._hmac(server_key, self._auth_message) + + if server_signature == verifier: + self.state = SCRAMState.Success + return True + else: + self.state = SCRAMState.VerifyFailed + return False