Skip to content
Snippets Groups Projects
Unverified Commit cc6e7af0 authored by dlorenc's avatar dlorenc Committed by GitHub
Browse files

Move the API context to a global middleware. (#77)

parent faa3d3f1
No related branches found
No related tags found
No related merge requests found
......@@ -22,6 +22,7 @@ require (
github.com/google/martian v2.1.0+incompatible
github.com/google/trillian v1.3.10
github.com/gorilla/handlers v1.5.1 // indirect
github.com/jessevdk/go-flags v1.4.0
github.com/kr/pretty v0.2.1 // indirect
github.com/magiconair/properties v1.8.4 // indirect
github.com/mitchellh/go-homedir v1.1.0
......
......@@ -83,11 +83,7 @@ type ctxKeyRekorAPI int
const rekorAPILookupKey ctxKeyRekorAPI = 0
func AddAPIToContext(ctx context.Context) (context.Context, error) {
api, err := NewAPI(ctx)
if err != nil {
return nil, err
}
func AddAPIToContext(ctx context.Context, api *API) (context.Context, error) {
return context.WithValue(ctx, rekorAPILookupKey, api), nil
}
......
......@@ -35,7 +35,7 @@ import (
func GetLogInfoHandler(params tlog.GetLogInfoParams) middleware.Responder {
httpReq := params.HTTPRequest
api, _ := NewAPI(httpReq.Context())
api := apiFromRequest(httpReq)
resp := api.client.getLatest(0)
if resp.status != codes.OK {
......@@ -69,7 +69,7 @@ func GetLogProofHandler(params tlog.GetLogProofParams) middleware.Responder {
if *params.FirstSize > params.LastSize {
return handleRekorAPIError(params, http.StatusBadRequest, nil, fmt.Sprintf(firstSizeLessThanLastSize, *params.FirstSize, params.LastSize))
}
api, _ := NewAPI(httpReq.Context())
api := apiFromRequest(httpReq)
resp := api.client.getConsistencyProof(*params.FirstSize, params.LastSize)
if resp.status != codes.OK {
......
......@@ -18,6 +18,7 @@ limitations under the License.
package restapi
import (
"context"
"crypto/tls"
"fmt"
"net/http"
......@@ -77,15 +78,6 @@ func configureAPI(api *operations.RekorServerAPI) http.Handler {
api.ServerShutdown = func() {}
//api object in context
api.AddMiddlewareFor("POST", "/api/v1/log/entries", addTrillianAPI)
api.AddMiddlewareFor("POST", "/api/v1/log/entries/retrieve", addTrillianAPI)
api.AddMiddlewareFor("GET", "/api/v1/log", addTrillianAPI)
api.AddMiddlewareFor("GET", "/api/v1/log/proof", addTrillianAPI)
api.AddMiddlewareFor("GET", "/api/v1/log/entries/{entryUUID}/proof", addTrillianAPI)
api.AddMiddlewareFor("GET", "/api/v1/log/entries", addTrillianAPI)
api.AddMiddlewareFor("GET", "/api/v1/log/entries/{entryUUID}", addTrillianAPI)
//not cacheable
api.AddMiddlewareFor("GET", "/api/v1/log", middleware.NoCache)
api.AddMiddlewareFor("GET", "/api/v1/log/proof", middleware.NoCache)
......@@ -122,6 +114,9 @@ func setupGlobalMiddleware(handler http.Handler) http.Handler {
returnHandler := middleware.Recoverer(handler)
returnHandler = middleware.Logger(returnHandler)
returnHandler = middleware.Heartbeat("/ping")(returnHandler)
// add the Trillian API object in context for all endpoints
returnHandler = addTrillianAPI(handler)
return middleware.RequestID(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
r = r.WithContext(log.WithRequestID(ctx, middleware.GetReqID(ctx)))
......@@ -146,8 +141,12 @@ func cacheForever(handler http.Handler) http.Handler {
}
func addTrillianAPI(handler http.Handler) http.Handler {
api, err := pkgapi.NewAPI(context.Background())
if err != nil {
log.Logger.Panic(err)
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
apiCtx, err := pkgapi.AddAPIToContext(r.Context())
apiCtx, err := pkgapi.AddAPIToContext(r.Context(), api)
if err != nil {
logAndServeError(w, r, fmt.Errorf("error adding trillian API object to request context: %v", err))
} else {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment