From 40d9419f4b294277789480c89e407066f9b63734 Mon Sep 17 00:00:00 2001
From: dlorenc <dlorenc@google.com>
Date: Wed, 29 Dec 2021 08:01:37 -0600
Subject: [PATCH] Refactor the shard map parsing so we can pass it down into
 the API object. (#564)

Right now the type itself is defined in the cli package, which means we can't
use it without an import cycle.

Signed-off-by: Dan Lorenc <lorenc.d@gmail.com>
---
 cmd/rekor-server/app/flags.go      | 51 +++++++++-----------------
 cmd/rekor-server/app/flags_test.go | 55 +++++-----------------------
 cmd/rekor-server/app/root.go       |  2 +-
 cmd/rekor-server/app/serve.go      |  2 +-
 pkg/api/api.go                     |  9 +++--
 pkg/api/ranges.go                  | 43 ++++++++++++++++++++++
 pkg/api/ranges_test.go             | 58 ++++++++++++++++++++++++++++++
 7 files changed, 133 insertions(+), 87 deletions(-)
 create mode 100644 pkg/api/ranges.go
 create mode 100644 pkg/api/ranges_test.go

diff --git a/cmd/rekor-server/app/flags.go b/cmd/rekor-server/app/flags.go
index 76705f2..f20412f 100644
--- a/cmd/rekor-server/app/flags.go
+++ b/cmd/rekor-server/app/flags.go
@@ -19,23 +19,20 @@ import (
 	"fmt"
 	"strconv"
 	"strings"
-)
 
-type LogRange struct {
-	TreeID     uint64
-	TreeLength uint64
-}
+	"github.com/sigstore/rekor/pkg/api"
+)
 
