Consume GoogleSignIn.validateJWT

TL;DR:
- Consume GoogleSignIn.validateJWT in the Handler for /verify
- Rename validation fn to validateJWT
- Prefer Text to String type
This commit is contained in:
William Carroll 2020-08-08 17:55:19 +01:00
parent 8a7a3b29a9
commit e8f35f0d10
4 changed files with 60 additions and 26 deletions

View file

@ -3,7 +3,7 @@
module GoogleSignIn where module GoogleSignIn where
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
import Data.String.Conversions (cs) import Data.String.Conversions (cs)
import Data.Text (Text) import Data.Text
import Web.JWT import Web.JWT
import Utils import Utils
@ -14,10 +14,16 @@ import qualified Data.Time.Clock.POSIX as POSIX
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
newtype EncodedJWT = EncodedJWT Text newtype EncodedJWT = EncodedJWT Text
deriving (Show)
newtype DecodedJWT = DecodedJWT (JWT UnverifiedJWT)
deriving (Show)
instance Eq DecodedJWT where
(DecodedJWT _) == (DecodedJWT _) = True
-- | Some of the errors that a JWT
data ValidationResult data ValidationResult
= Valid = Valid DecodedJWT
| DecodeError | DecodeError
| GoogleSaysInvalid Text | GoogleSaysInvalid Text
| NoMatchingClientIDs [StringOrURI] | NoMatchingClientIDs [StringOrURI]
@ -36,10 +42,10 @@ data ValidationResult
-- * The `exp` time has not passed -- * The `exp` time has not passed
-- --
-- Set `skipHTTP` to `True` to avoid making the network request for testing. -- Set `skipHTTP` to `True` to avoid making the network request for testing.
jwtIsValid :: Bool validateJWT :: Bool
-> EncodedJWT -> EncodedJWT
-> IO ValidationResult -> IO ValidationResult
jwtIsValid skipHTTP (EncodedJWT encodedJWT) = do validateJWT skipHTTP (EncodedJWT encodedJWT) = do
case encodedJWT |> decode of case encodedJWT |> decode of
Nothing -> pure DecodeError Nothing -> pure DecodeError
Just jwt -> do Just jwt -> do
@ -91,4 +97,16 @@ jwtIsValid skipHTTP (EncodedJWT encodedJWT) = do
if not $ currentTime <= jwtExpiry then if not $ currentTime <= jwtExpiry then
pure $ StaleExpiry jwtExpiry pure $ StaleExpiry jwtExpiry
else else
pure Valid pure $ jwt |> DecodedJWT |> Valid
-- | Attempt to explain the `ValidationResult` to a human.
explainResult :: ValidationResult -> String
explainResult (Valid _) = "Everything appears to be valid"
explainResult DecodeError = "We had difficulty decoding the provided JWT"
explainResult (GoogleSaysInvalid x) = "After checking with Google, they claimed that the provided JWT was invalid: " ++ cs x
explainResult (NoMatchingClientIDs audFields) = "None of the values in the `aud` field on the provided JWT match our client ID: " ++ show audFields
explainResult (WrongIssuer issuer) = "The `iss` field in the provided JWT does not match what we expect: " ++ show issuer
explainResult (StringOrURIParseFailure x) = "We had difficulty parsing values as URIs" ++ show x
explainResult TimeConversionFailure = "We had difficulty converting the current time to a value we can use to compare with the JWT's `exp` field"
explainResult (MissingRequiredClaim claim) = "Your JWT is missing the following claim: " ++ cs claim
explainResult (StaleExpiry x) = "The `exp` field on your JWT has expired" ++ x |> show |> cs

View file

