From cc1ee9c81d4f1c8620661cbe5945913e415e218f Mon Sep 17 00:00:00 2001 From: Vincent Ambo Date: Wed, 13 Feb 2019 14:06:34 +0100 Subject: [PATCH] fix: Allow trailing bits in base64 encodings After upgrading the base64 library, tests were failing because the new default of the library is to disallow trailing bits in JWTs. Some JWT provider implementations do however use this "forgiving" version of base64-encoding, hence it is required for token validation. This adds a base64::Config with the appropriate settings and also chains base64-errors separately from other token errors. --- src/lib.rs | 39 +++++++++++++++++++++++++++------------ 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index e62600e26..135b1df0f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -74,7 +74,7 @@ extern crate openssl; extern crate serde; extern crate serde_json; -use base64::{decode_config, URL_SAFE}; +use base64::{URL_SAFE_NO_PAD, Config, DecodeError}; use openssl::bn::BigNum; use openssl::error::ErrorStack; use openssl::hash::MessageDigest; @@ -88,6 +88,17 @@ use std::time::{UNIX_EPOCH, Duration, SystemTime}; #[cfg(test)] mod tests; + +/// URL-safe character set without padding that allows trailing bits, +/// which appear in some JWT implementations. +/// +/// Note: The functions on `base64::Config` are not marked `const`, +/// and the constructors are not exported, which is why this is +/// implemented as a function. +fn jwt_forgiving() -> Config { + URL_SAFE_NO_PAD.decode_allow_trailing_bits(true) +} + /// JWT algorithm used. The only supported algorithm is currently /// RS256. #[derive(Clone, Deserialize, Debug)] @@ -179,8 +190,11 @@ pub enum Validation { /// Possible results of a token validation. #[derive(Debug)] pub enum ValidationError { - /// Token was malformed (various possible reasons!) - MalformedJWT, + /// Invalid number of token components (not a JWT?) + InvalidComponents, + + /// Token segments had invalid base64-encoding. + InvalidBase64(DecodeError), /// Decoding of the provided JWK failed. InvalidJWK, @@ -211,6 +225,10 @@ impl From for ValidationError { fn from(err: serde_json::Error) -> Self { ValidationError::JSON(err) } } +impl From for ValidationError { + fn from(err: DecodeError) -> Self { ValidationError::InvalidBase64(err) } +} + /// Attempt to extract the `kid`-claim out of a JWT's header claims. /// /// This function is normally used when a token provider has multiple @@ -224,7 +242,7 @@ pub fn token_kid(token: &str) -> JWTResult> { // dismissing the rest. let parts: Vec<&str> = token.splitn(2, '.').collect(); if parts.len() != 2 { - return Err(ValidationError::MalformedJWT); + return Err(ValidationError::InvalidComponents); } // Decode only the first part of the token into a specialised @@ -262,7 +280,7 @@ pub fn validate(token: &str, if parts.len() != 3 { // This is unlikely considering that validation has already // been performed at this point, but better safe than sorry. - return Err(ValidationError::MalformedJWT) + return Err(ValidationError::InvalidComponents) } // Perform claim validations before constructing the valid token: @@ -284,7 +302,7 @@ pub fn validate(token: &str, /// Decode a single key fragment (base64-url encoded integer) to an /// OpenSSL BigNum. fn decode_fragment(fragment: &str) -> JWTResult { - let bytes = decode_config(fragment, URL_SAFE) + let bytes = base64::decode_config(fragment, jwt_forgiving()) .map_err(|_| ValidationError::InvalidJWK)?; BigNum::from_slice(&bytes).map_err(Into::into) @@ -301,9 +319,7 @@ fn public_key_from_jwk(jwk: &JWK) -> JWTResult> { /// Decode a base64-URL encoded string and deserialise the resulting /// JSON. fn deserialize_part(part: &str) -> JWTResult { - let json = base64::decode_config(part, URL_SAFE) - .map_err(|_| ValidationError::MalformedJWT)?; - + let json = base64::decode_config(part, jwt_forgiving())?; serde_json::from_slice(&json).map_err(Into::into) } @@ -321,7 +337,7 @@ fn validate_jwt_signature(jwt: &JWT, key: Rsa) -> JWTResult<()> { // splitting them is unnecessary. let token_parts: Vec<&str> = jwt.0.rsplitn(2, '.').collect(); if token_parts.len() != 2 { - return Err(ValidationError::MalformedJWT); + return Err(ValidationError::InvalidComponents); } // Second element of the vector will be the signed payload. @@ -329,8 +345,7 @@ fn validate_jwt_signature(jwt: &JWT, key: Rsa) -> JWTResult<()> { // First element of the vector will be the (encoded) signature. let sig_b64 = token_parts[0]; - let sig = base64::decode_config(sig_b64, URL_SAFE) - .map_err(|_| ValidationError::MalformedJWT)?; + let sig = base64::decode_config(sig_b64, jwt_forgiving())?; // Verify signature by inserting the payload data and checking it // against the decoded signature.