diff options
-rw-r--r-- | multi_string_flag.go | 9 | ||||
-rw-r--r-- | multi_string_flag_test.go | 12 |
2 files changed, 12 insertions, 9 deletions
diff --git a/multi_string_flag.go b/multi_string_flag.go index b8e85991..1dd76b28 100644 --- a/multi_string_flag.go +++ b/multi_string_flag.go @@ -1,9 +1,12 @@ package main import ( + "errors" "strings" ) +var errMultiStringSetEmptyValue = errors.New("set value cannot be empty") + // MultiStringFlag implements the flag.Value interface and allows a string flag // to be specified multiple times on the command line. // @@ -17,6 +20,9 @@ func (s *MultiStringFlag) String() string { // Set appends the value to the list of parameters func (s *MultiStringFlag) Set(value string) error { + if value == "" { + return errMultiStringSetEmptyValue + } *s = append(*s, value) return nil } @@ -24,9 +30,6 @@ func (s *MultiStringFlag) Set(value string) error { // Split each flag func (s *MultiStringFlag) Split() (result []string) { for _, str := range *s { - if str == "" { - continue - } result = append(result, strings.Split(str, ",")...) } diff --git a/multi_string_flag_test.go b/multi_string_flag_test.go index 4d1c1f8c..cca0fd05 100644 --- a/multi_string_flag_test.go +++ b/multi_string_flag_test.go @@ -2,7 +2,6 @@ package main import ( "flag" - "reflect" "testing" "github.com/stretchr/testify/require" @@ -17,6 +16,8 @@ func TestMultiStringFlagAppendsOnSet(t *testing.T) { require.NoError(t, iface.Set("foo")) require.NoError(t, iface.Set("bar")) + require.Error(t, iface.Set(""), errMultiStringSetEmptyValue) + require.Equal(t, MultiStringFlag{"foo", "bar"}, concrete) } @@ -29,7 +30,7 @@ func TestMultiStringFlag_Split(t *testing.T) { { name: "empty_string", s: &MultiStringFlag{}, // -flag "" - wantResult: nil, + wantResult: []string{}, }, { name: "one_value", @@ -39,14 +40,13 @@ func TestMultiStringFlag_Split(t *testing.T) { { name: "multiple_values", s: &MultiStringFlag{"value1", "", "value3"}, // -flag "value1,,value3" - wantResult: []string{"value1", "value3"}, + wantResult: []string{"value1", "", "value3"}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if gotResult := tt.s.Split(); !reflect.DeepEqual(gotResult, tt.wantResult) { - t.Errorf("MultiStringFlag.Split() = %v, want %v", gotResult, tt.wantResult) - } + gotResult := tt.s.Split() + require.ElementsMatch(t, tt.wantResult, gotResult) }) } } |