refactor: Use error enum + result type alias for failures

This makes the library slightly more "rusty". Instead of returning a
validation result which also represents potential success, use an enum
representing the error variants and the standard library's
`Result`-type to represent success/failure.
This commit is contained in:
Vincent Ambo 2018-09-04 10:58:59 +02:00
parent 0f8231e990
commit d3b200e820

View file

@ -57,6 +57,7 @@ use base64::{decode_config, URL_SAFE};
use openssl::bn::BigNum; use openssl::bn::BigNum;
use openssl::pkey::Public; use openssl::pkey::Public;
use openssl::rsa::{Rsa}; use openssl::rsa::{Rsa};
use openssl::error::ErrorStack;
/// JWT algorithm used. The only supported algorithm is currently /// JWT algorithm used. The only supported algorithm is currently
/// RS256. /// RS256.
@ -112,22 +113,33 @@ pub struct JWT {}
pub enum Validation {} pub enum Validation {}
/// Possible results of a token validation. /// Possible results of a token validation.
pub enum ValidationResult { #[derive(Debug)]
/// Signature and claim validation succeeded. pub enum ValidationError {
Valid, /// Token was malformed (various possible reasons!)
MalformedJWT,
/// Decoding of the provided JWK failed. /// Decoding of the provided JWK failed.
InvalidJWK(String), InvalidJWK,
/// Signature validation failed, i.e. because of a non-matching /// Signature validation failed, i.e. because of a non-matching
/// public key. /// public key.
InvalidSignature, InvalidSignature,
/// An OpenSSL operation failed along the way at a point at which
/// a more specific error variant could not be constructed.
OpenSSL(ErrorStack),
/// One or more claim validations failed. /// One or more claim validations failed.
// TODO: Provide reasons? // TODO: Provide reasons?
InvalidClaims, InvalidClaims,
} }
type JWTResult<T> = Result<T, ValidationError>;
impl From<ErrorStack> for ValidationError {
fn from(err: ErrorStack) -> Self { ValidationError::OpenSSL(err) }
}
/// Attempt to extract the `kid`-claim out of a JWT's header claims. /// Attempt to extract the `kid`-claim out of a JWT's header claims.
/// ///
/// This function is normally used when a token provider has multiple /// This function is normally used when a token provider has multiple
@ -147,7 +159,7 @@ pub fn token_kid(jwt: JWT) -> Option<String> {
/// ///
/// It is the user's task to ensure that the correct JWK is passed in /// It is the user's task to ensure that the correct JWK is passed in
/// for validation. /// for validation.
pub fn validate(jwt: JWT, jwk: JWK, validations: Vec<Validation>) -> ValidationResult { pub fn validate(jwt: JWT, jwk: JWK, validations: Vec<Validation>) -> JWTResult<()> {
unimplemented!() unimplemented!()
} }
@ -156,18 +168,21 @@ pub fn validate(jwt: JWT, jwk: JWK, validations: Vec<Validation>) -> ValidationR
// The functions in the following section are not part of the public // The functions in the following section are not part of the public
// API of this library. // API of this library.
/// Decode a single key fragment to an OpenSSL BigNum. /// Decode a single key fragment (base64-url encoded integer) to an
fn decode_fragment(fragment: &str) -> Option<BigNum> { /// OpenSSL BigNum.
let bytes = decode_config(fragment, URL_SAFE).ok()?; fn decode_fragment(fragment: &str) -> JWTResult<BigNum> {
BigNum::from_slice(&bytes).ok() let bytes = decode_config(fragment, URL_SAFE)
.map_err(|_| ValidationError::InvalidJWK)?;
BigNum::from_slice(&bytes).map_err(Into::into)
} }
/// Decode an RSA public key from a JWK by constructing it directly /// Decode an RSA public key from a JWK by constructing it directly
/// from the public RSA key fragments. /// from the public RSA key fragments.
fn public_key_from_jwk(jwk: &JWK) -> Option<Rsa<Public>> { fn public_key_from_jwk(jwk: &JWK) -> JWTResult<Rsa<Public>> {
let jwk_n = decode_fragment(&jwk.n)?; let jwk_n = decode_fragment(&jwk.n)?;
let jwk_e = decode_fragment(&jwk.e)?; let jwk_e = decode_fragment(&jwk.e)?;
Rsa::from_public_components(jwk_n, jwk_e).ok() Rsa::from_public_components(jwk_n, jwk_e).map_err(Into::into)
} }
#[cfg(test)] #[cfg(test)]