diff --git a/src/lib.rs b/src/lib.rs index e62600e26..135b1df0f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -74,7 +74,7 @@ extern crate openssl; extern crate serde; extern crate serde_json; -use base64::{decode_config, URL_SAFE}; +use base64::{URL_SAFE_NO_PAD, Config, DecodeError}; use openssl::bn::BigNum; use openssl::error::ErrorStack; use openssl::hash::MessageDigest; @@ -88,6 +88,17 @@ use std::time::{UNIX_EPOCH, Duration, SystemTime}; #[cfg(test)] mod tests; + +/// URL-safe character set without padding that allows trailing bits, +/// which appear in some JWT implementations. +/// +/// Note: The functions on `base64::Config` are not marked `const`, +/// and the constructors are not exported, which is why this is +/// implemented as a function. +fn jwt_forgiving() -> Config { + URL_SAFE_NO_PAD.decode_allow_trailing_bits(true) +} + /// JWT algorithm used. The only supported algorithm is currently /// RS256. #[derive(Clone, Deserialize, Debug)] @@ -179,8 +190,11 @@ pub enum Validation { /// Possible results of a token validation. #[derive(Debug)] pub enum ValidationError { - /// Token was malformed (various possible reasons!) - MalformedJWT, + /// Invalid number of token components (not a JWT?) + InvalidComponents, + + /// Token segments had invalid base64-encoding. + InvalidBase64(DecodeError), /// Decoding of the provided JWK failed. InvalidJWK, @@ -211,6 +225,10 @@ impl From for ValidationError { fn from(err: serde_json::Error) -> Self { ValidationError::JSON(err) } } +impl From for ValidationError { + fn from(err: DecodeError) -> Self { ValidationError::InvalidBase64(err) } +} + /// Attempt to extract the `kid`-claim out of a JWT's header claims. /// /// This function is normally used when a token provider has multiple @@ -224,7 +242,7 @@ pub fn token_kid(token: &str) -> JWTResult> { // dismissing the rest. let parts: Vec<&str> = token.splitn(2, '.').collect(); if parts.len() != 2 { - return Err(ValidationError::MalformedJWT); + return Err(ValidationError::InvalidComponents); } // Decode only the first part of the token into a specialised @@ -262,7 +280,7 @@ pub fn validate(token: &str, if parts.len() != 3 { // This is unlikely considering that validation has already // been performed at this point, but better safe than sorry. - return Err(ValidationError::MalformedJWT) + return Err(ValidationError::InvalidComponents) } // Perform claim validations before constructing the valid token: @@ -284,7 +302,7 @@ pub fn validate(token: &str, /// Decode a single key fragment (base64-url encoded integer) to an /// OpenSSL BigNum. fn decode_fragment(fragment: &str) -> JWTResult { - let bytes = decode_config(fragment, URL_SAFE) + let bytes = base64::decode_config(fragment, jwt_forgiving()) .map_err(|_| ValidationError::InvalidJWK)?; BigNum::from_slice(&bytes).map_err(Into::into) @@ -301,9 +319,7 @@ fn public_key_from_jwk(jwk: &JWK) -> JWTResult> { /// Decode a base64-URL encoded string and deserialise the resulting /// JSON. fn deserialize_part(part: &str) -> JWTResult { - let json = base64::decode_config(part, URL_SAFE) - .map_err(|_| ValidationError::MalformedJWT)?; - + let json = base64::decode_config(part, jwt_forgiving())?; serde_json::from_slice(&json).map_err(Into::into) } @@ -321,7 +337,7 @@ fn validate_jwt_signature(jwt: &JWT, key: Rsa) -> JWTResult<()> { // splitting them is unnecessary. let token_parts: Vec<&str> = jwt.0.rsplitn(2, '.').collect(); if token_parts.len() != 2 { - return Err(ValidationError::MalformedJWT); + return Err(ValidationError::InvalidComponents); } // Second element of the vector will be the signed payload. @@ -329,8 +345,7 @@ fn validate_jwt_signature(jwt: &JWT, key: Rsa) -> JWTResult<()> { // First element of the vector will be the (encoded) signature. let sig_b64 = token_parts[0]; - let sig = base64::decode_config(sig_b64, URL_SAFE) - .map_err(|_| ValidationError::MalformedJWT)?; + let sig = base64::decode_config(sig_b64, jwt_forgiving())?; // Verify signature by inserting the payload data and checking it // against the decoded signature.