Skip to content
Snippets Groups Projects
Unverified Commit 39948696 authored by Bob Callaway's avatar Bob Callaway Committed by GitHub
Browse files

GRPC calls now use request-specific context (#78)


After #77, we have one global GRPC channel for the entire process. This
change causes each GRPC call to be made on the incoming HTTP request's
context such that if an HTTP client cancels prematurely, we will handle
the GRPC cleanup appropriately.

Signed-off-by: default avatarBob Callaway <bcallawa@redhat.com>
parent cc6e7af0
No related branches found
No related tags found
No related merge requests found
......@@ -19,7 +19,6 @@ package api
import (
"context"
"fmt"
"net/http"
"time"
"github.com/google/trillian"
......@@ -42,14 +41,16 @@ func dial(ctx context.Context, rpcServer string) (*grpc.ClientConn, error) {
}
type API struct {
client *TrillianClient
pubkey *keyspb.PublicKey
logClient trillian.TrillianLogClient
logID int64
pubkey *keyspb.PublicKey
}
func NewAPI(ctx context.Context) (*API, error) {
func NewAPI() (*API, error) {
logRPCServer := fmt.Sprintf("%s:%d",
viper.GetString("trillian_log_server.address"),
viper.GetUint("trillian_log_server.port"))
ctx := context.Background()
tConn, err := dial(ctx, logRPCServer)
if err != nil {
return nil, err
......@@ -74,8 +75,9 @@ func NewAPI(ctx context.Context) (*API, error) {
}
return &API{
client: TrillianClientInstance(logClient, tLogID, ctx),
pubkey: t.PublicKey,
logClient: logClient,
logID: tLogID,
pubkey: t.PublicKey,
}, nil
}
......@@ -83,10 +85,20 @@ type ctxKeyRekorAPI int
const rekorAPILookupKey ctxKeyRekorAPI = 0
func AddAPIToContext(ctx context.Context, api *API) (context.Context, error) {
return context.WithValue(ctx, rekorAPILookupKey, api), nil
func AddAPIToContext(ctx context.Context, api *API) context.Context {
return context.WithValue(ctx, rekorAPILookupKey, api)
}
func apiFromRequest(r *http.Request) *API {
return r.Context().Value(rekorAPILookupKey).(*API)
func NewTrillianClient(ctx context.Context) *TrillianClient {
api := ctx.Value(rekorAPILookupKey).(*API)
if api == nil {
return nil
}
return &TrillianClient{
client: api.logClient,
logID: api.logID,
context: ctx,
pubkey: api.pubkey,
}
}
......@@ -42,10 +42,12 @@ import (
)
func GetLogEntryByIndexHandler(params entries.GetLogEntryByIndexParams) middleware.Responder {
httpReq := params.HTTPRequest
api := apiFromRequest(httpReq)
tc := NewTrillianClient(params.HTTPRequest.Context())
if tc == nil {
return handleRekorAPIError(params, http.StatusInternalServerError, errors.New("unable to get client from request context"), trillianCommunicationError)
}
resp := api.client.getLeafByIndex(params.LogIndex)
resp := tc.getLeafByIndex(params.LogIndex)
switch resp.status {
case codes.OK:
case codes.NotFound, codes.OutOfRange:
......@@ -83,9 +85,12 @@ func CreateLogEntryHandler(params entries.CreateLogEntryParams) middleware.Respo
return handleRekorAPIError(params, http.StatusInternalServerError, err, failedToGenerateCanonicalEntry)
}
api := apiFromRequest(httpReq)
tc := NewTrillianClient(httpReq.Context())
if tc == nil {
return handleRekorAPIError(params, http.StatusInternalServerError, errors.New("unable to get client from request context"), trillianCommunicationError)
}
resp := api.client.addLeaf(leaf)
resp := tc.addLeaf(leaf)
switch resp.status {
case codes.OK:
case codes.AlreadyExists, codes.FailedPrecondition:
......@@ -110,12 +115,15 @@ func CreateLogEntryHandler(params entries.CreateLogEntryParams) middleware.Respo
}
func GetLogEntryByUUIDHandler(params entries.GetLogEntryByUUIDParams) middleware.Responder {
httpReq := params.HTTPRequest
api := apiFromRequest(httpReq)
hashValue, _ := hex.DecodeString(params.EntryUUID)
hashes := [][]byte{hashValue}
resp := api.client.getLeafByHash(hashes) // TODO: if this API is deprecated, we need to ask for inclusion proof and then use index in proof result to get leaf
tc := NewTrillianClient(params.HTTPRequest.Context())
if tc == nil {
return handleRekorAPIError(params, http.StatusInternalServerError, errors.New("unable to get client from request context"), trillianCommunicationError)
}
resp := tc.getLeafByHash(hashes) // TODO: if this API is deprecated, we need to ask for inclusion proof and then use index in proof result to get leaf
switch resp.status {
case codes.OK:
case codes.NotFound:
......@@ -144,11 +152,13 @@ func GetLogEntryByUUIDHandler(params entries.GetLogEntryByUUIDParams) middleware
}
func GetLogEntryProofHandler(params entries.GetLogEntryProofParams) middleware.Responder {
httpReq := params.HTTPRequest
api := apiFromRequest(httpReq)
hashValue, _ := hex.DecodeString(params.EntryUUID)
tc := NewTrillianClient(params.HTTPRequest.Context())
if tc == nil {
return handleRekorAPIError(params, http.StatusInternalServerError, errors.New("unable to get client from request context"), trillianCommunicationError)
}
resp := api.client.getProofByHash(hashValue)
resp := tc.getProofByHash(hashValue)
switch resp.status {
case codes.OK:
case codes.NotFound:
......@@ -159,7 +169,7 @@ func GetLogEntryProofHandler(params entries.GetLogEntryProofParams) middleware.R
result := resp.getProofResult
// validate result is signed with the key we're aware of
pub, err := x509.ParsePKIXPublicKey(api.pubkey.Der)
pub, err := x509.ParsePKIXPublicKey(tc.pubkey.Der)
if err != nil {
return handleRekorAPIError(params, http.StatusInternalServerError, err, "")
}
......@@ -189,9 +199,12 @@ func GetLogEntryProofHandler(params entries.GetLogEntryProofParams) middleware.R
}
func SearchLogQueryHandler(params entries.SearchLogQueryParams) middleware.Responder {
httpReqCtx := params.HTTPRequest.Context()
resultPayload := []models.LogEntry{}
httpReq := params.HTTPRequest
api := apiFromRequest(httpReq)
tc := NewTrillianClient(httpReqCtx)
if tc == nil {
return handleRekorAPIError(params, http.StatusInternalServerError, errors.New("unable to get client from request context"), trillianCommunicationError)
}
//TODO: parallelize this into different goroutines to speed up search
searchHashes := [][]byte{}
......@@ -211,12 +224,12 @@ func SearchLogQueryHandler(params entries.SearchLogQueryParams) middleware.Respo
}
if entry.HasExternalEntities() {
if err := entry.FetchExternalEntities(httpReq.Context()); err != nil {
if err := entry.FetchExternalEntities(httpReqCtx); err != nil {
return handleRekorAPIError(params, http.StatusBadRequest, err, err.Error())
}
}
leaf, err := entry.Canonicalize(httpReq.Context())
leaf, err := entry.Canonicalize(httpReqCtx)
if err != nil {
return handleRekorAPIError(params, http.StatusInternalServerError, err, err.Error())
}
......@@ -225,7 +238,7 @@ func SearchLogQueryHandler(params entries.SearchLogQueryParams) middleware.Respo
searchHashes = append(searchHashes, leafHash)
}
resp := api.client.getLeafByHash(searchHashes) // TODO: if this API is deprecated, we need to ask for inclusion proof and then use index in proof result to get leaf
resp := tc.getLeafByHash(searchHashes) // TODO: if this API is deprecated, we need to ask for inclusion proof and then use index in proof result to get leaf
switch resp.status {
case codes.OK, codes.NotFound:
default:
......@@ -246,7 +259,7 @@ func SearchLogQueryHandler(params entries.SearchLogQueryParams) middleware.Respo
if len(params.Entry.LogIndexes) > 0 {
leaves := []*trillian.LogLeaf{}
for _, logIndex := range params.Entry.LogIndexes {
resp := api.client.getLeafByIndex(swag.Int64Value(logIndex))
resp := tc.getLeafByIndex(swag.Int64Value(logIndex))
switch resp.status {
case codes.OK, codes.NotFound:
default:
......
......@@ -34,17 +34,19 @@ import (
)
func GetLogInfoHandler(params tlog.GetLogInfoParams) middleware.Responder {
httpReq := params.HTTPRequest
api := apiFromRequest(httpReq)
tc := NewTrillianClient(params.HTTPRequest.Context())
if tc == nil {
return handleRekorAPIError(params, http.StatusInternalServerError, errors.New("unable to get client from request context"), trillianCommunicationError)
}
resp := api.client.getLatest(0)
resp := tc.getLatest(0)
if resp.status != codes.OK {
return handleRekorAPIError(params, http.StatusInternalServerError, fmt.Errorf("grpc error: %w", resp.err), trillianCommunicationError)
}
result := resp.getLatestResult
// validate result is signed with the key we're aware of
pub, err := x509.ParsePKIXPublicKey(api.pubkey.Der)
pub, err := x509.ParsePKIXPublicKey(tc.pubkey.Der)
if err != nil {
return handleRekorAPIError(params, http.StatusInternalServerError, err, "")
}
......@@ -65,20 +67,22 @@ func GetLogInfoHandler(params tlog.GetLogInfoParams) middleware.Responder {
}
func GetLogProofHandler(params tlog.GetLogProofParams) middleware.Responder {
httpReq := params.HTTPRequest
if *params.FirstSize > params.LastSize {
return handleRekorAPIError(params, http.StatusBadRequest, nil, fmt.Sprintf(firstSizeLessThanLastSize, *params.FirstSize, params.LastSize))
}
api := apiFromRequest(httpReq)
tc := NewTrillianClient(params.HTTPRequest.Context())
if tc == nil {
return handleRekorAPIError(params, http.StatusInternalServerError, errors.New("unable to get client from request context"), trillianCommunicationError)
}
resp := api.client.getConsistencyProof(*params.FirstSize, params.LastSize)
resp := tc.getConsistencyProof(*params.FirstSize, params.LastSize)
if resp.status != codes.OK {
return handleRekorAPIError(params, http.StatusInternalServerError, fmt.Errorf("grpc error: %w", resp.err), trillianCommunicationError)
}
result := resp.getConsistencyProofResult
// validate result is signed with the key we're aware of
pub, err := x509.ParsePKIXPublicKey(api.pubkey.Der)
pub, err := x509.ParsePKIXPublicKey(tc.pubkey.Der)
if err != nil {
return handleRekorAPIError(params, http.StatusInternalServerError, err, "")
}
......
......@@ -37,6 +37,7 @@ type TrillianClient struct {
client trillian.TrillianLogClient
logID int64
context context.Context
pubkey *keyspb.PublicKey
}
type Response struct {
......@@ -50,14 +51,6 @@ type Response struct {
getConsistencyProofResult *trillian.GetConsistencyProofResponse
}
func TrillianClientInstance(client trillian.TrillianLogClient, tLogID int64, ctx context.Context) *TrillianClient {
return &TrillianClient{
client: client,
logID: tLogID,
context: ctx,
}
}
func (t *TrillianClient) root() (types.LogRootV1, error) {
rqst := &trillian.GetLatestSignedLogRootRequest{
LogId: t.logID,
......
......@@ -18,9 +18,7 @@ limitations under the License.
package restapi
import (
"context"
"crypto/tls"
"fmt"
"net/http"
"github.com/go-chi/chi/middleware"
......@@ -141,17 +139,13 @@ func cacheForever(handler http.Handler) http.Handler {
}
func addTrillianAPI(handler http.Handler) http.Handler {
api, err := pkgapi.NewAPI(context.Background())
api, err := pkgapi.NewAPI()
if err != nil {
log.Logger.Panic(err)
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
apiCtx, err := pkgapi.AddAPIToContext(r.Context(), api)
if err != nil {
logAndServeError(w, r, fmt.Errorf("error adding trillian API object to request context: %v", err))
} else {
handler.ServeHTTP(w, r.WithContext(apiCtx))
}
apiCtx := pkgapi.AddAPIToContext(r.Context(), api)
handler.ServeHTTP(w, r.WithContext(apiCtx))
})
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment