From befbcc04aea3ad2149963575f300564949b00693 Mon Sep 17 00:00:00 2001
From: Lily Sturmann <lkatalin@users.noreply.github.com>
Date: Fri, 25 Mar 2022 22:52:04 -0400
Subject: [PATCH] Require tlog_id when inactive shard config file is passed in
 (#739)

tlog_id specifes the active shard and is kept for backwards compatibility.
To avoid replicating information, the shard config file is used only to
specify inactive shards and must be used in conjunction with a tlog_id flag.
Together, these build the logRanges type in the sharding module.

Signed-off-by: Lily Sturmann <lsturman@redhat.com>
---
 cmd/rekor-server/app/serve.go  |  2 +-
 pkg/api/api.go                 |  3 +-
 pkg/api/tlog.go                |  2 +-
 pkg/sharding/log_index.go      | 20 +++++++++---
 pkg/sharding/log_index_test.go | 42 +++++++++++-------------
 pkg/sharding/ranges.go         | 58 +++++++++++++++++++---------------
 pkg/sharding/ranges_test.go    | 21 ++++++------
 7 files changed, 78 insertions(+), 70 deletions(-)

diff --git a/cmd/rekor-server/app/serve.go b/cmd/rekor-server/app/serve.go
index 26f5d42..e55d901 100644
--- a/cmd/rekor-server/app/serve.go
+++ b/cmd/rekor-server/app/serve.go
@@ -106,7 +106,7 @@ var serveCmd = &cobra.Command{
 
 		// Update logRangeMap if flag was passed in
 		shardingConfig := viper.GetString("trillian_log_server.sharding_config")
-		treeID := viper.GetString("trillian_log_server.tlog_id")
+		treeID := viper.GetUint("trillian_log_server.tlog_id")
 
 		ranges, err := sharding.NewLogRanges(shardingConfig, treeID)
 		if err != nil {
diff --git a/pkg/api/api.go b/pkg/api/api.go
index 8469459..8d96e09 100644
--- a/pkg/api/api.go
+++ b/pkg/api/api.go
@@ -87,8 +87,7 @@ func NewAPI(ranges sharding.LogRanges) (*API, error) {
 		}
 		tLogID = t.TreeId
 	}
-	// append the active treeID to the API's logRangeMap for lookups
-	ranges.AppendRange(sharding.LogRange{TreeID: tLogID})
+	ranges.SetActive(tLogID)
 
 	rekorSigner, err := signer.New(ctx, viper.GetString("rekor_server.signer"))
 	if err != nil {
diff --git a/pkg/api/tlog.go b/pkg/api/tlog.go
index 6629a4a..60afaa8 100644
--- a/pkg/api/tlog.go
+++ b/pkg/api/tlog.go
@@ -42,7 +42,7 @@ func GetLogInfoHandler(params tlog.GetLogInfoParams) middleware.Responder {
 
 	// for each inactive shard, get the loginfo
 	var inactiveShards []*models.InactiveShardLogInfo
-	for _, shard := range tc.ranges.GetRanges() {
+	for _, shard := range tc.ranges.GetInactive() {
 		if shard.TreeID == tc.ranges.ActiveTreeID() {
 			break
 		}
diff --git a/pkg/sharding/log_index.go b/pkg/sharding/log_index.go
index 443fda5..dcdfc10 100644
--- a/pkg/sharding/log_index.go
+++ b/pkg/sharding/log_index.go
@@ -16,18 +16,28 @@ package sharding
 
 // VirtualLogIndex returns the virtual log index for a given leaf index
 func VirtualLogIndex(leafIndex int64, tid int64, ranges LogRanges) int64 {
-	// if we have no ranges, we have just one log! return the leafIndex as is
-	if ranges.Empty() {
-		return leafIndex
+	// if we have no inactive ranges, we have just one log! return the leafIndex as is
+	// as long as it matches the active tree ID
+	if ranges.NoInactive() {
+		if ranges.GetActive() == tid {
+			return leafIndex
+		}
+		return -1
 	}
 
 	var virtualIndex int64
-	for _, r := range ranges.GetRanges() {
+	for _, r := range ranges.GetInactive() {
 		if r.TreeID == tid {
 			return virtualIndex + leafIndex
 		}
 		virtualIndex += r.TreeLength
 	}
-	// this should never happen
+
+	// If no TreeID in Inactive matches the tid, the virtual index should be the active tree
+	if ranges.GetActive() == tid {
+		return virtualIndex + leafIndex
+	}
+
+	// Otherwise, the tid is invalid
 	return -1
 }
diff --git a/pkg/sharding/log_index_test.go b/pkg/sharding/log_index_test.go
index b99274a..039c4ef 100644
--- a/pkg/sharding/log_index_test.go
+++ b/pkg/sharding/log_index_test.go
@@ -33,50 +33,48 @@ func TestVirtualLogIndex(t *testing.T) {
 			expectedIndex: 5,
 		},
 		// Log 100: 0 1 2 3 4
-		// Log 300: 5 6 7
+		// Log 300: 5 6 7...
 		{
 			description: "two shards",
 			leafIndex:   2,
 			tid:         300,
 			ranges: LogRanges{
-				ranges: []LogRange{
+				inactive: []LogRange{
 					{
 						TreeID:     100,
 						TreeLength: 5,
-					}, {
-						TreeID: 300,
-					},
-				},
+					}},
+				active: 300,
 			},
 			expectedIndex: 7,
-		}, {
+		},
+		// Log 100: 0 1 2 3 4
+		// Log 300: 5 6 7 8
+		// Log 400: ...
+		{
 			description: "three shards",
 			leafIndex:   1,
 			tid:         300,
 			ranges: LogRanges{
-				ranges: []LogRange{
+				inactive: []LogRange{
 					{
 						TreeID:     100,
 						TreeLength: 5,
 					}, {
 						TreeID:     300,
 						TreeLength: 4,
-					}, {
-						TreeID: 400,
-					},
-				},
+					}},
+				active: 400,
 			},
 			expectedIndex: 6,
-		}, {
-			description: "ranges is empty but not-nil",
+		},
+		// Log 30: 1 2 3...
+		{
+			description: "only active tree",
 			leafIndex:   2,
 			tid:         30,
 			ranges: LogRanges{
-				ranges: []LogRange{
-					{
-						TreeID: 30,
-					},
-				},
+				active: 30,
 			},
 			expectedIndex: 2,
 		}, {
@@ -84,11 +82,7 @@ func TestVirtualLogIndex(t *testing.T) {
 			leafIndex:   2,
 			tid:         4,
 			ranges: LogRanges{
-				ranges: []LogRange{
-					{
-						TreeID: 30,
-					},
-				},
+				active: 30,
 			},
 			expectedIndex: -1,
 		},
diff --git a/pkg/sharding/ranges.go b/pkg/sharding/ranges.go
index edf3687..22750c3 100644
--- a/pkg/sharding/ranges.go
+++ b/pkg/sharding/ranges.go
@@ -18,7 +18,6 @@ package sharding
 import (
 	"fmt"
 	"io/ioutil"
-	"strconv"
 	"strings"
 
 	"github.com/ghodss/yaml"
@@ -26,7 +25,8 @@ import (
 )
 
 type LogRanges struct {
-	ranges Ranges
+	inactive Ranges
+	active   int64
 }
 
 type Ranges []LogRange
@@ -36,13 +36,12 @@ type LogRange struct {
 	TreeLength int64 `yaml:"treeLength"`
 }
 
-func NewLogRanges(path string, treeID string) (LogRanges, error) {
+func NewLogRanges(path string, treeID uint) (LogRanges, error) {
 	if path == "" {
 		return LogRanges{}, nil
 	}
-	id, err := strconv.Atoi(treeID)
-	if err != nil {
-		return LogRanges{}, errors.Wrapf(err, "%s is not a valid int64", treeID)
+	if treeID == 0 {
+		return LogRanges{}, errors.New("non-zero tlog_id required when passing in shard config filepath; please set the active tree ID via the `--trillian_log_server.tlog_id` flag")
 	}
 	// otherwise, try to read contents of the sharding config
 	var ranges Ranges
@@ -53,59 +52,68 @@ func NewLogRanges(path string, treeID string) (LogRanges, error) {
 	if err := yaml.Unmarshal(contents, &ranges); err != nil {
 		return LogRanges{}, err
 	}
-	ranges = append(ranges, LogRange{TreeID: int64(id)})
 	return LogRanges{
-		ranges: ranges,
+		inactive: ranges,
+		active:   int64(treeID),
 	}, nil
 }
 
 func (l *LogRanges) ResolveVirtualIndex(index int) (int64, int64) {
 	indexLeft := index
-	for _, l := range l.ranges {
+	for _, l := range l.inactive {
 		if indexLeft < int(l.TreeLength) {
 			return l.TreeID, int64(indexLeft)
 		}
 		indexLeft -= int(l.TreeLength)
 	}
 
-	// Return the last one!
-	return l.ranges[len(l.ranges)-1].TreeID, int64(indexLeft)
+	// If index not found in inactive trees, return the active tree
+	return l.active, int64(indexLeft)
 }
 
-// ActiveTreeID returns the active shard index, always the last shard in the range
 func (l *LogRanges) ActiveTreeID() int64 {
-	return l.ranges[len(l.ranges)-1].TreeID
+	return l.active
 }
 
-func (l *LogRanges) Empty() bool {
-	return l.ranges == nil
+func (l *LogRanges) NoInactive() bool {
+	return l.inactive == nil
 }
 
-// TotalLength returns the total length across all shards
-func (l *LogRanges) TotalLength() int64 {
+// TotalInactiveLength returns the total length across all inactive shards;
+// we don't know the length of the active shard.
+func (l *LogRanges) TotalInactiveLength() int64 {
 	var total int64
-	for _, r := range l.ranges {
+	for _, r := range l.inactive {
 		total += r.TreeLength
 	}
 	return total
 }
 
-func (l *LogRanges) SetRanges(r []LogRange) {
-	l.ranges = r
+func (l *LogRanges) SetInactive(r []LogRange) {
+	l.inactive = r
+}
+
+func (l *LogRanges) GetInactive() []LogRange {
+	return l.inactive
+}
+
+func (l *LogRanges) AppendInactive(r LogRange) {
+	l.inactive = append(l.inactive, r)
 }
 
-func (l *LogRanges) GetRanges() []LogRange {
-	return l.ranges
+func (l *LogRanges) SetActive(i int64) {
+	l.active = i
 }
 
-func (l *LogRanges) AppendRange(r LogRange) {
-	l.ranges = append(l.ranges, r)
+func (l *LogRanges) GetActive() int64 {
+	return l.active
 }
 
 func (l *LogRanges) String() string {
 	ranges := []string{}
-	for _, r := range l.ranges {
+	for _, r := range l.inactive {
 		ranges = append(ranges, fmt.Sprintf("%d=%d", r.TreeID, r.TreeLength))
 	}
+	ranges = append(ranges, fmt.Sprintf("active=%d", l.active))
 	return strings.Join(ranges, ",")
 }
diff --git a/pkg/sharding/ranges_test.go b/pkg/sharding/ranges_test.go
index 189380d..8ff228b 100644
--- a/pkg/sharding/ranges_test.go
+++ b/pkg/sharding/ranges_test.go
@@ -32,19 +32,17 @@ func TestNewLogRanges(t *testing.T) {
 	if err := ioutil.WriteFile(file, []byte(contents), 0644); err != nil {
 		t.Fatal(err)
 	}
-	treeID := "45"
+	treeID := uint(45)
 	expected := LogRanges{
-		ranges: []LogRange{
+		inactive: []LogRange{
 			{
 				TreeID:     1,
 				TreeLength: 3,
 			}, {
 				TreeID:     2,
 				TreeLength: 4,
-			}, {
-				TreeID: 45,
-			},
-		},
+			}},
+		active: int64(45),
 	}
 	got, err := NewLogRanges(file, treeID)
 	if err != nil {
@@ -53,19 +51,18 @@ func TestNewLogRanges(t *testing.T) {
 	if expected.ActiveTreeID() != got.ActiveTreeID() {
 		t.Fatalf("expected tree id %d got %d", expected.ActiveTreeID(), got.ActiveTreeID())
 	}
-	if !reflect.DeepEqual(expected.GetRanges(), got.GetRanges()) {
-		t.Fatalf("expected %v got %v", expected.GetRanges(), got.GetRanges())
+	if !reflect.DeepEqual(expected.GetInactive(), got.GetInactive()) {
+		t.Fatalf("expected %v got %v", expected.GetInactive(), got.GetInactive())
 	}
 }
 
 func TestLogRanges_ResolveVirtualIndex(t *testing.T) {
 	lrs := LogRanges{
-		ranges: []LogRange{
+		inactive: []LogRange{
 			{TreeID: 1, TreeLength: 17},
 			{TreeID: 2, TreeLength: 1},
-			{TreeID: 3, TreeLength: 100},
-			{TreeID: 4},
-		},
+			{TreeID: 3, TreeLength: 100}},
+		active: 4,
 	}
 
 	for _, tt := range []struct {
-- 
GitLab