Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitaly.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--internal/validator/error.go77
-rw-r--r--internal/validator/validate.go81
-rw-r--r--internal/validator/validator.go97
-rw-r--r--internal/validator/validator_test.go210
4 files changed, 465 insertions, 0 deletions
diff --git a/internal/validator/error.go b/internal/validator/error.go
new file mode 100644
index 000000000..f9cee3cde
--- /dev/null
+++ b/internal/validator/error.go
@@ -0,0 +1,77 @@
+package validator
+
+import (
+ "errors"
+ "fmt"
+ "strings"
+)
+
+// Error combines a key with an erorr. The Key is the path to the field
+// within the struct.
+type Error struct {
+ Key []string
+ Cause error
+}
+
+// NewError wraps the error with a key so it's location in the struct
+// can be propagated upwards.
+func NewError(key string, err error) error {
+ if key == "" {
+ // There's no reason to annotate something with an empty key. This
+ // handles not adding a key for the root node.
+ return err
+ }
+
+ var keyedErr Error
+ if errors.As(err, &keyedErr) {
+ // If the error is a KeyedError, prepend the key to the existing key.
+ keyedErr.Key = append([]string{key}, keyedErr.Key...)
+ return keyedErr
+ }
+
+ return Error{Key: []string{key}, Cause: err}
+}
+
+// Error returns the error message.
+func (err Error) Error() string {
+ return fmt.Sprintf("%s: %s", strings.Join(err.Key, "."), err.Cause)
+}
+
+// Unwrap the underlying error.
+func (err Error) Unwrap() error {
+ return err.Cause
+}
+
+// Errors is a collection of multiple errors.
+type Errors []error
+
+// Append adds an error to errs. If the err is Errors type itself,
+// it's flattened and each error appended to errs.
+func (errs Errors) Append(err error) Errors {
+ if otherErrs, ok := err.(Errors); ok {
+ return append(errs, otherErrs...)
+ }
+
+ return append(errs, err)
+}
+
+// Error returns the error message.
+func (errs Errors) Error() string {
+ var str []string
+ for _, err := range errs {
+ str = append(str, err.Error())
+ }
+
+ return strings.Join(str, "\n")
+}
+
+// ErrorOrNil returns an error if there are some errors, nil otherwise.
+// This should be used instead of returning a typed error to avoid
+// nil interface comparison problems.
+func (errs Errors) ErrorOrNil() error {
+ if len(errs) == 0 {
+ return nil
+ }
+
+ return errs
+}
diff --git a/internal/validator/validate.go b/internal/validator/validate.go
new file mode 100644
index 000000000..5162b7020
--- /dev/null
+++ b/internal/validator/validate.go
@@ -0,0 +1,81 @@
+package validator
+
+import (
+ "reflect"
+ "strconv"
+)
+
+// Func is the type of a validation function that validates a value
+// and returns an error if the validation fails.
+type Func[T any] func(T) error
+
+// Validator is an interface that provides validation functionality.
+type Validator interface {
+ // Validate the value and returns an error describing the validation
+ // failure if any.
+ Validate() error
+}
+
+// Validate validates the given value by invoking the value's Validate()
+// method. If the value is a struct or a slice, each field and element are
+// also recursed into and Validate() invoked on each field and element.
+// If Validate() return an error, it is wrapped into an Error to annotate
+// it with the field's or element's path in the hierarchy.
+func Validate(value any) error {
+ return validate("", reflect.ValueOf(value))
+}
+
+type skipper interface {
+ // skipTo returns the field that the validation walk should consider next.
+ // Useful for skipping some private fields from the walk.
+ skipTo() reflect.Value
+}
+
+func validate(key string, value reflect.Value) error {
+ var errs Errors
+ if v, ok := value.Interface().(Validator); ok {
+ if err := v.Validate(); err != nil {
+ errs = errs.Append(err)
+ }
+ }
+
+ if v, ok := value.Interface().(skipper); ok {
+ value = v.skipTo()
+ }
+
+ switch value.Kind() {
+ case reflect.Struct:
+ for i := 0; i < value.NumField(); i++ {
+ if err := validate(
+ tomlName(value.Type().Field(i)),
+ value.Field(i),
+ ); err != nil {
+ errs = errs.Append(err)
+ }
+ }
+ case reflect.Slice:
+ for i := 0; i < value.Len(); i++ {
+ if err := validate(
+ strconv.FormatInt(int64(i), 10),
+ value.Index(i),
+ ); err != nil {
+ errs = errs.Append(err)
+ }
+ }
+ }
+
+ for i := range errs {
+ errs[i] = NewError(key, errs[i])
+ }
+
+ return errs.ErrorOrNil()
+}
+
+func tomlName(value reflect.StructField) string {
+ name, ok := value.Tag.Lookup("toml")
+ if !ok {
+ name = value.Name
+ }
+
+ return name
+}
diff --git a/internal/validator/validator.go b/internal/validator/validator.go
new file mode 100644
index 000000000..bff307d87
--- /dev/null
+++ b/internal/validator/validator.go
@@ -0,0 +1,97 @@
+package validator
+
+import (
+ "errors"
+ "fmt"
+ "reflect"
+
+ "golang.org/x/exp/constraints"
+)
+
+// Combine combines the passed validation functions into a single function.
+// The returned function runs each validation function in order on the value
+// and returns the first error it encounters.
+func Combine[T any](validationFuncs ...Func[T]) Func[T] {
+ return func(value T) error {
+ for _, validate := range validationFuncs {
+ if err := validate(value); err != nil {
+ return err
+ }
+ }
+
+ return nil
+ }
+}
+
+// IsSet returns an error if the value is not set.
+func IsSet[T any](value T) error {
+ if reflect.ValueOf(value).IsZero() {
+ return errors.New("must be set")
+ }
+
+ return nil
+}
+
+// IsOneOf returns a validation function that errors if the value does
+// not equal one of the given values.
+func IsOneOf[T comparable](values ...T) Func[T] {
+ return func(value T) error {
+ for _, allowed := range values {
+ if value == allowed {
+ return nil
+ }
+ }
+
+ return fmt.Errorf("must be one of %v", values)
+ }
+}
+
+// Equal returns a validation function that errors if the value does not
+// equal the given value.
+func Equal[T comparable](validValue T) Func[T] {
+ return func(value T) error {
+ if value != validValue {
+ return fmt.Errorf(`must equal "%v"`, validValue)
+ }
+
+ return nil
+ }
+}
+
+// IsInRange returns a validation function that errors if the value
+// is not within the given range. The boundaries are valid values.
+func IsInRange[T constraints.Ordered](min, max T) Func[T] {
+ return func(value T) error {
+ if value < min || max < value {
+ return fmt.Errorf("must be in range [%v, %v]", min, max)
+ }
+
+ return nil
+ }
+}
+
+// Field is a validatable field.
+type Field[T any] struct {
+ Value T
+ validate Func[T]
+}
+
+// NewField returns a new Field with the attached validation function. The Value
+// field is not included in the reported key path if a child field fails validation.
+func NewField[T any](value T, validate Func[T]) Field[T] {
+ return Field[T]{
+ Value: value,
+ validate: validate,
+ }
+}
+
+// Validate validates the field.
+func (f Field[T]) Validate() error {
+ return f.validate(f.Value)
+}
+
+// skipTo skips the Field itself from the validation walk
+// and makes it proceed directly to the value.
+func (f Field[T]) skipTo() reflect.Value {
+ return reflect.ValueOf(f.Value)
+}
diff --git a/internal/validator/validator_test.go b/internal/validator/validator_test.go
new file mode 100644
index 000000000..b4ea7cfca
--- /dev/null
+++ b/internal/validator/validator_test.go
@@ -0,0 +1,210 @@
+package validator_test
+
+import (
+ "errors"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "gitlab.com/gitlab-org/gitaly/v15/internal/validator"
+)
+
+type CustomString string
+
+// Custom types can have validators attached to them. Their path in the
+// struct's hierarchy is automatically annotated on the errors.
+func (str CustomString) Validate() error {
+ // The validator functions are composable.
+ return validator.Combine(
+ validator.IsSet[CustomString],
+ validator.Equal[CustomString]("success"),
+ )(str)
+}
+
+type StringEnum string
+
+func (enum StringEnum) Validate() error {
+ return validator.IsOneOf("good-1", "good-2", "good-3")(string(enum))
+}
+
+type ChildStruct struct {
+ ShouldFail bool
+ EnumValue StringEnum `toml:"enum_value"`
+ CustomString CustomString `toml:"custom_string"`
+ // Fields which don't have validation logic are ignored.
+ UnvalidatedString string `toml:"unvalidated_string"`
+}
+
+// Validate functons are called on field that implements one, even in subfields.
+func (c ChildStruct) Validate() error {
+ if c.ShouldFail {
+ // Returning a validator.Error allows for annotating the returned error with a
+ // key so the correct location in the struct can be reported even if a subfield
+ // was being validated from a parent struct level validator.
+ return validator.NewError("should_fail", errors.New("struct validation error"))
+ }
+
+ return nil
+}
+
+type StringSlice []CustomString
+
+func (slc StringSlice) Validate() error {
+ if len(slc) == 0 {
+ return errors.New("must have elements")
+ }
+
+ return nil
+}
+
+type Configuration struct {
+ ShouldFail bool
+ StringSlice StringSlice `toml:"string_slice"`
+ CustomString CustomString `toml:"custom_string"`
+ // The ChildStruct gets walked into as well and validation invoked for each
+ // field separately.
+ ChildStruct ChildStruct `toml:"child_struct"`
+ ValidatedInt validator.Field[int] `toml:"validated_int"`
+ ValidatedStruct validator.Field[ChildStruct] `toml:"validated_struct"`
+}
+
+// Validate functons are called on field that implements one, even on the root level.
+func (c Configuration) Validate() error {
+ if c.ShouldFail {
+ // Returned errors don't need to have a special type.
+ return errors.New("struct validation error")
+ }
+
+ return nil
+}
+
+func newChildStruct() ChildStruct {
+ return ChildStruct{
+ EnumValue: "good-1",
+ CustomString: "success",
+ UnvalidatedString: "default_unvalidated",
+ }
+}
+
+func NewConfiguration() Configuration {
+ return Configuration{
+ CustomString: "success",
+ StringSlice: StringSlice{"success", "success", "success"},
+ ChildStruct: newChildStruct(),
+ ValidatedInt: validator.NewField(12, validator.IsInRange(10, 15)),
+ ValidatedStruct: validator.NewField(newChildStruct(), func(value ChildStruct) error {
+ if value != newChildStruct() {
+ return errors.New("struct must not be changed")
+ }
+
+ return nil
+ }),
+ }
+}
+
+func TestValidate(t *testing.T) {
+ t.Parallel()
+
+ for _, tc := range []struct {
+ desc string
+ cfg func(cfg *Configuration)
+ expectedErr string
+ }{
+ {
+ desc: "valid configuration",
+ cfg: func(*Configuration) {},
+ },
+ {
+ desc: "invalid top level value",
+ cfg: func(cfg *Configuration) {
+ cfg.CustomString = ""
+ },
+ expectedErr: "custom_string: must be set",
+ },
+ {
+ desc: "invalid child value",
+ cfg: func(cfg *Configuration) {
+ cfg.ChildStruct.CustomString = ""
+ },
+ expectedErr: "child_struct.custom_string: must be set",
+ },
+ {
+ desc: "multiple invalid",
+ cfg: func(cfg *Configuration) {
+ cfg.CustomString = ""
+ cfg.ChildStruct.CustomString = "fail"
+ },
+ expectedErr: `custom_string: must be set
+child_struct.custom_string: must equal "success"`,
+ },
+ {
+ desc: "field not in range",
+ cfg: func(cfg *Configuration) {
+ cfg.ValidatedInt.Value = 2
+ },
+ expectedErr: "validated_int: must be in range [10, 15]",
+ },
+ {
+ desc: "value equals min of the range",
+ cfg: func(cfg *Configuration) {
+ cfg.ValidatedInt.Value = 10
+ },
+ },
+ {
+ desc: "value equals max of the range",
+ cfg: func(cfg *Configuration) {
+ cfg.ValidatedInt.Value = 15
+ },
+ },
+ {
+ desc: "struct validation",
+ cfg: func(cfg *Configuration) {
+ cfg.ShouldFail = true
+ },
+ expectedErr: "struct validation error",
+ },
+ {
+ desc: "child struct validation",
+ cfg: func(cfg *Configuration) {
+ cfg.ChildStruct.ShouldFail = true
+ },
+ expectedErr: "child_struct.should_fail: struct validation error",
+ },
+ {
+ desc: "field struct validation",
+ cfg: func(cfg *Configuration) {
+ cfg.ValidatedStruct.Value.EnumValue = "bad-1"
+ },
+ expectedErr: `validated_struct: struct must not be changed
+validated_struct.enum_value: must be one of [good-1 good-2 good-3]`,
+ },
+ {
+ desc: "slice validation",
+ cfg: func(cfg *Configuration) {
+ cfg.StringSlice = nil
+ },
+ expectedErr: `string_slice: must have elements`,
+ },
+ {
+ desc: "invalid elements in slice",
+ cfg: func(cfg *Configuration) {
+ cfg.StringSlice = StringSlice{"", "success", "fail"}
+ },
+ expectedErr: `string_slice.0: must be set
+string_slice.2: must equal "success"`,
+ },
+ } {
+ tc := tc
+ t.Run(tc.desc, func(t *testing.T) {
+ t.Parallel()
+
+ cfg := NewConfiguration()
+ tc.cfg(&cfg)
+
+ if err := validator.Validate(cfg); tc.expectedErr == "" {
+ require.NoError(t, err)
+ } else {
+ require.EqualError(t, err, tc.expectedErr)
+ }
+ })
+ }
+}