From 399486960213a99d798fb4d2b37bd5046ce9a44e Mon Sep 17 00:00:00 2001 From: Bob Callaway <bobcallaway@users.noreply.github.com> Date: Thu, 24 Dec 2020 17:22:51 -0500 Subject: [PATCH] GRPC calls now use request-specific context (#78) After #77, we have one global GRPC channel for the entire process. This change causes each GRPC call to be made on the incoming HTTP request's context such that if an HTTP client cancels prematurely, we will handle the GRPC cleanup appropriately. Signed-off-by: Bob Callaway <bcallawa@redhat.com> --- pkg/api/api.go | 32 ++++++++---- pkg/api/entries.go | 49 ++++++++++++------- pkg/api/tlog.go | 20 +++++--- pkg/api/trillian_client.go | 9 +--- .../restapi/configure_rekor_server.go | 12 ++--- 5 files changed, 69 insertions(+), 53 deletions(-) diff --git a/pkg/api/api.go b/pkg/api/api.go index c83c1f7..2cdd5c3 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -19,7 +19,6 @@ package api import ( "context" "fmt" - "net/http" "time" "github.com/google/trillian" @@ -42,14 +41,16 @@ func dial(ctx context.Context, rpcServer string) (*grpc.ClientConn, error) { } type API struct { - client *TrillianClient - pubkey *keyspb.PublicKey + logClient trillian.TrillianLogClient + logID int64 + pubkey *keyspb.PublicKey } -func NewAPI(ctx context.Context) (*API, error) { +func NewAPI() (*API, error) { logRPCServer := fmt.Sprintf("%s:%d", viper.GetString("trillian_log_server.address"), viper.GetUint("trillian_log_server.port")) + ctx := context.Background() tConn, err := dial(ctx, logRPCServer) if err != nil { return nil, err @@ -74,8 +75,9 @@ func NewAPI(ctx context.Context) (*API, error) { } return &API{ - client: TrillianClientInstance(logClient, tLogID, ctx), - pubkey: t.PublicKey, + logClient: logClient, + logID: tLogID, + pubkey: t.PublicKey, }, nil } @@ -83,10 +85,20 @@ type ctxKeyRekorAPI int const rekorAPILookupKey ctxKeyRekorAPI = 0 -func AddAPIToContext(ctx context.Context, api *API) (context.Context, error) { - return context.WithValue(ctx, rekorAPILookupKey, api), nil +func AddAPIToContext(ctx context.Context, api *API) context.Context { + return context.WithValue(ctx, rekorAPILookupKey, api) } -func apiFromRequest(r *http.Request) *API { - return r.Context().Value(rekorAPILookupKey).(*API) +func NewTrillianClient(ctx context.Context) *TrillianClient { + api := ctx.Value(rekorAPILookupKey).(*API) + if api == nil { + return nil + } + + return &TrillianClient{ + client: api.logClient, + logID: api.logID, + context: ctx, + pubkey: api.pubkey, + } } diff --git a/pkg/api/entries.go b/pkg/api/entries.go index 7a4ae2c..a65657b 100644 --- a/pkg/api/entries.go +++ b/pkg/api/entries.go @@ -42,10 +42,12 @@ import ( ) func GetLogEntryByIndexHandler(params entries.GetLogEntryByIndexParams) middleware.Responder { - httpReq := params.HTTPRequest - api := apiFromRequest(httpReq) + tc := NewTrillianClient(params.HTTPRequest.Context()) + if tc == nil { + return handleRekorAPIError(params, http.StatusInternalServerError, errors.New("unable to get client from request context"), trillianCommunicationError) + } - resp := api.client.getLeafByIndex(params.LogIndex) + resp := tc.getLeafByIndex(params.LogIndex) switch resp.status { case codes.OK: case codes.NotFound, codes.OutOfRange: @@ -83,9 +85,12 @@ func CreateLogEntryHandler(params entries.CreateLogEntryParams) middleware.Respo return handleRekorAPIError(params, http.StatusInternalServerError, err, failedToGenerateCanonicalEntry) } - api := apiFromRequest(httpReq) + tc := NewTrillianClient(httpReq.Context()) + if tc == nil { + return handleRekorAPIError(params, http.StatusInternalServerError, errors.New("unable to get client from request context"), trillianCommunicationError) + } - resp := api.client.addLeaf(leaf) + resp := tc.addLeaf(leaf) switch resp.status { case codes.OK: case codes.AlreadyExists, codes.FailedPrecondition: @@ -110,12 +115,15 @@ func CreateLogEntryHandler(params entries.CreateLogEntryParams) middleware.Respo } func GetLogEntryByUUIDHandler(params entries.GetLogEntryByUUIDParams) middleware.Responder { - httpReq := params.HTTPRequest - api := apiFromRequest(httpReq) hashValue, _ := hex.DecodeString(params.EntryUUID) hashes := [][]byte{hashValue} - resp := api.client.getLeafByHash(hashes) // TODO: if this API is deprecated, we need to ask for inclusion proof and then use index in proof result to get leaf + tc := NewTrillianClient(params.HTTPRequest.Context()) + if tc == nil { + return handleRekorAPIError(params, http.StatusInternalServerError, errors.New("unable to get client from request context"), trillianCommunicationError) + } + + resp := tc.getLeafByHash(hashes) // TODO: if this API is deprecated, we need to ask for inclusion proof and then use index in proof result to get leaf switch resp.status { case codes.OK: case codes.NotFound: @@ -144,11 +152,13 @@ func GetLogEntryByUUIDHandler(params entries.GetLogEntryByUUIDParams) middleware } func GetLogEntryProofHandler(params entries.GetLogEntryProofParams) middleware.Responder { - httpReq := params.HTTPRequest - api := apiFromRequest(httpReq) hashValue, _ := hex.DecodeString(params.EntryUUID) + tc := NewTrillianClient(params.HTTPRequest.Context()) + if tc == nil { + return handleRekorAPIError(params, http.StatusInternalServerError, errors.New("unable to get client from request context"), trillianCommunicationError) + } - resp := api.client.getProofByHash(hashValue) + resp := tc.getProofByHash(hashValue) switch resp.status { case codes.OK: case codes.NotFound: @@ -159,7 +169,7 @@ func GetLogEntryProofHandler(params entries.GetLogEntryProofParams) middleware.R result := resp.getProofResult // validate result is signed with the key we're aware of - pub, err := x509.ParsePKIXPublicKey(api.pubkey.Der) + pub, err := x509.ParsePKIXPublicKey(tc.pubkey.Der) if err != nil { return handleRekorAPIError(params, http.StatusInternalServerError, err, "") } @@ -189,9 +199,12 @@ func GetLogEntryProofHandler(params entries.GetLogEntryProofParams) middleware.R } func SearchLogQueryHandler(params entries.SearchLogQueryParams) middleware.Responder { + httpReqCtx := params.HTTPRequest.Context() resultPayload := []models.LogEntry{} - httpReq := params.HTTPRequest - api := apiFromRequest(httpReq) + tc := NewTrillianClient(httpReqCtx) + if tc == nil { + return handleRekorAPIError(params, http.StatusInternalServerError, errors.New("unable to get client from request context"), trillianCommunicationError) + } //TODO: parallelize this into different goroutines to speed up search searchHashes := [][]byte{} @@ -211,12 +224,12 @@ func SearchLogQueryHandler(params entries.SearchLogQueryParams) middleware.Respo } if entry.HasExternalEntities() { - if err := entry.FetchExternalEntities(httpReq.Context()); err != nil { + if err := entry.FetchExternalEntities(httpReqCtx); err != nil { return handleRekorAPIError(params, http.StatusBadRequest, err, err.Error()) } } - leaf, err := entry.Canonicalize(httpReq.Context()) + leaf, err := entry.Canonicalize(httpReqCtx) if err != nil { return handleRekorAPIError(params, http.StatusInternalServerError, err, err.Error()) } @@ -225,7 +238,7 @@ func SearchLogQueryHandler(params entries.SearchLogQueryParams) middleware.Respo searchHashes = append(searchHashes, leafHash) } - resp := api.client.getLeafByHash(searchHashes) // TODO: if this API is deprecated, we need to ask for inclusion proof and then use index in proof result to get leaf + resp := tc.getLeafByHash(searchHashes) // TODO: if this API is deprecated, we need to ask for inclusion proof and then use index in proof result to get leaf switch resp.status { case codes.OK, codes.NotFound: default: @@ -246,7 +259,7 @@ func SearchLogQueryHandler(params entries.SearchLogQueryParams) middleware.Respo if len(params.Entry.LogIndexes) > 0 { leaves := []*trillian.LogLeaf{} for _, logIndex := range params.Entry.LogIndexes { - resp := api.client.getLeafByIndex(swag.Int64Value(logIndex)) + resp := tc.getLeafByIndex(swag.Int64Value(logIndex)) switch resp.status { case codes.OK, codes.NotFound: default: diff --git a/pkg/api/tlog.go b/pkg/api/tlog.go index 395fbe5..142e1ab 100644 --- a/pkg/api/tlog.go +++ b/pkg/api/tlog.go @@ -34,17 +34,19 @@ import ( ) func GetLogInfoHandler(params tlog.GetLogInfoParams) middleware.Responder { - httpReq := params.HTTPRequest - api := apiFromRequest(httpReq) + tc := NewTrillianClient(params.HTTPRequest.Context()) + if tc == nil { + return handleRekorAPIError(params, http.StatusInternalServerError, errors.New("unable to get client from request context"), trillianCommunicationError) + } - resp := api.client.getLatest(0) + resp := tc.getLatest(0) if resp.status != codes.OK { return handleRekorAPIError(params, http.StatusInternalServerError, fmt.Errorf("grpc error: %w", resp.err), trillianCommunicationError) } result := resp.getLatestResult // validate result is signed with the key we're aware of - pub, err := x509.ParsePKIXPublicKey(api.pubkey.Der) + pub, err := x509.ParsePKIXPublicKey(tc.pubkey.Der) if err != nil { return handleRekorAPIError(params, http.StatusInternalServerError, err, "") } @@ -65,20 +67,22 @@ func GetLogInfoHandler(params tlog.GetLogInfoParams) middleware.Responder { } func GetLogProofHandler(params tlog.GetLogProofParams) middleware.Responder { - httpReq := params.HTTPRequest if *params.FirstSize > params.LastSize { return handleRekorAPIError(params, http.StatusBadRequest, nil, fmt.Sprintf(firstSizeLessThanLastSize, *params.FirstSize, params.LastSize)) } - api := apiFromRequest(httpReq) + tc := NewTrillianClient(params.HTTPRequest.Context()) + if tc == nil { + return handleRekorAPIError(params, http.StatusInternalServerError, errors.New("unable to get client from request context"), trillianCommunicationError) + } - resp := api.client.getConsistencyProof(*params.FirstSize, params.LastSize) + resp := tc.getConsistencyProof(*params.FirstSize, params.LastSize) if resp.status != codes.OK { return handleRekorAPIError(params, http.StatusInternalServerError, fmt.Errorf("grpc error: %w", resp.err), trillianCommunicationError) } result := resp.getConsistencyProofResult // validate result is signed with the key we're aware of - pub, err := x509.ParsePKIXPublicKey(api.pubkey.Der) + pub, err := x509.ParsePKIXPublicKey(tc.pubkey.Der) if err != nil { return handleRekorAPIError(params, http.StatusInternalServerError, err, "") } diff --git a/pkg/api/trillian_client.go b/pkg/api/trillian_client.go index 283cc74..0363d5f 100644 --- a/pkg/api/trillian_client.go +++ b/pkg/api/trillian_client.go @@ -37,6 +37,7 @@ type TrillianClient struct { client trillian.TrillianLogClient logID int64 context context.Context + pubkey *keyspb.PublicKey } type Response struct { @@ -50,14 +51,6 @@ type Response struct { getConsistencyProofResult *trillian.GetConsistencyProofResponse } -func TrillianClientInstance(client trillian.TrillianLogClient, tLogID int64, ctx context.Context) *TrillianClient { - return &TrillianClient{ - client: client, - logID: tLogID, - context: ctx, - } -} - func (t *TrillianClient) root() (types.LogRootV1, error) { rqst := &trillian.GetLatestSignedLogRootRequest{ LogId: t.logID, diff --git a/pkg/generated/restapi/configure_rekor_server.go b/pkg/generated/restapi/configure_rekor_server.go index e945e4a..11eb0f9 100644 --- a/pkg/generated/restapi/configure_rekor_server.go +++ b/pkg/generated/restapi/configure_rekor_server.go @@ -18,9 +18,7 @@ limitations under the License. package restapi import ( - "context" "crypto/tls" - "fmt" "net/http" "github.com/go-chi/chi/middleware" @@ -141,17 +139,13 @@ func cacheForever(handler http.Handler) http.Handler { } func addTrillianAPI(handler http.Handler) http.Handler { - api, err := pkgapi.NewAPI(context.Background()) + api, err := pkgapi.NewAPI() if err != nil { log.Logger.Panic(err) } return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - apiCtx, err := pkgapi.AddAPIToContext(r.Context(), api) - if err != nil { - logAndServeError(w, r, fmt.Errorf("error adding trillian API object to request context: %v", err)) - } else { - handler.ServeHTTP(w, r.WithContext(apiCtx)) - } + apiCtx := pkgapi.AddAPIToContext(r.Context(), api) + handler.ServeHTTP(w, r.WithContext(apiCtx)) }) } -- GitLab