@ -7,10 +7,14 @@ module Main where
import Servant import Servant
import API import API
import Control.Monad.IO.Class (liftIO) import Control.Monad.IO.Class (liftIO)
import GoogleSignIn (EncodedJWT(..), ValidationResult(..))
import Data.String.Conversions (cs)
import Utils
import qualified Network.Wai.Handler.Warp as Warp import qualified Network.Wai.Handler.Warp as Warp
import qualified Network.Wai.Middleware.Cors as Cors import qualified Network.Wai.Middleware.Cors as Cors
import qualified Types as T import qualified Types as T
import qualified GoogleSignIn
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
server :: Server API server :: Server API
@ -18,8 +22,13 @@ server = verifyGoogleSignIn
where where
verifyGoogleSignIn :: T.VerifyGoogleSignInRequest -> Handler NoContent verifyGoogleSignIn :: T.VerifyGoogleSignInRequest -> Handler NoContent
verifyGoogleSignIn T.VerifyGoogleSignInRequest{..} = do verifyGoogleSignIn T.VerifyGoogleSignInRequest{..} = do
liftIO $ putStrLn $ "Received: " ++ idToken validationResult <- liftIO $ GoogleSignIn.validateJWT False (EncodedJWT idToken)
case validationResult of
Valid _ -> do
liftIO $ putStrLn "Sign-in valid! Let's create a session"
pure NoContent pure NoContent
err -> do
throwError err401 { errBody = err |> GoogleSignIn.explainResult |> cs }
main :: IO () main :: IO ()
main = do main = do

View file

@ -4,8 +4,8 @@ module Spec where
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
import Test.Hspec import Test.Hspec
import Utils import Utils
import Web.JWT (numericDate) import Web.JWT (numericDate, decode)
import GoogleSignIn (ValidationResult(..)) import GoogleSignIn (EncodedJWT(..), DecodedJWT(..), ValidationResult(..))
import qualified GoogleSignIn import qualified GoogleSignIn
import qualified Fixtures as F import qualified Fixtures as F
@ -16,36 +16,40 @@ import qualified Data.Time.Clock.POSIX as POSIX
main :: IO () main :: IO ()
main = hspec $ do main = hspec $ do
describe "GoogleSignIn" $ describe "GoogleSignIn" $
describe "jwtIsValid" $ do describe "validateJWT" $ do
let jwtIsValid' = GoogleSignIn.jwtIsValid True let validateJWT' = GoogleSignIn.validateJWT True
it "returns a decode error when an incorrectly encoded JWT is used" $ do it "returns a decode error when an incorrectly encoded JWT is used" $ do
jwtIsValid' (GoogleSignIn.EncodedJWT "rubbish") `shouldReturn` DecodeError validateJWT' (GoogleSignIn.EncodedJWT "rubbish") `shouldReturn` DecodeError
it "returns validation error when the aud field doesn't match my client ID" $ do it "returns validation error when the aud field doesn't match my client ID" $ do
let auds = ["wrong-client-id"] let auds = ["wrong-client-id"]
|> fmap TestUtils.unsafeStringOrURI |> fmap TestUtils.unsafeStringOrURI
encodedJWT = F.defaultJWTFields { F.overwriteAuds = auds } encodedJWT = F.defaultJWTFields { F.overwriteAuds = auds }
|> F.googleJWT |> F.googleJWT
jwtIsValid' encodedJWT `shouldReturn` NoMatchingClientIDs auds validateJWT' encodedJWT `shouldReturn` NoMatchingClientIDs auds
it "returns validation success when one of the aud fields matches my client ID" $ do it "returns validation success when one of the aud fields matches my client ID" $ do
let auds = ["wrong-client-id", "771151720060-buofllhed98fgt0j22locma05e7rpngl.apps.googleusercontent.com"] let auds = ["wrong-client-id", "771151720060-buofllhed98fgt0j22locma05e7rpngl.apps.googleusercontent.com"]
|> fmap TestUtils.unsafeStringOrURI |> fmap TestUtils.unsafeStringOrURI
encodedJWT = F.defaultJWTFields { F.overwriteAuds = auds } encodedJWT@(EncodedJWT jwt) =
F.defaultJWTFields { F.overwriteAuds = auds }
|> F.googleJWT |> F.googleJWT
jwtIsValid' encodedJWT `shouldReturn` Valid decodedJWT = jwt |> decode |> TestUtils.unsafeJust |> DecodedJWT
validateJWT' encodedJWT `shouldReturn` Valid decodedJWT
it "returns validation error when one of the iss field doesn't match accounts.google.com or https://accounts.google.com" $ do it "returns validation error when one of the iss field doesn't match accounts.google.com or https://accounts.google.com" $ do
let erroneousIssuer = TestUtils.unsafeStringOrURI "not-accounts.google.com" let erroneousIssuer = TestUtils.unsafeStringOrURI "not-accounts.google.com"
encodedJWT = F.defaultJWTFields { F.overwriteIss = erroneousIssuer } encodedJWT = F.defaultJWTFields { F.overwriteIss = erroneousIssuer }
|> F.googleJWT |> F.googleJWT
jwtIsValid' encodedJWT `shouldReturn` WrongIssuer erroneousIssuer validateJWT' encodedJWT `shouldReturn` WrongIssuer erroneousIssuer
it "returns validation success when the iss field matches accounts.google.com or https://accounts.google.com" $ do it "returns validation success when the iss field matches accounts.google.com or https://accounts.google.com" $ do
let erroneousIssuer = TestUtils.unsafeStringOrURI "https://accounts.google.com" let erroneousIssuer = TestUtils.unsafeStringOrURI "https://accounts.google.com"
encodedJWT = F.defaultJWTFields { F.overwriteIss = erroneousIssuer } encodedJWT@(EncodedJWT jwt) =
F.defaultJWTFields { F.overwriteIss = erroneousIssuer }
|> F.googleJWT |> F.googleJWT
jwtIsValid' encodedJWT `shouldReturn` Valid decodedJWT = jwt |> decode |> TestUtils.unsafeJust |> DecodedJWT
validateJWT' encodedJWT `shouldReturn` Valid decodedJWT
it "fails validation when the exp field has expired" $ do it "fails validation when the exp field has expired" $ do
let mErroneousExp = numericDate 0 let mErroneousExp = numericDate 0
@ -54,7 +58,7 @@ main = hspec $ do
Just erroneousExp -> do Just erroneousExp -> do
let encodedJWT = F.defaultJWTFields { F.overwriteExp = erroneousExp } let encodedJWT = F.defaultJWTFields { F.overwriteExp = erroneousExp }
|> F.googleJWT |> F.googleJWT
jwtIsValid' encodedJWT `shouldReturn` StaleExpiry erroneousExp validateJWT' encodedJWT `shouldReturn` StaleExpiry erroneousExp
it "passes validation when the exp field is current" $ do it "passes validation when the exp field is current" $ do
mFreshExp <- POSIX.getPOSIXTime mFreshExp <- POSIX.getPOSIXTime
@ -63,6 +67,8 @@ main = hspec $ do
case mFreshExp of case mFreshExp of
Nothing -> True `shouldBe` False Nothing -> True `shouldBe` False
Just freshExp -> do Just freshExp -> do
let encodedJWT = F.defaultJWTFields { F.overwriteExp = freshExp } let encodedJWT@(EncodedJWT jwt) =
F.defaultJWTFields { F.overwriteExp = freshExp }
|> F.googleJWT |> F.googleJWT
jwtIsValid' encodedJWT `shouldReturn` Valid decodedJWT = jwt |> decode |> TestUtils.unsafeJust |> DecodedJWT
validateJWT' encodedJWT `shouldReturn` Valid decodedJWT

View file

@ -4,10 +4,11 @@
module Types where module Types where
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
import Data.Aeson import Data.Aeson
import Data.Text
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
data VerifyGoogleSignInRequest = VerifyGoogleSignInRequest data VerifyGoogleSignInRequest = VerifyGoogleSignInRequest
{ idToken :: String { idToken :: Text
} deriving (Eq, Show) } deriving (Eq, Show)
instance FromJSON VerifyGoogleSignInRequest where instance FromJSON VerifyGoogleSignInRequest where