-type LogRanges struct {
-	Ranges []LogRange
+type LogRangesFlag struct {
+	Ranges api.LogRanges
 }
 
-func (l *LogRanges) Set(s string) error {
+func (l *LogRangesFlag) Set(s string) error {
 	ranges := strings.Split(s, ",")
-	l.Ranges = []LogRange{}
+	l.Ranges = api.LogRanges{}
 
 	var err error
-	inputRanges := []LogRange{}
+	inputRanges := []api.LogRange{}
 
 	// Only go up to the second to last one, the last one is special cased beow
 	for _, r := range ranges[:len(ranges)-1] {
@@ -43,7 +40,7 @@ func (l *LogRanges) Set(s string) error {
 		if len(split) != 2 {
 			return fmt.Errorf("invalid range flag, expected two parts separated by an =, got %s", r)
 		}
-		lr := LogRange{}
+		lr := api.LogRange{}
 		lr.TreeID, err = strconv.ParseUint(split[0], 10, 64)
 		if err != nil {
 			return err
@@ -63,7 +60,7 @@ func (l *LogRanges) Set(s string) error {
 		return err
 	}
 
-	inputRanges = append(inputRanges, LogRange{
+	inputRanges = append(inputRanges, api.LogRange{
 		TreeID: lastTreeID,
 	})
 
@@ -76,36 +73,20 @@ func (l *LogRanges) Set(s string) error {
 		TreeIDs[lr.TreeID] = struct{}{}
 	}
 
-	l.Ranges = inputRanges
+	l.Ranges = api.LogRanges{
+		Ranges: inputRanges,
+	}
 	return nil
 }
 
-func (l *LogRanges) String() string {
+func (l *LogRangesFlag) String() string {
 	ranges := []string{}
-	for _, r := range l.Ranges {
+	for _, r := range l.Ranges.Ranges {
 		ranges = append(ranges, fmt.Sprintf("%d=%d", r.TreeID, r.TreeLength))
 	}
 	return strings.Join(ranges, ",")
 }
 
-func (l *LogRanges) Type() string {
-	return "LogRanges"
-}
-
-func (l *LogRanges) ResolveVirtualIndex(index int) (uint64, uint64) {
-	indexLeft := index
-	for _, l := range l.Ranges {
-		if indexLeft < int(l.TreeLength) {
-			return l.TreeID, uint64(indexLeft)
-		}
-		indexLeft -= int(l.TreeLength)
-	}
-
-	// Return the last one!
-	return l.Ranges[len(l.Ranges)-1].TreeID, uint64(indexLeft)
-}
-
-// ActiveIndex returns the active shard index, always the last shard in the range
-func (l *LogRanges) ActiveIndex() uint64 {
-	return l.Ranges[len(l.Ranges)-1].TreeID
+func (l *LogRangesFlag) Type() string {
+	return "LogRangesFlag"
 }
diff --git a/cmd/rekor-server/app/flags_test.go b/cmd/rekor-server/app/flags_test.go
index ab82b26..90fa8b4 100644
--- a/cmd/rekor-server/app/flags_test.go
+++ b/cmd/rekor-server/app/flags_test.go
@@ -19,19 +19,20 @@ import (
 	"testing"
 
 	"github.com/google/go-cmp/cmp"
+	"github.com/sigstore/rekor/pkg/api"
 )
 
 func TestLogRanges_Set(t *testing.T) {
 	tests := []struct {
 		name   string
 		arg    string
-		want   []LogRange
+		want   []api.LogRange
 		active uint64
 	}{
 		{
 			name: "one, no length",
 			arg:  "1234",
-			want: []LogRange{
+			want: []api.LogRange{
 				{
 					TreeID:     1234,
 					TreeLength: 0,
@@ -42,7 +43,7 @@ func TestLogRanges_Set(t *testing.T) {
 		{
 			name: "two",
 			arg:  "1234=10,7234",
-			want: []LogRange{
+			want: []api.LogRange{
 				{
 					TreeID:     1234,
 					TreeLength: 10,
@@ -57,16 +58,16 @@ func TestLogRanges_Set(t *testing.T) {
 	}
 	for _, tt := range tests {
 		t.Run(tt.name, func(t *testing.T) {
-			l := &LogRanges{}
+			l := &LogRangesFlag{}
 			if err := l.Set(tt.arg); err != nil {
 				t.Errorf("LogRanges.Set() expected no error, got %v", err)
 			}
 
-			if diff := cmp.Diff(tt.want, l.Ranges); diff != "" {
+			if diff := cmp.Diff(tt.want, l.Ranges.Ranges); diff != "" {
 				t.Errorf(diff)
 			}
 
-			active := l.ActiveIndex()
+			active := l.Ranges.ActiveIndex()
 			if active != tt.active {
 				t.Errorf("LogRanges.Active() expected %d no error, got %d", tt.active, active)
 			}
@@ -94,50 +95,10 @@ func TestLogRanges_SetErr(t *testing.T) {
 	}
 	for _, tt := range tests {
 		t.Run(tt.name, func(t *testing.T) {
-			l := &LogRanges{}
+			l := &LogRangesFlag{}
 			if err := l.Set(tt.arg); err == nil {
 				t.Error("LogRanges.Set() expected error but got none")
 			}
 		})
 	}
 }
-
-func TestLogRanges_ResolveVirtualIndex(t *testing.T) {
-	lrs := LogRanges{
-		Ranges: []LogRange{
-			{TreeID: 1, TreeLength: 17},
-			{TreeID: 2, TreeLength: 1},
-			{TreeID: 3, TreeLength: 100},
-			{TreeID: 4},
-		},
-	}
-
-	for _, tt := range []struct {
-		Index      int
-		WantTreeID uint64
-		WantIndex  uint64
-	}{
-		{
-			Index:      3,
-			WantTreeID: 1, WantIndex: 3,
-		},
-		// This is the first (0th) entry in the next tree
-		{
-			Index:      17,
-			WantTreeID: 2, WantIndex: 0,
-		},
-		// Overflow
-		{
-			Index:      3000,
-			WantTreeID: 4, WantIndex: 2882,
-		},
-	} {
-		tree, index := lrs.ResolveVirtualIndex(tt.Index)
-		if tree != tt.WantTreeID {
-			t.Errorf("LogRanges.ResolveVirtualIndex() tree = %v, want %v", tree, tt.WantTreeID)
-		}
-		if index != tt.WantIndex {
-			t.Errorf("LogRanges.ResolveVirtualIndex() index = %v, want %v", index, tt.WantIndex)
-		}
-	}
-}
diff --git a/cmd/rekor-server/app/root.go b/cmd/rekor-server/app/root.go
index 2c7b104..f681db5 100644
--- a/cmd/rekor-server/app/root.go
+++ b/cmd/rekor-server/app/root.go
@@ -34,7 +34,7 @@ var (
 	cfgFile     string
 	logType     string
 	enablePprof bool
-	logRangeMap LogRanges
+	logRangeMap LogRangesFlag
 )
 
 // rootCmd represents the base command when called without any subcommands
diff --git a/cmd/rekor-server/app/serve.go b/cmd/rekor-server/app/serve.go
index ae13a63..768617f 100644
--- a/cmd/rekor-server/app/serve.go
+++ b/cmd/rekor-server/app/serve.go
@@ -102,7 +102,7 @@ var serveCmd = &cobra.Command{
 		server.Port = int(viper.GetUint("port"))
 		server.EnabledListeners = []string{"http"}
 
-		api.ConfigureAPI()
+		api.ConfigureAPI(logRangeMap.Ranges)
 		server.ConfigureAPI()
 
 		http.Handle("/metrics", promhttp.Handler())
diff --git a/pkg/api/api.go b/pkg/api/api.go
index 10a5485..5211ade 100644
--- a/pkg/api/api.go
+++ b/pkg/api/api.go
@@ -56,6 +56,7 @@ func dial(ctx context.Context, rpcServer string) (*grpc.ClientConn, error) {
 type API struct {
 	logClient    trillian.TrillianLogClient
 	logID        int64
+	logRanges    *LogRanges
 	pubkey       string // PEM encoded public key
 	pubkeyHash   string // SHA256 hash of DER-encoded public key
 	signer       signature.Signer
@@ -64,7 +65,7 @@ type API struct {
 	certChainPem string              // PEM encoded timestamping cert chain
 }
 
-func NewAPI() (*API, error) {
+func NewAPI(ranges LogRanges) (*API, error) {
 	logRPCServer := fmt.Sprintf("%s:%d",
 		viper.GetString("trillian_log_server.address"),
 		viper.GetUint("trillian_log_server.port"))
@@ -137,6 +138,7 @@ func NewAPI() (*API, error) {
 		// Transparency Log Stuff
 		logClient: logClient,
 		logID:     tLogID,
+		logRanges: &ranges,
 		// Signing/verifying fields
 		pubkey:     string(pubkey),
 		pubkeyHash: hex.EncodeToString(pubkeyHashBytes[:]),
@@ -154,10 +156,11 @@ var (
 	storageClient storage.AttestationStorage
 )
 
-func ConfigureAPI() {
+func ConfigureAPI(ranges LogRanges) {
 	cfg := radix.PoolConfig{}
 	var err error
-	api, err = NewAPI()
+
+	api, err = NewAPI(ranges)
 	if err != nil {
 		log.Logger.Panic(err)
 	}
diff --git a/pkg/api/ranges.go b/pkg/api/ranges.go
new file mode 100644
index 0000000..9b30e84
--- /dev/null
+++ b/pkg/api/ranges.go
@@ -0,0 +1,43 @@
+//
+// Copyright 2021 The Sigstore Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package api
+
+type LogRanges struct {
+	Ranges []LogRange
+}
+
+type LogRange struct {
+	TreeID     uint64
+	TreeLength uint64
+}
+
+func (l *LogRanges) ResolveVirtualIndex(index int) (uint64, uint64) {
+	indexLeft := index
+	for _, l := range l.Ranges {
+		if indexLeft < int(l.TreeLength) {
+			return l.TreeID, uint64(indexLeft)
+		}
+		indexLeft -= int(l.TreeLength)
+	}
+
+	// Return the last one!
+	return l.Ranges[len(l.Ranges)-1].TreeID, uint64(indexLeft)
+}
+
+// ActiveIndex returns the active shard index, always the last shard in the range
+func (l *LogRanges) ActiveIndex() uint64 {
+	return l.Ranges[len(l.Ranges)-1].TreeID
+}
diff --git a/pkg/api/ranges_test.go b/pkg/api/ranges_test.go
new file mode 100644
index 0000000..aad6a66
--- /dev/null
+++ b/pkg/api/ranges_test.go
@@ -0,0 +1,58 @@
+//
+// Copyright 2021 The Sigstore Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package api
+
+import "testing"
+
+func TestLogRanges_ResolveVirtualIndex(t *testing.T) {
+	lrs := LogRanges{
+		Ranges: []LogRange{
+			{TreeID: 1, TreeLength: 17},
+			{TreeID: 2, TreeLength: 1},
+			{TreeID: 3, TreeLength: 100},
+			{TreeID: 4},
+		},
+	}
+
+	for _, tt := range []struct {
+		Index      int
+		WantTreeID uint64
+		WantIndex  uint64
+	}{
+		{
+			Index:      3,
+			WantTreeID: 1, WantIndex: 3,
+		},
+		// This is the first (0th) entry in the next tree
+		{
+			Index:      17,
+			WantTreeID: 2, WantIndex: 0,
+		},
+		// Overflow
+		{
+			Index:      3000,
+			WantTreeID: 4, WantIndex: 2882,
+		},
+	} {
+		tree, index := lrs.ResolveVirtualIndex(tt.Index)
+		if tree != tt.WantTreeID {
+			t.Errorf("LogRanges.ResolveVirtualIndex() tree = %v, want %v", tree, tt.WantTreeID)
+		}
+		if index != tt.WantIndex {
+			t.Errorf("LogRanges.ResolveVirtualIndex() index = %v, want %v", index, tt.WantIndex)
+		}
+	}
+}
-- 
GitLab