9
votes

How can I validate and get info from a JWT received from Amazon Cognito?

I have setup Google authentication in Cognito, and set the redirect uri to to hit API Gateway, I then receive a code which I POST to this endpoint:

https://docs.aws.amazon.com/cognito/latest/developerguide/token-endpoint.html

To receive the JWT token, in a RS256 format. I am now struggling to validate, and parse the token in Golang. I’ve tried to parse it using jwt-go, but it appears to support HMAC instead by default and read somewhere that they recommend using frontend validation instead. I tried a few other packages and had similar problems.

I came across this answer here: Go Language and Verify JWT but assume the code is outdated as that just says panic: unable to find key.

jwt.io can easily decode the key, and probably verify too. I’m not sure where the public/secret keys are as Amazon generated the token, but from what I understand I need to use a JWK URL to validate too? I’ve found a few AWS specific solutions, but they all seem to be hundreds of lines long. Surely it isn’t that complicated in Golang is it?

6

6 Answers

8
votes

Public keys for Amazon Cognito

As you already guessed, you'll need the public key in order to verify the JWT token.

https://docs.aws.amazon.com/cognito/latest/developerguide/amazon-cognito-user-pools-using-tokens-verifying-a-jwt.html#amazon-cognito-user-pools-using-tokens-step-2

Download and store the corresponding public JSON Web Key (JWK) for your user pool. It is available as part of a JSON Web Key Set (JWKS). You can locate it at https://cognito-idp.{region}.amazonaws.com/{userPoolId}/.well-known/jwks.json

Parse keys and verify token

That JSON file structure is documented in the web, so you could potentially parse that manually, generate the public keys, etc.

But it'd probably be easier to just use a library, for example this one: https://github.com/lestrrat-go/jwx

And then jwt-go to deal with the JWT part: https://github.com/dgrijalva/jwt-go

You can then:

  1. Download and parse the public keys JSON using the first library

     keySet, err := jwk.Fetch(THE_COGNITO_URL_DESCRIBED_ABOVE)
    
  2. When parsing the token with jwt-go, use the "kid" field from the JWT header to find the right key to use

     token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
     if _, ok := token.Method.(*jwt.SigningMethodRS256); !ok {
         return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
     }
     kid, ok := token.Header["kid"].(string)
     if !ok {
         return nil, errors.New("kid header not found")
     }
     keys := keySet.LookupKeyID(kid);
     if !ok {
         return nil, fmt.Errorf("key with specified kid is not present in jwks")
     }
     var publickey interface{}
     err = keys.Raw(&publickey)
     if err != nil {
         return nil, fmt.Errorf("could not parse pubkey")
     }
     return publickey, nil
    
4
votes

The type assertion in the code provided by eugenioy and Kevin Wydler did not work for me: *jwt.SigningMethodRS256 is not a type.

*jwt.SigningMethodRS256 was a type in the initial commit. From the second commit on (back in July 2014) it was abstracted and replaced by a global variable (see here).

This following code works for me:

func verify(tokenString string, keySet *jwk.Set) {
  tkn, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
    if token.Method.Alg() != "RSA256" { // jwa.RS256.String() works as well
      return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
    }
    kid, ok := token.Header["kid"].(string)
    if !ok {
      return nil, errors.New("kid header not found")
    }
    keys := keySet.LookupKeyID(kid)
    if len(keys) == 0 {
      return nil, fmt.Errorf("key %v not found", kid)
    }
    var raw interface{}
    return raw, keys[0].Raw(&raw)
  })
}

Using the following dependency versions:

github.com/dgrijalva/jwt-go/v4 v4.0.0-preview1
github.com/lestrrat-go/jwx v1.0.4
3
votes

eugenioy's answer stopped working for me because of this refactor. I ended up fixing with something like this

token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
    token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
    if _, ok := token.Method.(*jwt.SigningMethodRS256); !ok {
        return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
    }
    kid, ok := token.Header["kid"].(string)
    if !ok {
        return nil, errors.New("kid header not found")
    }
    keys := keySet.LookupKeyID(kid);
    if len(keys) == 0 {
         return nil, fmt.Errorf("key %v not found", kid)
    }
    // keys[0].Materialize() doesn't exist anymore
    var raw interface{}
    return raw, keys[0].Raw(&raw)
})

3
votes

This is what I did with only the latest (v1.0.8) github.com/lestrrat-go/jwx. Note that github.com/dgrijalva/jwt-go does not seem to be maintained anymore and people are forking it to make the updates they need.

package main

import (
    ...
    "github.com/lestrrat-go/jwx/jwk"
    "github.com/lestrrat-go/jwx/jwt"
)
    ...

    keyset, err := jwk.Fetch("https://cognito-idp." + region + ".amazonaws.com/" + userPoolID + "/.well-known/jwks.json")

    parsedToken, err := jwt.Parse(
        bytes.NewReader(token), //token is a []byte
        jwt.WithKeySet(keyset),
        jwt.WithValidate(true),
        jwt.WithIssuer(...),
        jwt.WithClaimValue("key", value),
    )

    //check err as usual
    //here you can call methods on the parsedToken to get the claim values
    ...

Token claim methods

0
votes

This is what worked for me:

import (
    "errors"
    "fmt"
    "github.com/dgrijalva/jwt-go"
    "github.com/gin-gonic/gin"
    "github.com/lestrrat-go/jwx/jwk"
    "net/http"
    "os"
)

