Skip to content
Snippets Groups Projects
Unverified Commit cc6e7af0 authored by dlorenc's avatar dlorenc Committed by GitHub
Browse files

Move the API context to a global middleware. (#77)

parent faa3d3f1
No related branches found
No related tags found
No related merge requests found
...@@ -22,6 +22,7 @@ require ( ...@@ -22,6 +22,7 @@ require (
github.com/google/martian v2.1.0+incompatible github.com/google/martian v2.1.0+incompatible
github.com/google/trillian v1.3.10 github.com/google/trillian v1.3.10
github.com/gorilla/handlers v1.5.1 // indirect github.com/gorilla/handlers v1.5.1 // indirect
github.com/jessevdk/go-flags v1.4.0
github.com/kr/pretty v0.2.1 // indirect github.com/kr/pretty v0.2.1 // indirect
github.com/magiconair/properties v1.8.4 // indirect github.com/magiconair/properties v1.8.4 // indirect
github.com/mitchellh/go-homedir v1.1.0 github.com/mitchellh/go-homedir v1.1.0
......
...@@ -83,11 +83,7 @@ type ctxKeyRekorAPI int ...@@ -83,11 +83,7 @@ type ctxKeyRekorAPI int
const rekorAPILookupKey ctxKeyRekorAPI = 0 const rekorAPILookupKey ctxKeyRekorAPI = 0
func AddAPIToContext(ctx context.Context) (context.Context, error) { func AddAPIToContext(ctx context.Context, api *API) (context.Context, error) {
api, err := NewAPI(ctx)
if err != nil {
return nil, err
}
return context.WithValue(ctx, rekorAPILookupKey, api), nil return context.WithValue(ctx, rekorAPILookupKey, api), nil
} }
......
...@@ -35,7 +35,7 @@ import ( ...@@ -35,7 +35,7 @@ import (
func GetLogInfoHandler(params tlog.GetLogInfoParams) middleware.Responder { func GetLogInfoHandler(params tlog.GetLogInfoParams) middleware.Responder {
httpReq := params.HTTPRequest httpReq := params.HTTPRequest
api, _ := NewAPI(httpReq.Context()) api := apiFromRequest(httpReq)
resp := api.client.getLatest(0) resp := api.client.getLatest(0)
if resp.status != codes.OK { if resp.status != codes.OK {
...@@ -69,7 +69,7 @@ func GetLogProofHandler(params tlog.GetLogProofParams) middleware.Responder { ...@@ -69,7 +69,7 @@ func GetLogProofHandler(params tlog.GetLogProofParams) middleware.Responder {
if *params.FirstSize > params.LastSize { if *params.FirstSize > params.LastSize {
return handleRekorAPIError(params, http.StatusBadRequest, nil, fmt.Sprintf(firstSizeLessThanLastSize, *params.FirstSize, params.LastSize)) return handleRekorAPIError(params, http.StatusBadRequest, nil, fmt.Sprintf(firstSizeLessThanLastSize, *params.FirstSize, params.LastSize))
} }
api, _ := NewAPI(httpReq.Context()) api := apiFromRequest(httpReq)
resp := api.client.getConsistencyProof(*params.FirstSize, params.LastSize) resp := api.client.getConsistencyProof(*params.FirstSize, params.LastSize)
if resp.status != codes.OK { if resp.status != codes.OK {
......
...@@ -18,6 +18,7 @@ limitations under the License. ...@@ -18,6 +18,7 @@ limitations under the License.
package restapi package restapi
import ( import (
"context"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"net/http" "net/http"
...@@ -77,15 +78,6 @@ func configureAPI(api *operations.RekorServerAPI) http.Handler { ...@@ -77,15 +78,6 @@ func configureAPI(api *operations.RekorServerAPI) http.Handler {
api.ServerShutdown = func() {} api.ServerShutdown = func() {}
//api object in context
api.AddMiddlewareFor("POST", "/api/v1/log/entries", addTrillianAPI)
api.AddMiddlewareFor("POST", "/api/v1/log/entries/retrieve", addTrillianAPI)
api.AddMiddlewareFor("GET", "/api/v1/log", addTrillianAPI)
api.AddMiddlewareFor("GET", "/api/v1/log/proof", addTrillianAPI)
api.AddMiddlewareFor("GET", "/api/v1/log/entries/{entryUUID}/proof", addTrillianAPI)
api.AddMiddlewareFor("GET", "/api/v1/log/entries", addTrillianAPI)
api.AddMiddlewareFor("GET", "/api/v1/log/entries/{entryUUID}", addTrillianAPI)
//not cacheable //not cacheable
api.AddMiddlewareFor("GET", "/api/v1/log", middleware.NoCache) api.AddMiddlewareFor("GET", "/api/v1/log", middleware.NoCache)
api.AddMiddlewareFor("GET", "/api/v1/log/proof", middleware.NoCache) api.AddMiddlewareFor("GET", "/api/v1/log/proof", middleware.NoCache)
...@@ -122,6 +114,9 @@ func setupGlobalMiddleware(handler http.Handler) http.Handler { ...@@ -122,6 +114,9 @@ func setupGlobalMiddleware(handler http.Handler) http.Handler {
returnHandler := middleware.Recoverer(handler) returnHandler := middleware.Recoverer(handler)
returnHandler = middleware.Logger(returnHandler) returnHandler = middleware.Logger(returnHandler)
returnHandler = middleware.Heartbeat("/ping")(returnHandler) returnHandler = middleware.Heartbeat("/ping")(returnHandler)
// add the Trillian API object in context for all endpoints
returnHandler = addTrillianAPI(handler)
return middleware.RequestID(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return middleware.RequestID(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
r = r.WithContext(log.WithRequestID(ctx, middleware.GetReqID(ctx))) r = r.WithContext(log.WithRequestID(ctx, middleware.GetReqID(ctx)))
...@@ -146,8 +141,12 @@ func cacheForever(handler http.Handler) http.Handler { ...@@ -146,8 +141,12 @@ func cacheForever(handler http.Handler) http.Handler {
} }
func addTrillianAPI(handler http.Handler) http.Handler { func addTrillianAPI(handler http.Handler) http.Handler {
api, err := pkgapi.NewAPI(context.Background())
if err != nil {
log.Logger.Panic(err)
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
apiCtx, err := pkgapi.AddAPIToContext(r.Context()) apiCtx, err := pkgapi.AddAPIToContext(r.Context(), api)
if err != nil { if err != nil {
logAndServeError(w, r, fmt.Errorf("error adding trillian API object to request context: %v", err)) logAndServeError(w, r, fmt.Errorf("error adding trillian API object to request context: %v", err))
} else { } else {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment