diff options
author | Jacob Vosmaer <jacob@gitlab.com> | 2019-06-28 17:09:46 +0300 |
---|---|---|
committer | Jacob Vosmaer <jacob@gitlab.com> | 2019-06-28 17:09:46 +0300 |
commit | d549e8a84774fe41fd126c56a6e4223063546acc (patch) | |
tree | 620297d21980d23c1110070782d7eea1c6460484 | |
parent | 23fdc8c8937cfd353d09564702e6178266a4f28f (diff) |
Factor out hashfile checksum checks
-rw-r--r-- | internal/git/gitio/hashfile.go | 47 | ||||
-rw-r--r-- | internal/git/gitio/hashfile_test.go | 56 | ||||
-rw-r--r-- | parse-bitmap.go | 27 |
3 files changed, 105 insertions, 25 deletions
diff --git a/internal/git/gitio/hashfile.go b/internal/git/gitio/hashfile.go new file mode 100644 index 000000000..07ece1e86 --- /dev/null +++ b/internal/git/gitio/hashfile.go @@ -0,0 +1,47 @@ +package gitio + +import ( + "bytes" + "crypto/sha1" + "fmt" + "hash" + "io" +) + +type HashfileReader struct { + tr *TrailerReader + tee io.Reader + sum hash.Hash +} + +func NewHashfileReader(r io.Reader) *HashfileReader { + sum := sha1.New() + tr := NewTrailerReader(r, sum.Size()) + return &HashfileReader{ + tr: tr, + tee: io.TeeReader(tr, sum), + sum: sum, + } +} + +func (hr *HashfileReader) Read(p []byte) (int, error) { + n, err := hr.tee.Read(p) + if err == io.EOF { + return n, hr.validateChecksum() + } + + return n, err +} + +func (hr *HashfileReader) validateChecksum() error { + trailer, err := hr.tr.Trailer() + if err != nil { + return err + } + + if actualSum := hr.sum.Sum(nil); !bytes.Equal(trailer, actualSum) { + return fmt.Errorf("hashfile checksum mismatch: expected %x got %x", trailer, actualSum) + } + + return io.EOF +} diff --git a/internal/git/gitio/hashfile_test.go b/internal/git/gitio/hashfile_test.go new file mode 100644 index 000000000..aaa5a9e7a --- /dev/null +++ b/internal/git/gitio/hashfile_test.go @@ -0,0 +1,56 @@ +package gitio + +import ( + "io/ioutil" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestHashfileReader(t *testing.T) { + testCases := []struct { + desc string + in string + sum string + out string + fail bool + }{ + { + desc: "simple input", + in: "hello\xaa\xf4\xc6\x1d\xdc\xc5\xe8\xa2\xda\xbe\xde\x0f\x3b\x48\x2c\xd9\xae\xa9\x43\x4d", + out: "hello", + }, + { + desc: "empty input", + in: "\xda\x39\xa3\xee\x5e\x6b\x4b\x0d\x32\x55\xbf\xef\x95\x60\x18\x90\xaf\xd8\x07\x09", + out: "", + }, + { + desc: "checksum mismatch", + in: "hello\xff\xf4\xc6\x1d\xdc\xc5\xe8\xa2\xda\xbe\xde\x0f\x3b\x48\x2c\xd9\xae\xa9\x43\x4d", + out: "hello", + fail: true, + }, + { + desc: "input too short", + in: "hello world", + out: "", + fail: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + r := NewHashfileReader(strings.NewReader(tc.in)) + out, err := ioutil.ReadAll(r) + if tc.fail { + require.Error(t, err, "invalid input should cause error") + return + } + + require.NoError(t, err, "valid input") + require.Equal(t, tc.out, string(out), "compare output") + }) + } +} diff --git a/parse-bitmap.go b/parse-bitmap.go index e6ef9e6ce..ec5dde85b 100644 --- a/parse-bitmap.go +++ b/parse-bitmap.go @@ -2,8 +2,6 @@ package main import ( "bufio" - "bytes" - "crypto/sha1" "encoding/binary" "encoding/hex" "flag" @@ -58,9 +56,7 @@ func _main(packIdx string) error { } defer f.Close() - tr := gitio.NewTrailerReader(f, sumSize) - sum := sha1.New() - r := bufio.NewReader(io.TeeReader(tr, sum)) + r := bufio.NewReader(gitio.NewHashfileReader(f)) nBitmapCommits, err := parseBitmapHeader(r, packID) if err != nil { @@ -115,15 +111,6 @@ func _main(packIdx string) error { return fmt.Errorf("expected EOF, got %v", err) } - expectedSum, err := tr.Trailer() - if err != nil { - return err - } - - if !bytes.Equal(expectedSum, sum.Sum(nil)) { - return fmt.Errorf("bitmap checksum mismatch") - } - out := bufio.NewWriter(os.Stdout) defer out.Flush() @@ -320,9 +307,7 @@ func readIndex(packBase, packID string) ([]*packObject, error) { } defer f.Close() - tr := gitio.NewTrailerReader(f, sumSize) - sum := sha1.New() - r := bufio.NewReader(io.TeeReader(tr, sum)) + r := bufio.NewReader(gitio.NewHashfileReader(f)) const sig = "\377tOc\x00\x00\x00\x02" actualSig, err := readN(r, len(sig)) @@ -413,14 +398,6 @@ func readIndex(packBase, packID string) ([]*packObject, error) { return nil, err } - expectedSum, err := tr.Trailer() - if err != nil { - return nil, err - } - if !bytes.Equal(expectedSum, sum.Sum(nil)) { - return nil, fmt.Errorf("idx file checksum mismatch") - } - return objects, nil } |