Skip to content
Snippets Groups Projects
pflags.go 8.14 KiB
//
// Copyright 2021 The Sigstore Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package app

import (
	"fmt"
	"log"
	"strconv"
	"strings"
	"time"

	"github.com/sigstore/rekor/pkg/pki"
	"github.com/sigstore/rekor/pkg/sharding"
	"github.com/sigstore/rekor/pkg/util"

	"github.com/spf13/pflag"

	validator "github.com/go-playground/validator/v10"
	"github.com/pkg/errors"
)

type FlagType string

const (
	uuidFlag      FlagType = "uuid"
	shaFlag       FlagType = "sha"
	emailFlag     FlagType = "email"
	logIndexFlag  FlagType = "logIndex"
	pkiFormatFlag FlagType = "pkiFormat"
	typeFlag      FlagType = "type"
	fileFlag      FlagType = "file"
	urlFlag       FlagType = "url"
	fileOrURLFlag FlagType = "fileOrURL"
	oidFlag       FlagType = "oid"
	formatFlag    FlagType = "format"
	timeoutFlag   FlagType = "timeout"
)

type newPFlagValueFunc func() pflag.Value

var pflagValueFuncMap map[FlagType]newPFlagValueFunc

// TODO: unit tests for all of this
func initializePFlagMap() {
	pflagValueFuncMap = map[FlagType]newPFlagValueFunc{
		uuidFlag: func() pflag.Value {
			// this validates a UUID with or without a prepended TreeID;
			// the UUID corresponds to the merkle leaf hash of entries,
			// which is represented by a 64 character hexadecimal string
			return valueFactory(uuidFlag, validateID, "")
		},
		shaFlag: func() pflag.Value {
			// this validates a valid sha256 checksum which is optionally prefixed with 'sha256:'
			return valueFactory(shaFlag, validateSHAValue, "")
		},
		emailFlag: func() pflag.Value {
			// this validates an email address
			return valueFactory(emailFlag, validateString("required,email"), "")
		},
		logIndexFlag: func() pflag.Value {
			// this checks for a valid integer >= 0
			return valueFactory(logIndexFlag, validateLogIndex, "")
		},
		pkiFormatFlag: func() pflag.Value {
			// this ensures a PKI implementation exists for the requested format
			return valueFactory(pkiFormatFlag, validateString(fmt.Sprintf("required,oneof=%v", strings.Join(pki.SupportedFormats(), " "))), "pgp")
		},
		typeFlag: func() pflag.Value {
			// this ensures the type of the log entry matches a type supported in the CLI
			return valueFactory(typeFlag, validateTypeFlag, "rekord")
		},
		fileFlag: func() pflag.Value {
			// this validates that the file exists and can be opened by the current uid
			return valueFactory(fileFlag, validateString("required,file"), "")
		},
		urlFlag: func() pflag.Value {
			// this validates that the string is a valid http/https URL
			return valueFactory(urlFlag, validateString("required,url,startswith=http|startswith=https"), "")
		},
		fileOrURLFlag: func() pflag.Value {
			// applies logic of fileFlag OR urlFlag validators from above
			return valueFactory(fileOrURLFlag, validateFileOrURL, "")
		},
		oidFlag: func() pflag.Value {
			// this validates for an OID, which is a sequence of positive integers separated by periods
			return valueFactory(oidFlag, validateOID, "")
		},
		formatFlag: func() pflag.Value {
			// this validates the output format requested
			return valueFactory(formatFlag, validateString("required,oneof=json default"), "")
		},
		timeoutFlag: func() pflag.Value {
			// this validates the timeout is >= 0
			return valueFactory(formatFlag, validateTimeout, "")
		},
	}
}

// NewFlagValue creates a new pflag.Value for the specified type with the specified default value.
// If a default value is not desired, pass "" for defaultVal.
func NewFlagValue(flagType FlagType, defaultVal string) pflag.Value {
	valFunc := pflagValueFuncMap[flagType]
	val := valFunc()
	if defaultVal != "" {
		if err := val.Set(defaultVal); err != nil {
			log.Fatal(errors.Wrap(err, "initializing flag"))
		}
	}
	return val
}

type validationFunc func(string) error

func valueFactory(flagType FlagType, v validationFunc, defaultVal string) pflag.Value {
	return &baseValue{
		flagType:       flagType,
		validationFunc: v,
		value:          defaultVal,
	}
}

// baseValue implements pflag.Value
type baseValue struct {
	flagType       FlagType
	value          string
	validationFunc validationFunc
}

// Type returns the type of this Value
func (b baseValue) Type() string {
	return string(b.flagType)
}

// String returns the string representation of this Value
func (b baseValue) String() string {
	return b.value
}

// Set validates the provided string against the appropriate validation rule
// for b.flagType; if the string validates, it is stored in the Value and nil is returned.
// Otherwise the validation error is returned but the state of the Value is not changed.
func (b *baseValue) Set(s string) error {
	if err := b.validationFunc(s); err != nil {
		return err
	}
	b.value = s
	return nil
}

// isURL returns true if the supplied value is a valid URL and false otherwise
func isURL(v string) bool {
	valGen := pflagValueFuncMap[urlFlag]
	return valGen().Set(v) == nil
}

// validateSHAValue ensures that the supplied string matches the following formats:
// [sha256:]<64 hexadecimal characters>
// [sha1:]<40 hexadecimal characters>
// where [sha256:] and [sha1:] are optional
func validateSHAValue(v string) error {
	err := util.ValidateSHA1Value(v)
	if err == nil {
		return nil
	}

	if err := util.ValidateSHA256Value(v); err != nil {
		return fmt.Errorf("error parsing %v flag: %w", shaFlag, err)
	}

	return nil
}

// validateFileOrURL ensures the provided string is either a valid file path that can be opened or a valid URL
func validateFileOrURL(v string) error {
	valGen := pflagValueFuncMap[fileFlag]
	if valGen().Set(v) == nil {
		return nil
	}
	valGen = pflagValueFuncMap[urlFlag]
	return valGen().Set(v)
}

// validateID ensures the ID is either an EntryID (TreeID + UUID) or a UUID
func validateID(v string) error {
	if len(v) != sharding.EntryIDHexStringLen && len(v) != sharding.UUIDHexStringLen {
		return fmt.Errorf("ID len error, expected %v (EntryID) or %v (UUID) but got len %v for ID %v", sharding.EntryIDHexStringLen, sharding.UUIDHexStringLen, len(v), v)
	}

	if err := validateString("required,hexadecimal")(v); err != nil {
		return fmt.Errorf("invalid uuid: %v", v)
	}

	return nil
}

// validateLogIndex ensures that the supplied string is a valid log index (integer >= 0)
func validateLogIndex(v string) error {
	i, err := strconv.Atoi(v)
	if err != nil {
		return err
	}
	l := struct {
		Index int `validate:"gte=0"`
	}{i}

	return useValidator(logIndexFlag, l)
}

// validateOID ensures that the supplied string is a valid ASN.1 object identifier
func validateOID(v string) error {
	o := struct {
		Oid []string `validate:"dive,numeric"`
	}{strings.Split(v, ".")}

	return useValidator(oidFlag, o)
}

// validateTimeout ensures that the supplied string is a valid time.Duration value >= 0
func validateTimeout(v string) error {
	duration, err := time.ParseDuration(v)
	if err != nil {
		return err
	}
	d := struct {
		Duration time.Duration `validate:"min=0"`
	}{duration}
	return useValidator(timeoutFlag, d)
}

// validateTypeFlag ensures that the string is in the format type(\.version)? and
// that one of the types requested is implemented
func validateTypeFlag(v string) error {
	_, _, err := ParseTypeFlag(v)
	return err
}

// validateString returns a function that validates an input string against the specified tag,
// as defined in the format supported by go-playground/validator
func validateString(tag string) validationFunc {
	return func(v string) error {
		validator := validator.New()
		return validator.Var(v, tag)
	}
}

// useValidator performs struct level validation on s as defined in the struct's tags using
// the go-playground/validator library
func useValidator(flagType FlagType, s interface{}) error {
	validate := validator.New()
	if err := validate.Struct(s); err != nil {
		return fmt.Errorf("error parsing %v flag: %w", flagType, err)
	}

	return nil
}