Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitlab-pages.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKamil TrzciƄski <ayufan@ayufan.eu>2020-09-03 16:53:26 +0300
committerVladimir Shushlin <vshushlin@gitlab.com>2020-09-03 16:53:26 +0300
commite8a29071e5f4d0a8c35eede932cff7f4ed03adfe (patch)
tree7f62366277976df67c70fcf5a56b8080f000d98e
parentab92a1bf2fd3d3b28498d79245f19623cec5f618 (diff)
Abstract `VFS` `Root`
-rw-r--r--internal/serving/disk/helpers.go9
-rw-r--r--internal/serving/disk/reader.go63
-rw-r--r--internal/serving/disk/symlink/path_test.go149
-rw-r--r--internal/serving/disk/symlink/shims.go4
-rw-r--r--internal/serving/disk/symlink/symlink.go12
-rw-r--r--internal/source/gitlab/client/client_test.go6
-rw-r--r--internal/vfs/file.go10
-rw-r--r--internal/vfs/local/local_test.go96
-rw-r--r--internal/vfs/local/root.go99
-rw-r--r--internal/vfs/local/root_test.go273
-rw-r--r--internal/vfs/local/vfs.go30
-rw-r--r--internal/vfs/local/vfs_test.go103
-rw-r--r--internal/vfs/root.go69
-rw-r--r--internal/vfs/vfs.go49
14 files changed, 696 insertions, 276 deletions
diff --git a/internal/serving/disk/helpers.go b/internal/serving/disk/helpers.go
index e6d3f8ab..1ff83683 100644
--- a/internal/serving/disk/helpers.go
+++ b/internal/serving/disk/helpers.go
@@ -10,6 +10,7 @@ import (
"strings"
"gitlab.com/gitlab-org/gitlab-pages/internal/httputil"
+ "gitlab.com/gitlab-org/gitlab-pages/internal/vfs"
)
func endsWithSlash(path string) bool {
@@ -23,13 +24,13 @@ func endsWithoutHTMLExtension(path string) bool {
// Detect file's content-type either by extension or mime-sniffing.
// Implementation is adapted from Golang's `http.serveContent()`
// See https://github.com/golang/go/blob/902fc114272978a40d2e65c2510a18e870077559/src/net/http/fs.go#L194
-func (reader *Reader) detectContentType(ctx context.Context, path string) (string, error) {
+func (reader *Reader) detectContentType(ctx context.Context, root vfs.Root, path string) (string, error) {
contentType := mime.TypeByExtension(filepath.Ext(path))
if contentType == "" {
var buf [512]byte
- file, err := reader.vfs.Open(ctx, path)
+ file, err := root.Open(ctx, path)
if err != nil {
return "", err
}
@@ -55,7 +56,7 @@ func acceptsGZip(r *http.Request) bool {
return acceptedEncoding == "gzip"
}
-func (reader *Reader) handleGZip(ctx context.Context, w http.ResponseWriter, r *http.Request, fullPath string) string {
+func (reader *Reader) handleGZip(ctx context.Context, w http.ResponseWriter, r *http.Request, root vfs.Root, fullPath string) string {
if !acceptsGZip(r) {
return fullPath
}
@@ -63,7 +64,7 @@ func (reader *Reader) handleGZip(ctx context.Context, w http.ResponseWriter, r *
gzipPath := fullPath + ".gz"
// Ensure the .gz file is not a symlink
- fi, err := reader.vfs.Lstat(ctx, gzipPath)
+ fi, err := root.Lstat(ctx, gzipPath)
if err != nil || !fi.Mode().IsRegular() {
return fullPath
}
diff --git a/internal/serving/disk/reader.go b/internal/serving/disk/reader.go
index 8c7ee3bf..0d11a184 100644
--- a/internal/serving/disk/reader.go
+++ b/internal/serving/disk/reader.go
@@ -5,7 +5,6 @@ import (
"fmt"
"io"
"net/http"
- "path/filepath"
"strconv"
"strings"
"time"
@@ -25,7 +24,13 @@ type Reader struct {
func (reader *Reader) tryFile(h serving.Handler) error {
ctx := h.Request.Context()
- fullPath, err := reader.resolvePath(ctx, h.LookupPath.Path, h.SubPath)
+
+ root, err := reader.vfs.Root(ctx, h.LookupPath.Path)
+ if err != nil {
+ return err
+ }
+
+ fullPath, err := reader.resolvePath(ctx, root, h.SubPath)
request := h.Request
host := request.Host
@@ -33,7 +38,7 @@ func (reader *Reader) tryFile(h serving.Handler) error {
if locationError, _ := err.(*locationDirectoryError); locationError != nil {
if endsWithSlash(urlPath) {
- fullPath, err = reader.resolvePath(ctx, h.LookupPath.Path, h.SubPath, "index.html")
+ fullPath, err = reader.resolvePath(ctx, root, h.SubPath, "index.html")
} else {
// TODO why are we doing that? In tests it redirects to HTTPS. This seems wrong,
// issue about this: https://gitlab.com/gitlab-org/gitlab-pages/issues/273
@@ -50,24 +55,30 @@ func (reader *Reader) tryFile(h serving.Handler) error {
}
if locationError, _ := err.(*locationFileNoExtensionError); locationError != nil {
- fullPath, err = reader.resolvePath(ctx, h.LookupPath.Path, strings.TrimSuffix(h.SubPath, "/")+".html")
+ fullPath, err = reader.resolvePath(ctx, root, strings.TrimSuffix(h.SubPath, "/")+".html")
}
if err != nil {
return err
}
- return reader.serveFile(ctx, h.Writer, h.Request, fullPath, h.LookupPath.HasAccessControl)
+ return reader.serveFile(ctx, h.Writer, h.Request, root, fullPath, h.LookupPath.HasAccessControl)
}
func (reader *Reader) tryNotFound(h serving.Handler) error {
ctx := h.Request.Context()
- page404, err := reader.resolvePath(ctx, h.LookupPath.Path, "404.html")
+
+ root, err := reader.vfs.Root(ctx, h.LookupPath.Path)
if err != nil {
return err
}
- err = reader.serveCustomFile(ctx, h.Writer, h.Request, http.StatusNotFound, page404)
+ page404, err := reader.resolvePath(ctx, root, "404.html")
+ if err != nil {
+ return err
+ }
+
+ err = reader.serveCustomFile(ctx, h.Writer, h.Request, http.StatusNotFound, root, page404)
if err != nil {
return err
}
@@ -76,15 +87,12 @@ func (reader *Reader) tryNotFound(h serving.Handler) error {
// Resolve the HTTP request to a path on disk, converting requests for
// directories to requests for index.html inside the directory if appropriate.
-func (reader *Reader) resolvePath(ctx context.Context, publicPath string, subPath ...string) (string, error) {
- // Ensure that publicPath always ends with "/"
- publicPath = strings.TrimSuffix(publicPath, "/") + "/"
-
+func (reader *Reader) resolvePath(ctx context.Context, root vfs.Root, subPath ...string) (string, error) {
// Don't use filepath.Join as cleans the path,
// where we want to traverse full path as supplied by user
// (including ..)
- testPath := publicPath + strings.Join(subPath, "/")
- fullPath, err := symlink.EvalSymlinks(ctx, reader.vfs, testPath)
+ testPath := strings.Join(subPath, "/")
+ fullPath, err := symlink.EvalSymlinks(ctx, root, testPath)
if err != nil {
if endsWithoutHTMLExtension(testPath) {
@@ -96,12 +104,7 @@ func (reader *Reader) resolvePath(ctx context.Context, publicPath string, subPat
return "", err
}
- // The requested path resolved to somewhere outside of the public/ directory
- if !strings.HasPrefix(fullPath, publicPath) && fullPath != filepath.Clean(publicPath) {
- return "", fmt.Errorf("%q should be in %q", fullPath, publicPath)
- }
-
- fi, err := reader.vfs.Lstat(ctx, fullPath)
+ fi, err := root.Lstat(ctx, fullPath)
if err != nil {
return "", err
}
@@ -110,7 +113,7 @@ func (reader *Reader) resolvePath(ctx context.Context, publicPath string, subPat
if fi.IsDir() {
return "", &locationDirectoryError{
FullPath: fullPath,
- RelativePath: strings.TrimPrefix(fullPath, publicPath),
+ RelativePath: testPath,
}
}
@@ -123,17 +126,17 @@ func (reader *Reader) resolvePath(ctx context.Context, publicPath string, subPat
return fullPath, nil
}
-func (reader *Reader) serveFile(ctx context.Context, w http.ResponseWriter, r *http.Request, origPath string, accessControl bool) error {
- fullPath := reader.handleGZip(ctx, w, r, origPath)
+func (reader *Reader) serveFile(ctx context.Context, w http.ResponseWriter, r *http.Request, root vfs.Root, origPath string, accessControl bool) error {
+ fullPath := reader.handleGZip(ctx, w, r, root, origPath)
- file, err := reader.vfs.Open(ctx, fullPath)
+ file, err := root.Open(ctx, fullPath)
if err != nil {
return err
}
defer file.Close()
- fi, err := reader.vfs.Lstat(ctx, fullPath)
+ fi, err := root.Lstat(ctx, fullPath)
if err != nil {
return err
}
@@ -144,7 +147,7 @@ func (reader *Reader) serveFile(ctx context.Context, w http.ResponseWriter, r *h
w.Header().Set("Expires", time.Now().Add(10*time.Minute).Format(time.RFC1123))
}
- contentType, err := reader.detectContentType(ctx, origPath)
+ contentType, err := reader.detectContentType(ctx, root, origPath)
if err != nil {
return err
}
@@ -157,22 +160,22 @@ func (reader *Reader) serveFile(ctx context.Context, w http.ResponseWriter, r *h
return nil
}
-func (reader *Reader) serveCustomFile(ctx context.Context, w http.ResponseWriter, r *http.Request, code int, origPath string) error {
- fullPath := reader.handleGZip(ctx, w, r, origPath)
+func (reader *Reader) serveCustomFile(ctx context.Context, w http.ResponseWriter, r *http.Request, code int, root vfs.Root, origPath string) error {
+ fullPath := reader.handleGZip(ctx, w, r, root, origPath)
// Open and serve content of file
- file, err := reader.vfs.Open(ctx, fullPath)
+ file, err := root.Open(ctx, fullPath)
if err != nil {
return err
}
defer file.Close()
- fi, err := reader.vfs.Lstat(ctx, fullPath)
+ fi, err := root.Lstat(ctx, fullPath)
if err != nil {
return err
}
- contentType, err := reader.detectContentType(ctx, origPath)
+ contentType, err := reader.detectContentType(ctx, root, origPath)
if err != nil {
return err
}
diff --git a/internal/serving/disk/symlink/path_test.go b/internal/serving/disk/symlink/path_test.go
index 4d590db5..6b0a41f3 100644
--- a/internal/serving/disk/symlink/path_test.go
+++ b/internal/serving/disk/symlink/path_test.go
@@ -10,31 +10,27 @@ import (
"os"
"path/filepath"
"runtime"
+ "strings"
"testing"
+ "github.com/stretchr/testify/require"
+
"gitlab.com/gitlab-org/gitlab-pages/internal/serving/disk/symlink"
+ "gitlab.com/gitlab-org/gitlab-pages/internal/vfs"
"gitlab.com/gitlab-org/gitlab-pages/internal/vfs/local"
)
-var fs = local.VFS{}
+var fs = vfs.Instrumented(&local.VFS{}, "local")
-func chtmpdir(t *testing.T) (restore func()) {
- oldwd, err := os.Getwd()
- if err != nil {
- t.Fatalf("chtmpdir: %v", err)
- }
- d, err := ioutil.TempDir("", "test")
- if err != nil {
- t.Fatalf("chtmpdir: %v", err)
- }
- if err := os.Chdir(d); err != nil {
- t.Fatalf("chtmpdir: %v", err)
- }
- return func() {
- if err := os.Chdir(oldwd); err != nil {
- t.Fatalf("chtmpdir: %v", err)
- }
- os.RemoveAll(d)
+func tmpDir(t *testing.T) (vfs.Root, string, func()) {
+ tmpDir, err := ioutil.TempDir("", "symlink_tests")
+ require.NoError(t, err)
+
+ root, err := fs.Root(context.Background(), tmpDir)
+ require.NoError(t, err)
+
+ return root, tmpDir, func() {
+ os.RemoveAll(tmpDir)
}
}
@@ -74,19 +70,22 @@ var EvalSymlinksTests = []EvalSymlinksTest{
{"test/link2/..", "test"},
{"test/dir/link3", "."},
{"test/link2/link3/test", "test"},
- {"test/linkabs", "/"},
+ {"test/linkabs", "../.."},
{"test/link4/..", "test"},
{"src/versions/current/modules/test", "src/pool/test"},
}
// simpleJoin builds a file name from the directory and path.
// It does not use Join because we don't want ".." to be evaluated.
-func simpleJoin(dir, path string) string {
- return dir + string(filepath.Separator) + path
+func simpleJoin(path ...string) string {
+ return strings.Join(path, string(filepath.Separator))
}
-func testEvalSymlinks(t *testing.T, path, want string) {
- have, err := symlink.EvalSymlinks(context.Background(), fs, path)
+func testEvalSymlinks(t *testing.T, wd, path, want string) {
+ root, err := fs.Root(context.Background(), wd)
+ require.NoError(t, err)
+
+ have, err := symlink.EvalSymlinks(context.Background(), root, path)
if err != nil {
t.Errorf("EvalSymlinks(%q) error: %v", path, err)
return
@@ -96,46 +95,9 @@ func testEvalSymlinks(t *testing.T, path, want string) {
}
}
-func testEvalSymlinksAfterChdir(t *testing.T, wd, path, want string) {
- cwd, err := os.Getwd()
- if err != nil {
- t.Fatal(err)
- }
- defer func() {
- err := os.Chdir(cwd)
- if err != nil {
- t.Fatal(err)
- }
- }()
-
- err = os.Chdir(wd)
- if err != nil {
- t.Fatal(err)
- }
-
- have, err := symlink.EvalSymlinks(context.Background(), fs, path)
- if err != nil {
- t.Errorf("EvalSymlinks(%q) in %q directory error: %v", path, wd, err)
- return
- }
- if filepath.Clean(have) != filepath.Clean(want) {
- t.Errorf("EvalSymlinks(%q) in %q directory returns %q, want %q", path, wd, have, want)
- }
-}
-
func TestEvalSymlinks(t *testing.T) {
- tmpDir, err := ioutil.TempDir("", "evalsymlink")
- if err != nil {
- t.Fatal("creating temp dir:", err)
- }
- defer os.RemoveAll(tmpDir)
-
- // /tmp may itself be a symlink! Avoid the confusion, although
- // it means trusting the thing we're testing.
- tmpDir, err = filepath.EvalSymlinks(tmpDir)
- if err != nil {
- t.Fatal("eval symlink for tmp dir:", err)
- }
+ _, tmpDir, cleanup := tmpDir(t)
+ defer cleanup()
// Create the symlink farm using relative paths.
for _, d := range EvalSymlinksTestDirs {
@@ -153,42 +115,30 @@ func TestEvalSymlinks(t *testing.T) {
// Evaluate the symlink farm.
for _, test := range EvalSymlinksTests {
- path := simpleJoin(tmpDir, test.path)
-
- dest := simpleJoin(tmpDir, test.dest)
- if filepath.IsAbs(test.dest) || os.IsPathSeparator(test.dest[0]) {
- dest = test.dest
- }
- testEvalSymlinks(t, path, dest)
+ testEvalSymlinks(t, tmpDir, test.path, test.dest)
// test EvalSymlinks(".")
- testEvalSymlinksAfterChdir(t, path, ".", ".")
+ testEvalSymlinks(t, simpleJoin(tmpDir, test.path), ".", ".")
// test EvalSymlinks("C:.") on Windows
if runtime.GOOS == "windows" {
volDot := filepath.VolumeName(tmpDir) + "."
- testEvalSymlinksAfterChdir(t, path, volDot, volDot)
+ testEvalSymlinks(t, simpleJoin(tmpDir, test.path), volDot, volDot)
}
// test EvalSymlinks(".."+path)
- dotdotPath := simpleJoin("..", test.dest)
- if filepath.IsAbs(test.dest) || os.IsPathSeparator(test.dest[0]) {
- dotdotPath = test.dest
- }
- testEvalSymlinksAfterChdir(t,
- simpleJoin(tmpDir, "test"),
- simpleJoin("..", test.path),
- dotdotPath)
-
- // test EvalSymlinks(p) where p is relative path
- testEvalSymlinksAfterChdir(t, tmpDir, test.path, test.dest)
+ testEvalSymlinks(t,
+ tmpDir,
+ simpleJoin("test", "..", test.path),
+ test.dest)
}
}
func TestEvalSymlinksIsNotExist(t *testing.T) {
- defer chtmpdir(t)()
+ root, _, cleanup := tmpDir(t)
+ defer cleanup()
- _, err := symlink.EvalSymlinks(context.Background(), fs, "notexist")
+ _, err := symlink.EvalSymlinks(context.Background(), root, "notexist")
if !os.IsNotExist(err) {
t.Errorf("expected the file is not found, got %v\n", err)
}
@@ -199,21 +149,18 @@ func TestEvalSymlinksIsNotExist(t *testing.T) {
}
defer os.Remove("link")
- _, err = symlink.EvalSymlinks(context.Background(), fs, "link")
+ _, err = symlink.EvalSymlinks(context.Background(), root, "link")
if !os.IsNotExist(err) {
t.Errorf("expected the file is not found, got %v\n", err)
}
}
func TestIssue13582(t *testing.T) {
- tmpDir, err := ioutil.TempDir("", "issue13582")
- if err != nil {
- t.Fatal(err)
- }
- defer os.RemoveAll(tmpDir)
+ root, tmpDir, cleanup := tmpDir(t)
+ defer cleanup()
dir := filepath.Join(tmpDir, "dir")
- err = os.Mkdir(dir, 0755)
+ err := os.Mkdir(dir, 0755)
if err != nil {
t.Fatal(err)
}
@@ -238,25 +185,17 @@ func TestIssue13582(t *testing.T) {
t.Fatal(err)
}
- // /tmp may itself be a symlink!
- realTmpDir, err := filepath.EvalSymlinks(tmpDir)
- if err != nil {
- t.Fatal(err)
- }
- realDir := filepath.Join(realTmpDir, "dir")
- realFile := filepath.Join(realDir, "file")
-
tests := []struct {
path, want string
}{
- {dir, realDir},
- {linkToDir, realDir},
- {file, realFile},
- {link1, realFile},
- {link2, realFile},
+ {"dir", "dir"},
+ {"link_to_dir", "dir"},
+ {"link_to_dir/file", "dir/file"},
+ {"link_to_dir/link1", "dir/file"},
+ {"link_to_dir/link2", "dir/file"},
}
for i, test := range tests {
- have, err := symlink.EvalSymlinks(context.Background(), fs, test.path)
+ have, err := symlink.EvalSymlinks(context.Background(), root, test.path)
if err != nil {
t.Fatal(err)
}
diff --git a/internal/serving/disk/symlink/shims.go b/internal/serving/disk/symlink/shims.go
index d383b96b..90f67d45 100644
--- a/internal/serving/disk/symlink/shims.go
+++ b/internal/serving/disk/symlink/shims.go
@@ -12,6 +12,6 @@ func volumeNameLen(s string) int { return 0 }
func IsAbs(path string) bool { return filepath.IsAbs(path) }
func Clean(path string) string { return filepath.Clean(path) }
-func EvalSymlinks(ctx context.Context, fs vfs.VFS, path string) (string, error) {
- return walkSymlinks(ctx, fs, path)
+func EvalSymlinks(ctx context.Context, root vfs.Root, path string) (string, error) {
+ return walkSymlinks(ctx, root, path)
}
diff --git a/internal/serving/disk/symlink/symlink.go b/internal/serving/disk/symlink/symlink.go
index 50714811..3b5d242a 100644
--- a/internal/serving/disk/symlink/symlink.go
+++ b/internal/serving/disk/symlink/symlink.go
@@ -14,7 +14,7 @@ import (
"gitlab.com/gitlab-org/gitlab-pages/internal/vfs"
)
-func walkSymlinks(ctx context.Context, fs vfs.VFS, path string) (string, error) {
+func walkSymlinks(ctx context.Context, root vfs.Root, path string) (string, error) {
volLen := volumeNameLen(path)
pathSeparator := string(os.PathSeparator)
@@ -57,7 +57,10 @@ func walkSymlinks(ctx context.Context, fs vfs.VFS, path string) (string, error)
break
}
}
- if r < volLen || dest[r+1:] == ".." {
+
+ if r >= 0 && r+1 == volLen && os.IsPathSeparator(dest[r]) {
+ return "", errors.New("EvalSymlinks: cannot backtrack root path")
+ } else if r < volLen || dest[r+1:] == ".." {
// Either path has no slashes
// (it's empty or just "C:")
// or it ends in a ".." we had to keep.
@@ -83,7 +86,7 @@ func walkSymlinks(ctx context.Context, fs vfs.VFS, path string) (string, error)
// Resolve symlink.
- fi, err := fs.Lstat(ctx, dest)
+ fi, err := root.Lstat(ctx, dest)
if err != nil {
return "", err
}
@@ -102,7 +105,7 @@ func walkSymlinks(ctx context.Context, fs vfs.VFS, path string) (string, error)
return "", errors.New("EvalSymlinks: too many links")
}
- link, err := fs.Readlink(ctx, dest)
+ link, err := root.Readlink(ctx, dest)
if err != nil {
return "", err
}
@@ -145,5 +148,6 @@ func walkSymlinks(ctx context.Context, fs vfs.VFS, path string) (string, error)
end = 0
}
}
+
return Clean(dest), nil
}
diff --git a/internal/source/gitlab/client/client_test.go b/internal/source/gitlab/client/client_test.go
index 57c479d7..c888a059 100644
--- a/internal/source/gitlab/client/client_test.go
+++ b/internal/source/gitlab/client/client_test.go
@@ -300,7 +300,11 @@ func TestClientStatusClientTimeout(t *testing.T) {
err := client.Status()
require.Error(t, err)
- require.Contains(t, err.Error(), "Client.Timeout")
+ // we can receive any of these messages
+ // - context deadline exceeded (Client.Timeout exceeded while awaiting headers)
+ // - net/http: request canceled (Client.Timeout exceeded while awaiting headers)
+ // - context deadline exceeded
+ require.Contains(t, err.Error(), "exceeded")
}
func TestClientStatusConnectionRefused(t *testing.T) {
diff --git a/internal/vfs/file.go b/internal/vfs/file.go
new file mode 100644
index 00000000..5260c847
--- /dev/null
+++ b/internal/vfs/file.go
@@ -0,0 +1,10 @@
+package vfs
+
+import "io"
+
+// File represents an open file, which will typically be the response body of a Pages request.
+type File interface {
+ io.Reader
+ io.Seeker
+ io.Closer
+}
diff --git a/internal/vfs/local/local_test.go b/internal/vfs/local/local_test.go
deleted file mode 100644
index c41f96bc..00000000
--- a/internal/vfs/local/local_test.go
+++ /dev/null
@@ -1,96 +0,0 @@
-package local
-
-import (
- "context"
- "io/ioutil"
- "os"
- "testing"
-
- "github.com/stretchr/testify/require"
-)
-
-func TestReadlink(t *testing.T) {
- ctx := context.Background()
- fs := VFS{}
-
- target, err := fs.Readlink(ctx, "testdata/link")
- require.NoError(t, err)
- require.Equal(t, "file", target)
-}
-
-func TestReadlinkNotSymlink(t *testing.T) {
- ctx := context.Background()
- fs := VFS{}
-
- for _, path := range []string{"testdata", "testdata/file"} {
- t.Run(path, func(t *testing.T) {
- _, err := os.Lstat(path)
- require.NoError(t, err, "sanity check: input must actually exist")
-
- _, err = fs.Readlink(ctx, path)
- require.Error(t, err, "expect readlink to fail")
- })
- }
-}
-
-func TestLstat(t *testing.T) {
- ctx := context.Background()
- fs := VFS{}
-
- testCases := []struct {
- path string
- modePerm os.FileMode
- modeType os.FileMode
- }{
- {path: "testdata", modeType: os.ModeDir, modePerm: 0755},
- {path: "testdata/file", modeType: os.FileMode(0), modePerm: 0644},
- {path: "testdata/link", modeType: os.ModeSymlink}, // Permissions of symlinks are platform dependent
- }
-
- for _, tc := range testCases {
- t.Run(tc.path, func(t *testing.T) {
- if tc.modePerm > 0 {
- require.NoError(t, os.Chmod(tc.path, tc.modePerm), "preparation: deterministic permissions")
- }
-
- fi, err := fs.Lstat(ctx, tc.path)
- require.NoError(t, err, "lstat error")
-
- require.Equal(t, tc.modeType, fi.Mode()&os.ModeType, "file mode: type")
- if tc.modePerm > 0 {
- require.Equal(t, tc.modePerm, fi.Mode()&os.ModePerm, "file mode: permissions")
- }
- })
- }
-}
-
-func TestOpen(t *testing.T) {
- ctx := context.Background()
- fs := VFS{}
-
- f, err := fs.Open(ctx, "testdata/file")
- require.NoError(t, err, "open file")
-
- data, err := ioutil.ReadAll(f)
- require.NoError(t, err, "read from file")
- require.Equal(t, "hello\n", string(data), "file contents")
-
- require.NoError(t, f.Close(), "close file")
-}
-
-func TestOpenDenySymlink(t *testing.T) {
- ctx := context.Background()
- fs := VFS{}
- const symlinkPath = "testdata/link"
-
- fi, err := os.Stat(symlinkPath)
- require.NoError(t, err, "stat link")
- require.Equal(t, os.FileMode(0), fi.Mode()&os.ModeType, "sanity check: link target should be a regular file")
-
- fi, err = os.Lstat(symlinkPath)
- require.NoError(t, err, "lstat link")
- require.Equal(t, os.ModeSymlink, fi.Mode()&os.ModeType, "sanity check: testdata/link should be a symlink")
-
- _, err = fs.Open(ctx, symlinkPath)
- require.Error(t, err, "opening symlink should fail (security mechanism)")
-}
diff --git a/internal/vfs/local/root.go b/internal/vfs/local/root.go
new file mode 100644
index 00000000..0ed2206d
--- /dev/null
+++ b/internal/vfs/local/root.go
@@ -0,0 +1,99 @@
+package local
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "os"
+ "path/filepath"
+ "strings"
+
+ "golang.org/x/sys/unix"
+
+ "gitlab.com/gitlab-org/gitlab-pages/internal/vfs"
+)
+
+var errNotFile = errors.New("path needs to be a file")
+
+type invalidPathError struct {
+ rootPath string
+ requestedPath string
+}
+
+func (i *invalidPathError) Error() string {
+ return fmt.Sprintf("%q should be in %q", i.requestedPath, i.rootPath)
+}
+
+type Root struct {
+ rootPath string
+}
+
+func (r *Root) validatePath(path string) (string, string, error) {
+ fullPath := filepath.Join(r.rootPath, path)
+
+ if r.rootPath == fullPath {
+ return fullPath, "", nil
+ }
+
+ vfsPath := strings.TrimPrefix(fullPath, r.rootPath+"/")
+
+ // The requested path resolved to somewhere outside of the `r.rootPath` directory
+ if fullPath == vfsPath {
+ return "", "", &invalidPathError{rootPath: r.rootPath, requestedPath: fullPath}
+ }
+
+ return fullPath, vfsPath, nil
+}
+
+func (r *Root) Lstat(ctx context.Context, name string) (os.FileInfo, error) {
+ fullPath, _, err := r.validatePath(name)
+ if err != nil {
+ return nil, err
+ }
+
+ return os.Lstat(fullPath)
+}
+
+func (r *Root) Readlink(ctx context.Context, name string) (string, error) {
+ fullPath, _, err := r.validatePath(name)
+ if err != nil {
+ return "", err
+ }
+
+ target, err := os.Readlink(fullPath)
+ if err != nil {
+ return "", err
+ }
+
+ if filepath.IsAbs(target) {
+ return filepath.Rel(filepath.Dir(fullPath), target)
+ }
+
+ return target, nil
+}
+
+func (r *Root) Open(ctx context.Context, name string) (vfs.File, error) {
+ fullPath, _, err := r.validatePath(name)
+ if err != nil {
+ return nil, err
+ }
+
+ file, err := os.OpenFile(fullPath, os.O_RDONLY|unix.O_NOFOLLOW, 0)
+ if err != nil {
+ return nil, err
+ }
+
+ // We do a `Stat()` on a file due to race-conditions
+ // Someone could update (unlikely) a file between `Stat()/Open()`
+ fi, err := file.Stat()
+ if err != nil {
+ return nil, err
+ }
+
+ if !fi.Mode().IsRegular() {
+ file.Close()
+ return nil, errNotFile
+ }
+
+ return file, nil
+}
diff --git a/internal/vfs/local/root_test.go b/internal/vfs/local/root_test.go
new file mode 100644
index 00000000..a4711169
--- /dev/null
+++ b/internal/vfs/local/root_test.go
@@ -0,0 +1,273 @@
+package local
+
+import (
+ "context"
+ "io/ioutil"
+ "os"
+ "path/filepath"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestValidatePath(t *testing.T) {
+ ctx := context.Background()
+ rootVFS, err := localVFS.Root(ctx, ".")
+ require.NoError(t, err)
+
+ root := rootVFS.(*Root)
+
+ wd, err := os.Getwd()
+ require.NoError(t, err)
+
+ tests := map[string]struct {
+ path string
+ expectedFullPath string
+ expectedVFSPath string
+ expectedInvalidPath bool
+ }{
+ "a valid path": {
+ path: "testdata/link",
+ expectedFullPath: filepath.Join(wd, "testdata", "link"),
+ expectedVFSPath: filepath.Join("testdata", "link"),
+ },
+ "a path outside of root directory": {
+ path: "testdata/../../link",
+ expectedInvalidPath: true,
+ },
+ "an absolute path": {
+ // we don't support absolute paths, thus the `wd` will be preprended to `path`
+ path: filepath.Join(wd, "testdata", "link"),
+ expectedFullPath: filepath.Join(wd, wd, "testdata", "link"),
+ expectedVFSPath: filepath.Join(wd, "testdata", "link")[1:], // strip leading `/`
+ },
+ }
+
+ for name, test := range tests {
+ t.Run(name, func(t *testing.T) {
+ fullPath, vfsPath, err := root.validatePath(test.path)
+
+ if test.expectedInvalidPath {
+ require.IsType(t, &invalidPathError{}, err, "InvalidPath")
+ return
+ }
+
+ require.NoError(t, err, "validatePath")
+ assert.Equal(t, test.expectedFullPath, fullPath, "FullPath")
+ assert.Equal(t, test.expectedVFSPath, vfsPath, "VFSPath")
+ })
+ }
+}
+
+func TestReadlink(t *testing.T) {
+ ctx := context.Background()
+ root, err := localVFS.Root(ctx, ".")
+ require.NoError(t, err)
+
+ tests := map[string]struct {
+ path string
+ expectedTarget string
+ expectedErr string
+ expectedInvalidPath bool
+ expectedIsNotExist bool
+ }{
+ "a valid link": {
+ path: "testdata/link",
+ expectedTarget: "file",
+ },
+ "a file": {
+ path: "testdata/file",
+ expectedErr: "invalid argument",
+ },
+ "a path outside of root directory": {
+ path: "testdata/../../link",
+ expectedInvalidPath: true,
+ },
+ "a non-existing link": {
+ path: "non-existing",
+ expectedIsNotExist: true,
+ },
+ }
+
+ for name, test := range tests {
+ t.Run(name, func(t *testing.T) {
+ target, err := root.Readlink(ctx, test.path)
+
+ if test.expectedIsNotExist {
+ require.Equal(t, test.expectedIsNotExist, os.IsNotExist(err), "IsNotExist")
+ return
+ }
+
+ if test.expectedInvalidPath {
+ require.IsType(t, &invalidPathError{}, err, "InvalidPath")
+ return
+ }
+
+ if test.expectedErr != "" {
+ require.Error(t, err)
+ require.Contains(t, err.Error(), test.expectedErr, "Readlink")
+ return
+ }
+
+ require.NoError(t, err, "Readlink")
+ assert.Equal(t, test.expectedTarget, target, "target")
+ })
+ }
+}
+
+func TestReadlinkAbsolutePath(t *testing.T) {
+ // create structure as:
+ // /tmp/dir: directory
+ // /tmp/dir/symlink: points to `/tmp/file`
+ tmpDir, err := ioutil.TempDir("", "vfs")
+ require.NoError(t, err)
+ defer os.RemoveAll(tmpDir)
+
+ dirPath := filepath.Join(tmpDir, "dir")
+ err = os.Mkdir(dirPath, 0755)
+ require.NoError(t, err)
+
+ symlinkPath := filepath.Join(dirPath, "symlink")
+ filePath := filepath.Join(tmpDir, "file")
+ err = os.Symlink(filePath, symlinkPath)
+ require.NoError(t, err)
+
+ root, err := localVFS.Root(context.Background(), dirPath)
+ require.NoError(t, err)
+
+ targetPath, err := root.Readlink(context.Background(), "symlink")
+ require.NoError(t, err)
+
+ assert.Equal(t, "../file", targetPath, "the relative path is returned")
+}
+
+func TestLstat(t *testing.T) {
+ ctx := context.Background()
+ root, err := localVFS.Root(ctx, ".")
+ require.NoError(t, err)
+
+ tests := map[string]struct {
+ path string
+ modePerm os.FileMode
+ modeType os.FileMode
+ expectedInvalidPath bool
+ expectedIsNotExist bool
+ }{
+ "a directory": {
+ path: "testdata",
+ modeType: os.ModeDir,
+ modePerm: 0755,
+ },
+ "a file": {
+ path: "testdata/file",
+ modeType: os.FileMode(0),
+ modePerm: 0644,
+ },
+ "a link": {
+ path: "testdata/link",
+ modeType: os.ModeSymlink,
+ // modePerm: Permissions of symlinks are platform dependent
+ },
+ "a path outside of root directory": {
+ path: "testdata/../../link",
+ expectedInvalidPath: true,
+ },
+ "a non-existing link": {
+ path: "non-existing",
+ expectedIsNotExist: true,
+ },
+ }
+
+ for name, test := range tests {
+ t.Run(name, func(t *testing.T) {
+ if test.modePerm > 0 {
+ require.NoError(t, os.Chmod(test.path, test.modePerm), "preparation: deterministic permissions")
+ }
+
+ fi, err := root.Lstat(ctx, test.path)
+
+ if test.expectedIsNotExist {
+ require.Equal(t, test.expectedIsNotExist, os.IsNotExist(err), "IsNotExist")
+ return
+ }
+
+ if test.expectedInvalidPath {
+ require.IsType(t, &invalidPathError{}, err, "InvalidPath")
+ return
+ }
+
+ require.NoError(t, err, "Lstat")
+ require.Equal(t, test.modeType, fi.Mode()&os.ModeType, "file mode: type")
+ if test.modePerm > 0 {
+ require.Equal(t, test.modePerm, fi.Mode()&os.ModePerm, "file mode: permissions")
+ }
+ })
+ }
+}
+
+func TestOpen(t *testing.T) {
+ ctx := context.Background()
+ root, err := localVFS.Root(ctx, ".")
+ require.NoError(t, err)
+
+ tests := map[string]struct {
+ path string
+ expectedInvalidPath bool
+ expectedIsNotExist bool
+ expectedContent string
+ expectedErr string
+ }{
+ "a file": {
+ path: "testdata/file",
+ expectedContent: "hello\n",
+ },
+ "a directory": {
+ path: "testdata",
+ expectedErr: errNotFile.Error(),
+ },
+ "a link": {
+ path: "testdata/link",
+ expectedErr: "too many levels of symbolic links",
+ },
+ "a path outside of root directory": {
+ path: "testdata/../../link",
+ expectedInvalidPath: true,
+ },
+ "a non-existing file": {
+ path: "non-existing",
+ expectedIsNotExist: true,
+ },
+ }
+
+ for name, test := range tests {
+ t.Run(name, func(t *testing.T) {
+ file, err := root.Open(ctx, test.path)
+ if file != nil {
+ defer file.Close()
+ }
+
+ if test.expectedIsNotExist {
+ require.Equal(t, test.expectedIsNotExist, os.IsNotExist(err), "IsNotExist")
+ return
+ }
+
+ if test.expectedErr != "" {
+ require.Error(t, err, "Open")
+ require.Contains(t, err.Error(), test.expectedErr, "Open")
+ return
+ }
+
+ if test.expectedInvalidPath {
+ require.IsType(t, &invalidPathError{}, err, "InvalidPath")
+ return
+ }
+
+ require.NoError(t, err, "Open")
+
+ data, err := ioutil.ReadAll(file)
+ require.NoError(t, err, "ReadAll")
+ require.Equal(t, test.expectedContent, string(data), "ReadAll")
+ })
+ }
+}
diff --git a/internal/vfs/local/vfs.go b/internal/vfs/local/vfs.go
index 7c6f3ba6..a2eb14f7 100644
--- a/internal/vfs/local/vfs.go
+++ b/internal/vfs/local/vfs.go
@@ -2,18 +2,36 @@ package local
import (
"context"
+ "errors"
"os"
-
- "golang.org/x/sys/unix"
+ "path/filepath"
"gitlab.com/gitlab-org/gitlab-pages/internal/vfs"
)
+var errNotDirectory = errors.New("path needs to be a directory")
+
type VFS struct{}
-func (fs VFS) Lstat(ctx context.Context, name string) (os.FileInfo, error) { return os.Lstat(name) }
-func (fs VFS) Readlink(ctx context.Context, name string) (string, error) { return os.Readlink(name) }
+func (fs VFS) Root(ctx context.Context, path string) (vfs.Root, error) {
+ rootPath, err := filepath.Abs(path)
+ if err != nil {
+ return nil, err
+ }
+
+ rootPath, err = filepath.EvalSymlinks(rootPath)
+ if err != nil {
+ return nil, err
+ }
+
+ fi, err := os.Lstat(rootPath)
+ if err != nil {
+ return nil, err
+ }
+
+ if !fi.Mode().IsDir() {
+ return nil, errNotDirectory
+ }
-func (fs VFS) Open(ctx context.Context, name string) (vfs.File, error) {
- return os.OpenFile(name, os.O_RDONLY|unix.O_NOFOLLOW, 0)
+ return &Root{rootPath: rootPath}, nil
}
diff --git a/internal/vfs/local/vfs_test.go b/internal/vfs/local/vfs_test.go
new file mode 100644
index 00000000..6ceb08a5
--- /dev/null
+++ b/internal/vfs/local/vfs_test.go
@@ -0,0 +1,103 @@
+package local
+
+import (
+ "context"
+ "io/ioutil"
+ "os"
+ "path/filepath"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+var localVFS = &VFS{}
+
+func TestVFSRoot(t *testing.T) {
+ // create structure as:
+ // /tmp/dir: directory
+ // /tmp/dir_link: symlink to `dir`
+ // /tmp/dir_absolute_link: symlink to `/tmp/dir`
+ // /tmp/file: file
+ // /tmp/file_link: symlink to `file`
+ // /tmp/file_absolute_link: symlink to `/tmp/file`
+ tmpDir, err := ioutil.TempDir("", "vfs")
+ require.NoError(t, err)
+ defer os.RemoveAll(tmpDir)
+
+ dirPath := filepath.Join(tmpDir, "dir")
+ err = os.Mkdir(dirPath, 0755)
+ require.NoError(t, err)
+
+ filePath := filepath.Join(tmpDir, "file")
+ err = ioutil.WriteFile(filePath, []byte{}, 0644)
+ require.NoError(t, err)
+
+ symlinks := map[string]string{
+ "dir_link": "dir",
+ "dir_absolute_link": dirPath,
+ "file_link": "file",
+ "file_absolute_link": filePath,
+ }
+
+ for dest, src := range symlinks {
+ err := os.Symlink(src, filepath.Join(tmpDir, dest))
+ require.NoError(t, err)
+ }
+
+ tests := map[string]struct {
+ path string
+ expectedPath string
+ expectedErr error
+ expectedIsNotExist bool
+ }{
+ "a valid directory": {
+ path: "dir",
+ expectedPath: dirPath,
+ },
+ "a symlink to directory": {
+ path: "dir_link",
+ expectedPath: dirPath,
+ },
+ "a symlink to absolute directory": {
+ path: "dir_absolute_link",
+ expectedPath: dirPath,
+ },
+ "a file": {
+ path: "file",
+ expectedErr: errNotDirectory,
+ },
+ "a symlink to file": {
+ path: "file_link",
+ expectedErr: errNotDirectory,
+ },
+ "a symlink to absolute file": {
+ path: "file_absolute_link",
+ expectedErr: errNotDirectory,
+ },
+ "a non-existing file": {
+ path: "not-existing",
+ expectedIsNotExist: true,
+ },
+ }
+
+ for name, test := range tests {
+ t.Run(name, func(t *testing.T) {
+ rootVFS, err := localVFS.Root(context.Background(), filepath.Join(tmpDir, test.path))
+
+ if test.expectedIsNotExist {
+ require.Equal(t, test.expectedIsNotExist, os.IsNotExist(err))
+ return
+ }
+
+ if test.expectedErr != nil {
+ require.EqualError(t, err, test.expectedErr.Error())
+ return
+ }
+
+ require.NoError(t, err)
+ require.IsType(t, &Root{}, rootVFS)
+ assert.Equal(t, test.expectedPath, rootVFS.(*Root).rootPath)
+ })
+ }
+}
diff --git a/internal/vfs/root.go b/internal/vfs/root.go
new file mode 100644
index 00000000..30d97b0b
--- /dev/null
+++ b/internal/vfs/root.go
@@ -0,0 +1,69 @@
+package vfs
+
+import (
+ "context"
+ "os"
+ "strconv"
+
+ log "github.com/sirupsen/logrus"
+
+ "gitlab.com/gitlab-org/gitlab-pages/metrics"
+)
+
+// Root abstracts the things Pages needs to serve a static site from a given root rootPath.
+type Root interface {
+ Lstat(ctx context.Context, name string) (os.FileInfo, error)
+ Readlink(ctx context.Context, name string) (string, error)
+ Open(ctx context.Context, name string) (File, error)
+}
+
+type instrumentedRoot struct {
+ root Root
+ name string
+ rootPath string
+}
+
+func (i *instrumentedRoot) increment(operation string, err error) {
+ metrics.VFSOperations.WithLabelValues(i.name, operation, strconv.FormatBool(err == nil)).Inc()
+}
+
+func (i *instrumentedRoot) log() *log.Entry {
+ return log.WithField("vfs", i.name).WithField("root-path", i.rootPath)
+}
+
+func (i *instrumentedRoot) Lstat(ctx context.Context, name string) (os.FileInfo, error) {
+ fi, err := i.root.Lstat(ctx, name)
+
+ i.increment("Lstat", err)
+ i.log().
+ WithField("name", name).
+ WithError(err).
+ Traceln("Lstat call")
+
+ return fi, err
+}
+
+func (i *instrumentedRoot) Readlink(ctx context.Context, name string) (string, error) {
+ target, err := i.root.Readlink(ctx, name)
+
+ i.increment("Readlink", err)
+ i.log().
+ WithField("name", name).
+ WithField("ret-target", target).
+ WithError(err).
+ Traceln("Readlink call")
+
+ return target, err
+}
+
+func (i *instrumentedRoot) Open(ctx context.Context, name string) (File, error) {
+ f, err := i.root.Open(ctx, name)
+
+ i.increment("Open", err)
+ i.log().
+ WithField("name", name).
+ WithError(err).
+ Traceln("Open call")
+
+ return f, err
+}
diff --git a/internal/vfs/vfs.go b/internal/vfs/vfs.go
index 07c99b77..9d9551a1 100644
--- a/internal/vfs/vfs.go
+++ b/internal/vfs/vfs.go
@@ -2,54 +2,47 @@ package vfs
import (
"context"
- "io"
- "os"
"strconv"
+ log "github.com/sirupsen/logrus"
+
"gitlab.com/gitlab-org/gitlab-pages/metrics"
)
// VFS abstracts the things Pages needs to serve a static site from disk.
type VFS interface {
- Lstat(ctx context.Context, name string) (os.FileInfo, error)
- Readlink(ctx context.Context, name string) (string, error)
- Open(ctx context.Context, name string) (File, error)
-}
-
-// File represents an open file, which will typically be the response body of a Pages request.
-type File interface {
- io.Reader
- io.Seeker
- io.Closer
+ Root(ctx context.Context, path string) (Root, error)
}
func Instrumented(fs VFS, name string) VFS {
- return &InstrumentedVFS{fs: fs, name: name}
+ return &instrumentedVFS{fs: fs, name: name}
}
-type InstrumentedVFS struct {
+type instrumentedVFS struct {
fs VFS
name string
}
-func (i *InstrumentedVFS) increment(operation string, err error) {
+func (i *instrumentedVFS) increment(operation string, err error) {
metrics.VFSOperations.WithLabelValues(i.name, operation, strconv.FormatBool(err == nil)).Inc()
}
-func (i *InstrumentedVFS) Lstat(ctx context.Context, name string) (os.FileInfo, error) {
- fi, err := i.fs.Lstat(ctx, name)
- i.increment("Lstat", err)
- return fi, err
+func (i *instrumentedVFS) log() *log.Entry {
+ return log.WithField("vfs", i.name)
}
-func (i *InstrumentedVFS) Readlink(ctx context.Context, name string) (string, error) {
- target, err := i.fs.Readlink(ctx, name)
- i.increment("Readlink", err)
- return target, err
-}
+func (i *instrumentedVFS) Root(ctx context.Context, path string) (Root, error) {
+ root, err := i.fs.Root(ctx, path)
+
+ i.increment("Root", err)
+ i.log().
+ WithField("path", path).
+ WithError(err).
+ Traceln("Root call")
+
+ if err != nil {
+ return nil, err
+ }
-func (i *InstrumentedVFS) Open(ctx context.Context, name string) (File, error) {
- f, err := i.fs.Open(ctx, name)
- i.increment("Open", err)
- return f, err
+ return &instrumentedRoot{root: root, name: i.name, rootPath: path}, nil
}