diff options
author | Jaime Martinez <jmartinez@gitlab.com> | 2020-09-08 14:42:37 +0300 |
---|---|---|
committer | Vladimir Shushlin <vshushlin@gitlab.com> | 2020-09-08 14:42:37 +0300 |
commit | e3470b6c41e97aa5124cf5168421172ff119aa21 (patch) | |
tree | c82f271267f4685ac4f8a90161766670eba75cca | |
parent | c2a7040d020736419adf3afa70023b505bd83ab5 (diff) |
Add httprange package
Adds a slightly modified version of `httprange` package obtained from
https://gitlab.com/gitlab-org/gitlab-pages/-/merge_requests/326.
Only adds functionality, further improvements will be done in
consecutive iterations.
-rw-r--r-- | internal/httprange/http_ranged_reader.go | 61 | ||||
-rw-r--r-- | internal/httprange/http_ranged_reader_test.go | 260 | ||||
-rw-r--r-- | internal/httprange/http_reader.go | 197 | ||||
-rw-r--r-- | internal/httprange/http_reader_test.go | 326 | ||||
-rw-r--r-- | internal/httprange/resource.go | 77 | ||||
-rw-r--r-- | internal/httprange/resource_test.go | 96 |
6 files changed, 1017 insertions, 0 deletions
diff --git a/internal/httprange/http_ranged_reader.go b/internal/httprange/http_ranged_reader.go new file mode 100644 index 00000000..d023521d --- /dev/null +++ b/internal/httprange/http_ranged_reader.go @@ -0,0 +1,61 @@ +package httprange + +import ( + "io" +) + +// RangedReader for a resource. +// Implements the io.ReaderAt interface that can be used with Go's archive/zip package. +type RangedReader struct { + Resource *Resource + cachedReader *Reader +} + +func (rr *RangedReader) cachedRead(buf []byte, off int64) (int, error) { + _, err := rr.cachedReader.Seek(off, io.SeekStart) + if err != nil { + return 0, err + } + + return io.ReadFull(rr.cachedReader, buf) +} + +func (rr *RangedReader) ephemeralRead(buf []byte, offset int64) (n int, err error) { + reader := NewReader(rr.Resource, offset, int64(len(buf))) + defer reader.Close() + + return io.ReadFull(reader, buf) +} + +// SectionReader partitions a resource from `offset` with a specified `size` +func (rr *RangedReader) SectionReader(offset, size int64) *Reader { + return NewReader(rr.Resource, offset, size) +} + +// ReadAt reads from cachedReader if exists, otherwise fetches a new Resource first. +// Opens a resource and reads len(buf) bytes from offset into buf. +func (rr *RangedReader) ReadAt(buf []byte, offset int64) (n int, err error) { + if rr.cachedReader != nil { + return rr.cachedRead(buf, offset) + } + + return rr.ephemeralRead(buf, offset) +} + +// WithCachedReader creates a Reader and saves it to the RangedReader instance. +// It takes a readFunc that will Seek the contents from Reader. +func (rr *RangedReader) WithCachedReader(readFunc func()) { + rr.cachedReader = NewReader(rr.Resource, 0, rr.Resource.Size) + + defer func() { + rr.cachedReader.Close() + rr.cachedReader = nil + }() + + readFunc() +} + +// NewRangedReader creates a RangedReader object on a given resource +func NewRangedReader(resource *Resource) *RangedReader { + return &RangedReader{Resource: resource} +} diff --git a/internal/httprange/http_ranged_reader_test.go b/internal/httprange/http_ranged_reader_test.go new file mode 100644 index 00000000..72e645db --- /dev/null +++ b/internal/httprange/http_ranged_reader_test.go @@ -0,0 +1,260 @@ +package httprange + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +const ( + testData = "1234567890abcdefghij0987654321" + testDataLen = len(testData) +) + +func TestSectionReader(t *testing.T) { + tests := map[string]struct { + sectionOffset int + sectionSize int + readSize int + expectedContent string + expectedErr error + }{ + "no_buffer_no_err": { + sectionOffset: 0, + sectionSize: testDataLen, + readSize: 0, + expectedContent: "", + expectedErr: nil, + }, + "offset_starts_at_size": { + sectionOffset: testDataLen, + sectionSize: 1, + readSize: 1, + expectedContent: "", + expectedErr: ErrInvalidRange, + }, + "read_all": { + sectionOffset: 0, + sectionSize: testDataLen, + readSize: testDataLen, + expectedContent: testData, + expectedErr: io.EOF, + }, + "read_first_half": { + sectionOffset: 0, + sectionSize: testDataLen / 2, + readSize: testDataLen / 2, + expectedContent: testData[:testDataLen/2], + expectedErr: io.EOF, + }, + "read_second_half": { + sectionOffset: testDataLen / 2, + sectionSize: testDataLen / 2, + readSize: testDataLen / 2, + expectedContent: testData[testDataLen/2:], + expectedErr: io.EOF, + }, + "read_15_bytes_with_offset": { + sectionOffset: 3, + sectionSize: testDataLen / 2, + readSize: testDataLen / 2, + expectedContent: testData[3 : 3+testDataLen/2], + expectedErr: io.EOF, + }, + "read_13_bytes_with_offset": { + sectionOffset: 10, + sectionSize: testDataLen/2 - 2, + readSize: testDataLen/2 - 2, + expectedContent: testData[10 : 10+testDataLen/2-2], + expectedErr: io.EOF, + }, + } + + testServer := newTestServer(t, nil) + defer testServer.Close() + + resource, err := NewResource(context.Background(), testServer.URL+"/resource") + require.NoError(t, err) + + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + rr := NewRangedReader(resource) + s := rr.SectionReader(int64(tt.sectionOffset), int64(tt.sectionSize)) + defer s.Close() + + buf := make([]byte, tt.readSize) + n, err := s.Read(buf) + if tt.expectedErr != nil && err != io.EOF { + require.EqualError(t, err, tt.expectedErr.Error()) + return + } + + require.Equal(t, tt.expectedErr, err) + require.Equal(t, len(tt.expectedContent), n) + require.Equal(t, tt.expectedContent, string(buf[:n])) + }) + } +} + +func TestReadAt(t *testing.T) { + tests := map[string]struct { + sectionOffset int + readSize int + expectedContent string + expectedErr error + }{ + "no_buffer_no_err": { + sectionOffset: 0, + readSize: 0, + expectedContent: "", + expectedErr: nil, + }, + "offset_starts_at_size": { + sectionOffset: testDataLen, + readSize: 1, + expectedContent: "", + expectedErr: ErrInvalidRange, + }, + "read_at_end": { + sectionOffset: testDataLen, + readSize: 1, + expectedContent: "", + expectedErr: ErrInvalidRange, + }, + "read_all": { + sectionOffset: 0, + readSize: testDataLen, + expectedContent: testData, + expectedErr: nil, + }, + "read_first_half": { + sectionOffset: 0, + readSize: testDataLen / 2, + expectedContent: testData[:testDataLen/2], + expectedErr: nil, + }, + "read_second_half": { + sectionOffset: testDataLen / 2, + readSize: testDataLen / 2, + expectedContent: testData[testDataLen/2:], + expectedErr: nil, + }, + "read_15_bytes_with_offset": { + sectionOffset: 3, + readSize: testDataLen / 2, + expectedContent: testData[3 : 3+testDataLen/2], + expectedErr: nil, + }, + "read_13_bytes_with_offset": { + sectionOffset: 10, + readSize: testDataLen/2 - 2, + expectedContent: testData[10 : 10+testDataLen/2-2], + expectedErr: nil, + }, + } + + testServer := newTestServer(t, nil) + defer testServer.Close() + + resource, err := NewResource(context.Background(), testServer.URL+"/resource") + require.NoError(t, err) + + for name, tt := range tests { + rr := NewRangedReader(resource) + testFn := func(reader *RangedReader) func(t *testing.T) { + return func(t *testing.T) { + buf := make([]byte, tt.readSize) + + n, err := reader.ReadAt(buf, int64(tt.sectionOffset)) + if tt.expectedErr != nil { + require.EqualError(t, err, tt.expectedErr.Error()) + return + } + + require.NoError(t, err) + require.Equal(t, len(tt.expectedContent), n) + require.Equal(t, tt.expectedContent, string(buf[:n])) + } + } + + t.Run(name, func(t *testing.T) { + rr.WithCachedReader(func() { + t.Run("cachedReader", testFn(rr)) + }) + + t.Run("ephemeralReader", testFn(rr)) + }) + } +} + +func TestReadAtMultipart(t *testing.T) { + var counter int32 + + testServer := newTestServer(t, func() { + atomic.AddInt32(&counter, 1) + }) + defer testServer.Close() + + resource, err := NewResource(context.Background(), testServer.URL+"/resource") + require.NoError(t, err) + require.Equal(t, int32(1), counter) + + rr := NewRangedReader(resource) + + assertReadAtFunc := func(t *testing.T, bufLen, offset int, expectedDat string, expectedCounter int32) { + buf := make([]byte, bufLen) + n, err := rr.ReadAt(buf, int64(offset)) + require.NoError(t, err) + require.Equal(t, expectedCounter, counter) + + require.NoError(t, err) + require.Equal(t, bufLen, n) + require.Equal(t, expectedDat, string(buf)) + } + bufLen := testDataLen / 3 + + t.Run("ephemeralRead", func(t *testing.T) { + // "1234567890" + assertReadAtFunc(t, bufLen, 0, testData[:bufLen], 2) + // "abcdefghij" + assertReadAtFunc(t, bufLen, bufLen, testData[bufLen:2*bufLen], 3) + // "0987654321" + assertReadAtFunc(t, bufLen, 2*bufLen, testData[2*bufLen:], 4) + }) + + // cachedReader should not make extra requests, the expectedCounter should always be the same + counter = 1 + t.Run("cachedReader", func(t *testing.T) { + rr.WithCachedReader(func() { + // "1234567890" + assertReadAtFunc(t, bufLen, 0, testData[:bufLen], 2) + // "abcdefghij" + assertReadAtFunc(t, bufLen, bufLen, testData[bufLen:2*bufLen], 2) + // "0987654321" + assertReadAtFunc(t, bufLen, 2*bufLen, testData[2*bufLen:], 2) + }) + }) +} + +func newTestServer(t *testing.T, do func()) *httptest.Server { + t.Helper() + + // use a constant known time or else http.ServeContent will change Last-Modified value + tNow, err := time.Parse(time.RFC3339, "2006-01-02T15:04:05Z") + require.NoError(t, err) + + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if do != nil { + do() + } + + http.ServeContent(w, r, r.URL.Path, tNow, strings.NewReader(testData)) + })) +} diff --git a/internal/httprange/http_reader.go b/internal/httprange/http_reader.go new file mode 100644 index 00000000..72ca3cf9 --- /dev/null +++ b/internal/httprange/http_reader.go @@ -0,0 +1,197 @@ +package httprange + +import ( + "errors" + "fmt" + "io" + "net/http" + "time" + + "gitlab.com/gitlab-org/gitlab-pages/internal/httptransport" +) + +var ( + // ErrRangeRequestsNotSupported is returned by Seek and Read + // when the remote server does not allow range requests (Accept-Ranges was not set) + ErrRangeRequestsNotSupported = errors.New("range requests are not supported by the remote server") + + // ErrInvalidRange is returned by Read when trying to read past the end of the file + ErrInvalidRange = errors.New("invalid range") + + // ErrContentHasChanged is returned by Read when the content has changed since the first request + ErrContentHasChanged = errors.New("content has changed since first request") + + // seek errors no need to export them + errSeekInvalidWhence = errors.New("invalid whence") + errSeekOutsideRange = errors.New("outside of range") +) + +// Reader holds a Resource and specifies ranges to read from at a time. +// Implements the io.Reader, io.Seeker and io.Closer interfaces. +type Reader struct { + Resource *Resource + // res defines a current response serving data + res *http.Response + // rangeStart defines a starting range + rangeStart int64 + // rangeSize defines a size of range + rangeSize int64 + // offset defines a current place where data is being read from + offset int64 +} + +// TODO: make this configurable/take an http client when creating a reader/ranged reader +// instead https://gitlab.com/gitlab-org/gitlab-pages/-/issues/457 +var httpClient = &http.Client{ + // The longest time the request can be executed + Timeout: 30 * time.Minute, + Transport: httptransport.InternalTransport, + // TODO: add metrics https://gitlab.com/gitlab-org/gitlab-pages/-/issues/448 + // Transport: httptransport.NewTransportWithMetrics(metrics.ZIPHttpReaderReqDuration, metrics.ZIPHttpReaderReqTotal), +} + +// ensureResponse is set before reading from it. +// It will do the request if the reader hasn't got it yet. +func (r *Reader) ensureResponse() error { + if r.res != nil { + return nil + } + + req, err := r.prepareRequest() + if err != nil { + return err + } + + // TODO: add Traceln info for HTTP calls with headers and response https://gitlab.com/gitlab-org/gitlab-pages/-/issues/448 + res, err := httpClient.Do(req) + if err != nil { + return err + } + + err = r.setResponse(res) + if err != nil { + // cleanup body on failure from r.setResponse to avoid memory leak + res.Body.Close() + } + + return err +} + +func (r *Reader) prepareRequest() (*http.Request, error) { + if r.rangeStart < 0 || r.rangeSize < 0 || r.rangeStart+r.rangeSize > r.Resource.Size { + return nil, ErrInvalidRange + } + + if r.offset < r.rangeStart || r.offset >= r.rangeStart+r.rangeSize { + return nil, ErrInvalidRange + } + + req, err := http.NewRequest("GET", r.Resource.URL, nil) + if err != nil { + return nil, err + } + + if r.Resource.ETag != "" { + req.Header.Set("ETag", r.Resource.ETag) + } else if r.Resource.LastModified != "" { + // Last-Modified should be a fallback mechanism in case ETag is not present + // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Last-Modified + req.Header.Set("If-Range", r.Resource.LastModified) + } + + req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", r.offset, r.rangeStart+r.rangeSize-1)) + + return req, nil +} + +func (r *Reader) setResponse(res *http.Response) error { + // TODO: add metrics https://gitlab.com/gitlab-org/gitlab-pages/-/issues/448 + switch res.StatusCode { + case http.StatusOK: + // some servers return 200 OK for bytes=0- + // TODO: should we handle r.Resource.Last-Modified as well? + if r.offset > 0 || r.Resource.ETag != "" && r.Resource.ETag != res.Header.Get("ETag") { + return ErrContentHasChanged + } + case http.StatusPartialContent: + // Requested `Range` request succeeded https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/206 + break + case http.StatusRequestedRangeNotSatisfiable: + return ErrRangeRequestsNotSupported + default: + return fmt.Errorf("httprange: read response %d: %q", res.StatusCode, res.Status) + } + + r.res = res + + return nil +} + +// Seek returns the new offset relative to the start of the file and an error, if any. +// io.SeekStart means relative to the start of the file, +// io.SeekCurrent means relative to the current offset, and +// io.SeekEnd means relative to the end. +func (r *Reader) Seek(offset int64, whence int) (int64, error) { + var newOffset int64 + + switch whence { + case io.SeekStart: + newOffset = r.rangeStart + offset + + case io.SeekCurrent: + newOffset = r.offset + offset + + case io.SeekEnd: + newOffset = r.rangeStart + r.rangeSize + offset + + default: + return 0, errSeekInvalidWhence + } + + if newOffset < r.rangeStart || newOffset > r.rangeStart+r.rangeSize { + return 0, errSeekOutsideRange + } + + if newOffset != r.offset { + // recycle r.res + r.Close() + } + + r.offset = newOffset + return newOffset - r.rangeStart, nil +} + +// Read data into a given buffer. +func (r *Reader) Read(buf []byte) (int, error) { + if len(buf) == 0 { + return 0, nil + } + + if err := r.ensureResponse(); err != nil { + return 0, err + } + + n, err := r.res.Body.Read(buf) + if err == nil || err == io.EOF { + r.offset += int64(n) + } + + return n, err +} + +// Close closes a requests body +func (r *Reader) Close() error { + if r.res != nil { + // no need to read until the end + err := r.res.Body.Close() + r.res = nil + return err + } + + return nil +} + +// NewReader creates a Reader object on a given resource for a given range +func NewReader(resource *Resource, offset, size int64) *Reader { + return &Reader{Resource: resource, rangeStart: offset, rangeSize: size, offset: offset} +} diff --git a/internal/httprange/http_reader_test.go b/internal/httprange/http_reader_test.go new file mode 100644 index 00000000..507a7fe8 --- /dev/null +++ b/internal/httprange/http_reader_test.go @@ -0,0 +1,326 @@ +package httprange + +import ( + "context" + "io" + "net/http" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSeekAndRead(t *testing.T) { + testServer := newTestServer(t, nil) + defer testServer.Close() + + resource, err := NewResource(context.Background(), testServer.URL+"/data") + require.NoError(t, err) + + tests := map[string]struct { + readerOffset int64 + seekOffset int64 + seekWhence int + readSize int + expectedContent string + expectedSeekErrMsg string + expectedReadErr error + }{ + // io.SeekStart ... + "read_all_from_seek_start": { + readSize: testDataLen, + seekWhence: io.SeekStart, + expectedContent: testData, + expectedReadErr: io.EOF, + }, + "read_10_bytes_from_seek_start": { + readSize: testDataLen / 3, + seekWhence: io.SeekStart, + // "1234567890" + expectedContent: testData[:testDataLen/3], + expectedReadErr: nil, + }, + "read_10_bytes_from_seek_start_with_seek_offset": { + readSize: testDataLen / 3, + seekOffset: int64(testDataLen / 3), + seekWhence: io.SeekStart, + // "abcdefghij" + expectedContent: testData[testDataLen/3 : 2*testDataLen/3], + expectedReadErr: nil, + }, + "read_10_bytes_from_seek_offset_until_eof": { + readSize: testDataLen / 3, + seekOffset: int64(2 * testDataLen / 3), + seekWhence: io.SeekStart, + // "0987654321" + expectedContent: testData[2*testDataLen/3:], + expectedReadErr: io.EOF, + }, + "read_10_bytes_from_reader_offset_with_seek_offset_to_eof": { + readSize: testDataLen / 3, + readerOffset: int64(testDataLen / 3), // reader offset at "a" + seekOffset: int64(testDataLen / 3), // seek offset at "0" + seekWhence: io.SeekStart, + // "0987654321" + expectedContent: testData[2*testDataLen/3:], + expectedReadErr: io.EOF, + }, + "invalid_seek_start_negative_seek_offset": { + seekOffset: -1, + seekWhence: io.SeekStart, + expectedSeekErrMsg: "outside of range", + }, + "invalid_range_seek_at_end": { + readSize: testDataLen, + seekOffset: int64(testDataLen), + seekWhence: io.SeekStart, + expectedReadErr: ErrInvalidRange, + }, + // io.SeekCurrent ... + "read_all_from_seek_current": { + readSize: testDataLen, + seekWhence: io.SeekCurrent, + expectedContent: testData, + expectedReadErr: io.EOF, + }, + "read_10_bytes_from_seek_current": { + readSize: testDataLen / 3, + seekWhence: io.SeekCurrent, + // "1234567890" + expectedContent: testData[:testDataLen/3], + expectedReadErr: nil, + }, + "read_10_bytes_from_seek_current_with_seek_offset": { + readSize: testDataLen / 3, + seekOffset: int64(testDataLen / 3), + seekWhence: io.SeekCurrent, + // "abcdefghij" + expectedContent: testData[testDataLen/3 : 2*testDataLen/3], + expectedReadErr: nil, + }, + "read_10_bytes_from_seek_current_with_seek_offset_until_eof": { + readSize: testDataLen / 3, + seekOffset: int64(2 * testDataLen / 3), + seekWhence: io.SeekCurrent, + // "0987654321" + expectedContent: testData[2*testDataLen/3:], + expectedReadErr: io.EOF, + }, + "read_10_bytes_from_reader_offset_and_seek_current_with_seek_offset_to_eof": { + readSize: testDataLen / 3, + readerOffset: int64(testDataLen / 3), // reader offset at "a" + seekOffset: int64(testDataLen / 3), // seek offset at "0" + seekWhence: io.SeekCurrent, + // "0987654321" + expectedContent: testData[2*testDataLen/3:], + expectedReadErr: io.EOF, + }, + "invalid_seek_current_negative_seek_offset": { + seekOffset: -1, + seekWhence: io.SeekCurrent, + expectedSeekErrMsg: "outside of range", + }, + // io.SeekEnd with negative offsets + "read_all_from_seek_end": { + readSize: testDataLen, + seekWhence: io.SeekEnd, + seekOffset: -int64(testDataLen), + expectedContent: testData, + expectedReadErr: io.EOF, + }, + "read_10_bytes_from_seek_end": { + readSize: testDataLen / 3, + seekWhence: io.SeekEnd, + seekOffset: -int64(testDataLen), + // "1234567890" + expectedContent: testData[:testDataLen/3], + expectedReadErr: nil, + }, + "read_10_bytes_from_seek_end_with_seek_offset": { + readSize: testDataLen / 3, + readerOffset: int64(2 * testDataLen / 3), + seekOffset: -int64(testDataLen / 3), + seekWhence: io.SeekEnd, + // "0987654321" + expectedContent: testData[2*testDataLen/3:], + expectedReadErr: io.EOF, + }, + "read_10_bytes_from_seek_end_with_seek_offset_until_eof": { + readSize: testDataLen / 3, + seekOffset: -int64(testDataLen / 3), + seekWhence: io.SeekEnd, + // "0987654321" + expectedContent: testData[2*testDataLen/3:], + expectedReadErr: io.EOF, + }, + "read_10_bytes_from_reader_offset_and_seek_end_with_seek_offset_to_eof": { + readSize: testDataLen / 3, + readerOffset: int64(testDataLen / 3), // reader offset at "a" + seekOffset: -int64(2 * testDataLen / 3), // seek offset at "a" + seekWhence: io.SeekEnd, + // "abcdefghij" + expectedContent: testData[testDataLen/3 : 2*testDataLen/3], + expectedReadErr: nil, + }, + "invalid_seek_end_positive_seek_offset": { + readSize: testDataLen, + seekOffset: 1, + seekWhence: io.SeekEnd, + expectedSeekErrMsg: "outside of range", + }, + "invalid_range_reading_from_end": { + readSize: testDataLen / 3, + seekWhence: io.SeekEnd, + expectedReadErr: ErrInvalidRange, + }, + } + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + r := NewReader(resource, tt.readerOffset, resource.Size-tt.readerOffset) + + _, err := r.Seek(tt.seekOffset, tt.seekWhence) + if tt.expectedSeekErrMsg != "" { + require.EqualError(t, err, tt.expectedSeekErrMsg) + return + } + require.NoError(t, err) + + buf := make([]byte, tt.readSize) + n, err := r.Read(buf) + if tt.expectedReadErr != nil { + require.Equal(t, tt.expectedReadErr, err) + return + } + + require.Equal(t, n, tt.readSize) + require.Equal(t, tt.expectedContent, string(buf)) + }) + } +} + +func TestReaderSetResponse(t *testing.T) { + tests := map[string]struct { + status int + offset int64 + prevETag string + resEtag string + expectedErrMsg string + }{ + "partial_content_success": { + status: http.StatusPartialContent, + }, + "status_ok_success": { + status: http.StatusOK, + }, + "status_ok_previous_response_invalid_offset": { + status: http.StatusOK, + offset: 1, + expectedErrMsg: ErrContentHasChanged.Error(), + }, + "status_ok_previous_response_different_etag": { + status: http.StatusOK, + prevETag: "old", + resEtag: "new", + expectedErrMsg: ErrContentHasChanged.Error(), + }, + "requested_range_not_satisfiable": { + status: http.StatusRequestedRangeNotSatisfiable, + expectedErrMsg: ErrRangeRequestsNotSupported.Error(), + }, + "unhandled_status_code": { + status: http.StatusNotFound, + expectedErrMsg: "httprange: read response 404:", + }, + } + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + r := NewReader(&Resource{ETag: tt.prevETag}, tt.offset, 0) + res := &http.Response{StatusCode: tt.status, Header: map[string][]string{}} + res.Header.Set("ETag", tt.resEtag) + + err := r.setResponse(res) + if tt.expectedErrMsg != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tt.expectedErrMsg) + return + } + + require.NoError(t, err) + require.Equal(t, r.res, res) + }) + } +} + +func TestReaderSeek(t *testing.T) { + type fields struct { + Resource *Resource + res *http.Response + rangeStart int64 + rangeSize int64 + offset int64 + } + + tests := map[string]struct { + fields fields + offset int64 + whence int + want int64 + newOffset int64 + expectedErrMsg string + }{ + "invalid_whence": { + whence: -1, + expectedErrMsg: "invalid whence", + }, + "outside_of_range_invalid_offset": { + whence: io.SeekStart, + offset: -1, + fields: fields{rangeStart: 1}, + expectedErrMsg: "outside of range", + }, + "outside_of_range_invalid_new_offset": { + whence: io.SeekStart, + offset: 2, // newOffset = 3 + fields: fields{rangeStart: 1, rangeSize: 1}, + expectedErrMsg: "outside of range", + }, + "seek_start": { + whence: io.SeekStart, + offset: 1, + want: 1, + newOffset: 2, + fields: fields{rangeStart: 1, rangeSize: 1}, + }, + "seek_current": { + whence: io.SeekCurrent, + offset: 2, + want: 1, + newOffset: 2, + fields: fields{rangeStart: 1, rangeSize: 1, offset: 0}, + }, + "seek_end": { + whence: io.SeekEnd, + want: 1, + newOffset: 2, + fields: fields{rangeStart: 1, rangeSize: 1, offset: 0}, + }, + } + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + r := &Reader{ + res: tt.fields.res, + rangeStart: tt.fields.rangeStart, + rangeSize: tt.fields.rangeSize, + offset: tt.fields.offset, + } + + got, err := r.Seek(tt.offset, tt.whence) + if tt.expectedErrMsg != "" { + require.EqualError(t, err, tt.expectedErrMsg) + return + } + + require.Equal(t, tt.want, got) + require.Equal(t, tt.newOffset, r.offset) + }) + } +} diff --git a/internal/httprange/resource.go b/internal/httprange/resource.go new file mode 100644 index 00000000..7e21ef29 --- /dev/null +++ b/internal/httprange/resource.go @@ -0,0 +1,77 @@ +package httprange + +import ( + "context" + "fmt" + "io" + "io/ioutil" + "net/http" + "strconv" + "strings" +) + +// Resource represents any HTTP resource that can be read by a GET operation. +// It holds the resource's URL and metadata about it. +type Resource struct { + URL string + ETag string + LastModified string + Size int64 +} + +func NewResource(ctx context.Context, url string) (*Resource, error) { + // the `h.URL` is likely pre-signed URL that only supports GET requests + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, err + } + + req = req.WithContext(ctx) + + // we fetch a single byte and ensure that range requests is additionally supported + req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", 0, 0)) + + // nolint: bodyclose + // body will be closed by discardAndClose + res, err := httpClient.Do(req) + if err != nil { + return nil, err + } + + defer func() { + io.CopyN(ioutil.Discard, res.Body, 1) // since we want to read a single byte + res.Body.Close() + }() + + resource := &Resource{ + URL: url, + ETag: res.Header.Get("ETag"), + LastModified: res.Header.Get("Last-Modified"), + } + + switch res.StatusCode { + case http.StatusOK: + resource.Size = res.ContentLength + return resource, nil + + case http.StatusPartialContent: + contentRange := res.Header.Get("Content-Range") + ranges := strings.SplitN(contentRange, "/", 2) + if len(ranges) != 2 { + return nil, fmt.Errorf("invalid `Content-Range`: %q", contentRange) + } + + resource.Size, err = strconv.ParseInt(ranges[1], 0, 64) + if err != nil { + return nil, fmt.Errorf("invalid `Content-Range`: %q %w", contentRange, err) + } + + return resource, nil + + case http.StatusRequestedRangeNotSatisfiable: + return nil, ErrRangeRequestsNotSupported + + default: + return nil, fmt.Errorf("httprange: new resource %d: %q", res.StatusCode, res.Status) + } +} diff --git a/internal/httprange/resource_test.go b/internal/httprange/resource_test.go new file mode 100644 index 00000000..89d15a21 --- /dev/null +++ b/internal/httprange/resource_test.go @@ -0,0 +1,96 @@ +package httprange + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewResource(t *testing.T) { + resource := Resource{ + URL: "/some/resource", + ETag: "etag", + LastModified: "Wed, 21 Oct 2015 07:28:00 GMT", + Size: 1, + } + + tests := map[string]struct { + url string + status int + contentRange string + want Resource + expectedErrMsg string + }{ + "status_ok": { + url: "/some/resource", + status: http.StatusOK, + want: resource, + }, + "status_partial_content_success": { + url: "/some/resource", + status: http.StatusPartialContent, + contentRange: "bytes 200-1000/67589", + want: func() Resource { + r := resource + r.Size = 67589 + return r + }(), + }, + "status_partial_content_invalid_content_range": { + url: "/some/resource", + status: http.StatusPartialContent, + contentRange: "invalid", + expectedErrMsg: "invalid `Content-Range`:", + }, + "status_partial_content_content_range_not_a_number": { + url: "/some/resource", + status: http.StatusPartialContent, + contentRange: "bytes 200-1000/notanumber", + expectedErrMsg: "invalid `Content-Range`:", + }, + "StatusRequestedRangeNotSatisfiable": { + url: "/some/resource", + status: http.StatusRequestedRangeNotSatisfiable, + expectedErrMsg: ErrRangeRequestsNotSupported.Error(), + }, + "not_found": { + url: "/some/resource", + status: http.StatusNotFound, + expectedErrMsg: fmt.Sprintf("httprange: new resource %d: %q", http.StatusNotFound, "404 Not Found"), + }, + "invalid_url": { + url: "/%", + expectedErrMsg: "invalid URL escape", + }, + } + + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("ETag", tt.want.ETag) + w.Header().Set("Last-Modified", tt.want.LastModified) + w.Header().Set("Content-Range", tt.contentRange) + w.WriteHeader(tt.status) + w.Write([]byte("1")) + })) + defer testServer.Close() + + got, err := NewResource(context.Background(), testServer.URL+tt.url) + if tt.expectedErrMsg != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tt.expectedErrMsg) + return + } + + require.NoError(t, err) + require.Contains(t, got.URL, tt.want.URL) + require.Equal(t, tt.want.LastModified, got.LastModified) + require.Equal(t, tt.want.ETag, got.ETag) + require.Equal(t, tt.want.Size, got.Size) + }) + } +} |