From 399486960213a99d798fb4d2b37bd5046ce9a44e Mon Sep 17 00:00:00 2001
From: Bob Callaway <bobcallaway@users.noreply.github.com>
Date: Thu, 24 Dec 2020 17:22:51 -0500
Subject: [PATCH] GRPC calls now use request-specific context (#78)

After #77, we have one global GRPC channel for the entire process. This
change causes each GRPC call to be made on the incoming HTTP request's
context such that if an HTTP client cancels prematurely, we will handle
the GRPC cleanup appropriately.

Signed-off-by: Bob Callaway <bcallawa@redhat.com>
---
 pkg/api/api.go                                | 32 ++++++++----
 pkg/api/entries.go                            | 49 ++++++++++++-------
 pkg/api/tlog.go                               | 20 +++++---
 pkg/api/trillian_client.go                    |  9 +---
 .../restapi/configure_rekor_server.go         | 12 ++---
 5 files changed, 69 insertions(+), 53 deletions(-)

diff --git a/pkg/api/api.go b/pkg/api/api.go
index c83c1f7..2cdd5c3 100644
--- a/pkg/api/api.go
+++ b/pkg/api/api.go
@@ -19,7 +19,6 @@ package api
 import (
 	"context"
 	"fmt"
-	"net/http"
 	"time"
 
 	"github.com/google/trillian"
@@ -42,14 +41,16 @@ func dial(ctx context.Context, rpcServer string) (*grpc.ClientConn, error) {
 }
 
 type API struct {
-	client *TrillianClient
-	pubkey *keyspb.PublicKey
+	logClient trillian.TrillianLogClient
+	logID     int64
+	pubkey    *keyspb.PublicKey
 }
 
-func NewAPI(ctx context.Context) (*API, error) {
+func NewAPI() (*API, error) {
 	logRPCServer := fmt.Sprintf("%s:%d",
 		viper.GetString("trillian_log_server.address"),
 		viper.GetUint("trillian_log_server.port"))
+	ctx := context.Background()
 	tConn, err := dial(ctx, logRPCServer)
 	if err != nil {
 		return nil, err
@@ -74,8 +75,9 @@ func NewAPI(ctx context.Context) (*API, error) {
 	}
 
 	return &API{
-		client: TrillianClientInstance(logClient, tLogID, ctx),
-		pubkey: t.PublicKey,
+		logClient: logClient,
+		logID:     tLogID,
+		pubkey:    t.PublicKey,
 	}, nil
 }
 
@@ -83,10 +85,20 @@ type ctxKeyRekorAPI int
 
 const rekorAPILookupKey ctxKeyRekorAPI = 0
 
-func AddAPIToContext(ctx context.Context, api *API) (context.Context, error) {
-	return context.WithValue(ctx, rekorAPILookupKey, api), nil
+func AddAPIToContext(ctx context.Context, api *API) context.Context {
+	return context.WithValue(ctx, rekorAPILookupKey, api)
 }
 
-func apiFromRequest(r *http.Request) *API {
-	return r.Context().Value(rekorAPILookupKey).(*API)
+func NewTrillianClient(ctx context.Context) *TrillianClient {
+	api := ctx.Value(rekorAPILookupKey).(*API)
+	if api == nil {
+		return nil
+	}
+
+	return &TrillianClient{
+		client:  api.logClient,
+		logID:   api.logID,
+		context: ctx,
+		pubkey:  api.pubkey,
+	}
 }
diff --git a/pkg/api/entries.go b/pkg/api/entries.go
index 7a4ae2c..a65657b 100644
--- a/pkg/api/entries.go
+++ b/pkg/api/entries.go
@@ -42,10 +42,12 @@ import (
 )
 
 func GetLogEntryByIndexHandler(params entries.GetLogEntryByIndexParams) middleware.Responder {
-	httpReq := params.HTTPRequest
-	api := apiFromRequest(httpReq)
+	tc := NewTrillianClient(params.HTTPRequest.Context())
+	if tc == nil {
+		return handleRekorAPIError(params, http.StatusInternalServerError, errors.New("unable to get client from request context"), trillianCommunicationError)
+	}
 
-	resp := api.client.getLeafByIndex(params.LogIndex)
+	resp := tc.getLeafByIndex(params.LogIndex)
 	switch resp.status {
 	case codes.OK:
 	case codes.NotFound, codes.OutOfRange:
@@ -83,9 +85,12 @@ func CreateLogEntryHandler(params entries.CreateLogEntryParams) middleware.Respo
 		return handleRekorAPIError(params, http.StatusInternalServerError, err, failedToGenerateCanonicalEntry)
 	}
 
-	api := apiFromRequest(httpReq)
+	tc := NewTrillianClient(httpReq.Context())
+	if tc == nil {
+		return handleRekorAPIError(params, http.StatusInternalServerError, errors.New("unable to get client from request context"), trillianCommunicationError)
+	}
 
-	resp := api.client.addLeaf(leaf)
+	resp := tc.addLeaf(leaf)
 	switch resp.status {
 	case codes.OK:
 	case codes.AlreadyExists, codes.FailedPrecondition:
@@ -110,12 +115,15 @@ func CreateLogEntryHandler(params entries.CreateLogEntryParams) middleware.Respo
 }
 
 func GetLogEntryByUUIDHandler(params entries.GetLogEntryByUUIDParams) middleware.Responder {
-	httpReq := params.HTTPRequest
-	api := apiFromRequest(httpReq)
 	hashValue, _ := hex.DecodeString(params.EntryUUID)
 	hashes := [][]byte{hashValue}
 
-	resp := api.client.getLeafByHash(hashes) // TODO: if this API is deprecated, we need to ask for inclusion proof and then use index in proof result to get leaf
+	tc := NewTrillianClient(params.HTTPRequest.Context())
+	if tc == nil {
+		return handleRekorAPIError(params, http.StatusInternalServerError, errors.New("unable to get client from request context"), trillianCommunicationError)
+	}
+
+	resp := tc.getLeafByHash(hashes) // TODO: if this API is deprecated, we need to ask for inclusion proof and then use index in proof result to get leaf
 	switch resp.status {
 	case codes.OK:
 	case codes.NotFound:
@@ -144,11 +152,13 @@ func GetLogEntryByUUIDHandler(params entries.GetLogEntryByUUIDParams) middleware
 }
 
 func GetLogEntryProofHandler(params entries.GetLogEntryProofParams) middleware.Responder {
-	httpReq := params.HTTPRequest
-	api := apiFromRequest(httpReq)
 	hashValue, _ := hex.DecodeString(params.EntryUUID)
+	tc := NewTrillianClient(params.HTTPRequest.Context())
+	if tc == nil {
+		return handleRekorAPIError(params, http.StatusInternalServerError, errors.New("unable to get client from request context"), trillianCommunicationError)
+	}
 
-	resp := api.client.getProofByHash(hashValue)
+	resp := tc.getProofByHash(hashValue)
 	switch resp.status {
 	case codes.OK:
 	case codes.NotFound:
@@ -159,7 +169,7 @@ func GetLogEntryProofHandler(params entries.GetLogEntryProofParams) middleware.R
 	result := resp.getProofResult
 
 	// validate result is signed with the key we're aware of
-	pub, err := x509.ParsePKIXPublicKey(api.pubkey.Der)
+	pub, err := x509.ParsePKIXPublicKey(tc.pubkey.Der)
 	if err != nil {
 		return handleRekorAPIError(params, http.StatusInternalServerError, err, "")
 	}
@@ -189,9 +199,12 @@ func GetLogEntryProofHandler(params entries.GetLogEntryProofParams) middleware.R
 }
 
 func SearchLogQueryHandler(params entries.SearchLogQueryParams) middleware.Responder {
+	httpReqCtx := params.HTTPRequest.Context()
 	resultPayload := []models.LogEntry{}
-	httpReq := params.HTTPRequest
-	api := apiFromRequest(httpReq)
+	tc := NewTrillianClient(httpReqCtx)
+	if tc == nil {
+		return handleRekorAPIError(params, http.StatusInternalServerError, errors.New("unable to get client from request context"), trillianCommunicationError)
+	}
 
 	//TODO: parallelize this into different goroutines to speed up search
 	searchHashes := [][]byte{}
@@ -211,12 +224,12 @@ func SearchLogQueryHandler(params entries.SearchLogQueryParams) middleware.Respo
 			}
 
 			if entry.HasExternalEntities() {
-				if err := entry.FetchExternalEntities(httpReq.Context()); err != nil {
+				if err := entry.FetchExternalEntities(httpReqCtx); err != nil {
 					return handleRekorAPIError(params, http.StatusBadRequest, err, err.Error())
 				}
 			}
 
-			leaf, err := entry.Canonicalize(httpReq.Context())
+			leaf, err := entry.Canonicalize(httpReqCtx)
 			if err != nil {
 				return handleRekorAPIError(params, http.StatusInternalServerError, err, err.Error())
 			}
@@ -225,7 +238,7 @@ func SearchLogQueryHandler(params entries.SearchLogQueryParams) middleware.Respo
 			searchHashes = append(searchHashes, leafHash)
 		}
 
-		resp := api.client.getLeafByHash(searchHashes) // TODO: if this API is deprecated, we need to ask for inclusion proof and then use index in proof result to get leaf
+		resp := tc.getLeafByHash(searchHashes) // TODO: if this API is deprecated, we need to ask for inclusion proof and then use index in proof result to get leaf
 		switch resp.status {
 		case codes.OK, codes.NotFound:
 		default:
@@ -246,7 +259,7 @@ func SearchLogQueryHandler(params entries.SearchLogQueryParams) middleware.Respo
 	if len(params.Entry.LogIndexes) > 0 {
 		leaves := []*trillian.LogLeaf{}
 		for _, logIndex := range params.Entry.LogIndexes {
-			resp := api.client.getLeafByIndex(swag.Int64Value(logIndex))
+			resp := tc.getLeafByIndex(swag.Int64Value(logIndex))
 			switch resp.status {
 			case codes.OK, codes.NotFound:
 			default:
diff --git a/pkg/api/tlog.go b/pkg/api/tlog.go
index 395fbe5..142e1ab 100644
--- a/pkg/api/tlog.go
+++ b/pkg/api/tlog.go
@@ -34,17 +34,19 @@ import (
 )
 
 func GetLogInfoHandler(params tlog.GetLogInfoParams) middleware.Responder {
-	httpReq := params.HTTPRequest
-	api := apiFromRequest(httpReq)
+	tc := NewTrillianClient(params.HTTPRequest.Context())
+	if tc == nil {
+		return handleRekorAPIError(params, http.StatusInternalServerError, errors.New("unable to get client from request context"), trillianCommunicationError)
+	}
 
-	resp := api.client.getLatest(0)
+	resp := tc.getLatest(0)
 	if resp.status != codes.OK {
 		return handleRekorAPIError(params, http.StatusInternalServerError, fmt.Errorf("grpc error: %w", resp.err), trillianCommunicationError)
 	}
 	result := resp.getLatestResult
 
 	// validate result is signed with the key we're aware of
-	pub, err := x509.ParsePKIXPublicKey(api.pubkey.Der)
+	pub, err := x509.ParsePKIXPublicKey(tc.pubkey.Der)
 	if err != nil {
 		return handleRekorAPIError(params, http.StatusInternalServerError, err, "")
 	}
@@ -65,20 +67,22 @@ func GetLogInfoHandler(params tlog.GetLogInfoParams) middleware.Responder {
 }
 
 func GetLogProofHandler(params tlog.GetLogProofParams) middleware.Responder {
-	httpReq := params.HTTPRequest
 	if *params.FirstSize > params.LastSize {
 		return handleRekorAPIError(params, http.StatusBadRequest, nil, fmt.Sprintf(firstSizeLessThanLastSize, *params.FirstSize, params.LastSize))
 	}
-	api := apiFromRequest(httpReq)
+	tc := NewTrillianClient(params.HTTPRequest.Context())
+	if tc == nil {
+		return handleRekorAPIError(params, http.StatusInternalServerError, errors.New("unable to get client from request context"), trillianCommunicationError)
+	}
 
-	resp := api.client.getConsistencyProof(*params.FirstSize, params.LastSize)
+	resp := tc.getConsistencyProof(*params.FirstSize, params.LastSize)
 	if resp.status != codes.OK {
 		return handleRekorAPIError(params, http.StatusInternalServerError, fmt.Errorf("grpc error: %w", resp.err), trillianCommunicationError)
 	}
 	result := resp.getConsistencyProofResult
 
 	// validate result is signed with the key we're aware of
-	pub, err := x509.ParsePKIXPublicKey(api.pubkey.Der)
+	pub, err := x509.ParsePKIXPublicKey(tc.pubkey.Der)
 	if err != nil {
 		return handleRekorAPIError(params, http.StatusInternalServerError, err, "")
 	}
diff --git a/pkg/api/trillian_client.go b/pkg/api/trillian_client.go
index 283cc74..0363d5f 100644
--- a/pkg/api/trillian_client.go
+++ b/pkg/api/trillian_client.go
@@ -37,6 +37,7 @@ type TrillianClient struct {
 	client  trillian.TrillianLogClient
 	logID   int64
 	context context.Context
+	pubkey  *keyspb.PublicKey
 }
 
 type Response struct {
@@ -50,14 +51,6 @@ type Response struct {
 	getConsistencyProofResult *trillian.GetConsistencyProofResponse
 }
 
-func TrillianClientInstance(client trillian.TrillianLogClient, tLogID int64, ctx context.Context) *TrillianClient {
-	return &TrillianClient{
-		client:  client,
-		logID:   tLogID,
-		context: ctx,
-	}
-}
-
 func (t *TrillianClient) root() (types.LogRootV1, error) {
 	rqst := &trillian.GetLatestSignedLogRootRequest{
 		LogId: t.logID,
diff --git a/pkg/generated/restapi/configure_rekor_server.go b/pkg/generated/restapi/configure_rekor_server.go
index e945e4a..11eb0f9 100644
--- a/pkg/generated/restapi/configure_rekor_server.go
+++ b/pkg/generated/restapi/configure_rekor_server.go
@@ -18,9 +18,7 @@ limitations under the License.
 package restapi
 
 import (
-	"context"
 	"crypto/tls"
-	"fmt"
 	"net/http"
 
 	"github.com/go-chi/chi/middleware"
@@ -141,17 +139,13 @@ func cacheForever(handler http.Handler) http.Handler {
 }
 
 func addTrillianAPI(handler http.Handler) http.Handler {
-	api, err := pkgapi.NewAPI(context.Background())
+	api, err := pkgapi.NewAPI()
 	if err != nil {
 		log.Logger.Panic(err)
 	}
 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		apiCtx, err := pkgapi.AddAPIToContext(r.Context(), api)
-		if err != nil {
-			logAndServeError(w, r, fmt.Errorf("error adding trillian API object to request context: %v", err))
-		} else {
-			handler.ServeHTTP(w, r.WithContext(apiCtx))
-		}
+		apiCtx := pkgapi.AddAPIToContext(r.Context(), api)
+		handler.ServeHTTP(w, r.WithContext(apiCtx))
 	})
 }
 
-- 
GitLab