Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/gohugoio/hugo.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'tpl/compare')
-rw-r--r--tpl/compare/compare.go198
-rw-r--r--tpl/compare/compare_test.go197
2 files changed, 395 insertions, 0 deletions
diff --git a/tpl/compare/compare.go b/tpl/compare/compare.go
new file mode 100644
index 000000000..8b7a96bf0
--- /dev/null
+++ b/tpl/compare/compare.go
@@ -0,0 +1,198 @@
+// Copyright 2017 The Hugo Authors. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package compare
+
+import (
+ "fmt"
+ "reflect"
+ "strconv"
+ "time"
+)
+
+// Default checks whether a given value is set and returns a default value if it
+// is not. "Set" in this context means non-zero for numeric types and times;
+// non-zero length for strings, arrays, slices, and maps;
+// any boolean or struct value; or non-nil for any other types.
+func Default(dflt interface{}, given ...interface{}) (interface{}, error) {
+ // given is variadic because the following construct will not pass a piped
+ // argument when the key is missing: {{ index . "key" | default "foo" }}
+ // The Go template will complain that we got 1 argument when we expectd 2.
+
+ if len(given) == 0 {
+ return dflt, nil
+ }
+ if len(given) != 1 {
+ return nil, fmt.Errorf("wrong number of args for default: want 2 got %d", len(given)+1)
+ }
+
+ g := reflect.ValueOf(given[0])
+ if !g.IsValid() {
+ return dflt, nil
+ }
+
+ set := false
+
+ switch g.Kind() {
+ case reflect.Bool:
+ set = true
+ case reflect.String, reflect.Array, reflect.Slice, reflect.Map:
+ set = g.Len() != 0
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ set = g.Int() != 0
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
+ set = g.Uint() != 0
+ case reflect.Float32, reflect.Float64:
+ set = g.Float() != 0
+ case reflect.Complex64, reflect.Complex128:
+ set = g.Complex() != 0
+ case reflect.Struct:
+ switch actual := given[0].(type) {
+ case time.Time:
+ set = !actual.IsZero()
+ default:
+ set = true
+ }
+ default:
+ set = !g.IsNil()
+ }
+
+ if set {
+ return given[0], nil
+ }
+
+ return dflt, nil
+}
+
+// Eq returns the boolean truth of arg1 == arg2.
+func Eq(x, y interface{}) bool {
+ normalize := func(v interface{}) interface{} {
+ vv := reflect.ValueOf(v)
+ switch vv.Kind() {
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ return vv.Int()
+ case reflect.Float32, reflect.Float64:
+ return vv.Float()
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+ return vv.Uint()
+ default:
+ return v
+ }
+ }
+ x = normalize(x)
+ y = normalize(y)
+ return reflect.DeepEqual(x, y)
+}
+
+// Ne returns the boolean truth of arg1 != arg2.
+func Ne(x, y interface{}) bool {
+ return !Eq(x, y)
+}
+
+// Ge returns the boolean truth of arg1 >= arg2.
+func Ge(a, b interface{}) bool {
+ left, right := compareGetFloat(a, b)
+ return left >= right
+}
+
+// Gt returns the boolean truth of arg1 > arg2.
+func Gt(a, b interface{}) bool {
+ left, right := compareGetFloat(a, b)
+ return left > right
+}
+
+// Le returns the boolean truth of arg1 <= arg2.
+func Le(a, b interface{}) bool {
+ left, right := compareGetFloat(a, b)
+ return left <= right
+}
+
+// Lt returns the boolean truth of arg1 < arg2.
+func Lt(a, b interface{}) bool {
+ left, right := compareGetFloat(a, b)
+ return left < right
+}
+
+func compareGetFloat(a interface{}, b interface{}) (float64, float64) {
+ var left, right float64
+ var leftStr, rightStr *string
+ av := reflect.ValueOf(a)
+
+ switch av.Kind() {
+ case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice:
+ left = float64(av.Len())
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ left = float64(av.Int())
+ case reflect.Float32, reflect.Float64:
+ left = av.Float()
+ case reflect.String:
+ var err error
+ left, err = strconv.ParseFloat(av.String(), 64)
+ if err != nil {
+ str := av.String()
+ leftStr = &str
+ }
+ case reflect.Struct:
+ switch av.Type() {
+ case timeType:
+ left = float64(toTimeUnix(av))
+ }
+ }
+
+ bv := reflect.ValueOf(b)
+
+ switch bv.Kind() {
+ case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice:
+ right = float64(bv.Len())
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ right = float64(bv.Int())
+ case reflect.Float32, reflect.Float64:
+ right = bv.Float()
+ case reflect.String:
+ var err error
+ right, err = strconv.ParseFloat(bv.String(), 64)
+ if err != nil {
+ str := bv.String()
+ rightStr = &str
+ }
+ case reflect.Struct:
+ switch bv.Type() {
+ case timeType:
+ right = float64(toTimeUnix(bv))
+ }
+ }
+
+ switch {
+ case leftStr == nil || rightStr == nil:
+ case *leftStr < *rightStr:
+ return 0, 1
+ case *leftStr > *rightStr:
+ return 1, 0
+ default:
+ return 0, 0
+ }
+
+ return left, right
+}
+
+var timeType = reflect.TypeOf((*time.Time)(nil)).Elem()
+
+func toTimeUnix(v reflect.Value) int64 {
+ if v.Kind() == reflect.Interface {
+ return toTimeUnix(v.Elem())
+ }
+ if v.Type() != timeType {
+ panic("coding error: argument must be time.Time type reflect Value")
+ }
+ return v.MethodByName("Unix").Call([]reflect.Value{})[0].Int()
+}
diff --git a/tpl/compare/compare_test.go b/tpl/compare/compare_test.go
new file mode 100644
index 000000000..d40a6fe5f
--- /dev/null
+++ b/tpl/compare/compare_test.go
@@ -0,0 +1,197 @@
+// Copyright 2017 The Hugo Authors. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package compare
+
+import (
+ "fmt"
+ "path"
+ "reflect"
+ "runtime"
+ "testing"
+ "time"
+
+ "github.com/spf13/cast"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+type tstCompareType int
+
+const (
+ tstEq tstCompareType = iota
+ tstNe
+ tstGt
+ tstGe
+ tstLt
+ tstLe
+)
+
+func tstIsEq(tp tstCompareType) bool { return tp == tstEq || tp == tstGe || tp == tstLe }
+func tstIsGt(tp tstCompareType) bool { return tp == tstGt || tp == tstGe }
+func tstIsLt(tp tstCompareType) bool { return tp == tstLt || tp == tstLe }
+
+func TestDefaultFunc(t *testing.T) {
+ t.Parallel()
+
+ then := time.Now()
+ now := time.Now()
+
+ for i, test := range []struct {
+ dflt interface{}
+ given interface{}
+ expect interface{}
+ }{
+ {true, false, false},
+ {"5", 0, "5"},
+
+ {"test1", "set", "set"},
+ {"test2", "", "test2"},
+ {"test3", nil, "test3"},
+
+ {[2]int{10, 20}, [2]int{1, 2}, [2]int{1, 2}},
+ {[2]int{10, 20}, [0]int{}, [2]int{10, 20}},
+ {[2]int{100, 200}, nil, [2]int{100, 200}},
+
+ {[]string{"one"}, []string{"uno"}, []string{"uno"}},
+ {[]string{"two"}, []string{}, []string{"two"}},
+ {[]string{"three"}, nil, []string{"three"}},
+
+ {map[string]int{"one": 1}, map[string]int{"uno": 1}, map[string]int{"uno": 1}},
+ {map[string]int{"one": 1}, map[string]int{}, map[string]int{"one": 1}},
+ {map[string]int{"two": 2}, nil, map[string]int{"two": 2}},
+
+ {10, 1, 1},
+ {10, 0, 10},
+ {20, nil, 20},
+
+ {float32(10), float32(1), float32(1)},
+ {float32(10), 0, float32(10)},
+ {float32(20), nil, float32(20)},
+
+ {complex(2, -2), complex(1, -1), complex(1, -1)},
+ {complex(2, -2), complex(0, 0), complex(2, -2)},
+ {complex(3, -3), nil, complex(3, -3)},
+
+ {struct{ f string }{f: "one"}, struct{}{}, struct{}{}},
+ {struct{ f string }{f: "two"}, nil, struct{ f string }{f: "two"}},
+
+ {then, now, now},
+ {then, time.Time{}, then},
+ } {
+ errMsg := fmt.Sprintf("[%d] %v", i, test)
+
+ result, err := Default(test.dflt, test.given)
+
+ require.NoError(t, err, errMsg)
+ assert.Equal(t, result, test.expect, errMsg)
+ }
+}
+
+func TestCompare(t *testing.T) {
+ t.Parallel()
+
+ for _, test := range []struct {
+ tstCompareType
+ funcUnderTest func(a, b interface{}) bool
+ }{
+ {tstGt, Gt},
+ {tstLt, Lt},
+ {tstGe, Ge},
+ {tstLe, Le},
+ {tstEq, Eq},
+ {tstNe, Ne},
+ } {
+ doTestCompare(t, test.tstCompareType, test.funcUnderTest)
+ }
+}
+
+func doTestCompare(t *testing.T, tp tstCompareType, funcUnderTest func(a, b interface{}) bool) {
+ for i, test := range []struct {
+ left interface{}
+ right interface{}
+ expectIndicator int
+ }{
+ {5, 8, -1},
+ {8, 5, 1},
+ {5, 5, 0},
+ {int(5), int64(5), 0},
+ {int32(5), int(5), 0},
+ {int16(4), int(5), -1},
+ {uint(15), uint64(15), 0},
+ {-2, 1, -1},
+ {2, -5, 1},
+ {0.0, 1.23, -1},
+ {1.1, 1.1, 0},
+ {float32(1.0), float64(1.0), 0},
+ {1.23, 0.0, 1},
+ {"5", "5", 0},
+ {"8", "5", 1},
+ {"5", "0001", 1},
+ {[]int{100, 99}, []int{1, 2, 3, 4}, -1},
+ {cast.ToTime("2015-11-20"), cast.ToTime("2015-11-20"), 0},
+ {cast.ToTime("2015-11-19"), cast.ToTime("2015-11-20"), -1},
+ {cast.ToTime("2015-11-20"), cast.ToTime("2015-11-19"), 1},
+ {"a", "a", 0},
+ {"a", "b", -1},
+ {"b", "a", 1},
+ } {
+ result := funcUnderTest(test.left, test.right)
+ success := false
+
+ if test.expectIndicator == 0 {
+ if tstIsEq(tp) {
+ success = result
+ } else {
+ success = !result
+ }
+ }
+
+ if test.expectIndicator < 0 {
+ success = result && (tstIsLt(tp) || tp == tstNe)
+ success = success || (!result && !tstIsLt(tp))
+ }
+
+ if test.expectIndicator > 0 {
+ success = result && (tstIsGt(tp) || tp == tstNe)
+ success = success || (!result && (!tstIsGt(tp) || tp != tstNe))
+ }
+
+ if !success {
+ t.Errorf("[%d][%s] %v compared to %v: %t", i, path.Base(runtime.FuncForPC(reflect.ValueOf(funcUnderTest).Pointer()).Name()), test.left, test.right, result)
+ }
+ }
+}
+
+func TestTimeUnix(t *testing.T) {
+ t.Parallel()
+ var sec int64 = 1234567890
+ tv := reflect.ValueOf(time.Unix(sec, 0))
+ i := 1
+
+ res := toTimeUnix(tv)
+ if sec != res {
+ t.Errorf("[%d] timeUnix got %v but expected %v", i, res, sec)
+ }
+
+ i++
+ func(t *testing.T) {
+ defer func() {
+ if err := recover(); err == nil {
+ t.Errorf("[%d] timeUnix didn't return an expected error", i)
+ }
+ }()
+ iv := reflect.ValueOf(sec)
+ toTimeUnix(iv)
+ }(t)
+}