diff --git a/pkg/api/api.go b/pkg/api/api.go index c83c1f72b36f9df680f972a3f629e74afbd8a355..2cdd5c37a916e7d7c01cd312b38353133c4f82cd 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -19,7 +19,6 @@ package api import ( "context" "fmt" - "net/http" "time" "github.com/google/trillian" @@ -42,14 +41,16 @@ func dial(ctx context.Context, rpcServer string) (*grpc.ClientConn, error) { } type API struct { - client *TrillianClient - pubkey *keyspb.PublicKey + logClient trillian.TrillianLogClient + logID int64 + pubkey *keyspb.PublicKey } -func NewAPI(ctx context.Context) (*API, error) { +func NewAPI() (*API, error) { logRPCServer := fmt.Sprintf("%s:%d", viper.GetString("trillian_log_server.address"), viper.GetUint("trillian_log_server.port")) + ctx := context.Background() tConn, err := dial(ctx, logRPCServer) if err != nil { return nil, err @@ -74,8 +75,9 @@ func NewAPI(ctx context.Context) (*API, error) { } return &API{ - client: TrillianClientInstance(logClient, tLogID, ctx), - pubkey: t.PublicKey, + logClient: logClient, + logID: tLogID, + pubkey: t.PublicKey, }, nil } @@ -83,10 +85,20 @@ type ctxKeyRekorAPI int const rekorAPILookupKey ctxKeyRekorAPI = 0 -func AddAPIToContext(ctx context.Context, api *API) (context.Context, error) { - return context.WithValue(ctx, rekorAPILookupKey, api), nil +func AddAPIToContext(ctx context.Context, api *API) context.Context { + return context.WithValue(ctx, rekorAPILookupKey, api) } -func apiFromRequest(r *http.Request) *API { - return r.Context().Value(rekorAPILookupKey).(*API) +func NewTrillianClient(ctx context.Context) *TrillianClient { + api := ctx.Value(rekorAPILookupKey).(*API) + if api == nil { + return nil + } + + return &TrillianClient{ + client: api.logClient, + logID: api.logID, + context: ctx, + pubkey: api.pubkey, + } } diff --git a/pkg/api/entries.go b/pkg/api/entries.go index 7a4ae2c51045843f386c04037d73eec86e32faa4..a65657b3fc5066607df278033bfe49734ccf670c 100644 --- a/pkg/api/entries.go +++ b/pkg/api/entries.go @@ -42,10 +42,12 @@ import ( ) func GetLogEntryByIndexHandler(params entries.GetLogEntryByIndexParams) middleware.Responder { - httpReq := params.HTTPRequest - api := apiFromRequest(httpReq) + tc := NewTrillianClient(params.HTTPRequest.Context()) + if tc == nil { + return handleRekorAPIError(params, http.StatusInternalServerError, errors.New("unable to get client from request context"), trillianCommunicationError) + } - resp := api.client.getLeafByIndex(params.LogIndex) + resp := tc.getLeafByIndex(params.LogIndex) switch resp.status { case codes.OK: case codes.NotFound, codes.OutOfRange: @@ -83,9 +85,12 @@ func CreateLogEntryHandler(params entries.CreateLogEntryParams) middleware.Respo return handleRekorAPIError(params, http.StatusInternalServerError, err, failedToGenerateCanonicalEntry) } - api := apiFromRequest(httpReq) + tc := NewTrillianClient(httpReq.Context()) + if tc == nil { + return handleRekorAPIError(params, http.StatusInternalServerError, errors.New("unable to get client from request context"), trillianCommunicationError) + } - resp := api.client.addLeaf(leaf) + resp := tc.addLeaf(leaf) switch resp.status { case codes.OK: case codes.AlreadyExists, codes.FailedPrecondition: @@ -110,12 +115,15 @@ func CreateLogEntryHandler(params entries.CreateLogEntryParams) middleware.Respo } func GetLogEntryByUUIDHandler(params entries.GetLogEntryByUUIDParams) middleware.Responder { - httpReq := params.HTTPRequest - api := apiFromRequest(httpReq) hashValue, _ := hex.DecodeString(params.EntryUUID) hashes := [][]byte{hashValue} - resp := api.client.getLeafByHash(hashes) // TODO: if this API is deprecated, we need to ask for inclusion proof and then use index in proof result to get leaf + tc := NewTrillianClient(params.HTTPRequest.Context()) + if tc == nil { + return handleRekorAPIError(params, http.StatusInternalServerError, errors.New("unable to get client from request context"), trillianCommunicationError) + } + + resp := tc.getLeafByHash(hashes) // TODO: if this API is deprecated, we need to ask for inclusion proof and then use index in proof result to get leaf switch resp.status { case codes.OK: case codes.NotFound: @@ -144,11 +152,13 @@ func GetLogEntryByUUIDHandler(params entries.GetLogEntryByUUIDParams) middleware } func GetLogEntryProofHandler(params entries.GetLogEntryProofParams) middleware.Responder { - httpReq := params.HTTPRequest - api := apiFromRequest(httpReq) hashValue, _ := hex.DecodeString(params.EntryUUID) + tc := NewTrillianClient(params.HTTPRequest.Context()) + if tc == nil { + return handleRekorAPIError(params, http.StatusInternalServerError, errors.New("unable to get client from request context"), trillianCommunicationError) + } - resp := api.client.getProofByHash(hashValue) + resp := tc.getProofByHash(hashValue) switch resp.status { case codes.OK: case codes.NotFound: @@ -159,7 +169,7 @@ func GetLogEntryProofHandler(params entries.GetLogEntryProofParams) middleware.R result := resp.getProofResult // validate result is signed with the key we're aware of - pub, err := x509.ParsePKIXPublicKey(api.pubkey.Der) + pub, err := x509.ParsePKIXPublicKey(tc.pubkey.Der) if err != nil { return handleRekorAPIError(params, http.StatusInternalServerError, err, "") } @@ -189,9 +199,12 @@ func GetLogEntryProofHandler(params entries.GetLogEntryProofParams) middleware.R } func SearchLogQueryHandler(params entries.SearchLogQueryParams) middleware.Responder { + httpReqCtx := params.HTTPRequest.Context() resultPayload := []models.LogEntry{} - httpReq := params.HTTPRequest - api := apiFromRequest(httpReq) + tc := NewTrillianClient(httpReqCtx) + if tc == nil { + return handleRekorAPIError(params, http.StatusInternalServerError, errors.New("unable to get client from request context"), trillianCommunicationError) + } //TODO: parallelize this into different goroutines to speed up search searchHashes := [][]byte{} @@ -211,12 +224,12 @@ func SearchLogQueryHandler(params entries.SearchLogQueryParams) middleware.Respo } if entry.HasExternalEntities() { - if err := entry.FetchExternalEntities(httpReq.Context()); err != nil { + if err := entry.FetchExternalEntities(httpReqCtx); err != nil { return handleRekorAPIError(params, http.StatusBadRequest, err, err.Error()) } } - leaf, err := entry.Canonicalize(httpReq.Context()) + leaf, err := entry.Canonicalize(httpReqCtx) if err != nil { return handleRekorAPIError(params, http.StatusInternalServerError, err, err.Error()) } @@ -225,7 +238,7 @@ func SearchLogQueryHandler(params entries.SearchLogQueryParams) middleware.Respo searchHashes = append(searchHashes, leafHash) } - resp := api.client.getLeafByHash(searchHashes) // TODO: if this API is deprecated, we need to ask for inclusion proof and then use index in proof result to get leaf + resp := tc.getLeafByHash(searchHashes) // TODO: if this API is deprecated, we need to ask for inclusion proof and then use index in proof result to get leaf switch resp.status { case codes.OK, codes.NotFound: default: @@ -246,7 +259,7 @@ func SearchLogQueryHandler(params entries.SearchLogQueryParams) middleware.Respo if len(params.Entry.LogIndexes) > 0 { leaves := []*trillian.LogLeaf{} for _, logIndex := range params.Entry.LogIndexes { - resp := api.client.getLeafByIndex(swag.Int64Value(logIndex)) + resp := tc.getLeafByIndex(swag.Int64Value(logIndex)) switch resp.status { case codes.OK, codes.NotFound: default: diff --git a/pkg/api/tlog.go b/pkg/api/tlog.go index 395fbe58093854dd17e577666123905a68377b15..142e1abdbf94a22fe3a21beee42b7b6cdef85e73 100644 --- a/pkg/api/tlog.go +++ b/pkg/api/tlog.go @@ -34,17 +34,19 @@ import ( ) func GetLogInfoHandler(params tlog.GetLogInfoParams) middleware.Responder { - httpReq := params.HTTPRequest - api := apiFromRequest(httpReq) + tc := NewTrillianClient(params.HTTPRequest.Context()) + if tc == nil { + return handleRekorAPIError(params, http.StatusInternalServerError, errors.New("unable to get client from request context"), trillianCommunicationError) + } - resp := api.client.getLatest(0) + resp := tc.getLatest(0) if resp.status != codes.OK { return handleRekorAPIError(params, http.StatusInternalServerError, fmt.Errorf("grpc error: %w", resp.err), trillianCommunicationError) } result := resp.getLatestResult // validate result is signed with the key we're aware of - pub, err := x509.ParsePKIXPublicKey(api.pubkey.Der) + pub, err := x509.ParsePKIXPublicKey(tc.pubkey.Der) if err != nil { return handleRekorAPIError(params, http.StatusInternalServerError, err, "") } @@ -65,20 +67,22 @@ func GetLogInfoHandler(params tlog.GetLogInfoParams) middleware.Responder { } func GetLogProofHandler(params tlog.GetLogProofParams) middleware.Responder { - httpReq := params.HTTPRequest if *params.FirstSize > params.LastSize { return handleRekorAPIError(params, http.StatusBadRequest, nil, fmt.Sprintf(firstSizeLessThanLastSize, *params.FirstSize, params.LastSize)) } - api := apiFromRequest(httpReq) + tc := NewTrillianClient(params.HTTPRequest.Context()) + if tc == nil { + return handleRekorAPIError(params, http.StatusInternalServerError, errors.New("unable to get client from request context"), trillianCommunicationError) + } - resp := api.client.getConsistencyProof(*params.FirstSize, params.LastSize) + resp := tc.getConsistencyProof(*params.FirstSize, params.LastSize) if resp.status != codes.OK { return handleRekorAPIError(params, http.StatusInternalServerError, fmt.Errorf("grpc error: %w", resp.err), trillianCommunicationError) } result := resp.getConsistencyProofResult // validate result is signed with the key we're aware of - pub, err := x509.ParsePKIXPublicKey(api.pubkey.Der) + pub, err := x509.ParsePKIXPublicKey(tc.pubkey.Der) if err != nil { return handleRekorAPIError(params, http.StatusInternalServerError, err, "") } diff --git a/pkg/api/trillian_client.go b/pkg/api/trillian_client.go index 283cc74826701f772dd49ee53cfab2c122f388c8..0363d5f236e61a97b1282f4391c5d459cb9fd562 100644 --- a/pkg/api/trillian_client.go +++ b/pkg/api/trillian_client.go @@ -37,6 +37,7 @@ type TrillianClient struct { client trillian.TrillianLogClient logID int64 context context.Context + pubkey *keyspb.PublicKey } type Response struct { @@ -50,14 +51,6 @@ type Response struct { getConsistencyProofResult *trillian.GetConsistencyProofResponse } -func TrillianClientInstance(client trillian.TrillianLogClient, tLogID int64, ctx context.Context) *TrillianClient { - return &TrillianClient{ - client: client, - logID: tLogID, - context: ctx, - } -} - func (t *TrillianClient) root() (types.LogRootV1, error) { rqst := &trillian.GetLatestSignedLogRootRequest{ LogId: t.logID, diff --git a/pkg/generated/restapi/configure_rekor_server.go b/pkg/generated/restapi/configure_rekor_server.go index e945e4a3a92bb6b5a5a957e77326ec28245f996d..11eb0f96ed2de5daec5b1739798bc0ede7784000 100644 --- a/pkg/generated/restapi/configure_rekor_server.go +++ b/pkg/generated/restapi/configure_rekor_server.go @@ -18,9 +18,7 @@ limitations under the License. package restapi import ( - "context" "crypto/tls" - "fmt" "net/http" "github.com/go-chi/chi/middleware" @@ -141,17 +139,13 @@ func cacheForever(handler http.Handler) http.Handler { } func addTrillianAPI(handler http.Handler) http.Handler { - api, err := pkgapi.NewAPI(context.Background()) + api, err := pkgapi.NewAPI() if err != nil { log.Logger.Panic(err) } return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - apiCtx, err := pkgapi.AddAPIToContext(r.Context(), api) - if err != nil { - logAndServeError(w, r, fmt.Errorf("error adding trillian API object to request context: %v", err)) - } else { - handler.ServeHTTP(w, r.WithContext(apiCtx)) - } + apiCtx := pkgapi.AddAPIToContext(r.Context(), api) + handler.ServeHTTP(w, r.WithContext(apiCtx)) }) }