From 770e8bf1ff285d0b6f3304feaf73a4d9a868cec7 Mon Sep 17 00:00:00 2001
From: dlorenc <dlorenc@google.com>
Date: Fri, 31 Dec 2021 07:12:49 -0600
Subject: [PATCH] Refactor helm type to remove intermediate state. (#575)

This should be the actual last one :)

Signed-off-by: Dan Lorenc <lorenc.d@gmail.com>
---
 pkg/types/helm/v0.0.1/entry.go      | 63 ++++++++++++-------------
 pkg/types/helm/v0.0.1/entry_test.go | 72 ++++++++++++++---------------
 2 files changed, 68 insertions(+), 67 deletions(-)

diff --git a/pkg/types/helm/v0.0.1/entry.go b/pkg/types/helm/v0.0.1/entry.go
index 545d3cc..7634040 100644
--- a/pkg/types/helm/v0.0.1/entry.go
+++ b/pkg/types/helm/v0.0.1/entry.go
@@ -33,7 +33,6 @@ import (
 	"github.com/go-openapi/swag"
 	"github.com/sigstore/rekor/pkg/generated/models"
 	"github.com/sigstore/rekor/pkg/log"
-	"github.com/sigstore/rekor/pkg/pki"
 	"github.com/sigstore/rekor/pkg/pki/pgp"
 	"github.com/sigstore/rekor/pkg/types"
 	"github.com/sigstore/rekor/pkg/types/helm"
@@ -52,10 +51,7 @@ func init() {
 }
 
 type V001Entry struct {
-	HelmObj       models.HelmV001Schema
-	keyObj        pki.PublicKey
-	sigObj        pki.Signature
-	provenanceObj *helm.Provenance
+	HelmObj models.HelmV001Schema
 }
 
 func (v V001Entry) APIVersion() string {
@@ -69,16 +65,26 @@ func NewEntry() types.EntryImpl {
 func (v V001Entry) IndexKeys() ([]string, error) {
 	var result []string
 
-	key, err := v.keyObj.CanonicalValue()
+	keyObj, err := pgp.NewPublicKey(bytes.NewReader(v.HelmObj.PublicKey.Content))
+	if err != nil {
+		return nil, err
+	}
+
+	provenance := helm.Provenance{}
+	if err := provenance.Unmarshal(bytes.NewReader(v.HelmObj.Chart.Provenance.Content)); err != nil {
+		return nil, err
+	}
+
+	key, err := keyObj.CanonicalValue()
 	if err != nil {
 		return nil, err
 	}
 	keyHash := sha256.Sum256(key)
 	result = append(result, strings.ToLower(hex.EncodeToString(keyHash[:])))
 
-	result = append(result, v.keyObj.EmailAddresses()...)
+	result = append(result, keyObj.EmailAddresses()...)
 
-	algorithm, chartHash, err := v.provenanceObj.GetChartAlgorithmHash()
+	algorithm, chartHash, err := provenance.GetChartAlgorithmHash()
 
 	if err != nil {
 		log.Logger.Error(err)
@@ -121,11 +127,7 @@ func (v V001Entry) hasExternalEntities() bool {
 	return false
 }
 
-func (v *V001Entry) fetchExternalEntities(ctx context.Context) error {
-	if err := v.validate(); err != nil {
-		return types.ValidationError(err)
-	}
-
+func (v *V001Entry) fetchExternalEntities(ctx context.Context) (*helm.Provenance, *pgp.PublicKey, *pgp.Signature, error) {
 	g, ctx := errgroup.WithContext(ctx)
 
 	provenanceR, provenanceW := io.Pipe()
@@ -160,7 +162,7 @@ func (v *V001Entry) fetchExternalEntities(ctx context.Context) error {
 		}
 		defer keyReadCloser.Close()
 
-		v.keyObj, err = pgp.NewPublicKey(keyReadCloser)
+		keyObj, err := pgp.NewPublicKey(keyReadCloser)
 		if err != nil {
 			return closePipesOnError(types.ValidationError(err))
 		}
@@ -168,25 +170,28 @@ func (v *V001Entry) fetchExternalEntities(ctx context.Context) error {
 		select {
 		case <-ctx.Done():
 			return ctx.Err()
-		case keyResult <- v.keyObj.(*pgp.PublicKey):
+		case keyResult <- keyObj:
 			return nil
 		}
 	})
 
+	var key *pgp.PublicKey
+	provenance := &helm.Provenance{}
+	var sig *pgp.Signature
 	g.Go(func() error {
 
-		provenance := helm.Provenance{}
 		if err := provenance.Unmarshal(provenanceR); err != nil {
 			return closePipesOnError(types.ValidationError(err))
 		}
 
-		key := <-keyResult
+		key = <-keyResult
 		if key == nil {
 			return closePipesOnError(errors.New("error processing public key"))
 		}
 
 		// Set signature
-		sig, err := pgp.NewSignature(provenance.Block.ArmoredSignature.Body)
+		var err error
+		sig, err = pgp.NewSignature(provenance.Block.ArmoredSignature.Body)
 		if err != nil {
 			return closePipesOnError(types.ValidationError(err))
 		}
@@ -196,9 +201,6 @@ func (v *V001Entry) fetchExternalEntities(ctx context.Context) error {
 			return closePipesOnError(types.ValidationError(err))
 		}
 
-		v.sigObj = sig
-		v.provenanceObj = &provenance
-
 		select {
 		case <-ctx.Done():
 			return ctx.Err()
@@ -208,27 +210,26 @@ func (v *V001Entry) fetchExternalEntities(ctx context.Context) error {
 	})
 
 	if err := g.Wait(); err != nil {
-		return err
+		return nil, nil, nil, err
 	}
 
-	return nil
+	return provenance, key, sig, nil
 }
 
 func (v *V001Entry) Canonicalize(ctx context.Context) ([]byte, error) {
-	if err := v.fetchExternalEntities(ctx); err != nil {
+	provenanceObj, keyObj, sigObj, err := v.fetchExternalEntities(ctx)
+	if err != nil {
 		return nil, err
 	}
 
-	if v.keyObj == nil {
+	if keyObj == nil {
 		return nil, errors.New("key object not initialized before canonicalization")
 	}
 
 	canonicalEntry := models.HelmV001Schema{}
 
-	var err error
-
 	canonicalEntry.PublicKey = &models.HelmV001SchemaPublicKey{}
-	keyContent, err := v.keyObj.CanonicalValue()
+	keyContent, err := keyObj.CanonicalValue()
 	if err != nil {
 		return nil, err
 	}
@@ -237,7 +238,7 @@ func (v *V001Entry) Canonicalize(ctx context.Context) ([]byte, error) {
 
 	canonicalEntry.Chart = &models.HelmV001SchemaChart{}
 
-	algorithm, chartHash, err := v.provenanceObj.GetChartAlgorithmHash()
+	algorithm, chartHash, err := provenanceObj.GetChartAlgorithmHash()
 
 	if err != nil {
 		return nil, err
@@ -250,7 +251,7 @@ func (v *V001Entry) Canonicalize(ctx context.Context) ([]byte, error) {
 	canonicalEntry.Chart.Provenance = &models.HelmV001SchemaChartProvenance{}
 	canonicalEntry.Chart.Provenance.Signature = &models.HelmV001SchemaChartProvenanceSignature{}
 
-	sigContent, err := v.sigObj.CanonicalValue()
+	sigContent, err := sigObj.CanonicalValue()
 	if err != nil {
 		return nil, err
 	}
@@ -350,7 +351,7 @@ func (v V001Entry) CreateFromArtifactProperties(ctx context.Context, props types
 	}
 
 	if re.hasExternalEntities() {
-		if err := re.fetchExternalEntities(ctx); err != nil {
+		if _, _, _, err := re.fetchExternalEntities(ctx); err != nil {
 			return nil, fmt.Errorf("error retrieving external entities: %v", err)
 		}
 	}
diff --git a/pkg/types/helm/v0.0.1/entry_test.go b/pkg/types/helm/v0.0.1/entry_test.go
index fbbc680..b053ecc 100644
--- a/pkg/types/helm/v0.0.1/entry_test.go
+++ b/pkg/types/helm/v0.0.1/entry_test.go
@@ -218,50 +218,50 @@ func TestCrossFieldValidation(t *testing.T) {
 	}
 
 	for _, tc := range testCases {
-		if err := tc.entry.validate(); (err == nil) != tc.expectUnmarshalSuccess {
-			t.Errorf("unexpected result in '%v': %v", tc.caseDesc, err)
-		}
+		t.Run(tc.caseDesc, func(t *testing.T) {
 
-		v := &V001Entry{}
-		r := models.Helm{
-			APIVersion: swag.String(tc.entry.APIVersion()),
-			Spec:       tc.entry.HelmObj,
-		}
-
-		unmarshalAndValidate := func() error {
-			if err := v.Unmarshal(&r); err != nil {
-				return err
+			if err := tc.entry.validate(); (err == nil) != tc.expectUnmarshalSuccess {
+				t.Errorf("unexpected result in '%v': %v", tc.caseDesc, err)
 			}
-			if err := v.validate(); err != nil {
-				return err
+
+			v := &V001Entry{}
+			r := models.Helm{
+				APIVersion: swag.String(tc.entry.APIVersion()),
+				Spec:       tc.entry.HelmObj,
 			}
-			return nil
-		}
 
-		if err := unmarshalAndValidate(); (err == nil) != tc.expectUnmarshalSuccess {
-			t.Errorf("unexpected result in '%v': %v", tc.caseDesc, err)
-		}
+			if err := v.Unmarshal(&r); (err == nil) != tc.expectUnmarshalSuccess {
+				t.Errorf("unexpected result in '%v': %v", tc.caseDesc, err)
+			}
 
-		if tc.entry.hasExternalEntities() != tc.hasExtEntities {
-			t.Errorf("unexpected result from HasExternalEntities for '%v'", tc.caseDesc)
-		}
+			if !tc.expectUnmarshalSuccess {
+				return
+			}
+			if err := v.validate(); err != nil {
+				return
+			}
 
-		b, err := v.Canonicalize(context.TODO())
-		if (err == nil) != tc.expectCanonicalizeSuccess {
-			t.Errorf("unexpected result from Canonicalize for '%v': %v", tc.caseDesc, err)
-		} else if err != nil {
-			if _, ok := err.(types.ValidationError); !ok {
-				t.Errorf("canonicalize returned an unexpected error that isn't of type types.ValidationError: %v", err)
+			if tc.entry.hasExternalEntities() != tc.hasExtEntities {
+				t.Errorf("unexpected result from HasExternalEntities for '%v'", tc.caseDesc)
 			}
-		}
-		if b != nil {
-			pe, err := models.UnmarshalProposedEntry(bytes.NewReader(b), runtime.JSONConsumer())
-			if err != nil {
-				t.Errorf("unexpected err from Unmarshalling canonicalized entry for '%v': %v", tc.caseDesc, err)
+
+			b, err := v.Canonicalize(context.TODO())
+			if (err == nil) != tc.expectCanonicalizeSuccess {
+				t.Errorf("unexpected result from Canonicalize for '%v': %v", tc.caseDesc, err)
+			} else if err != nil {
+				if _, ok := err.(types.ValidationError); !ok {
+					t.Errorf("canonicalize returned an unexpected error that isn't of type types.ValidationError: %v", err)
+				}
 			}
-			if _, err := types.NewEntry(pe); err != nil {
-				t.Errorf("unexpected err from type-specific unmarshalling for '%v': %v", tc.caseDesc, err)
+			if b != nil {
+				pe, err := models.UnmarshalProposedEntry(bytes.NewReader(b), runtime.JSONConsumer())
+				if err != nil {
+					t.Errorf("unexpected err from Unmarshalling canonicalized entry for '%v': %v", tc.caseDesc, err)
+				}
+				if _, err := types.NewEntry(pe); err != nil {
+					t.Errorf("unexpected err from type-specific unmarshalling for '%v': %v", tc.caseDesc, err)
+				}
 			}
-		}
+		})
 	}
 }
-- 
GitLab