From cc6e7af0e03049acf2d8737dc9e2e87331f28a41 Mon Sep 17 00:00:00 2001
From: dlorenc <dlorenc@google.com>
Date: Wed, 23 Dec 2020 12:31:45 -0600
Subject: [PATCH] Move the API context to a global middleware. (#77)

---
 go.mod                                        |  1 +
 pkg/api/api.go                                |  6 +-----
 pkg/api/tlog.go                               |  4 ++--
 .../restapi/configure_rekor_server.go         | 19 +++++++++----------
 4 files changed, 13 insertions(+), 17 deletions(-)

diff --git a/go.mod b/go.mod
index 7e5271f..81b30d9 100644
--- a/go.mod
+++ b/go.mod
@@ -22,6 +22,7 @@ require (
 	github.com/google/martian v2.1.0+incompatible
 	github.com/google/trillian v1.3.10
 	github.com/gorilla/handlers v1.5.1 // indirect
+	github.com/jessevdk/go-flags v1.4.0
 	github.com/kr/pretty v0.2.1 // indirect
 	github.com/magiconair/properties v1.8.4 // indirect
 	github.com/mitchellh/go-homedir v1.1.0
diff --git a/pkg/api/api.go b/pkg/api/api.go
index 092d9dd..c83c1f7 100644
--- a/pkg/api/api.go
+++ b/pkg/api/api.go
@@ -83,11 +83,7 @@ type ctxKeyRekorAPI int
 
 const rekorAPILookupKey ctxKeyRekorAPI = 0
 
-func AddAPIToContext(ctx context.Context) (context.Context, error) {
-	api, err := NewAPI(ctx)
-	if err != nil {
-		return nil, err
-	}
+func AddAPIToContext(ctx context.Context, api *API) (context.Context, error) {
 	return context.WithValue(ctx, rekorAPILookupKey, api), nil
 }
 
diff --git a/pkg/api/tlog.go b/pkg/api/tlog.go
index c68a0b4..395fbe5 100644
--- a/pkg/api/tlog.go
+++ b/pkg/api/tlog.go
@@ -35,7 +35,7 @@ import (
 
 func GetLogInfoHandler(params tlog.GetLogInfoParams) middleware.Responder {
 	httpReq := params.HTTPRequest
-	api, _ := NewAPI(httpReq.Context())
+	api := apiFromRequest(httpReq)
 
 	resp := api.client.getLatest(0)
 	if resp.status != codes.OK {
@@ -69,7 +69,7 @@ func GetLogProofHandler(params tlog.GetLogProofParams) middleware.Responder {
 	if *params.FirstSize > params.LastSize {
 		return handleRekorAPIError(params, http.StatusBadRequest, nil, fmt.Sprintf(firstSizeLessThanLastSize, *params.FirstSize, params.LastSize))
 	}
-	api, _ := NewAPI(httpReq.Context())
+	api := apiFromRequest(httpReq)
 
 	resp := api.client.getConsistencyProof(*params.FirstSize, params.LastSize)
 	if resp.status != codes.OK {
diff --git a/pkg/generated/restapi/configure_rekor_server.go b/pkg/generated/restapi/configure_rekor_server.go
index 636712a..e945e4a 100644
--- a/pkg/generated/restapi/configure_rekor_server.go
+++ b/pkg/generated/restapi/configure_rekor_server.go
@@ -18,6 +18,7 @@ limitations under the License.
 package restapi
 
 import (
+	"context"
 	"crypto/tls"
 	"fmt"
 	"net/http"
@@ -77,15 +78,6 @@ func configureAPI(api *operations.RekorServerAPI) http.Handler {
 
 	api.ServerShutdown = func() {}
 
-	//api object in context
-	api.AddMiddlewareFor("POST", "/api/v1/log/entries", addTrillianAPI)
-	api.AddMiddlewareFor("POST", "/api/v1/log/entries/retrieve", addTrillianAPI)
-	api.AddMiddlewareFor("GET", "/api/v1/log", addTrillianAPI)
-	api.AddMiddlewareFor("GET", "/api/v1/log/proof", addTrillianAPI)
-	api.AddMiddlewareFor("GET", "/api/v1/log/entries/{entryUUID}/proof", addTrillianAPI)
-	api.AddMiddlewareFor("GET", "/api/v1/log/entries", addTrillianAPI)
-	api.AddMiddlewareFor("GET", "/api/v1/log/entries/{entryUUID}", addTrillianAPI)
-
 	//not cacheable
 	api.AddMiddlewareFor("GET", "/api/v1/log", middleware.NoCache)
 	api.AddMiddlewareFor("GET", "/api/v1/log/proof", middleware.NoCache)
@@ -122,6 +114,9 @@ func setupGlobalMiddleware(handler http.Handler) http.Handler {
 	returnHandler := middleware.Recoverer(handler)
 	returnHandler = middleware.Logger(returnHandler)
 	returnHandler = middleware.Heartbeat("/ping")(returnHandler)
+
+	// add the Trillian API object in context for all endpoints
+	returnHandler = addTrillianAPI(handler)
 	return middleware.RequestID(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 		ctx := r.Context()
 		r = r.WithContext(log.WithRequestID(ctx, middleware.GetReqID(ctx)))
@@ -146,8 +141,12 @@ func cacheForever(handler http.Handler) http.Handler {
 }
 
 func addTrillianAPI(handler http.Handler) http.Handler {
+	api, err := pkgapi.NewAPI(context.Background())
+	if err != nil {
+		log.Logger.Panic(err)
+	}
 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		apiCtx, err := pkgapi.AddAPIToContext(r.Context())
+		apiCtx, err := pkgapi.AddAPIToContext(r.Context(), api)
 		if err != nil {
 			logAndServeError(w, r, fmt.Errorf("error adding trillian API object to request context: %v", err))
 		} else {
-- 
GitLab