diff options
-rw-r--r-- | internal/validator/error.go | 77 | ||||
-rw-r--r-- | internal/validator/validate.go | 81 | ||||
-rw-r--r-- | internal/validator/validator.go | 97 | ||||
-rw-r--r-- | internal/validator/validator_test.go | 210 |
4 files changed, 465 insertions, 0 deletions
diff --git a/internal/validator/error.go b/internal/validator/error.go new file mode 100644 index 000000000..f9cee3cde --- /dev/null +++ b/internal/validator/error.go @@ -0,0 +1,77 @@ +package validator + +import ( + "errors" + "fmt" + "strings" +) + +// Error combines a key with an erorr. The Key is the path to the field +// within the struct. +type Error struct { + Key []string + Cause error +} + +// NewError wraps the error with a key so it's location in the struct +// can be propagated upwards. +func NewError(key string, err error) error { + if key == "" { + // There's no reason to annotate something with an empty key. This + // handles not adding a key for the root node. + return err + } + + var keyedErr Error + if errors.As(err, &keyedErr) { + // If the error is a KeyedError, prepend the key to the existing key. + keyedErr.Key = append([]string{key}, keyedErr.Key...) + return keyedErr + } + + return Error{Key: []string{key}, Cause: err} +} + +// Error returns the error message. +func (err Error) Error() string { + return fmt.Sprintf("%s: %s", strings.Join(err.Key, "."), err.Cause) +} + +// Unwrap the underlying error. +func (err Error) Unwrap() error { + return err.Cause +} + +// Errors is a collection of multiple errors. +type Errors []error + +// Append adds an error to errs. If the err is Errors type itself, +// it's flattened and each error appended to errs. +func (errs Errors) Append(err error) Errors { + if otherErrs, ok := err.(Errors); ok { + return append(errs, otherErrs...) + } + + return append(errs, err) +} + +// Error returns the error message. +func (errs Errors) Error() string { + var str []string + for _, err := range errs { + str = append(str, err.Error()) + } + + return strings.Join(str, "\n") +} + +// ErrorOrNil returns an error if there are some errors, nil otherwise. +// This should be used instead of returning a typed error to avoid +// nil interface comparison problems. +func (errs Errors) ErrorOrNil() error { + if len(errs) == 0 { + return nil + } + + return errs +} diff --git a/internal/validator/validate.go b/internal/validator/validate.go new file mode 100644 index 000000000..5162b7020 --- /dev/null +++ b/internal/validator/validate.go @@ -0,0 +1,81 @@ +package validator + +import ( + "reflect" + "strconv" +) + +// Func is the type of a validation function that validates a value +// and returns an error if the validation fails. +type Func[T any] func(T) error + +// Validator is an interface that provides validation functionality. +type Validator interface { + // Validate the value and returns an error describing the validation + // failure if any. + Validate() error +} + +// Validate validates the given value by invoking the value's Validate() +// method. If the value is a struct or a slice, each field and element are +// also recursed into and Validate() invoked on each field and element. +// If Validate() return an error, it is wrapped into an Error to annotate +// it with the field's or element's path in the hierarchy. +func Validate(value any) error { + return validate("", reflect.ValueOf(value)) +} + +type skipper interface { + // skipTo returns the field that the validation walk should consider next. + // Useful for skipping some private fields from the walk. + skipTo() reflect.Value +} + +func validate(key string, value reflect.Value) error { + var errs Errors + if v, ok := value.Interface().(Validator); ok { + if err := v.Validate(); err != nil { + errs = errs.Append(err) + } + } + + if v, ok := value.Interface().(skipper); ok { + value = v.skipTo() + } + + switch value.Kind() { + case reflect.Struct: + for i := 0; i < value.NumField(); i++ { + if err := validate( + tomlName(value.Type().Field(i)), + value.Field(i), + ); err != nil { + errs = errs.Append(err) + } + } + case reflect.Slice: + for i := 0; i < value.Len(); i++ { + if err := validate( + strconv.FormatInt(int64(i), 10), + value.Index(i), + ); err != nil { + errs = errs.Append(err) + } + } + } + + for i := range errs { + errs[i] = NewError(key, errs[i]) + } + + return errs.ErrorOrNil() +} + +func tomlName(value reflect.StructField) string { + name, ok := value.Tag.Lookup("toml") + if !ok { + name = value.Name + } + + return name +} diff --git a/internal/validator/validator.go b/internal/validator/validator.go new file mode 100644 index 000000000..bff307d87 --- /dev/null +++ b/internal/validator/validator.go @@ -0,0 +1,97 @@ +package validator + +import ( + "errors" + "fmt" + "reflect" + + "golang.org/x/exp/constraints" +) + +// Combine combines the passed validation functions into a single function. +// The returned function runs each validation function in order on the value +// and returns the first error it encounters. +func Combine[T any](validationFuncs ...Func[T]) Func[T] { + return func(value T) error { + for _, validate := range validationFuncs { + if err := validate(value); err != nil { + return err + } + } + + return nil + } +} + +// IsSet returns an error if the value is not set. +func IsSet[T any](value T) error { + if reflect.ValueOf(value).IsZero() { + return errors.New("must be set") + } + + return nil +} + +// IsOneOf returns a validation function that errors if the value does +// not equal one of the given values. +func IsOneOf[T comparable](values ...T) Func[T] { + return func(value T) error { + for _, allowed := range values { + if value == allowed { + return nil + } + } + + return fmt.Errorf("must be one of %v", values) + } +} + +// Equal returns a validation function that errors if the value does not +// equal the given value. +func Equal[T comparable](validValue T) Func[T] { + return func(value T) error { + if value != validValue { + return fmt.Errorf(`must equal "%v"`, validValue) + } + + return nil + } +} + +// IsInRange returns a validation function that errors if the value +// is not within the given range. The boundaries are valid values. +func IsInRange[T constraints.Ordered](min, max T) Func[T] { + return func(value T) error { + if value < min || max < value { + return fmt.Errorf("must be in range [%v, %v]", min, max) + } + + return nil + } +} + +// Field is a validatable field. +type Field[T any] struct { + Value T + validate Func[T] +} + +// NewField returns a new Field with the attached validation function. The Value +// field is not included in the reported key path if a child field fails validation. +func NewField[T any](value T, validate Func[T]) Field[T] { + return Field[T]{ + Value: value, + validate: validate, + } +} + +// Validate validates the field. +func (f Field[T]) Validate() error { + return f.validate(f.Value) +} + +// skipTo skips the Field itself from the validation walk +// and makes it proceed directly to the value. +func (f Field[T]) skipTo() reflect.Value { + return reflect.ValueOf(f.Value) +} diff --git a/internal/validator/validator_test.go b/internal/validator/validator_test.go new file mode 100644 index 000000000..b4ea7cfca --- /dev/null +++ b/internal/validator/validator_test.go @@ -0,0 +1,210 @@ +package validator_test + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/gitaly/v15/internal/validator" +) + +type CustomString string + +// Custom types can have validators attached to them. Their path in the +// struct's hierarchy is automatically annotated on the errors. +func (str CustomString) Validate() error { + // The validator functions are composable. + return validator.Combine( + validator.IsSet[CustomString], + validator.Equal[CustomString]("success"), + )(str) +} + +type StringEnum string + +func (enum StringEnum) Validate() error { + return validator.IsOneOf("good-1", "good-2", "good-3")(string(enum)) +} + +type ChildStruct struct { + ShouldFail bool + EnumValue StringEnum `toml:"enum_value"` + CustomString CustomString `toml:"custom_string"` + // Fields which don't have validation logic are ignored. + UnvalidatedString string `toml:"unvalidated_string"` +} + +// Validate functons are called on field that implements one, even in subfields. +func (c ChildStruct) Validate() error { + if c.ShouldFail { + // Returning a validator.Error allows for annotating the returned error with a + // key so the correct location in the struct can be reported even if a subfield + // was being validated from a parent struct level validator. + return validator.NewError("should_fail", errors.New("struct validation error")) + } + + return nil +} + +type StringSlice []CustomString + +func (slc StringSlice) Validate() error { + if len(slc) == 0 { + return errors.New("must have elements") + } + + return nil +} + +type Configuration struct { + ShouldFail bool + StringSlice StringSlice `toml:"string_slice"` + CustomString CustomString `toml:"custom_string"` + // The ChildStruct gets walked into as well and validation invoked for each + // field separately. + ChildStruct ChildStruct `toml:"child_struct"` + ValidatedInt validator.Field[int] `toml:"validated_int"` + ValidatedStruct validator.Field[ChildStruct] `toml:"validated_struct"` +} + +// Validate functons are called on field that implements one, even on the root level. +func (c Configuration) Validate() error { + if c.ShouldFail { + // Returned errors don't need to have a special type. + return errors.New("struct validation error") + } + + return nil +} + +func newChildStruct() ChildStruct { + return ChildStruct{ + EnumValue: "good-1", + CustomString: "success", + UnvalidatedString: "default_unvalidated", + } +} + +func NewConfiguration() Configuration { + return Configuration{ + CustomString: "success", + StringSlice: StringSlice{"success", "success", "success"}, + ChildStruct: newChildStruct(), + ValidatedInt: validator.NewField(12, validator.IsInRange(10, 15)), + ValidatedStruct: validator.NewField(newChildStruct(), func(value ChildStruct) error { + if value != newChildStruct() { + return errors.New("struct must not be changed") + } + + return nil + }), + } +} + +func TestValidate(t *testing.T) { + t.Parallel() + + for _, tc := range []struct { + desc string + cfg func(cfg *Configuration) + expectedErr string + }{ + { + desc: "valid configuration", + cfg: func(*Configuration) {}, + }, + { + desc: "invalid top level value", + cfg: func(cfg *Configuration) { + cfg.CustomString = "" + }, + expectedErr: "custom_string: must be set", + }, + { + desc: "invalid child value", + cfg: func(cfg *Configuration) { + cfg.ChildStruct.CustomString = "" + }, + expectedErr: "child_struct.custom_string: must be set", + }, + { + desc: "multiple invalid", + cfg: func(cfg *Configuration) { + cfg.CustomString = "" + cfg.ChildStruct.CustomString = "fail" + }, + expectedErr: `custom_string: must be set +child_struct.custom_string: must equal "success"`, + }, + { + desc: "field not in range", + cfg: func(cfg *Configuration) { + cfg.ValidatedInt.Value = 2 + }, + expectedErr: "validated_int: must be in range [10, 15]", + }, + { + desc: "value equals min of the range", + cfg: func(cfg *Configuration) { + cfg.ValidatedInt.Value = 10 + }, + }, + { + desc: "value equals max of the range", + cfg: func(cfg *Configuration) { + cfg.ValidatedInt.Value = 15 + }, + }, + { + desc: "struct validation", + cfg: func(cfg *Configuration) { + cfg.ShouldFail = true + }, + expectedErr: "struct validation error", + }, + { + desc: "child struct validation", + cfg: func(cfg *Configuration) { + cfg.ChildStruct.ShouldFail = true + }, + expectedErr: "child_struct.should_fail: struct validation error", + }, + { + desc: "field struct validation", + cfg: func(cfg *Configuration) { + cfg.ValidatedStruct.Value.EnumValue = "bad-1" + }, + expectedErr: `validated_struct: struct must not be changed +validated_struct.enum_value: must be one of [good-1 good-2 good-3]`, + }, + { + desc: "slice validation", + cfg: func(cfg *Configuration) { + cfg.StringSlice = nil + }, + expectedErr: `string_slice: must have elements`, + }, + { + desc: "invalid elements in slice", + cfg: func(cfg *Configuration) { + cfg.StringSlice = StringSlice{"", "success", "fail"} + }, + expectedErr: `string_slice.0: must be set +string_slice.2: must equal "success"`, + }, + } { + tc := tc + t.Run(tc.desc, func(t *testing.T) { + t.Parallel() + + cfg := NewConfiguration() + tc.cfg(&cfg) + + if err := validator.Validate(cfg); tc.expectedErr == "" { + require.NoError(t, err) + } else { + require.EqualError(t, err, tc.expectedErr) + } + }) + } +} |