diff options
author | Kamil TrzciĆski <ayufan@ayufan.eu> | 2020-09-03 16:53:26 +0300 |
---|---|---|
committer | Vladimir Shushlin <vshushlin@gitlab.com> | 2020-09-03 16:53:26 +0300 |
commit | e8a29071e5f4d0a8c35eede932cff7f4ed03adfe (patch) | |
tree | 7f62366277976df67c70fcf5a56b8080f000d98e | |
parent | ab92a1bf2fd3d3b28498d79245f19623cec5f618 (diff) |
Abstract `VFS` `Root`
-rw-r--r-- | internal/serving/disk/helpers.go | 9 | ||||
-rw-r--r-- | internal/serving/disk/reader.go | 63 | ||||
-rw-r--r-- | internal/serving/disk/symlink/path_test.go | 149 | ||||
-rw-r--r-- | internal/serving/disk/symlink/shims.go | 4 | ||||
-rw-r--r-- | internal/serving/disk/symlink/symlink.go | 12 | ||||
-rw-r--r-- | internal/source/gitlab/client/client_test.go | 6 | ||||
-rw-r--r-- | internal/vfs/file.go | 10 | ||||
-rw-r--r-- | internal/vfs/local/local_test.go | 96 | ||||
-rw-r--r-- | internal/vfs/local/root.go | 99 | ||||
-rw-r--r-- | internal/vfs/local/root_test.go | 273 | ||||
-rw-r--r-- | internal/vfs/local/vfs.go | 30 | ||||
-rw-r--r-- | internal/vfs/local/vfs_test.go | 103 | ||||
-rw-r--r-- | internal/vfs/root.go | 69 | ||||
-rw-r--r-- | internal/vfs/vfs.go | 49 |
14 files changed, 696 insertions, 276 deletions
diff --git a/internal/serving/disk/helpers.go b/internal/serving/disk/helpers.go index e6d3f8ab..1ff83683 100644 --- a/internal/serving/disk/helpers.go +++ b/internal/serving/disk/helpers.go @@ -10,6 +10,7 @@ import ( "strings" "gitlab.com/gitlab-org/gitlab-pages/internal/httputil" + "gitlab.com/gitlab-org/gitlab-pages/internal/vfs" ) func endsWithSlash(path string) bool { @@ -23,13 +24,13 @@ func endsWithoutHTMLExtension(path string) bool { // Detect file's content-type either by extension or mime-sniffing. // Implementation is adapted from Golang's `http.serveContent()` // See https://github.com/golang/go/blob/902fc114272978a40d2e65c2510a18e870077559/src/net/http/fs.go#L194 -func (reader *Reader) detectContentType(ctx context.Context, path string) (string, error) { +func (reader *Reader) detectContentType(ctx context.Context, root vfs.Root, path string) (string, error) { contentType := mime.TypeByExtension(filepath.Ext(path)) if contentType == "" { var buf [512]byte - file, err := reader.vfs.Open(ctx, path) + file, err := root.Open(ctx, path) if err != nil { return "", err } @@ -55,7 +56,7 @@ func acceptsGZip(r *http.Request) bool { return acceptedEncoding == "gzip" } -func (reader *Reader) handleGZip(ctx context.Context, w http.ResponseWriter, r *http.Request, fullPath string) string { +func (reader *Reader) handleGZip(ctx context.Context, w http.ResponseWriter, r *http.Request, root vfs.Root, fullPath string) string { if !acceptsGZip(r) { return fullPath } @@ -63,7 +64,7 @@ func (reader *Reader) handleGZip(ctx context.Context, w http.ResponseWriter, r * gzipPath := fullPath + ".gz" // Ensure the .gz file is not a symlink - fi, err := reader.vfs.Lstat(ctx, gzipPath) + fi, err := root.Lstat(ctx, gzipPath) if err != nil || !fi.Mode().IsRegular() { return fullPath } diff --git a/internal/serving/disk/reader.go b/internal/serving/disk/reader.go index 8c7ee3bf..0d11a184 100644 --- a/internal/serving/disk/reader.go +++ b/internal/serving/disk/reader.go @@ -5,7 +5,6 @@ import ( "fmt" "io" "net/http" - "path/filepath" "strconv" "strings" "time" @@ -25,7 +24,13 @@ type Reader struct { func (reader *Reader) tryFile(h serving.Handler) error { ctx := h.Request.Context() - fullPath, err := reader.resolvePath(ctx, h.LookupPath.Path, h.SubPath) + + root, err := reader.vfs.Root(ctx, h.LookupPath.Path) + if err != nil { + return err + } + + fullPath, err := reader.resolvePath(ctx, root, h.SubPath) request := h.Request host := request.Host @@ -33,7 +38,7 @@ func (reader *Reader) tryFile(h serving.Handler) error { if locationError, _ := err.(*locationDirectoryError); locationError != nil { if endsWithSlash(urlPath) { - fullPath, err = reader.resolvePath(ctx, h.LookupPath.Path, h.SubPath, "index.html") + fullPath, err = reader.resolvePath(ctx, root, h.SubPath, "index.html") } else { // TODO why are we doing that? In tests it redirects to HTTPS. This seems wrong, // issue about this: https://gitlab.com/gitlab-org/gitlab-pages/issues/273 @@ -50,24 +55,30 @@ func (reader *Reader) tryFile(h serving.Handler) error { } if locationError, _ := err.(*locationFileNoExtensionError); locationError != nil { - fullPath, err = reader.resolvePath(ctx, h.LookupPath.Path, strings.TrimSuffix(h.SubPath, "/")+".html") + fullPath, err = reader.resolvePath(ctx, root, strings.TrimSuffix(h.SubPath, "/")+".html") } if err != nil { return err } - return reader.serveFile(ctx, h.Writer, h.Request, fullPath, h.LookupPath.HasAccessControl) + return reader.serveFile(ctx, h.Writer, h.Request, root, fullPath, h.LookupPath.HasAccessControl) } func (reader *Reader) tryNotFound(h serving.Handler) error { ctx := h.Request.Context() - page404, err := reader.resolvePath(ctx, h.LookupPath.Path, "404.html") + + root, err := reader.vfs.Root(ctx, h.LookupPath.Path) if err != nil { return err } - err = reader.serveCustomFile(ctx, h.Writer, h.Request, http.StatusNotFound, page404) + page404, err := reader.resolvePath(ctx, root, "404.html") + if err != nil { + return err + } + + err = reader.serveCustomFile(ctx, h.Writer, h.Request, http.StatusNotFound, root, page404) if err != nil { return err } @@ -76,15 +87,12 @@ func (reader *Reader) tryNotFound(h serving.Handler) error { // Resolve the HTTP request to a path on disk, converting requests for // directories to requests for index.html inside the directory if appropriate. -func (reader *Reader) resolvePath(ctx context.Context, publicPath string, subPath ...string) (string, error) { - // Ensure that publicPath always ends with "/" - publicPath = strings.TrimSuffix(publicPath, "/") + "/" - +func (reader *Reader) resolvePath(ctx context.Context, root vfs.Root, subPath ...string) (string, error) { // Don't use filepath.Join as cleans the path, // where we want to traverse full path as supplied by user // (including ..) - testPath := publicPath + strings.Join(subPath, "/") - fullPath, err := symlink.EvalSymlinks(ctx, reader.vfs, testPath) + testPath := strings.Join(subPath, "/") + fullPath, err := symlink.EvalSymlinks(ctx, root, testPath) if err != nil { if endsWithoutHTMLExtension(testPath) { @@ -96,12 +104,7 @@ func (reader *Reader) resolvePath(ctx context.Context, publicPath string, subPat return "", err } - // The requested path resolved to somewhere outside of the public/ directory - if !strings.HasPrefix(fullPath, publicPath) && fullPath != filepath.Clean(publicPath) { - return "", fmt.Errorf("%q should be in %q", fullPath, publicPath) - } - - fi, err := reader.vfs.Lstat(ctx, fullPath) + fi, err := root.Lstat(ctx, fullPath) if err != nil { return "", err } @@ -110,7 +113,7 @@ func (reader *Reader) resolvePath(ctx context.Context, publicPath string, subPat if fi.IsDir() { return "", &locationDirectoryError{ FullPath: fullPath, - RelativePath: strings.TrimPrefix(fullPath, publicPath), + RelativePath: testPath, } } @@ -123,17 +126,17 @@ func (reader *Reader) resolvePath(ctx context.Context, publicPath string, subPat return fullPath, nil } -func (reader *Reader) serveFile(ctx context.Context, w http.ResponseWriter, r *http.Request, origPath string, accessControl bool) error { - fullPath := reader.handleGZip(ctx, w, r, origPath) +func (reader *Reader) serveFile(ctx context.Context, w http.ResponseWriter, r *http.Request, root vfs.Root, origPath string, accessControl bool) error { + fullPath := reader.handleGZip(ctx, w, r, root, origPath) - file, err := reader.vfs.Open(ctx, fullPath) + file, err := root.Open(ctx, fullPath) if err != nil { return err } defer file.Close() - fi, err := reader.vfs.Lstat(ctx, fullPath) + fi, err := root.Lstat(ctx, fullPath) if err != nil { return err } @@ -144,7 +147,7 @@ func (reader *Reader) serveFile(ctx context.Context, w http.ResponseWriter, r *h w.Header().Set("Expires", time.Now().Add(10*time.Minute).Format(time.RFC1123)) } - contentType, err := reader.detectContentType(ctx, origPath) + contentType, err := reader.detectContentType(ctx, root, origPath) if err != nil { return err } @@ -157,22 +160,22 @@ func (reader *Reader) serveFile(ctx context.Context, w http.ResponseWriter, r *h return nil } -func (reader *Reader) serveCustomFile(ctx context.Context, w http.ResponseWriter, r *http.Request, code int, origPath string) error { - fullPath := reader.handleGZip(ctx, w, r, origPath) +func (reader *Reader) serveCustomFile(ctx context.Context, w http.ResponseWriter, r *http.Request, code int, root vfs.Root, origPath string) error { + fullPath := reader.handleGZip(ctx, w, r, root, origPath) // Open and serve content of file - file, err := reader.vfs.Open(ctx, fullPath) + file, err := root.Open(ctx, fullPath) if err != nil { return err } defer file.Close() - fi, err := reader.vfs.Lstat(ctx, fullPath) + fi, err := root.Lstat(ctx, fullPath) if err != nil { return err } - contentType, err := reader.detectContentType(ctx, origPath) + contentType, err := reader.detectContentType(ctx, root, origPath) if err != nil { return err } diff --git a/internal/serving/disk/symlink/path_test.go b/internal/serving/disk/symlink/path_test.go index 4d590db5..6b0a41f3 100644 --- a/internal/serving/disk/symlink/path_test.go +++ b/internal/serving/disk/symlink/path_test.go @@ -10,31 +10,27 @@ import ( "os" "path/filepath" "runtime" + "strings" "testing" + "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/gitlab-pages/internal/serving/disk/symlink" + "gitlab.com/gitlab-org/gitlab-pages/internal/vfs" "gitlab.com/gitlab-org/gitlab-pages/internal/vfs/local" ) -var fs = local.VFS{} +var fs = vfs.Instrumented(&local.VFS{}, "local") -func chtmpdir(t *testing.T) (restore func()) { - oldwd, err := os.Getwd() - if err != nil { - t.Fatalf("chtmpdir: %v", err) - } - d, err := ioutil.TempDir("", "test") - if err != nil { - t.Fatalf("chtmpdir: %v", err) - } - if err := os.Chdir(d); err != nil { - t.Fatalf("chtmpdir: %v", err) - } - return func() { - if err := os.Chdir(oldwd); err != nil { - t.Fatalf("chtmpdir: %v", err) - } - os.RemoveAll(d) +func tmpDir(t *testing.T) (vfs.Root, string, func()) { + tmpDir, err := ioutil.TempDir("", "symlink_tests") + require.NoError(t, err) + + root, err := fs.Root(context.Background(), tmpDir) + require.NoError(t, err) + + return root, tmpDir, func() { + os.RemoveAll(tmpDir) } } @@ -74,19 +70,22 @@ var EvalSymlinksTests = []EvalSymlinksTest{ {"test/link2/..", "test"}, {"test/dir/link3", "."}, {"test/link2/link3/test", "test"}, - {"test/linkabs", "/"}, + {"test/linkabs", "../.."}, {"test/link4/..", "test"}, {"src/versions/current/modules/test", "src/pool/test"}, } // simpleJoin builds a file name from the directory and path. // It does not use Join because we don't want ".." to be evaluated. -func simpleJoin(dir, path string) string { - return dir + string(filepath.Separator) + path +func simpleJoin(path ...string) string { + return strings.Join(path, string(filepath.Separator)) } -func testEvalSymlinks(t *testing.T, path, want string) { - have, err := symlink.EvalSymlinks(context.Background(), fs, path) +func testEvalSymlinks(t *testing.T, wd, path, want string) { + root, err := fs.Root(context.Background(), wd) + require.NoError(t, err) + + have, err := symlink.EvalSymlinks(context.Background(), root, path) if err != nil { t.Errorf("EvalSymlinks(%q) error: %v", path, err) return @@ -96,46 +95,9 @@ func testEvalSymlinks(t *testing.T, path, want string) { } } -func testEvalSymlinksAfterChdir(t *testing.T, wd, path, want string) { - cwd, err := os.Getwd() - if err != nil { - t.Fatal(err) - } - defer func() { - err := os.Chdir(cwd) - if err != nil { - t.Fatal(err) - } - }() - - err = os.Chdir(wd) - if err != nil { - t.Fatal(err) - } - - have, err := symlink.EvalSymlinks(context.Background(), fs, path) - if err != nil { - t.Errorf("EvalSymlinks(%q) in %q directory error: %v", path, wd, err) - return - } - if filepath.Clean(have) != filepath.Clean(want) { - t.Errorf("EvalSymlinks(%q) in %q directory returns %q, want %q", path, wd, have, want) - } -} - func TestEvalSymlinks(t *testing.T) { - tmpDir, err := ioutil.TempDir("", "evalsymlink") - if err != nil { - t.Fatal("creating temp dir:", err) - } - defer os.RemoveAll(tmpDir) - - // /tmp may itself be a symlink! Avoid the confusion, although - // it means trusting the thing we're testing. - tmpDir, err = filepath.EvalSymlinks(tmpDir) - if err != nil { - t.Fatal("eval symlink for tmp dir:", err) - } + _, tmpDir, cleanup := tmpDir(t) + defer cleanup() // Create the symlink farm using relative paths. for _, d := range EvalSymlinksTestDirs { @@ -153,42 +115,30 @@ func TestEvalSymlinks(t *testing.T) { // Evaluate the symlink farm. for _, test := range EvalSymlinksTests { - path := simpleJoin(tmpDir, test.path) - - dest := simpleJoin(tmpDir, test.dest) - if filepath.IsAbs(test.dest) || os.IsPathSeparator(test.dest[0]) { - dest = test.dest - } - testEvalSymlinks(t, path, dest) + testEvalSymlinks(t, tmpDir, test.path, test.dest) // test EvalSymlinks(".") - testEvalSymlinksAfterChdir(t, path, ".", ".") + testEvalSymlinks(t, simpleJoin(tmpDir, test.path), ".", ".") // test EvalSymlinks("C:.") on Windows if runtime.GOOS == "windows" { volDot := filepath.VolumeName(tmpDir) + "." - testEvalSymlinksAfterChdir(t, path, volDot, volDot) + testEvalSymlinks(t, simpleJoin(tmpDir, test.path), volDot, volDot) } // test EvalSymlinks(".."+path) - dotdotPath := simpleJoin("..", test.dest) - if filepath.IsAbs(test.dest) || os.IsPathSeparator(test.dest[0]) { - dotdotPath = test.dest - } - testEvalSymlinksAfterChdir(t, - simpleJoin(tmpDir, "test"), - simpleJoin("..", test.path), - dotdotPath) - - // test EvalSymlinks(p) where p is relative path - testEvalSymlinksAfterChdir(t, tmpDir, test.path, test.dest) + testEvalSymlinks(t, + tmpDir, + simpleJoin("test", "..", test.path), + test.dest) } } func TestEvalSymlinksIsNotExist(t *testing.T) { - defer chtmpdir(t)() + root, _, cleanup := tmpDir(t) + defer cleanup() - _, err := symlink.EvalSymlinks(context.Background(), fs, "notexist") + _, err := symlink.EvalSymlinks(context.Background(), root, "notexist") if !os.IsNotExist(err) { t.Errorf("expected the file is not found, got %v\n", err) } @@ -199,21 +149,18 @@ func TestEvalSymlinksIsNotExist(t *testing.T) { } defer os.Remove("link") - _, err = symlink.EvalSymlinks(context.Background(), fs, "link") + _, err = symlink.EvalSymlinks(context.Background(), root, "link") if !os.IsNotExist(err) { t.Errorf("expected the file is not found, got %v\n", err) } } func TestIssue13582(t *testing.T) { - tmpDir, err := ioutil.TempDir("", "issue13582") - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll(tmpDir) + root, tmpDir, cleanup := tmpDir(t) + defer cleanup() dir := filepath.Join(tmpDir, "dir") - err = os.Mkdir(dir, 0755) + err := os.Mkdir(dir, 0755) if err != nil { t.Fatal(err) } @@ -238,25 +185,17 @@ func TestIssue13582(t *testing.T) { t.Fatal(err) } - // /tmp may itself be a symlink! - realTmpDir, err := filepath.EvalSymlinks(tmpDir) - if err != nil { - t.Fatal(err) - } - realDir := filepath.Join(realTmpDir, "dir") - realFile := filepath.Join(realDir, "file") - tests := []struct { path, want string }{ - {dir, realDir}, - {linkToDir, realDir}, - {file, realFile}, - {link1, realFile}, - {link2, realFile}, + {"dir", "dir"}, + {"link_to_dir", "dir"}, + {"link_to_dir/file", "dir/file"}, + {"link_to_dir/link1", "dir/file"}, + {"link_to_dir/link2", "dir/file"}, } for i, test := range tests { - have, err := symlink.EvalSymlinks(context.Background(), fs, test.path) + have, err := symlink.EvalSymlinks(context.Background(), root, test.path) if err != nil { t.Fatal(err) } diff --git a/internal/serving/disk/symlink/shims.go b/internal/serving/disk/symlink/shims.go index d383b96b..90f67d45 100644 --- a/internal/serving/disk/symlink/shims.go +++ b/internal/serving/disk/symlink/shims.go @@ -12,6 +12,6 @@ func volumeNameLen(s string) int { return 0 } func IsAbs(path string) bool { return filepath.IsAbs(path) } func Clean(path string) string { return filepath.Clean(path) } -func EvalSymlinks(ctx context.Context, fs vfs.VFS, path string) (string, error) { - return walkSymlinks(ctx, fs, path) +func EvalSymlinks(ctx context.Context, root vfs.Root, path string) (string, error) { + return walkSymlinks(ctx, root, path) } diff --git a/internal/serving/disk/symlink/symlink.go b/internal/serving/disk/symlink/symlink.go index 50714811..3b5d242a 100644 --- a/internal/serving/disk/symlink/symlink.go +++ b/internal/serving/disk/symlink/symlink.go @@ -14,7 +14,7 @@ import ( "gitlab.com/gitlab-org/gitlab-pages/internal/vfs" ) -func walkSymlinks(ctx context.Context, fs vfs.VFS, path string) (string, error) { +func walkSymlinks(ctx context.Context, root vfs.Root, path string) (string, error) { volLen := volumeNameLen(path) pathSeparator := string(os.PathSeparator) @@ -57,7 +57,10 @@ func walkSymlinks(ctx context.Context, fs vfs.VFS, path string) (string, error) break } } - if r < volLen || dest[r+1:] == ".." { + + if r >= 0 && r+1 == volLen && os.IsPathSeparator(dest[r]) { + return "", errors.New("EvalSymlinks: cannot backtrack root path") + } else if r < volLen || dest[r+1:] == ".." { // Either path has no slashes // (it's empty or just "C:") // or it ends in a ".." we had to keep. @@ -83,7 +86,7 @@ func walkSymlinks(ctx context.Context, fs vfs.VFS, path string) (string, error) // Resolve symlink. - fi, err := fs.Lstat(ctx, dest) + fi, err := root.Lstat(ctx, dest) if err != nil { return "", err } @@ -102,7 +105,7 @@ func walkSymlinks(ctx context.Context, fs vfs.VFS, path string) (string, error) return "", errors.New("EvalSymlinks: too many links") } - link, err := fs.Readlink(ctx, dest) + link, err := root.Readlink(ctx, dest) if err != nil { return "", err } @@ -145,5 +148,6 @@ func walkSymlinks(ctx context.Context, fs vfs.VFS, path string) (string, error) end = 0 } } + return Clean(dest), nil } diff --git a/internal/source/gitlab/client/client_test.go b/internal/source/gitlab/client/client_test.go index 57c479d7..c888a059 100644 --- a/internal/source/gitlab/client/client_test.go +++ b/internal/source/gitlab/client/client_test.go @@ -300,7 +300,11 @@ func TestClientStatusClientTimeout(t *testing.T) { err := client.Status() require.Error(t, err) - require.Contains(t, err.Error(), "Client.Timeout") + // we can receive any of these messages + // - context deadline exceeded (Client.Timeout exceeded while awaiting headers) + // - net/http: request canceled (Client.Timeout exceeded while awaiting headers) + // - context deadline exceeded + require.Contains(t, err.Error(), "exceeded") } func TestClientStatusConnectionRefused(t *testing.T) { diff --git a/internal/vfs/file.go b/internal/vfs/file.go new file mode 100644 index 00000000..5260c847 --- /dev/null +++ b/internal/vfs/file.go @@ -0,0 +1,10 @@ +package vfs + +import "io" + +// File represents an open file, which will typically be the response body of a Pages request. +type File interface { + io.Reader + io.Seeker + io.Closer +} diff --git a/internal/vfs/local/local_test.go b/internal/vfs/local/local_test.go deleted file mode 100644 index c41f96bc..00000000 --- a/internal/vfs/local/local_test.go +++ /dev/null @@ -1,96 +0,0 @@ -package local - -import ( - "context" - "io/ioutil" - "os" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestReadlink(t *testing.T) { - ctx := context.Background() - fs := VFS{} - - target, err := fs.Readlink(ctx, "testdata/link") - require.NoError(t, err) - require.Equal(t, "file", target) -} - -func TestReadlinkNotSymlink(t *testing.T) { - ctx := context.Background() - fs := VFS{} - - for _, path := range []string{"testdata", "testdata/file"} { - t.Run(path, func(t *testing.T) { - _, err := os.Lstat(path) - require.NoError(t, err, "sanity check: input must actually exist") - - _, err = fs.Readlink(ctx, path) - require.Error(t, err, "expect readlink to fail") - }) - } -} - -func TestLstat(t *testing.T) { - ctx := context.Background() - fs := VFS{} - - testCases := []struct { - path string - modePerm os.FileMode - modeType os.FileMode - }{ - {path: "testdata", modeType: os.ModeDir, modePerm: 0755}, - {path: "testdata/file", modeType: os.FileMode(0), modePerm: 0644}, - {path: "testdata/link", modeType: os.ModeSymlink}, // Permissions of symlinks are platform dependent - } - - for _, tc := range testCases { - t.Run(tc.path, func(t *testing.T) { - if tc.modePerm > 0 { - require.NoError(t, os.Chmod(tc.path, tc.modePerm), "preparation: deterministic permissions") - } - - fi, err := fs.Lstat(ctx, tc.path) - require.NoError(t, err, "lstat error") - - require.Equal(t, tc.modeType, fi.Mode()&os.ModeType, "file mode: type") - if tc.modePerm > 0 { - require.Equal(t, tc.modePerm, fi.Mode()&os.ModePerm, "file mode: permissions") - } - }) - } -} - -func TestOpen(t *testing.T) { - ctx := context.Background() - fs := VFS{} - - f, err := fs.Open(ctx, "testdata/file") - require.NoError(t, err, "open file") - - data, err := ioutil.ReadAll(f) - require.NoError(t, err, "read from file") - require.Equal(t, "hello\n", string(data), "file contents") - - require.NoError(t, f.Close(), "close file") -} - -func TestOpenDenySymlink(t *testing.T) { - ctx := context.Background() - fs := VFS{} - const symlinkPath = "testdata/link" - - fi, err := os.Stat(symlinkPath) - require.NoError(t, err, "stat link") - require.Equal(t, os.FileMode(0), fi.Mode()&os.ModeType, "sanity check: link target should be a regular file") - - fi, err = os.Lstat(symlinkPath) - require.NoError(t, err, "lstat link") - require.Equal(t, os.ModeSymlink, fi.Mode()&os.ModeType, "sanity check: testdata/link should be a symlink") - - _, err = fs.Open(ctx, symlinkPath) - require.Error(t, err, "opening symlink should fail (security mechanism)") -} diff --git a/internal/vfs/local/root.go b/internal/vfs/local/root.go new file mode 100644 index 00000000..0ed2206d --- /dev/null +++ b/internal/vfs/local/root.go @@ -0,0 +1,99 @@ +package local + +import ( + "context" + "errors" + "fmt" + "os" + "path/filepath" + "strings" + + "golang.org/x/sys/unix" + + "gitlab.com/gitlab-org/gitlab-pages/internal/vfs" +) + +var errNotFile = errors.New("path needs to be a file") + +type invalidPathError struct { + rootPath string + requestedPath string +} + +func (i *invalidPathError) Error() string { + return fmt.Sprintf("%q should be in %q", i.requestedPath, i.rootPath) +} + +type Root struct { + rootPath string +} + +func (r *Root) validatePath(path string) (string, string, error) { + fullPath := filepath.Join(r.rootPath, path) + + if r.rootPath == fullPath { + return fullPath, "", nil + } + + vfsPath := strings.TrimPrefix(fullPath, r.rootPath+"/") + + // The requested path resolved to somewhere outside of the `r.rootPath` directory + if fullPath == vfsPath { + return "", "", &invalidPathError{rootPath: r.rootPath, requestedPath: fullPath} + } + + return fullPath, vfsPath, nil +} + +func (r *Root) Lstat(ctx context.Context, name string) (os.FileInfo, error) { + fullPath, _, err := r.validatePath(name) + if err != nil { + return nil, err + } + + return os.Lstat(fullPath) +} + +func (r *Root) Readlink(ctx context.Context, name string) (string, error) { + fullPath, _, err := r.validatePath(name) + if err != nil { + return "", err + } + + target, err := os.Readlink(fullPath) + if err != nil { + return "", err + } + + if filepath.IsAbs(target) { + return filepath.Rel(filepath.Dir(fullPath), target) + } + + return target, nil +} + +func (r *Root) Open(ctx context.Context, name string) (vfs.File, error) { + fullPath, _, err := r.validatePath(name) + if err != nil { + return nil, err + } + + file, err := os.OpenFile(fullPath, os.O_RDONLY|unix.O_NOFOLLOW, 0) + if err != nil { + return nil, err + } + + // We do a `Stat()` on a file due to race-conditions + // Someone could update (unlikely) a file between `Stat()/Open()` + fi, err := file.Stat() + if err != nil { + return nil, err + } + + if !fi.Mode().IsRegular() { + file.Close() + return nil, errNotFile + } + + return file, nil +} diff --git a/internal/vfs/local/root_test.go b/internal/vfs/local/root_test.go new file mode 100644 index 00000000..a4711169 --- /dev/null +++ b/internal/vfs/local/root_test.go @@ -0,0 +1,273 @@ +package local + +import ( + "context" + "io/ioutil" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestValidatePath(t *testing.T) { + ctx := context.Background() + rootVFS, err := localVFS.Root(ctx, ".") + require.NoError(t, err) + + root := rootVFS.(*Root) + + wd, err := os.Getwd() + require.NoError(t, err) + + tests := map[string]struct { + path string + expectedFullPath string + expectedVFSPath string + expectedInvalidPath bool + }{ + "a valid path": { + path: "testdata/link", + expectedFullPath: filepath.Join(wd, "testdata", "link"), + expectedVFSPath: filepath.Join("testdata", "link"), + }, + "a path outside of root directory": { + path: "testdata/../../link", + expectedInvalidPath: true, + }, + "an absolute path": { + // we don't support absolute paths, thus the `wd` will be preprended to `path` + path: filepath.Join(wd, "testdata", "link"), + expectedFullPath: filepath.Join(wd, wd, "testdata", "link"), + expectedVFSPath: filepath.Join(wd, "testdata", "link")[1:], // strip leading `/` + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + fullPath, vfsPath, err := root.validatePath(test.path) + + if test.expectedInvalidPath { + require.IsType(t, &invalidPathError{}, err, "InvalidPath") + return + } + + require.NoError(t, err, "validatePath") + assert.Equal(t, test.expectedFullPath, fullPath, "FullPath") + assert.Equal(t, test.expectedVFSPath, vfsPath, "VFSPath") + }) + } +} + +func TestReadlink(t *testing.T) { + ctx := context.Background() + root, err := localVFS.Root(ctx, ".") + require.NoError(t, err) + + tests := map[string]struct { + path string + expectedTarget string + expectedErr string + expectedInvalidPath bool + expectedIsNotExist bool + }{ + "a valid link": { + path: "testdata/link", + expectedTarget: "file", + }, + "a file": { + path: "testdata/file", + expectedErr: "invalid argument", + }, + "a path outside of root directory": { + path: "testdata/../../link", + expectedInvalidPath: true, + }, + "a non-existing link": { + path: "non-existing", + expectedIsNotExist: true, + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + target, err := root.Readlink(ctx, test.path) + + if test.expectedIsNotExist { + require.Equal(t, test.expectedIsNotExist, os.IsNotExist(err), "IsNotExist") + return + } + + if test.expectedInvalidPath { + require.IsType(t, &invalidPathError{}, err, "InvalidPath") + return + } + + if test.expectedErr != "" { + require.Error(t, err) + require.Contains(t, err.Error(), test.expectedErr, "Readlink") + return + } + + require.NoError(t, err, "Readlink") + assert.Equal(t, test.expectedTarget, target, "target") + }) + } +} + +func TestReadlinkAbsolutePath(t *testing.T) { + // create structure as: + // /tmp/dir: directory + // /tmp/dir/symlink: points to `/tmp/file` + tmpDir, err := ioutil.TempDir("", "vfs") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + dirPath := filepath.Join(tmpDir, "dir") + err = os.Mkdir(dirPath, 0755) + require.NoError(t, err) + + symlinkPath := filepath.Join(dirPath, "symlink") + filePath := filepath.Join(tmpDir, "file") + err = os.Symlink(filePath, symlinkPath) + require.NoError(t, err) + + root, err := localVFS.Root(context.Background(), dirPath) + require.NoError(t, err) + + targetPath, err := root.Readlink(context.Background(), "symlink") + require.NoError(t, err) + + assert.Equal(t, "../file", targetPath, "the relative path is returned") +} + +func TestLstat(t *testing.T) { + ctx := context.Background() + root, err := localVFS.Root(ctx, ".") + require.NoError(t, err) + + tests := map[string]struct { + path string + modePerm os.FileMode + modeType os.FileMode + expectedInvalidPath bool + expectedIsNotExist bool + }{ + "a directory": { + path: "testdata", + modeType: os.ModeDir, + modePerm: 0755, + }, + "a file": { + path: "testdata/file", + modeType: os.FileMode(0), + modePerm: 0644, + }, + "a link": { + path: "testdata/link", + modeType: os.ModeSymlink, + // modePerm: Permissions of symlinks are platform dependent + }, + "a path outside of root directory": { + path: "testdata/../../link", + expectedInvalidPath: true, + }, + "a non-existing link": { + path: "non-existing", + expectedIsNotExist: true, + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if test.modePerm > 0 { + require.NoError(t, os.Chmod(test.path, test.modePerm), "preparation: deterministic permissions") + } + + fi, err := root.Lstat(ctx, test.path) + + if test.expectedIsNotExist { + require.Equal(t, test.expectedIsNotExist, os.IsNotExist(err), "IsNotExist") + return + } + + if test.expectedInvalidPath { + require.IsType(t, &invalidPathError{}, err, "InvalidPath") + return + } + + require.NoError(t, err, "Lstat") + require.Equal(t, test.modeType, fi.Mode()&os.ModeType, "file mode: type") + if test.modePerm > 0 { + require.Equal(t, test.modePerm, fi.Mode()&os.ModePerm, "file mode: permissions") + } + }) + } +} + +func TestOpen(t *testing.T) { + ctx := context.Background() + root, err := localVFS.Root(ctx, ".") + require.NoError(t, err) + + tests := map[string]struct { + path string + expectedInvalidPath bool + expectedIsNotExist bool + expectedContent string + expectedErr string + }{ + "a file": { + path: "testdata/file", + expectedContent: "hello\n", + }, + "a directory": { + path: "testdata", + expectedErr: errNotFile.Error(), + }, + "a link": { + path: "testdata/link", + expectedErr: "too many levels of symbolic links", + }, + "a path outside of root directory": { + path: "testdata/../../link", + expectedInvalidPath: true, + }, + "a non-existing file": { + path: "non-existing", + expectedIsNotExist: true, + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + file, err := root.Open(ctx, test.path) + if file != nil { + defer file.Close() + } + + if test.expectedIsNotExist { + require.Equal(t, test.expectedIsNotExist, os.IsNotExist(err), "IsNotExist") + return + } + + if test.expectedErr != "" { + require.Error(t, err, "Open") + require.Contains(t, err.Error(), test.expectedErr, "Open") + return + } + + if test.expectedInvalidPath { + require.IsType(t, &invalidPathError{}, err, "InvalidPath") + return + } + + require.NoError(t, err, "Open") + + data, err := ioutil.ReadAll(file) + require.NoError(t, err, "ReadAll") + require.Equal(t, test.expectedContent, string(data), "ReadAll") + }) + } +} diff --git a/internal/vfs/local/vfs.go b/internal/vfs/local/vfs.go index 7c6f3ba6..a2eb14f7 100644 --- a/internal/vfs/local/vfs.go +++ b/internal/vfs/local/vfs.go @@ -2,18 +2,36 @@ package local import ( "context" + "errors" "os" - - "golang.org/x/sys/unix" + "path/filepath" "gitlab.com/gitlab-org/gitlab-pages/internal/vfs" ) +var errNotDirectory = errors.New("path needs to be a directory") + type VFS struct{} -func (fs VFS) Lstat(ctx context.Context, name string) (os.FileInfo, error) { return os.Lstat(name) } -func (fs VFS) Readlink(ctx context.Context, name string) (string, error) { return os.Readlink(name) } +func (fs VFS) Root(ctx context.Context, path string) (vfs.Root, error) { + rootPath, err := filepath.Abs(path) + if err != nil { + return nil, err + } + + rootPath, err = filepath.EvalSymlinks(rootPath) + if err != nil { + return nil, err + } + + fi, err := os.Lstat(rootPath) + if err != nil { + return nil, err + } + + if !fi.Mode().IsDir() { + return nil, errNotDirectory + } -func (fs VFS) Open(ctx context.Context, name string) (vfs.File, error) { - return os.OpenFile(name, os.O_RDONLY|unix.O_NOFOLLOW, 0) + return &Root{rootPath: rootPath}, nil } diff --git a/internal/vfs/local/vfs_test.go b/internal/vfs/local/vfs_test.go new file mode 100644 index 00000000..6ceb08a5 --- /dev/null +++ b/internal/vfs/local/vfs_test.go @@ -0,0 +1,103 @@ +package local + +import ( + "context" + "io/ioutil" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var localVFS = &VFS{} + +func TestVFSRoot(t *testing.T) { + // create structure as: + // /tmp/dir: directory + // /tmp/dir_link: symlink to `dir` + // /tmp/dir_absolute_link: symlink to `/tmp/dir` + // /tmp/file: file + // /tmp/file_link: symlink to `file` + // /tmp/file_absolute_link: symlink to `/tmp/file` + tmpDir, err := ioutil.TempDir("", "vfs") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + dirPath := filepath.Join(tmpDir, "dir") + err = os.Mkdir(dirPath, 0755) + require.NoError(t, err) + + filePath := filepath.Join(tmpDir, "file") + err = ioutil.WriteFile(filePath, []byte{}, 0644) + require.NoError(t, err) + + symlinks := map[string]string{ + "dir_link": "dir", + "dir_absolute_link": dirPath, + "file_link": "file", + "file_absolute_link": filePath, + } + + for dest, src := range symlinks { + err := os.Symlink(src, filepath.Join(tmpDir, dest)) + require.NoError(t, err) + } + + tests := map[string]struct { + path string + expectedPath string + expectedErr error + expectedIsNotExist bool + }{ + "a valid directory": { + path: "dir", + expectedPath: dirPath, + }, + "a symlink to directory": { + path: "dir_link", + expectedPath: dirPath, + }, + "a symlink to absolute directory": { + path: "dir_absolute_link", + expectedPath: dirPath, + }, + "a file": { + path: "file", + expectedErr: errNotDirectory, + }, + "a symlink to file": { + path: "file_link", + expectedErr: errNotDirectory, + }, + "a symlink to absolute file": { + path: "file_absolute_link", + expectedErr: errNotDirectory, + }, + "a non-existing file": { + path: "not-existing", + expectedIsNotExist: true, + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + rootVFS, err := localVFS.Root(context.Background(), filepath.Join(tmpDir, test.path)) + + if test.expectedIsNotExist { + require.Equal(t, test.expectedIsNotExist, os.IsNotExist(err)) + return + } + + if test.expectedErr != nil { + require.EqualError(t, err, test.expectedErr.Error()) + return + } + + require.NoError(t, err) + require.IsType(t, &Root{}, rootVFS) + assert.Equal(t, test.expectedPath, rootVFS.(*Root).rootPath) + }) + } +} diff --git a/internal/vfs/root.go b/internal/vfs/root.go new file mode 100644 index 00000000..30d97b0b --- /dev/null +++ b/internal/vfs/root.go @@ -0,0 +1,69 @@ +package vfs + +import ( + "context" + "os" + "strconv" + + log "github.com/sirupsen/logrus" + + "gitlab.com/gitlab-org/gitlab-pages/metrics" +) + +// Root abstracts the things Pages needs to serve a static site from a given root rootPath. +type Root interface { + Lstat(ctx context.Context, name string) (os.FileInfo, error) + Readlink(ctx context.Context, name string) (string, error) + Open(ctx context.Context, name string) (File, error) +} + +type instrumentedRoot struct { + root Root + name string + rootPath string +} + +func (i *instrumentedRoot) increment(operation string, err error) { + metrics.VFSOperations.WithLabelValues(i.name, operation, strconv.FormatBool(err == nil)).Inc() +} + +func (i *instrumentedRoot) log() *log.Entry { + return log.WithField("vfs", i.name).WithField("root-path", i.rootPath) +} + +func (i *instrumentedRoot) Lstat(ctx context.Context, name string) (os.FileInfo, error) { + fi, err := i.root.Lstat(ctx, name) + + i.increment("Lstat", err) + i.log(). + WithField("name", name). + WithError(err). + Traceln("Lstat call") + + return fi, err +} + +func (i *instrumentedRoot) Readlink(ctx context.Context, name string) (string, error) { + target, err := i.root.Readlink(ctx, name) + + i.increment("Readlink", err) + i.log(). + WithField("name", name). + WithField("ret-target", target). + WithError(err). + Traceln("Readlink call") + + return target, err +} + +func (i *instrumentedRoot) Open(ctx context.Context, name string) (File, error) { + f, err := i.root.Open(ctx, name) + + i.increment("Open", err) + i.log(). + WithField("name", name). + WithError(err). + Traceln("Open call") + + return f, err +} diff --git a/internal/vfs/vfs.go b/internal/vfs/vfs.go index 07c99b77..9d9551a1 100644 --- a/internal/vfs/vfs.go +++ b/internal/vfs/vfs.go @@ -2,54 +2,47 @@ package vfs import ( "context" - "io" - "os" "strconv" + log "github.com/sirupsen/logrus" + "gitlab.com/gitlab-org/gitlab-pages/metrics" ) // VFS abstracts the things Pages needs to serve a static site from disk. type VFS interface { - Lstat(ctx context.Context, name string) (os.FileInfo, error) - Readlink(ctx context.Context, name string) (string, error) - Open(ctx context.Context, name string) (File, error) -} - -// File represents an open file, which will typically be the response body of a Pages request. -type File interface { - io.Reader - io.Seeker - io.Closer + Root(ctx context.Context, path string) (Root, error) } func Instrumented(fs VFS, name string) VFS { - return &InstrumentedVFS{fs: fs, name: name} + return &instrumentedVFS{fs: fs, name: name} } -type InstrumentedVFS struct { +type instrumentedVFS struct { fs VFS name string } -func (i *InstrumentedVFS) increment(operation string, err error) { +func (i *instrumentedVFS) increment(operation string, err error) { metrics.VFSOperations.WithLabelValues(i.name, operation, strconv.FormatBool(err == nil)).Inc() } -func (i *InstrumentedVFS) Lstat(ctx context.Context, name string) (os.FileInfo, error) { - fi, err := i.fs.Lstat(ctx, name) - i.increment("Lstat", err) - return fi, err +func (i *instrumentedVFS) log() *log.Entry { + return log.WithField("vfs", i.name) } -func (i *InstrumentedVFS) Readlink(ctx context.Context, name string) (string, error) { - target, err := i.fs.Readlink(ctx, name) - i.increment("Readlink", err) - return target, err -} +func (i *instrumentedVFS) Root(ctx context.Context, path string) (Root, error) { + root, err := i.fs.Root(ctx, path) + + i.increment("Root", err) + i.log(). + WithField("path", path). + WithError(err). + Traceln("Root call") + + if err != nil { + return nil, err + } -func (i *InstrumentedVFS) Open(ctx context.Context, name string) (File, error) { - f, err := i.fs.Open(ctx, name) - i.increment("Open", err) - return f, err + return &instrumentedRoot{root: root, name: i.name, rootPath: path}, nil } |