From cc6e7af0e03049acf2d8737dc9e2e87331f28a41 Mon Sep 17 00:00:00 2001 From: dlorenc <dlorenc@google.com> Date: Wed, 23 Dec 2020 12:31:45 -0600 Subject: [PATCH] Move the API context to a global middleware. (#77) --- go.mod | 1 + pkg/api/api.go | 6 +----- pkg/api/tlog.go | 4 ++-- .../restapi/configure_rekor_server.go | 19 +++++++++---------- 4 files changed, 13 insertions(+), 17 deletions(-) diff --git a/go.mod b/go.mod index 7e5271f..81b30d9 100644 --- a/go.mod +++ b/go.mod @@ -22,6 +22,7 @@ require ( github.com/google/martian v2.1.0+incompatible github.com/google/trillian v1.3.10 github.com/gorilla/handlers v1.5.1 // indirect + github.com/jessevdk/go-flags v1.4.0 github.com/kr/pretty v0.2.1 // indirect github.com/magiconair/properties v1.8.4 // indirect github.com/mitchellh/go-homedir v1.1.0 diff --git a/pkg/api/api.go b/pkg/api/api.go index 092d9dd..c83c1f7 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -83,11 +83,7 @@ type ctxKeyRekorAPI int const rekorAPILookupKey ctxKeyRekorAPI = 0 -func AddAPIToContext(ctx context.Context) (context.Context, error) { - api, err := NewAPI(ctx) - if err != nil { - return nil, err - } +func AddAPIToContext(ctx context.Context, api *API) (context.Context, error) { return context.WithValue(ctx, rekorAPILookupKey, api), nil } diff --git a/pkg/api/tlog.go b/pkg/api/tlog.go index c68a0b4..395fbe5 100644 --- a/pkg/api/tlog.go +++ b/pkg/api/tlog.go @@ -35,7 +35,7 @@ import ( func GetLogInfoHandler(params tlog.GetLogInfoParams) middleware.Responder { httpReq := params.HTTPRequest - api, _ := NewAPI(httpReq.Context()) + api := apiFromRequest(httpReq) resp := api.client.getLatest(0) if resp.status != codes.OK { @@ -69,7 +69,7 @@ func GetLogProofHandler(params tlog.GetLogProofParams) middleware.Responder { if *params.FirstSize > params.LastSize { return handleRekorAPIError(params, http.StatusBadRequest, nil, fmt.Sprintf(firstSizeLessThanLastSize, *params.FirstSize, params.LastSize)) } - api, _ := NewAPI(httpReq.Context()) + api := apiFromRequest(httpReq) resp := api.client.getConsistencyProof(*params.FirstSize, params.LastSize) if resp.status != codes.OK { diff --git a/pkg/generated/restapi/configure_rekor_server.go b/pkg/generated/restapi/configure_rekor_server.go index 636712a..e945e4a 100644 --- a/pkg/generated/restapi/configure_rekor_server.go +++ b/pkg/generated/restapi/configure_rekor_server.go @@ -18,6 +18,7 @@ limitations under the License. package restapi import ( + "context" "crypto/tls" "fmt" "net/http" @@ -77,15 +78,6 @@ func configureAPI(api *operations.RekorServerAPI) http.Handler { api.ServerShutdown = func() {} - //api object in context - api.AddMiddlewareFor("POST", "/api/v1/log/entries", addTrillianAPI) - api.AddMiddlewareFor("POST", "/api/v1/log/entries/retrieve", addTrillianAPI) - api.AddMiddlewareFor("GET", "/api/v1/log", addTrillianAPI) - api.AddMiddlewareFor("GET", "/api/v1/log/proof", addTrillianAPI) - api.AddMiddlewareFor("GET", "/api/v1/log/entries/{entryUUID}/proof", addTrillianAPI) - api.AddMiddlewareFor("GET", "/api/v1/log/entries", addTrillianAPI) - api.AddMiddlewareFor("GET", "/api/v1/log/entries/{entryUUID}", addTrillianAPI) - //not cacheable api.AddMiddlewareFor("GET", "/api/v1/log", middleware.NoCache) api.AddMiddlewareFor("GET", "/api/v1/log/proof", middleware.NoCache) @@ -122,6 +114,9 @@ func setupGlobalMiddleware(handler http.Handler) http.Handler { returnHandler := middleware.Recoverer(handler) returnHandler = middleware.Logger(returnHandler) returnHandler = middleware.Heartbeat("/ping")(returnHandler) + + // add the Trillian API object in context for all endpoints + returnHandler = addTrillianAPI(handler) return middleware.RequestID(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() r = r.WithContext(log.WithRequestID(ctx, middleware.GetReqID(ctx))) @@ -146,8 +141,12 @@ func cacheForever(handler http.Handler) http.Handler { } func addTrillianAPI(handler http.Handler) http.Handler { + api, err := pkgapi.NewAPI(context.Background()) + if err != nil { + log.Logger.Panic(err) + } return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - apiCtx, err := pkgapi.AddAPIToContext(r.Context()) + apiCtx, err := pkgapi.AddAPIToContext(r.Context(), api) if err != nil { logAndServeError(w, r, fmt.Errorf("error adding trillian API object to request context: %v", err)) } else { -- GitLab