Skip to content
Snippets Groups Projects
Unverified Commit 40d9419f authored by dlorenc's avatar dlorenc Committed by GitHub
Browse files

Refactor the shard map parsing so we can pass it down into the API object. (#564)


Right now the type itself is defined in the cli package, which means we can't
use it without an import cycle.

Signed-off-by: default avatarDan Lorenc <lorenc.d@gmail.com>
parent be91b55e
No related branches found
No related tags found
No related merge requests found
...@@ -19,23 +19,20 @@ import ( ...@@ -19,23 +19,20 @@ import (
"fmt" "fmt"
"strconv" "strconv"
"strings" "strings"
)
type LogRange struct { "github.com/sigstore/rekor/pkg/api"
TreeID uint64 )
TreeLength uint64
}
type LogRanges struct { type LogRangesFlag struct {
Ranges []LogRange Ranges api.LogRanges
} }
func (l *LogRanges) Set(s string) error { func (l *LogRangesFlag) Set(s string) error {
ranges := strings.Split(s, ",") ranges := strings.Split(s, ",")
l.Ranges = []LogRange{} l.Ranges = api.LogRanges{}
var err error var err error
inputRanges := []LogRange{} inputRanges := []api.LogRange{}
// Only go up to the second to last one, the last one is special cased beow // Only go up to the second to last one, the last one is special cased beow
for _, r := range ranges[:len(ranges)-1] { for _, r := range ranges[:len(ranges)-1] {
...@@ -43,7 +40,7 @@ func (l *LogRanges) Set(s string) error { ...@@ -43,7 +40,7 @@ func (l *LogRanges) Set(s string) error {
if len(split) != 2 { if len(split) != 2 {
return fmt.Errorf("invalid range flag, expected two parts separated by an =, got %s", r) return fmt.Errorf("invalid range flag, expected two parts separated by an =, got %s", r)
} }
lr := LogRange{} lr := api.LogRange{}
lr.TreeID, err = strconv.ParseUint(split[0], 10, 64) lr.TreeID, err = strconv.ParseUint(split[0], 10, 64)
if err != nil { if err != nil {
return err return err
...@@ -63,7 +60,7 @@ func (l *LogRanges) Set(s string) error { ...@@ -63,7 +60,7 @@ func (l *LogRanges) Set(s string) error {
return err return err
} }
inputRanges = append(inputRanges, LogRange{ inputRanges = append(inputRanges, api.LogRange{
TreeID: lastTreeID, TreeID: lastTreeID,
}) })
...@@ -76,36 +73,20 @@ func (l *LogRanges) Set(s string) error { ...@@ -76,36 +73,20 @@ func (l *LogRanges) Set(s string) error {
TreeIDs[lr.TreeID] = struct{}{} TreeIDs[lr.TreeID] = struct{}{}
} }
l.Ranges = inputRanges l.Ranges = api.LogRanges{
Ranges: inputRanges,
}
return nil return nil
} }
func (l *LogRanges) String() string { func (l *LogRangesFlag) String() string {
ranges := []string{} ranges := []string{}
for _, r := range l.Ranges { for _, r := range l.Ranges.Ranges {
ranges = append(ranges, fmt.Sprintf("%d=%d", r.TreeID, r.TreeLength)) ranges = append(ranges, fmt.Sprintf("%d=%d", r.TreeID, r.TreeLength))
} }
return strings.Join(ranges, ",") return strings.Join(ranges, ",")
} }
func (l *LogRanges) Type() string { func (l *LogRangesFlag) Type() string {
return "LogRanges" return "LogRangesFlag"
}
func (l *LogRanges) ResolveVirtualIndex(index int) (uint64, uint64) {
indexLeft := index
for _, l := range l.Ranges {
if indexLeft < int(l.TreeLength) {
return l.TreeID, uint64(indexLeft)
}
indexLeft -= int(l.TreeLength)
}
// Return the last one!
return l.Ranges[len(l.Ranges)-1].TreeID, uint64(indexLeft)
}
// ActiveIndex returns the active shard index, always the last shard in the range
func (l *LogRanges) ActiveIndex() uint64 {
return l.Ranges[len(l.Ranges)-1].TreeID
} }
...@@ -19,19 +19,20 @@ import ( ...@@ -19,19 +19,20 @@ import (
"testing" "testing"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/sigstore/rekor/pkg/api"
) )
func TestLogRanges_Set(t *testing.T) { func TestLogRanges_Set(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
arg string arg string
want []LogRange want []api.LogRange
active uint64 active uint64
}{ }{
{ {
name: "one, no length", name: "one, no length",
arg: "1234", arg: "1234",
want: []LogRange{ want: []api.LogRange{
{ {
TreeID: 1234, TreeID: 1234,
TreeLength: 0, TreeLength: 0,
...@@ -42,7 +43,7 @@ func TestLogRanges_Set(t *testing.T) { ...@@ -42,7 +43,7 @@ func TestLogRanges_Set(t *testing.T) {
{ {
name: "two", name: "two",
arg: "1234=10,7234", arg: "1234=10,7234",
want: []LogRange{ want: []api.LogRange{
{ {
TreeID: 1234, TreeID: 1234,
TreeLength: 10, TreeLength: 10,
...@@ -57,16 +58,16 @@ func TestLogRanges_Set(t *testing.T) { ...@@ -57,16 +58,16 @@ func TestLogRanges_Set(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
l := &LogRanges{} l := &LogRangesFlag{}
if err := l.Set(tt.arg); err != nil { if err := l.Set(tt.arg); err != nil {
t.Errorf("LogRanges.Set() expected no error, got %v", err) t.Errorf("LogRanges.Set() expected no error, got %v", err)
} }
if diff := cmp.Diff(tt.want, l.Ranges); diff != "" { if diff := cmp.Diff(tt.want, l.Ranges.Ranges); diff != "" {
t.Errorf(diff) t.Errorf(diff)
} }
active := l.ActiveIndex() active := l.Ranges.ActiveIndex()
if active != tt.active { if active != tt.active {
t.Errorf("LogRanges.Active() expected %d no error, got %d", tt.active, active) t.Errorf("LogRanges.Active() expected %d no error, got %d", tt.active, active)
} }
...@@ -94,50 +95,10 @@ func TestLogRanges_SetErr(t *testing.T) { ...@@ -94,50 +95,10 @@ func TestLogRanges_SetErr(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
l := &LogRanges{} l := &LogRangesFlag{}
if err := l.Set(tt.arg); err == nil { if err := l.Set(tt.arg); err == nil {
t.Error("LogRanges.Set() expected error but got none") t.Error("LogRanges.Set() expected error but got none")
} }
}) })
} }
} }
func TestLogRanges_ResolveVirtualIndex(t *testing.T) {
lrs := LogRanges{
Ranges: []LogRange{
{TreeID: 1, TreeLength: 17},
{TreeID: 2, TreeLength: 1},
{TreeID: 3, TreeLength: 100},
{TreeID: 4},
},
}
for _, tt := range []struct {
Index int
WantTreeID uint64
WantIndex uint64
}{
{
Index: 3,
WantTreeID: 1, WantIndex: 3,
},
// This is the first (0th) entry in the next tree
{
Index: 17,
WantTreeID: 2, WantIndex: 0,
},
// Overflow
{
Index: 3000,
WantTreeID: 4, WantIndex: 2882,
},
} {
tree, index := lrs.ResolveVirtualIndex(tt.Index)
if tree != tt.WantTreeID {
t.Errorf("LogRanges.ResolveVirtualIndex() tree = %v, want %v", tree, tt.WantTreeID)
}
if index != tt.WantIndex {
t.Errorf("LogRanges.ResolveVirtualIndex() index = %v, want %v", index, tt.WantIndex)
}
}
}
...@@ -34,7 +34,7 @@ var ( ...@@ -34,7 +34,7 @@ var (
cfgFile string cfgFile string
logType string logType string
enablePprof bool enablePprof bool
logRangeMap LogRanges logRangeMap LogRangesFlag
) )
// rootCmd represents the base command when called without any subcommands // rootCmd represents the base command when called without any subcommands
......
...@@ -102,7 +102,7 @@ var serveCmd = &cobra.Command{ ...@@ -102,7 +102,7 @@ var serveCmd = &cobra.Command{
server.Port = int(viper.GetUint("port")) server.Port = int(viper.GetUint("port"))
server.EnabledListeners = []string{"http"} server.EnabledListeners = []string{"http"}
api.ConfigureAPI() api.ConfigureAPI(logRangeMap.Ranges)
server.ConfigureAPI() server.ConfigureAPI()
http.Handle("/metrics", promhttp.Handler()) http.Handle("/metrics", promhttp.Handler())
......
...@@ -56,6 +56,7 @@ func dial(ctx context.Context, rpcServer string) (*grpc.ClientConn, error) { ...@@ -56,6 +56,7 @@ func dial(ctx context.Context, rpcServer string) (*grpc.ClientConn, error) {
type API struct { type API struct {
logClient trillian.TrillianLogClient logClient trillian.TrillianLogClient
logID int64 logID int64
logRanges *LogRanges
pubkey string // PEM encoded public key pubkey string // PEM encoded public key
pubkeyHash string // SHA256 hash of DER-encoded public key pubkeyHash string // SHA256 hash of DER-encoded public key
signer signature.Signer signer signature.Signer
...@@ -64,7 +65,7 @@ type API struct { ...@@ -64,7 +65,7 @@ type API struct {
certChainPem string // PEM encoded timestamping cert chain certChainPem string // PEM encoded timestamping cert chain
} }
func NewAPI() (*API, error) { func NewAPI(ranges LogRanges) (*API, error) {
logRPCServer := fmt.Sprintf("%s:%d", logRPCServer := fmt.Sprintf("%s:%d",
viper.GetString("trillian_log_server.address"), viper.GetString("trillian_log_server.address"),
viper.GetUint("trillian_log_server.port")) viper.GetUint("trillian_log_server.port"))
...@@ -137,6 +138,7 @@ func NewAPI() (*API, error) { ...@@ -137,6 +138,7 @@ func NewAPI() (*API, error) {
// Transparency Log Stuff // Transparency Log Stuff
logClient: logClient, logClient: logClient,
logID: tLogID, logID: tLogID,
logRanges: &ranges,
// Signing/verifying fields // Signing/verifying fields
pubkey: string(pubkey), pubkey: string(pubkey),
pubkeyHash: hex.EncodeToString(pubkeyHashBytes[:]), pubkeyHash: hex.EncodeToString(pubkeyHashBytes[:]),
...@@ -154,10 +156,11 @@ var ( ...@@ -154,10 +156,11 @@ var (
storageClient storage.AttestationStorage storageClient storage.AttestationStorage
) )
func ConfigureAPI() { func ConfigureAPI(ranges LogRanges) {
cfg := radix.PoolConfig{} cfg := radix.PoolConfig{}
var err error var err error
api, err = NewAPI()
api, err = NewAPI(ranges)
if err != nil { if err != nil {
log.Logger.Panic(err) log.Logger.Panic(err)
} }
......
//
// Copyright 2021 The Sigstore Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package api
type LogRanges struct {
Ranges []LogRange
}
type LogRange struct {
TreeID uint64
TreeLength uint64
}
func (l *LogRanges) ResolveVirtualIndex(index int) (uint64, uint64) {
indexLeft := index
for _, l := range l.Ranges {
if indexLeft < int(l.TreeLength) {
return l.TreeID, uint64(indexLeft)
}
indexLeft -= int(l.TreeLength)
}
// Return the last one!
return l.Ranges[len(l.Ranges)-1].TreeID, uint64(indexLeft)
}
// ActiveIndex returns the active shard index, always the last shard in the range
func (l *LogRanges) ActiveIndex() uint64 {
return l.Ranges[len(l.Ranges)-1].TreeID
}
//
// Copyright 2021 The Sigstore Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package api
import "testing"
func TestLogRanges_ResolveVirtualIndex(t *testing.T) {
lrs := LogRanges{
Ranges: []LogRange{
{TreeID: 1, TreeLength: 17},
{TreeID: 2, TreeLength: 1},
{TreeID: 3, TreeLength: 100},
{TreeID: 4},
},
}
for _, tt := range []struct {
Index int
WantTreeID uint64
WantIndex uint64
}{
{
Index: 3,
WantTreeID: 1, WantIndex: 3,
},
// This is the first (0th) entry in the next tree
{
Index: 17,
WantTreeID: 2, WantIndex: 0,
},
// Overflow
{
Index: 3000,
WantTreeID: 4, WantIndex: 2882,
},
} {
tree, index := lrs.ResolveVirtualIndex(tt.Index)
if tree != tt.WantTreeID {
t.Errorf("LogRanges.ResolveVirtualIndex() tree = %v, want %v", tree, tt.WantTreeID)
}
if index != tt.WantIndex {
t.Errorf("LogRanges.ResolveVirtualIndex() index = %v, want %v", index, tt.WantIndex)
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment