diff --git a/pkg/types/rpm/v0.0.1/entry.go b/pkg/types/rpm/v0.0.1/entry.go index 72e66952b1c24a572524caccf38efec94342f57c..13fb36d8af0df67a565c4120861835b3f3b8d383 100644 --- a/pkg/types/rpm/v0.0.1/entry.go +++ b/pkg/types/rpm/v0.0.1/entry.go @@ -16,6 +16,7 @@ package rpm import ( + "bytes" "context" "crypto/sha256" "encoding/hex" @@ -36,7 +37,6 @@ import ( "github.com/sigstore/rekor/pkg/generated/models" "github.com/sigstore/rekor/pkg/log" - "github.com/sigstore/rekor/pkg/pki" "github.com/sigstore/rekor/pkg/pki/pgp" "github.com/sigstore/rekor/pkg/types" "github.com/sigstore/rekor/pkg/types/rpm" @@ -55,8 +55,6 @@ func init() { type V001Entry struct { RPMModel models.RpmV001Schema - keyObj pki.PublicKey - rpmObj *rpmutils.PackageFile } func (v V001Entry) APIVersion() string { @@ -70,14 +68,18 @@ func NewEntry() types.EntryImpl { func (v V001Entry) IndexKeys() ([]string, error) { var result []string - key, err := v.keyObj.CanonicalValue() + keyObj, err := pgp.NewPublicKey(bytes.NewReader(v.RPMModel.PublicKey.Content)) + if err != nil { + return nil, err + } + key, err := keyObj.CanonicalValue() if err != nil { return nil, err } keyHash := sha256.Sum256(key) result = append(result, strings.ToLower(hex.EncodeToString(keyHash[:]))) - result = append(result, v.keyObj.EmailAddresses()...) + result = append(result, keyObj.EmailAddresses()...) if v.RPMModel.Package.Hash != nil { hashKey := strings.ToLower(fmt.Sprintf("%s:%s", *v.RPMModel.Package.Hash.Algorithm, *v.RPMModel.Package.Hash.Value)) @@ -105,7 +107,7 @@ func (v *V001Entry) Unmarshal(pe models.ProposedEntry) error { return v.validate() } -func (v V001Entry) HasExternalEntities() bool { +func (v V001Entry) hasExternalEntities() bool { if v.RPMModel.Package != nil && v.RPMModel.Package.URL.String() != "" { return true @@ -116,10 +118,10 @@ func (v V001Entry) HasExternalEntities() bool { return false } -func (v *V001Entry) FetchExternalEntities(ctx context.Context) error { +func (v *V001Entry) fetchExternalEntities(ctx context.Context) (*pgp.PublicKey, *rpmutils.PackageFile, error) { if err := v.validate(); err != nil { - return types.ValidationError(err) + return nil, nil, types.ValidationError(err) } g, ctx := errgroup.WithContext(ctx) @@ -179,6 +181,7 @@ func (v *V001Entry) FetchExternalEntities(ctx context.Context) error { } }) + var keyObj *pgp.PublicKey g.Go(func() error { keyReadCloser, err := util.FileOrURLReadCloser(ctx, v.RPMModel.PublicKey.URL.String(), v.RPMModel.PublicKey.Content) @@ -187,12 +190,12 @@ func (v *V001Entry) FetchExternalEntities(ctx context.Context) error { } defer keyReadCloser.Close() - v.keyObj, err = pgp.NewPublicKey(keyReadCloser) + keyObj, err = pgp.NewPublicKey(keyReadCloser) if err != nil { return closePipesOnError(types.ValidationError(err)) } - keyring, err := v.keyObj.(*pgp.PublicKey).KeyRing() + keyring, err := keyObj.KeyRing() if err != nil { return closePipesOnError(types.ValidationError(err)) } @@ -209,10 +212,11 @@ func (v *V001Entry) FetchExternalEntities(ctx context.Context) error { } }) + var rpmObj *rpmutils.PackageFile g.Go(func() error { var err error - v.rpmObj, err = rpmutils.ReadPackageFile(rpmR) + rpmObj, err = rpmutils.ReadPackageFile(rpmR) if err != nil { return closePipesOnError(types.ValidationError(err)) } @@ -232,7 +236,7 @@ func (v *V001Entry) FetchExternalEntities(ctx context.Context) error { computedSHA := <-hashResult if err := g.Wait(); err != nil { - return err + return nil, nil, err } // if we get here, all goroutines succeeded without error @@ -242,23 +246,21 @@ func (v *V001Entry) FetchExternalEntities(ctx context.Context) error { v.RPMModel.Package.Hash.Value = swag.String(computedSHA) } - return nil + return keyObj, rpmObj, nil } func (v *V001Entry) Canonicalize(ctx context.Context) ([]byte, error) { - if err := v.FetchExternalEntities(ctx); err != nil { + keyObj, rpmObj, err := v.fetchExternalEntities(ctx) + if err != nil { return nil, err } - if v.keyObj == nil { - return nil, errors.New("key object not initialized before canonicalization") - } canonicalEntry := models.RpmV001Schema{} - var err error // need to canonicalize key content + canonicalEntry.PublicKey = &models.RpmV001SchemaPublicKey{} - canonicalEntry.PublicKey.Content, err = v.keyObj.CanonicalValue() + canonicalEntry.PublicKey.Content, err = keyObj.CanonicalValue() if err != nil { return nil, err } @@ -271,18 +273,18 @@ func (v *V001Entry) Canonicalize(ctx context.Context) ([]byte, error) { // set NEVRA headers canonicalEntry.Package.Headers = make(map[string]string) - canonicalEntry.Package.Headers["Name"] = v.rpmObj.Name() - canonicalEntry.Package.Headers["Epoch"] = strconv.Itoa(v.rpmObj.Epoch()) - canonicalEntry.Package.Headers["Version"] = v.rpmObj.Version() - canonicalEntry.Package.Headers["Release"] = v.rpmObj.Release() - canonicalEntry.Package.Headers["Architecture"] = v.rpmObj.Architecture() - if md5sum := v.rpmObj.GetBytes(0, 1004); md5sum != nil { + canonicalEntry.Package.Headers["Name"] = rpmObj.Name() + canonicalEntry.Package.Headers["Epoch"] = strconv.Itoa(rpmObj.Epoch()) + canonicalEntry.Package.Headers["Version"] = rpmObj.Version() + canonicalEntry.Package.Headers["Release"] = rpmObj.Release() + canonicalEntry.Package.Headers["Architecture"] = rpmObj.Architecture() + if md5sum := rpmObj.GetBytes(0, 1004); md5sum != nil { canonicalEntry.Package.Headers["RPMSIGTAG_MD5"] = hex.EncodeToString(md5sum) } - if sha1sum := v.rpmObj.GetBytes(0, 1012); sha1sum != nil { + if sha1sum := rpmObj.GetBytes(0, 1012); sha1sum != nil { canonicalEntry.Package.Headers["RPMSIGTAG_SHA1"] = hex.EncodeToString(sha1sum) } - if sha256sum := v.rpmObj.GetBytes(0, 1016); sha256sum != nil { + if sha256sum := rpmObj.GetBytes(0, 1016); sha256sum != nil { canonicalEntry.Package.Headers["RPMSIGTAG_SHA256"] = hex.EncodeToString(sha256sum) } @@ -375,8 +377,8 @@ func (v V001Entry) CreateFromArtifactProperties(ctx context.Context, props types return nil, err } - if re.HasExternalEntities() { - if err := re.FetchExternalEntities(context.Background()); err != nil { + if re.hasExternalEntities() { + if _, _, err := re.fetchExternalEntities(context.Background()); err != nil { return nil, fmt.Errorf("error retrieving external entities: %v", err) } } diff --git a/pkg/types/rpm/v0.0.1/entry_test.go b/pkg/types/rpm/v0.0.1/entry_test.go index 1588ed0a435e04c99116cb260cd6ea80c405603c..96d3c45ec8db55de6f93d5d412bb35276ea58ce1 100644 --- a/pkg/types/rpm/v0.0.1/entry_test.go +++ b/pkg/types/rpm/v0.0.1/entry_test.go @@ -369,7 +369,7 @@ func TestCrossFieldValidation(t *testing.T) { t.Errorf("unexpected result in '%v': %v", tc.caseDesc, err) } - if tc.entry.HasExternalEntities() != tc.hasExtEntities { + if tc.entry.hasExternalEntities() != tc.hasExtEntities { t.Errorf("unexpected result from HasExternalEntities for '%v'", tc.caseDesc) }