diff options
Diffstat (limited to 'workhorse/internal/testhelper')
-rw-r--r-- | workhorse/internal/testhelper/gitaly.go | 384 | ||||
-rw-r--r-- | workhorse/internal/testhelper/testhelper.go | 152 |
2 files changed, 536 insertions, 0 deletions
diff --git a/workhorse/internal/testhelper/gitaly.go b/workhorse/internal/testhelper/gitaly.go new file mode 100644 index 00000000000..24884505440 --- /dev/null +++ b/workhorse/internal/testhelper/gitaly.go @@ -0,0 +1,384 @@ +package testhelper + +import ( + "fmt" + "io" + "io/ioutil" + "path" + "strings" + "sync" + + "github.com/golang/protobuf/jsonpb" //lint:ignore SA1019 https://gitlab.com/gitlab-org/gitlab-workhorse/-/issues/274 + "github.com/golang/protobuf/proto" //lint:ignore SA1019 https://gitlab.com/gitlab-org/gitlab-workhorse/-/issues/274 + "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb" + "gitlab.com/gitlab-org/labkit/log" + "golang.org/x/net/context" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" +) + +type GitalyTestServer struct { + finalMessageCode codes.Code + sync.WaitGroup + LastIncomingMetadata metadata.MD + gitalypb.UnimplementedRepositoryServiceServer + gitalypb.UnimplementedBlobServiceServer + gitalypb.UnimplementedDiffServiceServer +} + +var ( + GitalyInfoRefsResponseMock = strings.Repeat("Mock Gitaly InfoRefsResponse data", 100000) + GitalyGetBlobResponseMock = strings.Repeat("Mock Gitaly GetBlobResponse data", 100000) + GitalyGetArchiveResponseMock = strings.Repeat("Mock Gitaly GetArchiveResponse data", 100000) + GitalyGetDiffResponseMock = strings.Repeat("Mock Gitaly GetDiffResponse data", 100000) + GitalyGetPatchResponseMock = strings.Repeat("Mock Gitaly GetPatchResponse data", 100000) + + GitalyGetSnapshotResponseMock = strings.Repeat("Mock Gitaly GetSnapshotResponse data", 100000) + + GitalyReceivePackResponseMock []byte + GitalyUploadPackResponseMock []byte +) + +func init() { + var err error + if GitalyReceivePackResponseMock, err = ioutil.ReadFile(path.Join(RootDir(), "testdata/receive-pack-fixture.txt")); err != nil { + log.WithError(err).Fatal("Unable to read pack response") + } + if GitalyUploadPackResponseMock, err = ioutil.ReadFile(path.Join(RootDir(), "testdata/upload-pack-fixture.txt")); err != nil { + log.WithError(err).Fatal("Unable to read pack response") + } +} + +func NewGitalyServer(finalMessageCode codes.Code) *GitalyTestServer { + return &GitalyTestServer{finalMessageCode: finalMessageCode} +} + +func (s *GitalyTestServer) InfoRefsUploadPack(in *gitalypb.InfoRefsRequest, stream gitalypb.SmartHTTPService_InfoRefsUploadPackServer) error { + s.WaitGroup.Add(1) + defer s.WaitGroup.Done() + + if err := validateRepository(in.GetRepository()); err != nil { + return err + } + + fmt.Printf("Result: %+v\n", in) + + marshaler := &jsonpb.Marshaler{} + jsonString, err := marshaler.MarshalToString(in) + if err != nil { + return err + } + + data := []byte(strings.Join([]string{ + jsonString, + "git-upload-pack", + GitalyInfoRefsResponseMock, + }, "\000")) + + s.LastIncomingMetadata = nil + if md, ok := metadata.FromIncomingContext(stream.Context()); ok { + s.LastIncomingMetadata = md + } + + return s.sendInfoRefs(stream, data) +} + +func (s *GitalyTestServer) InfoRefsReceivePack(in *gitalypb.InfoRefsRequest, stream gitalypb.SmartHTTPService_InfoRefsReceivePackServer) error { + s.WaitGroup.Add(1) + defer s.WaitGroup.Done() + + if err := validateRepository(in.GetRepository()); err != nil { + return err + } + + fmt.Printf("Result: %+v\n", in) + + jsonString, err := marshalJSON(in) + if err != nil { + return err + } + + data := []byte(strings.Join([]string{ + jsonString, + "git-receive-pack", + GitalyInfoRefsResponseMock, + }, "\000")) + + return s.sendInfoRefs(stream, data) +} + +func marshalJSON(msg proto.Message) (string, error) { + marshaler := &jsonpb.Marshaler{} + return marshaler.MarshalToString(msg) +} + +type infoRefsSender interface { + Send(*gitalypb.InfoRefsResponse) error +} + +func (s *GitalyTestServer) sendInfoRefs(stream infoRefsSender, data []byte) error { + nSends, err := sendBytes(data, 100, func(p []byte) error { + return stream.Send(&gitalypb.InfoRefsResponse{Data: p}) + }) + if err != nil { + return err + } + if nSends <= 1 { + panic("should have sent more than one message") + } + + return s.finalError() +} + +func (s *GitalyTestServer) PostReceivePack(stream gitalypb.SmartHTTPService_PostReceivePackServer) error { + s.WaitGroup.Add(1) + defer s.WaitGroup.Done() + + req, err := stream.Recv() + if err != nil { + return err + } + + repo := req.GetRepository() + if err := validateRepository(repo); err != nil { + return err + } + + jsonString, err := marshalJSON(req) + if err != nil { + return err + } + + data := []byte(jsonString + "\000") + + // The body of the request starts in the second message + for { + req, err := stream.Recv() + if err != nil { + if err != io.EOF { + return err + } + break + } + + // We want to echo the request data back + data = append(data, req.GetData()...) + } + + nSends, _ := sendBytes(data, 100, func(p []byte) error { + return stream.Send(&gitalypb.PostReceivePackResponse{Data: p}) + }) + + if nSends <= 1 { + panic("should have sent more than one message") + } + + return s.finalError() +} + +func (s *GitalyTestServer) PostUploadPack(stream gitalypb.SmartHTTPService_PostUploadPackServer) error { + s.WaitGroup.Add(1) + defer s.WaitGroup.Done() + + req, err := stream.Recv() + if err != nil { + return err + } + + if err := validateRepository(req.GetRepository()); err != nil { + return err + } + + jsonString, err := marshalJSON(req) + if err != nil { + return err + } + + if err := stream.Send(&gitalypb.PostUploadPackResponse{ + Data: []byte(strings.Join([]string{jsonString}, "\000") + "\000"), + }); err != nil { + return err + } + + nSends := 0 + // The body of the request starts in the second message. Gitaly streams PostUploadPack responses + // as soon as possible without reading the request completely first. We stream messages here + // directly back to the client to simulate the streaming of the actual implementation. + for { + req, err := stream.Recv() + if err != nil { + if err != io.EOF { + return err + } + break + } + + if err := stream.Send(&gitalypb.PostUploadPackResponse{Data: req.GetData()}); err != nil { + return err + } + + nSends++ + } + + if nSends <= 1 { + panic("should have sent more than one message") + } + + return s.finalError() +} + +func (s *GitalyTestServer) CommitIsAncestor(ctx context.Context, in *gitalypb.CommitIsAncestorRequest) (*gitalypb.CommitIsAncestorResponse, error) { + return nil, nil +} + +func (s *GitalyTestServer) GetBlob(in *gitalypb.GetBlobRequest, stream gitalypb.BlobService_GetBlobServer) error { + s.WaitGroup.Add(1) + defer s.WaitGroup.Done() + + if err := validateRepository(in.GetRepository()); err != nil { + return err + } + + response := &gitalypb.GetBlobResponse{ + Oid: in.GetOid(), + Size: int64(len(GitalyGetBlobResponseMock)), + } + nSends, err := sendBytes([]byte(GitalyGetBlobResponseMock), 100, func(p []byte) error { + response.Data = p + + if err := stream.Send(response); err != nil { + return err + } + + // Use a new response so we don't send other fields (Size, ...) over and over + response = &gitalypb.GetBlobResponse{} + + return nil + }) + if err != nil { + return err + } + if nSends <= 1 { + panic("should have sent more than one message") + } + + return s.finalError() +} + +func (s *GitalyTestServer) GetArchive(in *gitalypb.GetArchiveRequest, stream gitalypb.RepositoryService_GetArchiveServer) error { + s.WaitGroup.Add(1) + defer s.WaitGroup.Done() + + if err := validateRepository(in.GetRepository()); err != nil { + return err + } + + nSends, err := sendBytes([]byte(GitalyGetArchiveResponseMock), 100, func(p []byte) error { + return stream.Send(&gitalypb.GetArchiveResponse{Data: p}) + }) + if err != nil { + return err + } + if nSends <= 1 { + panic("should have sent more than one message") + } + + return s.finalError() +} + +func (s *GitalyTestServer) RawDiff(in *gitalypb.RawDiffRequest, stream gitalypb.DiffService_RawDiffServer) error { + nSends, err := sendBytes([]byte(GitalyGetDiffResponseMock), 100, func(p []byte) error { + return stream.Send(&gitalypb.RawDiffResponse{ + Data: p, + }) + }) + if err != nil { + return err + } + if nSends <= 1 { + panic("should have sent more than one message") + } + + return s.finalError() +} + +func (s *GitalyTestServer) RawPatch(in *gitalypb.RawPatchRequest, stream gitalypb.DiffService_RawPatchServer) error { + s.WaitGroup.Add(1) + defer s.WaitGroup.Done() + + if err := validateRepository(in.GetRepository()); err != nil { + return err + } + + nSends, err := sendBytes([]byte(GitalyGetPatchResponseMock), 100, func(p []byte) error { + return stream.Send(&gitalypb.RawPatchResponse{ + Data: p, + }) + }) + if err != nil { + return err + } + if nSends <= 1 { + panic("should have sent more than one message") + } + + return s.finalError() +} + +func (s *GitalyTestServer) GetSnapshot(in *gitalypb.GetSnapshotRequest, stream gitalypb.RepositoryService_GetSnapshotServer) error { + s.WaitGroup.Add(1) + defer s.WaitGroup.Done() + + if err := validateRepository(in.GetRepository()); err != nil { + return err + } + + nSends, err := sendBytes([]byte(GitalyGetSnapshotResponseMock), 100, func(p []byte) error { + return stream.Send(&gitalypb.GetSnapshotResponse{Data: p}) + }) + if err != nil { + return err + } + if nSends <= 1 { + panic("should have sent more than one message") + } + + return s.finalError() +} + +// sendBytes returns the number of times the 'sender' function was called and an error. +func sendBytes(data []byte, chunkSize int, sender func([]byte) error) (int, error) { + i := 0 + for ; len(data) > 0; i++ { + n := chunkSize + if n > len(data) { + n = len(data) + } + + if err := sender(data[:n]); err != nil { + return i, err + } + data = data[n:] + } + + return i, nil +} + +func (s *GitalyTestServer) finalError() error { + if code := s.finalMessageCode; code != codes.OK { + return status.Errorf(code, "error as specified by test") + } + + return nil +} + +func validateRepository(repo *gitalypb.Repository) error { + if len(repo.GetStorageName()) == 0 { + return fmt.Errorf("missing storage_name: %v", repo) + } + if len(repo.GetRelativePath()) == 0 { + return fmt.Errorf("missing relative_path: %v", repo) + } + return nil +} diff --git a/workhorse/internal/testhelper/testhelper.go b/workhorse/internal/testhelper/testhelper.go new file mode 100644 index 00000000000..40097bd453a --- /dev/null +++ b/workhorse/internal/testhelper/testhelper.go @@ -0,0 +1,152 @@ +package testhelper + +import ( + "errors" + "fmt" + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "os" + "path" + "regexp" + "runtime" + "testing" + "time" + + "github.com/dgrijalva/jwt-go" + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/labkit/log" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/secret" +) + +func ConfigureSecret() { + secret.SetPath(path.Join(RootDir(), "testdata/test-secret")) +} + +func RequireResponseBody(t *testing.T, response *httptest.ResponseRecorder, expectedBody string) { + t.Helper() + require.Equal(t, expectedBody, response.Body.String(), "response body") +} + +func RequireResponseHeader(t *testing.T, w interface{}, header string, expected ...string) { + t.Helper() + var actual []string + + header = http.CanonicalHeaderKey(header) + type headerer interface{ Header() http.Header } + + switch resp := w.(type) { + case *http.Response: + actual = resp.Header[header] + case headerer: + actual = resp.Header()[header] + default: + t.Fatal("invalid type of w passed RequireResponseHeader") + } + + require.Equal(t, expected, actual, "values for HTTP header %s", header) +} + +func TestServerWithHandler(url *regexp.Regexp, handler http.HandlerFunc) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + logEntry := log.WithFields(log.Fields{ + "method": r.Method, + "url": r.URL, + "action": "DENY", + }) + + if url != nil && !url.MatchString(r.URL.Path) { + logEntry.Info("UPSTREAM") + w.WriteHeader(404) + return + } + + if version := r.Header.Get("Gitlab-Workhorse"); version == "" { + logEntry.Info("UPSTREAM") + w.WriteHeader(403) + return + } + + handler(w, r) + })) +} + +var workhorseExecutables = []string{"gitlab-workhorse", "gitlab-zip-cat", "gitlab-zip-metadata", "gitlab-resize-image"} + +func BuildExecutables() error { + rootDir := RootDir() + + for _, exe := range workhorseExecutables { + if _, err := os.Stat(path.Join(rootDir, exe)); os.IsNotExist(err) { + return fmt.Errorf("cannot find executable %s. Please run 'make prepare-tests'", exe) + } + } + + oldPath := os.Getenv("PATH") + testPath := fmt.Sprintf("%s:%s", rootDir, oldPath) + if err := os.Setenv("PATH", testPath); err != nil { + return fmt.Errorf("failed to set PATH to %v", testPath) + } + + return nil +} + +func RootDir() string { + _, currentFile, _, ok := runtime.Caller(0) + if !ok { + panic(errors.New("RootDir: calling runtime.Caller failed")) + } + return path.Join(path.Dir(currentFile), "../..") +} + +func LoadFile(t *testing.T, filePath string) string { + t.Helper() + content, err := ioutil.ReadFile(path.Join(RootDir(), filePath)) + require.NoError(t, err) + return string(content) +} + +func ReadAll(t *testing.T, r io.Reader) []byte { + t.Helper() + + b, err := ioutil.ReadAll(r) + require.NoError(t, err) + return b +} + +func ParseJWT(token *jwt.Token) (interface{}, error) { + // Don't forget to validate the alg is what you expect: + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + + ConfigureSecret() + secretBytes, err := secret.Bytes() + if err != nil { + return nil, fmt.Errorf("read secret from file: %v", err) + } + + return secretBytes, nil +} + +// UploadClaims represents the JWT claim for upload parameters +type UploadClaims struct { + Upload map[string]string `json:"upload"` + jwt.StandardClaims +} + +func Retry(t testing.TB, timeout time.Duration, fn func() error) { + t.Helper() + start := time.Now() + var err error + for ; time.Since(start) < timeout; time.Sleep(time.Millisecond) { + err = fn() + if err == nil { + return + } + } + t.Fatalf("test timeout after %v; last error: %v", timeout, err) +} |