func verifyToken(token *jwt.Token) (interface{}, error) {
    // make sure to replace this with your actual URL
    // https://docs.aws.amazon.com/cognito/latest/developerguide/amazon-cognito-user-pools-using-tokens-verifying-a-jwt.html#amazon-cognito-user-pools-using-tokens-step-2
    jwksURL := "COGNITO_JWKS_URL" 
    set, err := jwk.FetchHTTP(jwksURL)
    if err != nil {
        return nil, err
    }

    keyID, ok := token.Header["kid"].(string)
    if !ok {
        return nil, errors.New("expecting JWT header to have string kid")
    }

    keys := set.LookupKeyID(keyID)
    if len(keys) == 0 {
        return nil, fmt.Errorf("key %v not found", keyID)
    }

    if key := set.LookupKeyID(keyID); len(key) == 1 {
        return key[0].Materialize()
    }

    return nil, fmt.Errorf("unable to find key %q", keyID)
}

I am calling it like this (using AWS Lambda gin) in my case. If you are using a different way of managing requests, make sure to replace that with http.Request or any other framework that you might be using:

func JWTVerify() gin.HandlerFunc {
    return func(c *gin.Context) {
        tokenString := c.GetHeader("AccessToken")
        _, err := jwt.Parse(tokenString, verifyToken)
        if err != nil {
            c.AbortWithStatus(http.StatusUnauthorized)
        }
    }
}

This is my go.mod:

module MY_MODULE_NAME
go 1.12

require (
    github.com/aws/aws-lambda-go v1.20.0
    github.com/aws/aws-sdk-go v1.36.0
    github.com/awslabs/aws-lambda-go-api-proxy v0.9.0
    github.com/dgrijalva/jwt-go v3.2.0+incompatible
    github.com/gin-gonic/gin v1.6.3
    github.com/google/uuid v1.1.2
    github.com/lestrrat-go/jwx v0.9.2
    github.com/onsi/ginkgo v1.14.2 // indirect
    github.com/onsi/gomega v1.10.3 // indirect
    golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
)
0
votes

Here's an example using github.com/golang-jwt/jwt, (formally known as github.com/dgrijalva/jwt-go,) and a JWKs like the one AWS Cognito provides.

It'll refresh the AWS Cognito JWKs once every hour, refresh when a JWT signed with an unknown kid comes in, and have a global rate limit of 1 HTTP request to refresh the JWKs every 5 minutes.

package main

import (
    "fmt"
    "log"
    "time"

    "github.com/golang-jwt/jwt"

    "github.com/MicahParks/keyfunc"
)

func main() {

    // Get the JWKs URL from your AWS region and userPoolId.
    //
    // See the AWS docs here:
    // https://docs.aws.amazon.com/cognito/latest/developerguide/amazon-cognito-user-pools-using-tokens-verifying-a-jwt.html
    regionID := ""   // TODO Get the region ID for your AWS Cognito instance.
    userPoolID := "" // TODO Get the user pool ID of your AWS Cognito instance.
    jwksURL := fmt.Sprintf("https://cognito-idp.%s.amazonaws.com/%s/.well-known/jwks.json", regionID, userPoolID)

    // Create the keyfunc options. Use an error handler that logs. Refresh the JWKs when a JWT signed by an unknown KID
    // is found or at the specified interval. Rate limit these refreshes. Timeout the initial JWKs refresh request after
    // 10 seconds. This timeout is also used to create the initial context.Context for keyfunc.Get.
    refreshInterval := time.Hour
    refreshRateLimit := time.Minute * 5
    refreshTimeout := time.Second * 10
    refreshUnknownKID := true
    options := keyfunc.Options{
        RefreshErrorHandler: func(err error) {
            log.Printf("There was an error with the jwt.KeyFunc\nError:%s\n", err.Error())
        },
        RefreshInterval:   &refreshInterval,
        RefreshRateLimit:  &refreshRateLimit,
        RefreshTimeout:    &refreshTimeout,
        RefreshUnknownKID: &refreshUnknownKID,
    }

    // Create the JWKs from the resource at the given URL.
    jwks, err := keyfunc.Get(jwksURL, options)
    if err != nil {
        log.Fatalf("Failed to create JWKs from resource at the given URL.\nError:%s\n", err.Error())
    }

    // Get a JWT to parse.
    jwtB64 := "eyJraWQiOiJmNTVkOWE0ZSIsInR5cCI6IkpXVCIsImFsZyI6IlJTMjU2In0.eyJzdWIiOiJLZXNoYSIsImF1ZCI6IlRhc2h1YW4iLCJpc3MiOiJqd2tzLXNlcnZpY2UuYXBwc3BvdC5jb20iLCJleHAiOjE2MTkwMjUyMTEsImlhdCI6MTYxOTAyNTE3NywianRpIjoiMWY3MTgwNzAtZTBiOC00OGNmLTlmMDItMGE1M2ZiZWNhYWQwIn0.vetsI8W0c4Z-bs2YCVcPb9HsBm1BrMhxTBSQto1koG_lV-2nHwksz8vMuk7J7Q1sMa7WUkXxgthqu9RGVgtGO2xor6Ub0WBhZfIlFeaRGd6ZZKiapb-ASNK7EyRIeX20htRf9MzFGwpWjtrS5NIGvn1a7_x9WcXU9hlnkXaAWBTUJ2H73UbjDdVtlKFZGWM5VGANY4VG7gSMaJqCIKMxRPn2jnYbvPIYz81sjjbd-sc2-ePRjso7Rk6s382YdOm-lDUDl2APE-gqkLWdOJcj68fc6EBIociradX_ADytj-JYEI6v0-zI-8jSckYIGTUF5wjamcDfF5qyKpjsmdrZJA"

    // Parse the JWT.
    token, err := jwt.Parse(jwtB64, jwks.KeyFunc)
    if err != nil {
        log.Fatalf("Failed to parse the JWT.\nError:%s\n", err.Error())
    }

    // Check if the token is valid.
    if !token.Valid {
        log.Fatalf("The token is not valid.")
    }

    log.Println("The token is valid.")
}