diff options
Diffstat (limited to 'tpl/compare')
-rw-r--r-- | tpl/compare/compare.go | 198 | ||||
-rw-r--r-- | tpl/compare/compare_test.go | 197 |
2 files changed, 395 insertions, 0 deletions
diff --git a/tpl/compare/compare.go b/tpl/compare/compare.go new file mode 100644 index 000000000..8b7a96bf0 --- /dev/null +++ b/tpl/compare/compare.go @@ -0,0 +1,198 @@ +// Copyright 2017 The Hugo Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package compare + +import ( + "fmt" + "reflect" + "strconv" + "time" +) + +// Default checks whether a given value is set and returns a default value if it +// is not. "Set" in this context means non-zero for numeric types and times; +// non-zero length for strings, arrays, slices, and maps; +// any boolean or struct value; or non-nil for any other types. +func Default(dflt interface{}, given ...interface{}) (interface{}, error) { + // given is variadic because the following construct will not pass a piped + // argument when the key is missing: {{ index . "key" | default "foo" }} + // The Go template will complain that we got 1 argument when we expectd 2. + + if len(given) == 0 { + return dflt, nil + } + if len(given) != 1 { + return nil, fmt.Errorf("wrong number of args for default: want 2 got %d", len(given)+1) + } + + g := reflect.ValueOf(given[0]) + if !g.IsValid() { + return dflt, nil + } + + set := false + + switch g.Kind() { + case reflect.Bool: + set = true + case reflect.String, reflect.Array, reflect.Slice, reflect.Map: + set = g.Len() != 0 + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + set = g.Int() != 0 + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + set = g.Uint() != 0 + case reflect.Float32, reflect.Float64: + set = g.Float() != 0 + case reflect.Complex64, reflect.Complex128: + set = g.Complex() != 0 + case reflect.Struct: + switch actual := given[0].(type) { + case time.Time: + set = !actual.IsZero() + default: + set = true + } + default: + set = !g.IsNil() + } + + if set { + return given[0], nil + } + + return dflt, nil +} + +// Eq returns the boolean truth of arg1 == arg2. +func Eq(x, y interface{}) bool { + normalize := func(v interface{}) interface{} { + vv := reflect.ValueOf(v) + switch vv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return vv.Int() + case reflect.Float32, reflect.Float64: + return vv.Float() + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return vv.Uint() + default: + return v + } + } + x = normalize(x) + y = normalize(y) + return reflect.DeepEqual(x, y) +} + +// Ne returns the boolean truth of arg1 != arg2. +func Ne(x, y interface{}) bool { + return !Eq(x, y) +} + +// Ge returns the boolean truth of arg1 >= arg2. +func Ge(a, b interface{}) bool { + left, right := compareGetFloat(a, b) + return left >= right +} + +// Gt returns the boolean truth of arg1 > arg2. +func Gt(a, b interface{}) bool { + left, right := compareGetFloat(a, b) + return left > right +} + +// Le returns the boolean truth of arg1 <= arg2. +func Le(a, b interface{}) bool { + left, right := compareGetFloat(a, b) + return left <= right +} + +// Lt returns the boolean truth of arg1 < arg2. +func Lt(a, b interface{}) bool { + left, right := compareGetFloat(a, b) + return left < right +} + +func compareGetFloat(a interface{}, b interface{}) (float64, float64) { + var left, right float64 + var leftStr, rightStr *string + av := reflect.ValueOf(a) + + switch av.Kind() { + case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice: + left = float64(av.Len()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + left = float64(av.Int()) + case reflect.Float32, reflect.Float64: + left = av.Float() + case reflect.String: + var err error + left, err = strconv.ParseFloat(av.String(), 64) + if err != nil { + str := av.String() + leftStr = &str + } + case reflect.Struct: + switch av.Type() { + case timeType: + left = float64(toTimeUnix(av)) + } + } + + bv := reflect.ValueOf(b) + + switch bv.Kind() { + case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice: + right = float64(bv.Len()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + right = float64(bv.Int()) + case reflect.Float32, reflect.Float64: + right = bv.Float() + case reflect.String: + var err error + right, err = strconv.ParseFloat(bv.String(), 64) + if err != nil { + str := bv.String() + rightStr = &str + } + case reflect.Struct: + switch bv.Type() { + case timeType: + right = float64(toTimeUnix(bv)) + } + } + + switch { + case leftStr == nil || rightStr == nil: + case *leftStr < *rightStr: + return 0, 1 + case *leftStr > *rightStr: + return 1, 0 + default: + return 0, 0 + } + + return left, right +} + +var timeType = reflect.TypeOf((*time.Time)(nil)).Elem() + +func toTimeUnix(v reflect.Value) int64 { + if v.Kind() == reflect.Interface { + return toTimeUnix(v.Elem()) + } + if v.Type() != timeType { + panic("coding error: argument must be time.Time type reflect Value") + } + return v.MethodByName("Unix").Call([]reflect.Value{})[0].Int() +} diff --git a/tpl/compare/compare_test.go b/tpl/compare/compare_test.go new file mode 100644 index 000000000..d40a6fe5f --- /dev/null +++ b/tpl/compare/compare_test.go @@ -0,0 +1,197 @@ +// Copyright 2017 The Hugo Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package compare + +import ( + "fmt" + "path" + "reflect" + "runtime" + "testing" + "time" + + "github.com/spf13/cast" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type tstCompareType int + +const ( + tstEq tstCompareType = iota + tstNe + tstGt + tstGe + tstLt + tstLe +) + +func tstIsEq(tp tstCompareType) bool { return tp == tstEq || tp == tstGe || tp == tstLe } +func tstIsGt(tp tstCompareType) bool { return tp == tstGt || tp == tstGe } +func tstIsLt(tp tstCompareType) bool { return tp == tstLt || tp == tstLe } + +func TestDefaultFunc(t *testing.T) { + t.Parallel() + + then := time.Now() + now := time.Now() + + for i, test := range []struct { + dflt interface{} + given interface{} + expect interface{} + }{ + {true, false, false}, + {"5", 0, "5"}, + + {"test1", "set", "set"}, + {"test2", "", "test2"}, + {"test3", nil, "test3"}, + + {[2]int{10, 20}, [2]int{1, 2}, [2]int{1, 2}}, + {[2]int{10, 20}, [0]int{}, [2]int{10, 20}}, + {[2]int{100, 200}, nil, [2]int{100, 200}}, + + {[]string{"one"}, []string{"uno"}, []string{"uno"}}, + {[]string{"two"}, []string{}, []string{"two"}}, + {[]string{"three"}, nil, []string{"three"}}, + + {map[string]int{"one": 1}, map[string]int{"uno": 1}, map[string]int{"uno": 1}}, + {map[string]int{"one": 1}, map[string]int{}, map[string]int{"one": 1}}, + {map[string]int{"two": 2}, nil, map[string]int{"two": 2}}, + + {10, 1, 1}, + {10, 0, 10}, + {20, nil, 20}, + + {float32(10), float32(1), float32(1)}, + {float32(10), 0, float32(10)}, + {float32(20), nil, float32(20)}, + + {complex(2, -2), complex(1, -1), complex(1, -1)}, + {complex(2, -2), complex(0, 0), complex(2, -2)}, + {complex(3, -3), nil, complex(3, -3)}, + + {struct{ f string }{f: "one"}, struct{}{}, struct{}{}}, + {struct{ f string }{f: "two"}, nil, struct{ f string }{f: "two"}}, + + {then, now, now}, + {then, time.Time{}, then}, + } { + errMsg := fmt.Sprintf("[%d] %v", i, test) + + result, err := Default(test.dflt, test.given) + + require.NoError(t, err, errMsg) + assert.Equal(t, result, test.expect, errMsg) + } +} + +func TestCompare(t *testing.T) { + t.Parallel() + + for _, test := range []struct { + tstCompareType + funcUnderTest func(a, b interface{}) bool + }{ + {tstGt, Gt}, + {tstLt, Lt}, + {tstGe, Ge}, + {tstLe, Le}, + {tstEq, Eq}, + {tstNe, Ne}, + } { + doTestCompare(t, test.tstCompareType, test.funcUnderTest) + } +} + +func doTestCompare(t *testing.T, tp tstCompareType, funcUnderTest func(a, b interface{}) bool) { + for i, test := range []struct { + left interface{} + right interface{} + expectIndicator int + }{ + {5, 8, -1}, + {8, 5, 1}, + {5, 5, 0}, + {int(5), int64(5), 0}, + {int32(5), int(5), 0}, + {int16(4), int(5), -1}, + {uint(15), uint64(15), 0}, + {-2, 1, -1}, + {2, -5, 1}, + {0.0, 1.23, -1}, + {1.1, 1.1, 0}, + {float32(1.0), float64(1.0), 0}, + {1.23, 0.0, 1}, + {"5", "5", 0}, + {"8", "5", 1}, + {"5", "0001", 1}, + {[]int{100, 99}, []int{1, 2, 3, 4}, -1}, + {cast.ToTime("2015-11-20"), cast.ToTime("2015-11-20"), 0}, + {cast.ToTime("2015-11-19"), cast.ToTime("2015-11-20"), -1}, + {cast.ToTime("2015-11-20"), cast.ToTime("2015-11-19"), 1}, + {"a", "a", 0}, + {"a", "b", -1}, + {"b", "a", 1}, + } { + result := funcUnderTest(test.left, test.right) + success := false + + if test.expectIndicator == 0 { + if tstIsEq(tp) { + success = result + } else { + success = !result + } + } + + if test.expectIndicator < 0 { + success = result && (tstIsLt(tp) || tp == tstNe) + success = success || (!result && !tstIsLt(tp)) + } + + if test.expectIndicator > 0 { + success = result && (tstIsGt(tp) || tp == tstNe) + success = success || (!result && (!tstIsGt(tp) || tp != tstNe)) + } + + if !success { + t.Errorf("[%d][%s] %v compared to %v: %t", i, path.Base(runtime.FuncForPC(reflect.ValueOf(funcUnderTest).Pointer()).Name()), test.left, test.right, result) + } + } +} + +func TestTimeUnix(t *testing.T) { + t.Parallel() + var sec int64 = 1234567890 + tv := reflect.ValueOf(time.Unix(sec, 0)) + i := 1 + + res := toTimeUnix(tv) + if sec != res { + t.Errorf("[%d] timeUnix got %v but expected %v", i, res, sec) + } + + i++ + func(t *testing.T) { + defer func() { + if err := recover(); err == nil { + t.Errorf("[%d] timeUnix didn't return an expected error", i) + } + }() + iv := reflect.ValueOf(sec) + toTimeUnix(iv) + }(t) +} |