From 73618b82f2cf731e4cb216fbfaf571b231f4faa6 Mon Sep 17 00:00:00 2001
From: dlorenc <dlorenc@google.com>
Date: Tue, 28 Dec 2021 13:44:43 -0600
Subject: [PATCH] Refactor the RPM type to remove more intermediate state.
 (#566)

This is required to make IndexKeys() work on stored types.

Signed-off-by: Dan Lorenc <lorenc.d@gmail.com>
---
 pkg/types/rpm/v0.0.1/entry.go      | 60 +++++++++++++++---------------
 pkg/types/rpm/v0.0.1/entry_test.go |  2 +-
 2 files changed, 32 insertions(+), 30 deletions(-)

diff --git a/pkg/types/rpm/v0.0.1/entry.go b/pkg/types/rpm/v0.0.1/entry.go
index 72e6695..13fb36d 100644
--- a/pkg/types/rpm/v0.0.1/entry.go
+++ b/pkg/types/rpm/v0.0.1/entry.go
@@ -16,6 +16,7 @@
 package rpm
 
 import (
+	"bytes"
 	"context"
 	"crypto/sha256"
 	"encoding/hex"
@@ -36,7 +37,6 @@ import (
 
 	"github.com/sigstore/rekor/pkg/generated/models"
 	"github.com/sigstore/rekor/pkg/log"
-	"github.com/sigstore/rekor/pkg/pki"
 	"github.com/sigstore/rekor/pkg/pki/pgp"
 	"github.com/sigstore/rekor/pkg/types"
 	"github.com/sigstore/rekor/pkg/types/rpm"
@@ -55,8 +55,6 @@ func init() {
 
 type V001Entry struct {
 	RPMModel models.RpmV001Schema
-	keyObj   pki.PublicKey
-	rpmObj   *rpmutils.PackageFile
 }
 
 func (v V001Entry) APIVersion() string {
@@ -70,14 +68,18 @@ func NewEntry() types.EntryImpl {
 func (v V001Entry) IndexKeys() ([]string, error) {
 	var result []string
 
-	key, err := v.keyObj.CanonicalValue()
+	keyObj, err := pgp.NewPublicKey(bytes.NewReader(v.RPMModel.PublicKey.Content))
+	if err != nil {
+		return nil, err
+	}
+	key, err := keyObj.CanonicalValue()
 	if err != nil {
 		return nil, err
 	}
 	keyHash := sha256.Sum256(key)
 	result = append(result, strings.ToLower(hex.EncodeToString(keyHash[:])))
 
-	result = append(result, v.keyObj.EmailAddresses()...)
+	result = append(result, keyObj.EmailAddresses()...)
 
 	if v.RPMModel.Package.Hash != nil {
 		hashKey := strings.ToLower(fmt.Sprintf("%s:%s", *v.RPMModel.Package.Hash.Algorithm, *v.RPMModel.Package.Hash.Value))
@@ -105,7 +107,7 @@ func (v *V001Entry) Unmarshal(pe models.ProposedEntry) error {
 	return v.validate()
 }
 
-func (v V001Entry) HasExternalEntities() bool {
+func (v V001Entry) hasExternalEntities() bool {
 
 	if v.RPMModel.Package != nil && v.RPMModel.Package.URL.String() != "" {
 		return true
@@ -116,10 +118,10 @@ func (v V001Entry) HasExternalEntities() bool {
 	return false
 }
 
-func (v *V001Entry) FetchExternalEntities(ctx context.Context) error {
+func (v *V001Entry) fetchExternalEntities(ctx context.Context) (*pgp.PublicKey, *rpmutils.PackageFile, error) {
 
 	if err := v.validate(); err != nil {
-		return types.ValidationError(err)
+		return nil, nil, types.ValidationError(err)
 	}
 
 	g, ctx := errgroup.WithContext(ctx)
@@ -179,6 +181,7 @@ func (v *V001Entry) FetchExternalEntities(ctx context.Context) error {
 		}
 	})
 
+	var keyObj *pgp.PublicKey
 	g.Go(func() error {
 		keyReadCloser, err := util.FileOrURLReadCloser(ctx, v.RPMModel.PublicKey.URL.String(),
 			v.RPMModel.PublicKey.Content)
@@ -187,12 +190,12 @@ func (v *V001Entry) FetchExternalEntities(ctx context.Context) error {
 		}
 		defer keyReadCloser.Close()
 
-		v.keyObj, err = pgp.NewPublicKey(keyReadCloser)
+		keyObj, err = pgp.NewPublicKey(keyReadCloser)
 		if err != nil {
 			return closePipesOnError(types.ValidationError(err))
 		}
 
-		keyring, err := v.keyObj.(*pgp.PublicKey).KeyRing()
+		keyring, err := keyObj.KeyRing()
 		if err != nil {
 			return closePipesOnError(types.ValidationError(err))
 		}
@@ -209,10 +212,11 @@ func (v *V001Entry) FetchExternalEntities(ctx context.Context) error {
 		}
 	})
 
+	var rpmObj *rpmutils.PackageFile
 	g.Go(func() error {
 
 		var err error
-		v.rpmObj, err = rpmutils.ReadPackageFile(rpmR)
+		rpmObj, err = rpmutils.ReadPackageFile(rpmR)
 		if err != nil {
 			return closePipesOnError(types.ValidationError(err))
 		}
@@ -232,7 +236,7 @@ func (v *V001Entry) FetchExternalEntities(ctx context.Context) error {
 	computedSHA := <-hashResult
 
 	if err := g.Wait(); err != nil {
-		return err
+		return nil, nil, err
 	}
 
 	// if we get here, all goroutines succeeded without error
@@ -242,23 +246,21 @@ func (v *V001Entry) FetchExternalEntities(ctx context.Context) error {
 		v.RPMModel.Package.Hash.Value = swag.String(computedSHA)
 	}
 
-	return nil
+	return keyObj, rpmObj, nil
 }
 
 func (v *V001Entry) Canonicalize(ctx context.Context) ([]byte, error) {
-	if err := v.FetchExternalEntities(ctx); err != nil {
+	keyObj, rpmObj, err := v.fetchExternalEntities(ctx)
+	if err != nil {
 		return nil, err
 	}
-	if v.keyObj == nil {
-		return nil, errors.New("key object not initialized before canonicalization")
-	}
 
 	canonicalEntry := models.RpmV001Schema{}
 
-	var err error
 	// need to canonicalize key content
+
 	canonicalEntry.PublicKey = &models.RpmV001SchemaPublicKey{}
-	canonicalEntry.PublicKey.Content, err = v.keyObj.CanonicalValue()
+	canonicalEntry.PublicKey.Content, err = keyObj.CanonicalValue()
 	if err != nil {
 		return nil, err
 	}
@@ -271,18 +273,18 @@ func (v *V001Entry) Canonicalize(ctx context.Context) ([]byte, error) {
 
 	// set NEVRA headers
 	canonicalEntry.Package.Headers = make(map[string]string)
-	canonicalEntry.Package.Headers["Name"] = v.rpmObj.Name()
-	canonicalEntry.Package.Headers["Epoch"] = strconv.Itoa(v.rpmObj.Epoch())
-	canonicalEntry.Package.Headers["Version"] = v.rpmObj.Version()
-	canonicalEntry.Package.Headers["Release"] = v.rpmObj.Release()
-	canonicalEntry.Package.Headers["Architecture"] = v.rpmObj.Architecture()
-	if md5sum := v.rpmObj.GetBytes(0, 1004); md5sum != nil {
+	canonicalEntry.Package.Headers["Name"] = rpmObj.Name()
+	canonicalEntry.Package.Headers["Epoch"] = strconv.Itoa(rpmObj.Epoch())
+	canonicalEntry.Package.Headers["Version"] = rpmObj.Version()
+	canonicalEntry.Package.Headers["Release"] = rpmObj.Release()
+	canonicalEntry.Package.Headers["Architecture"] = rpmObj.Architecture()
+	if md5sum := rpmObj.GetBytes(0, 1004); md5sum != nil {
 		canonicalEntry.Package.Headers["RPMSIGTAG_MD5"] = hex.EncodeToString(md5sum)
 	}
-	if sha1sum := v.rpmObj.GetBytes(0, 1012); sha1sum != nil {
+	if sha1sum := rpmObj.GetBytes(0, 1012); sha1sum != nil {
 		canonicalEntry.Package.Headers["RPMSIGTAG_SHA1"] = hex.EncodeToString(sha1sum)
 	}
-	if sha256sum := v.rpmObj.GetBytes(0, 1016); sha256sum != nil {
+	if sha256sum := rpmObj.GetBytes(0, 1016); sha256sum != nil {
 		canonicalEntry.Package.Headers["RPMSIGTAG_SHA256"] = hex.EncodeToString(sha256sum)
 	}
 
@@ -375,8 +377,8 @@ func (v V001Entry) CreateFromArtifactProperties(ctx context.Context, props types
 		return nil, err
 	}
 
-	if re.HasExternalEntities() {
-		if err := re.FetchExternalEntities(context.Background()); err != nil {
+	if re.hasExternalEntities() {
+		if _, _, err := re.fetchExternalEntities(context.Background()); err != nil {
 			return nil, fmt.Errorf("error retrieving external entities: %v", err)
 		}
 	}
diff --git a/pkg/types/rpm/v0.0.1/entry_test.go b/pkg/types/rpm/v0.0.1/entry_test.go
index 1588ed0..96d3c45 100644
--- a/pkg/types/rpm/v0.0.1/entry_test.go
+++ b/pkg/types/rpm/v0.0.1/entry_test.go
@@ -369,7 +369,7 @@ func TestCrossFieldValidation(t *testing.T) {
 			t.Errorf("unexpected result in '%v': %v", tc.caseDesc, err)
 		}
 
-		if tc.entry.HasExternalEntities() != tc.hasExtEntities {
+		if tc.entry.hasExternalEntities() != tc.hasExtEntities {
 			t.Errorf("unexpected result from HasExternalEntities for '%v'", tc.caseDesc)
 		}
 
-- 
GitLab