From d244798b10248223df79d8dbb4ae08f18eca611d Mon Sep 17 00:00:00 2001
From: Bob Callaway <bobcallaway@users.noreply.github.com>
Date: Fri, 8 Jan 2021 13:13:30 -0500
Subject: [PATCH] Print inclusion proof in human-readable format (#91)

also adds --log-index parameter to CLI verify command

Signed-off-by: Bob Callaway <bcallawa@redhat.com>
---
 cmd/cli/app/get.go                        | 11 ++--
 cmd/cli/app/log_proof.go                  |  2 +-
 cmd/cli/app/pflags.go                     | 53 ++++++++++++++-
 cmd/cli/app/pflags_test.go                | 46 ++++++++++++-
 cmd/cli/app/upload.go                     |  2 +-
 cmd/cli/app/verify.go                     | 78 ++++++++++++++---------
 openapi.yaml                              |  2 +-
 pkg/generated/models/consistency_proof.go |  2 +-
 pkg/generated/restapi/embedded_spec.go    |  4 +-
 9 files changed, 155 insertions(+), 45 deletions(-)

diff --git a/cmd/cli/app/get.go b/cmd/cli/app/get.go
index 532ebf5..6ef4442 100644
--- a/cmd/cli/app/get.go
+++ b/cmd/cli/app/get.go
@@ -16,7 +16,6 @@ limitations under the License.
 package app
 
 import (
-	"errors"
 	"fmt"
 	"strconv"
 
@@ -48,13 +47,11 @@ var getCmd = &cobra.Command{
 		logIndex := viper.GetString("log-index")
 		if logIndex != "" {
 			params := entries.NewGetLogEntryByIndexParams()
-			logIndexInt, err := strconv.Atoi(logIndex)
+			logIndexInt, err := strconv.ParseInt(logIndex, 10, 0)
 			if err != nil {
 				log.Fatal(fmt.Errorf("error parsing --log-index: %w", err))
-			} else if logIndexInt < 0 {
-				log.Fatal(errors.New("--log-index must be greater than or equal to 0"))
 			}
-			params.LogIndex = int64(logIndexInt)
+			params.LogIndex = logIndexInt
 
 			resp, err := rekorClient.Entries.GetLogEntryByIndex(params)
 			if err != nil {
@@ -99,7 +96,9 @@ func init() {
 	if err := addUUIDPFlags(getCmd, false); err != nil {
 		log.Logger.Fatal("Error parsing cmd line args:", err)
 	}
-	getCmd.Flags().String("log-index", "", "the index of the entry in the transparency log")
+	if err := addLogIndexFlag(getCmd, false); err != nil {
+		log.Logger.Fatal("Error parsing cmd line args:", err)
+	}
 
 	rootCmd.AddCommand(getCmd)
 }
diff --git a/cmd/cli/app/log_proof.go b/cmd/cli/app/log_proof.go
index c428965..120ee9e 100644
--- a/cmd/cli/app/log_proof.go
+++ b/cmd/cli/app/log_proof.go
@@ -68,7 +68,7 @@ var logProofCmd = &cobra.Command{
 		}
 
 		consistencyProof := result.GetPayload()
-		fmt.Printf("Root Hash: %v\n", *consistencyProof.RootHash)
+		fmt.Printf("Current Root Hash: %v\n", *consistencyProof.RootHash)
 		fmt.Printf("Hashes: [")
 		for i, hash := range consistencyProof.Hashes {
 			if i+1 == len(consistencyProof.Hashes) {
diff --git a/cmd/cli/app/pflags.go b/cmd/cli/app/pflags.go
index 190e131..f278286 100644
--- a/cmd/cli/app/pflags.go
+++ b/cmd/cli/app/pflags.go
@@ -26,6 +26,7 @@ import (
 	"net/url"
 	"os"
 	"path/filepath"
+	"strconv"
 
 	"github.com/go-openapi/strfmt"
 	"github.com/go-openapi/swag"
@@ -48,7 +49,7 @@ func addArtifactPFlags(cmd *cobra.Command) error {
 	return nil
 }
 
-func validateArtifactPFlags(uuidValid bool) error {
+func validateArtifactPFlags(uuidValid, indexValid bool) error {
 	uuidGiven := false
 	if uuidValid {
 		uuid := shaFlag{}
@@ -61,6 +62,18 @@ func validateArtifactPFlags(uuidValid bool) error {
 			uuidGiven = true
 		}
 	}
+	indexGiven := false
+	if indexValid {
+		logIndex := logIndexFlag{}
+		logIndexStr := viper.GetString("log-index")
+
+		if logIndexStr != "" {
+			if err := logIndex.Set(logIndexStr); err != nil {
+				return err
+			}
+			indexGiven = true
+		}
+	}
 	// we will need artifact, public-key, signature, and potentially SHA
 	rekord := viper.GetString("rekord")
 
@@ -77,7 +90,7 @@ func validateArtifactPFlags(uuidValid bool) error {
 	sha := viper.GetString("sha")
 
 	if rekord == "" && artifact.String() == "" {
-		if uuidGiven && uuidValid {
+		if (uuidGiven && uuidValid) || (indexGiven && indexValid) {
 			return nil
 		}
 		return errors.New("either 'rekord' or 'artifact' must be specified")
@@ -253,3 +266,39 @@ func addUUIDPFlags(cmd *cobra.Command, required bool) error {
 	}
 	return nil
 }
+
+type logIndexFlag struct {
+	index int64
+}
+
+func (l *logIndexFlag) String() string {
+	return fmt.Sprint(l.index)
+}
+
+func (l *logIndexFlag) Set(v string) error {
+	if v == "" {
+		return errors.New("flag must be specified")
+	}
+	logIndexInt, err := strconv.ParseInt(v, 10, 0)
+	if err != nil {
+		return fmt.Errorf("error parsing --log-index: %w", err)
+	} else if logIndexInt < 0 {
+		return errors.New("--log-index must be greater than or equal to 0")
+	}
+	l.index = logIndexInt
+	return nil
+}
+
+func (l *logIndexFlag) Type() string {
+	return "logIndex"
+}
+
+func addLogIndexFlag(cmd *cobra.Command, required bool) error {
+	cmd.Flags().Var(&logIndexFlag{}, "log-index", "the index of the entry in the transparency log")
+	if required {
+		if err := cmd.MarkFlagRequired("log-index"); err != nil {
+			return err
+		}
+	}
+	return nil
+}
diff --git a/cmd/cli/app/pflags_test.go b/cmd/cli/app/pflags_test.go
index 15bab85..f72f695 100644
--- a/cmd/cli/app/pflags_test.go
+++ b/cmd/cli/app/pflags_test.go
@@ -38,6 +38,8 @@ func TestArtifactPFlags(t *testing.T) {
 		sha                   string
 		uuid                  string
 		uuidRequired          bool
+		logIndex              string
+		logIndexRequired      bool
 		expectParseSuccess    bool
 		expectValidateSuccess bool
 	}
@@ -245,6 +247,34 @@ func TestArtifactPFlags(t *testing.T) {
 			expectParseSuccess:    true,
 			expectValidateSuccess: false,
 		},
+		{
+			caseDesc:              "valid log index",
+			logIndex:              "1",
+			logIndexRequired:      true,
+			expectParseSuccess:    true,
+			expectValidateSuccess: true,
+		},
+		{
+			caseDesc:              "invalid log index",
+			logIndex:              "not_a_int",
+			logIndexRequired:      true,
+			expectParseSuccess:    false,
+			expectValidateSuccess: false,
+		},
+		{
+			caseDesc:              "invalid log index - less than 0",
+			logIndex:              "-1",
+			logIndexRequired:      true,
+			expectParseSuccess:    false,
+			expectValidateSuccess: false,
+		},
+		{
+			caseDesc:              "unwanted log index",
+			logIndex:              "1",
+			logIndexRequired:      false,
+			expectParseSuccess:    true,
+			expectValidateSuccess: false,
+		},
 		{
 			caseDesc:              "no flags when either uuid, rekord, or artifact++ are needed",
 			uuidRequired:          false,
@@ -257,6 +287,12 @@ func TestArtifactPFlags(t *testing.T) {
 			expectParseSuccess:    true,
 			expectValidateSuccess: false,
 		},
+		{
+			caseDesc:              "missing log index flag when it is needed",
+			logIndexRequired:      true,
+			expectParseSuccess:    true,
+			expectValidateSuccess: false,
+		},
 	}
 
 	for _, tc := range tests {
@@ -267,6 +303,9 @@ func TestArtifactPFlags(t *testing.T) {
 		if err := addUUIDPFlags(blankCmd, tc.uuidRequired); err != nil {
 			t.Fatalf("unexpected error adding uuid flags in '%v': %v", tc.caseDesc, err)
 		}
+		if err := addLogIndexFlag(blankCmd, tc.logIndexRequired); err != nil {
+			t.Fatalf("unexpected error adding log index flags in '%v': %v", tc.caseDesc, err)
+		}
 
 		args := []string{}
 
@@ -288,6 +327,9 @@ func TestArtifactPFlags(t *testing.T) {
 		if tc.uuid != "" {
 			args = append(args, "--uuid", tc.uuid)
 		}
+		if tc.logIndex != "" {
+			args = append(args, "--log-index", tc.logIndex)
+		}
 
 		if err := blankCmd.ParseFlags(args); (err == nil) != tc.expectParseSuccess {
 			t.Errorf("unexpected result parsing '%v': %v", tc.caseDesc, err)
@@ -298,11 +340,11 @@ func TestArtifactPFlags(t *testing.T) {
 			if err := viper.BindPFlags(blankCmd.Flags()); err != nil {
 				t.Fatalf("unexpected result initializing viper in '%v': %v", tc.caseDesc, err)
 			}
-			if err := validateArtifactPFlags(tc.uuidRequired); (err == nil) != tc.expectValidateSuccess {
+			if err := validateArtifactPFlags(tc.uuidRequired, tc.logIndexRequired); (err == nil) != tc.expectValidateSuccess {
 				t.Errorf("unexpected result validating '%v': %v", tc.caseDesc, err)
 				continue
 			}
-			if !tc.uuidRequired {
+			if !tc.uuidRequired && !tc.logIndexRequired {
 				if _, err := CreateRekordFromPFlags(); err != nil {
 					t.Errorf("unexpected result in '%v' building Rekord: %v", tc.caseDesc, err)
 				}
diff --git a/cmd/cli/app/upload.go b/cmd/cli/app/upload.go
index cc0d129..3a8e789 100644
--- a/cmd/cli/app/upload.go
+++ b/cmd/cli/app/upload.go
@@ -34,7 +34,7 @@ var uploadCmd = &cobra.Command{
 		if err := viper.BindPFlags(cmd.Flags()); err != nil {
 			log.Logger.Fatal("Error initializing cmd line args: ", err)
 		}
-		if err := validateArtifactPFlags(false); err != nil {
+		if err := validateArtifactPFlags(false, false); err != nil {
 			log.Logger.Error(err)
 			_ = cmd.Help()
 			os.Exit(1)
diff --git a/cmd/cli/app/verify.go b/cmd/cli/app/verify.go
index 2bb01d2..8879faf 100644
--- a/cmd/cli/app/verify.go
+++ b/cmd/cli/app/verify.go
@@ -18,9 +18,10 @@ package app
 import (
 	"encoding/hex"
 	"fmt"
+	"math/bits"
 	"os"
+	"strconv"
 
-	"github.com/google/trillian/merkle"
 	"github.com/google/trillian/merkle/rfc6962"
 	"github.com/projectrekor/rekor/pkg/generated/client/entries"
 	"github.com/projectrekor/rekor/pkg/generated/models"
@@ -40,7 +41,7 @@ var verifyCmd = &cobra.Command{
 		if err := viper.BindPFlags(cmd.Flags()); err != nil {
 			log.Logger.Fatal("Error initializing cmd line args: ", err)
 		}
-		if err := validateArtifactPFlags(true); err != nil {
+		if err := validateArtifactPFlags(true, true); err != nil {
 			log.Logger.Error(err)
 			_ = cmd.Help()
 			os.Exit(1)
@@ -60,13 +61,22 @@ var verifyCmd = &cobra.Command{
 			searchParams := entries.NewSearchLogQueryParams()
 			searchLogQuery := models.SearchLogQuery{}
 
-			rekordEntry, err := CreateRekordFromPFlags()
-			if err != nil {
-				log.Fatal(err)
+			logIndex := viper.GetString("log-index")
+			if logIndex != "" {
+				logIndexInt, err := strconv.ParseInt(logIndex, 10, 0)
+				if err != nil {
+					log.Fatal(fmt.Errorf("error parsing --log-index: %w", err))
+				}
+				searchLogQuery.LogIndexes = []*int64{&logIndexInt}
+			} else {
+				rekordEntry, err := CreateRekordFromPFlags()
+				if err != nil {
+					log.Fatal(err)
+				}
+
+				entries := []models.ProposedEntry{rekordEntry}
+				searchLogQuery.SetEntries(entries)
 			}
-
-			entries := []models.ProposedEntry{rekordEntry}
-			searchLogQuery.SetEntries(entries)
 			searchParams.SetEntry(&searchLogQuery)
 
 			resp, err := rekorClient.Entries.SearchLogQuery(searchParams)
@@ -93,31 +103,38 @@ var verifyCmd = &cobra.Command{
 			log.Fatal(err)
 		}
 
-		inclusionProof := resp.GetPayload()
-		hashes := [][]byte{}
-
-		for _, hash := range inclusionProof.Hashes {
-			val, err := hex.DecodeString(hash)
-			if err != nil {
-				log.Fatal(err)
+		inclusionProof := resp.Payload
+		index := *inclusionProof.LogIndex
+		size := *inclusionProof.TreeSize
+		rootHash := *inclusionProof.RootHash
+		fmt.Printf("Current Root Hash: %v\n", rootHash)
+		fmt.Printf("Entry Hash: %v\n", params.EntryUUID)
+		fmt.Printf("Entry Index: %v\n", index)
+		fmt.Printf("Current Tree Size: %v\n\n", size)
+
+		hasher := rfc6962.DefaultHasher
+		inner := bits.Len64(uint64(index ^ (size - 1)))
+		var left, right []byte
+		result, _ := hex.DecodeString(params.EntryUUID)
+		fmt.Printf("Inclusion Proof:\n")
+		for i, h := range inclusionProof.Hashes {
+			if i < inner && (index>>uint(i))&1 == 0 {
+				left = result
+				right, _ = hex.DecodeString(h)
+			} else {
+				left, _ = hex.DecodeString(h)
+				right = result
 			}
-			hashes = append(hashes, val)
-		}
-
-		leafHash, err := hex.DecodeString(params.EntryUUID)
-		if err != nil {
-			log.Fatal(err)
-		}
-		rootHash, err := hex.DecodeString(*inclusionProof.RootHash)
-		if err != nil {
-			log.Fatal(err)
+			result = hasher.HashChildren(left, right)
+			fmt.Printf("SHA256(0x01 | %v | %v) =\n\t%v\n\n", hex.EncodeToString(left), hex.EncodeToString(right), hex.EncodeToString(result))
 		}
+		resultHash := hex.EncodeToString(result)
 
-		v := merkle.NewLogVerifier(rfc6962.DefaultHasher)
-		if err := v.VerifyInclusionProof(*inclusionProof.LogIndex, *inclusionProof.TreeSize, hashes, rootHash, leafHash); err != nil {
-			log.Fatal(err)
+		if resultHash == rootHash {
+			fmt.Printf("%v == %v, proof complete\n", resultHash, rootHash)
+		} else {
+			fmt.Printf("proof could not be correctly generated!")
 		}
-		log.Info("Proof correct!")
 	},
 }
 
@@ -128,6 +145,9 @@ func init() {
 	if err := addUUIDPFlags(verifyCmd, false); err != nil {
 		log.Logger.Fatal("Error parsing cmd line args:", err)
 	}
+	if err := addLogIndexFlag(verifyCmd, false); err != nil {
+		log.Logger.Fatal("Error parsing cmd line args:", err)
+	}
 
 	rootCmd.AddCommand(verifyCmd)
 }
diff --git a/openapi.yaml b/openapi.yaml
index 40e850e..9ca0f29 100644
--- a/openapi.yaml
+++ b/openapi.yaml
@@ -270,7 +270,7 @@ definitions:
     properties:
       rootHash:
         type: string
-        description: The hash value stored at the root of the merkle tree at time the proof was generated
+        description: The hash value stored at the root of the merkle tree at the time the proof was generated
         pattern: '^[0-9a-fA-F]{64}$'
       hashes:
         type: array
diff --git a/pkg/generated/models/consistency_proof.go b/pkg/generated/models/consistency_proof.go
index 6f590bf..5fb9af9 100644
--- a/pkg/generated/models/consistency_proof.go
+++ b/pkg/generated/models/consistency_proof.go
@@ -41,7 +41,7 @@ type ConsistencyProof struct {
 	// Required: true
 	Hashes []string `json:"hashes"`
 
-	// The hash value stored at the root of the merkle tree at time the proof was generated
+	// The hash value stored at the root of the merkle tree at the time the proof was generated
 	// Required: true
 	// Pattern: ^[0-9a-fA-F]{64}$
 	RootHash *string `json:"rootHash"`
diff --git a/pkg/generated/restapi/embedded_spec.go b/pkg/generated/restapi/embedded_spec.go
index f0f6b4e..2936862 100644
--- a/pkg/generated/restapi/embedded_spec.go
+++ b/pkg/generated/restapi/embedded_spec.go
@@ -314,7 +314,7 @@ func init() {
           }
         },
         "rootHash": {
-          "description": "The hash value stored at the root of the merkle tree at time the proof was generated",
+          "description": "The hash value stored at the root of the merkle tree at the time the proof was generated",
           "type": "string",
           "pattern": "^[0-9a-fA-F]{64}$"
         }
@@ -806,7 +806,7 @@ func init() {
           }
         },
         "rootHash": {
-          "description": "The hash value stored at the root of the merkle tree at time the proof was generated",
+          "description": "The hash value stored at the root of the merkle tree at the time the proof was generated",
           "type": "string",
           "pattern": "^[0-9a-fA-F]{64}$"
         }
-- 
GitLab