diff options
author | Alejandro Rodríguez <alejorro70@gmail.com> | 2017-02-28 17:53:02 +0300 |
---|---|---|
committer | Alejandro Rodríguez <alejorro70@gmail.com> | 2017-03-02 23:42:24 +0300 |
commit | ad5e5732498dfc6b275f38ac16846242b67b6425 (patch) | |
tree | 66b3548e08874b3477e5cfe7a2db52700b6431bb | |
parent | 262f60a19ed2f477be91a700a34f85d2ff0fd474 (diff) |
Implement refs operations
-rw-r--r-- | internal/service/ref/refs.go | 178 | ||||
-rw-r--r-- | internal/service/ref/refs_test.go | 307 | ||||
-rw-r--r-- | internal/service/ref/server.go | 14 | ||||
-rw-r--r-- | internal/service/ref/util.go | 80 | ||||
-rw-r--r-- | internal/service/register.go | 2 |
5 files changed, 581 insertions, 0 deletions
diff --git a/internal/service/ref/refs.go b/internal/service/ref/refs.go new file mode 100644 index 000000000..1097d24d1 --- /dev/null +++ b/internal/service/ref/refs.go @@ -0,0 +1,178 @@ +package ref + +import ( + "bufio" + "bytes" + "fmt" + "io" + "log" + + pb "gitlab.com/gitlab-org/gitaly-proto/go" + "gitlab.com/gitlab-org/gitaly/internal/helper" + "golang.org/x/net/context" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" +) + +var ( + master = []byte("refs/heads/master") + // We declare the following functions in variables so that we can override them in our tests + findBranchNames = _findBranchNames + headReference = _headReference +) + +func handleGitCommand(w refNamesWriter, r io.Reader) error { + scanner := bufio.NewScanner(r) + for scanner.Scan() { + if err := w.AddRef(scanner.Bytes()); err != nil { + return err + } + } + if err := scanner.Err(); err != nil { + return err + } + return w.Flush() +} + +func findRefs(writer refNamesWriter, repo *pb.Repository, pattern string) error { + if repo == nil { + message := "Bad Request (empty repository)" + log.Printf("FindRefs: %q", message) + return grpc.Errorf(codes.InvalidArgument, message) + } + + repoPath := repo.Path + + log.Printf("FindRefs: RepoPath=%q Pattern=%q", repoPath, pattern) + + cmd, err := helper.GitCommand("--git-dir", repoPath, "for-each-ref", pattern, "--format=%(refname)") + if err != nil { + return err + } + defer cmd.Kill() + + handleGitCommand(writer, cmd) + + return cmd.Wait() +} + +// FindAllBranchNames creates a stream of ref names for all branches in the given repository +func (s *server) FindAllBranchNames(in *pb.FindAllBranchNamesRequest, stream pb.Ref_FindAllBranchNamesServer) error { + return findRefs(newFindAllBranchNamesWriter(stream, s.MaxMsgSize), in.Repository, "refs/heads") +} + +// FindAllTagNames creates a stream of ref names for all tags in the given repository +func (s *server) FindAllTagNames(in *pb.FindAllTagNamesRequest, stream pb.Ref_FindAllTagNamesServer) error { + return findRefs(newFindAllTagNamesWriter(stream, s.MaxMsgSize), in.Repository, "refs/tags") +} + +func _findBranchNames(repoPath string) ([][]byte, error) { + var names [][]byte + + cmd, err := helper.GitCommand("--git-dir", repoPath, "for-each-ref", "refs/heads", "--format=%(refname)") + if err != nil { + return nil, err + } + defer cmd.Kill() + + scanner := bufio.NewScanner(cmd) + for scanner.Scan() { + names, _ = appendRef(names, scanner.Bytes()) + } + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("reading standard input: %v", err) + } + + if err := cmd.Wait(); err != nil { + return nil, err + } + + return names, nil +} + +func _headReference(repoPath string) ([]byte, error) { + var headRef []byte + + cmd, err := helper.GitCommand("--git-dir", repoPath, "rev-parse", "--symbolic-full-name", "HEAD") + if err != nil { + return nil, err + } + defer cmd.Kill() + + scanner := bufio.NewScanner(cmd) + scanner.Scan() + if err := scanner.Err(); err != nil { + return nil, err + } + headRef = scanner.Bytes() + + if err := cmd.Wait(); err != nil { + return nil, err + } + + return headRef, nil +} + +func defaultBranchName(repoPath string) ([]byte, error) { + branches, err := findBranchNames(repoPath) + + if err != nil { + return nil, err + } + + // Return empty ref name if there are no branches + if len(branches) == 0 { + return nil, nil + } + + // Return first branch name if there's only one + if len(branches) == 1 { + return branches[0], nil + } + + hasMaster := false + headRef, err := headReference(repoPath) + if err != nil { + return nil, err + } + for _, branch := range branches { + // Return HEAD if it corresponds to a branch + if bytes.Equal(headRef, branch) { + return headRef, nil + } + if bytes.Equal(branch, master) { + hasMaster = true + } + } + // Return `ref/names/master` if it exists + if hasMaster { + return master, nil + } + // If all else fails, return the first branch name + return branches[0], nil +} + +// FindDefaultBranchName returns the default branch name for the given repository +func (s *server) FindDefaultBranchName(ctx context.Context, in *pb.FindDefaultBranchNameRequest) (*pb.FindDefaultBranchNameResponse, error) { + if in.Repository == nil { + message := "Bad Request (empty repository)" + log.Printf("FindDefaultBranchName: %q", message) + return nil, grpc.Errorf(codes.InvalidArgument, message) + } + + repoPath := in.Repository.Path + + log.Printf("FindDefaultBranchName: RepoPath=%q", repoPath) + + defaultBranchName, err := defaultBranchName(repoPath) + if err != nil { + return nil, err + } + + return &pb.FindDefaultBranchNameResponse{Name: defaultBranchName}, nil +} + +// FindRefName returns the first refname of a Repository +func (s *server) FindRefName(ctx context.Context, in *pb.FindRefNameRequest) (*pb.FindRefNameResponse, error) { + return nil, nil +} diff --git a/internal/service/ref/refs_test.go b/internal/service/ref/refs_test.go new file mode 100644 index 000000000..03873302c --- /dev/null +++ b/internal/service/ref/refs_test.go @@ -0,0 +1,307 @@ +package ref + +import ( + "bytes" + "io" + "log" + "net" + "os" + "os/exec" + "path" + "testing" + "time" + + pb "gitlab.com/gitlab-org/gitaly-proto/go" + "golang.org/x/net/context" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/reflection" +) + +const ( + scratchDir = "testdata/scratch" + testRepoRoot = "testdata/data" + testRepo = "group/test.git" +) + +var serverSocketPath = path.Join(scratchDir, "gitaly.sock") + +func containsRef(refs [][]byte, ref string) bool { + for _, b := range refs { + if string(b) == ref { + return true + } + } + return false +} + +func TestMain(m *testing.M) { + source := "https://gitlab.com/gitlab-org/gitlab-test.git" + clonePath := path.Join(testRepoRoot, testRepo) + if _, err := os.Stat(clonePath); err != nil { + testCmd := exec.Command("git", "clone", "--bare", source, clonePath) + testCmd.Stdout = os.Stdout + testCmd.Stderr = os.Stderr + + if err := testCmd.Run(); err != nil { + log.Printf("Test setup: failed to run %v", testCmd) + os.Exit(-1) + } + } + + if err := os.MkdirAll(scratchDir, 0755); err != nil { + log.Fatal(err) + } + + os.Exit(func() int { + return m.Run() + }()) +} + +func TestSuccessfulFindAllBranchNames(t *testing.T) { + server := runRefServer(t) + defer server.Stop() + + client := newRefClient(t) + repo := &pb.Repository{Path: path.Join(testRepoRoot, testRepo)} + rpcRequest := &pb.FindAllBranchNamesRequest{Repository: repo} + + c, err := client.FindAllBranchNames(context.Background(), rpcRequest) + if err != nil { + t.Fatal(err) + } + + var names [][]byte + for { + r, err := c.Recv() + if err == io.EOF { + break + } + if err != nil { + t.Fatal(err) + } + names = append(names, r.GetNames()...) + } + + for _, branch := range []string{"master", "100%branch", "improve/awesome", "'test'"} { + if !containsRef(names, "refs/heads/"+branch) { + t.Fatalf("Expected to find branch %q in all branch names", branch) + } + } +} + +func TestEmptyFindAllBranchNamesRequest(t *testing.T) { + server := runRefServer(t) + defer server.Stop() + + client := newRefClient(t) + rpcRequest := &pb.FindAllBranchNamesRequest{} + + c, err := client.FindAllBranchNames(context.Background(), rpcRequest) + if err != nil { + t.Fatal(err) + } + + var recvError error + for recvError == nil { + _, recvError = c.Recv() + } + + if grpc.Code(recvError) != codes.InvalidArgument { + t.Fatal(recvError) + } +} + +func TestSuccessfulFindAllTagNames(t *testing.T) { + server := runRefServer(t) + defer server.Stop() + + client := newRefClient(t) + repo := &pb.Repository{Path: path.Join(testRepoRoot, testRepo)} + rpcRequest := &pb.FindAllTagNamesRequest{Repository: repo} + + c, err := client.FindAllTagNames(context.Background(), rpcRequest) + if err != nil { + t.Fatal(err) + } + + var names [][]byte + for { + r, err := c.Recv() + if err == io.EOF { + break + } + if err != nil { + t.Fatal(err) + } + names = append(names, r.GetNames()...) + } + + for _, tag := range []string{"v1.0.0", "v1.1.0"} { + if !containsRef(names, "refs/tags/"+tag) { + t.Fatal("Expected to find tag", tag, "in all tag names") + } + } +} + +func TestEmptyFindAllTagNamesRequest(t *testing.T) { + server := runRefServer(t) + defer server.Stop() + + client := newRefClient(t) + rpcRequest := &pb.FindAllTagNamesRequest{} + + c, err := client.FindAllTagNames(context.Background(), rpcRequest) + if err != nil { + t.Fatal(err) + } + + var recvError error + for recvError == nil { + _, recvError = c.Recv() + } + + if grpc.Code(recvError) != codes.InvalidArgument { + t.Fatal(recvError) + } +} + +func TestHeadReference(t *testing.T) { + headRef, err := headReference(path.Join(testRepoRoot, testRepo)) + if err != nil { + t.Fatal(err) + } + if string(headRef) != "refs/heads/master" { + t.Fatal("Expected HEAD reference to be 'ref/heads/master', got '", string(headRef), "'") + } +} + +func TestDefaultBranchName(t *testing.T) { + // We are going to override these functions during this test. Restore them after we're done + defer func() { + findBranchNames = _findBranchNames + headReference = _headReference + }() + + testCases := []struct { + desc string + findBranchNames func(string) ([][]byte, error) + headReference func(string) ([]byte, error) + expected []byte + }{ + { + desc: "Get first branch when only one branch exists", + expected: []byte("refs/heads/foo"), + findBranchNames: func(string) ([][]byte, error) { + return [][]byte{[]byte("refs/heads/foo")}, nil + }, + headReference: func(string) ([]byte, error) { return nil, nil }, + }, + { + desc: "Get empy ref if no branches exists", + expected: nil, + findBranchNames: func(string) ([][]byte, error) { return [][]byte{}, nil }, + headReference: func(string) ([]byte, error) { return nil, nil }, + }, + { + desc: "Get the name of the head reference when more than one branch exists", + expected: []byte("refs/heads/bar"), + findBranchNames: func(string) ([][]byte, error) { + return [][]byte{[]byte("refs/heads/foo"), []byte("refs/heads/bar")}, nil + }, + headReference: func(string) ([]byte, error) { return []byte("refs/heads/bar"), nil }, + }, + { + desc: "Get `ref/heads/master` when several branches exist", + expected: []byte("refs/heads/master"), + findBranchNames: func(string) ([][]byte, error) { + return [][]byte{[]byte("refs/heads/foo"), []byte("refs/heads/master"), []byte("refs/heads/bar")}, nil + }, + headReference: func(string) ([]byte, error) { return nil, nil }, + }, + { + desc: "Get the name of the first branch when several branches exists and no other conditions are met", + expected: []byte("refs/heads/foo"), + findBranchNames: func(string) ([][]byte, error) { + return [][]byte{[]byte("refs/heads/foo"), []byte("refs/heads/bar"), []byte("refs/heads/baz")}, nil + }, + headReference: func(string) ([]byte, error) { return nil, nil }, + }, + } + + for _, testCase := range testCases { + findBranchNames = testCase.findBranchNames + headReference = testCase.headReference + + defaultBranch, err := defaultBranchName("") + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(defaultBranch, testCase.expected) { + t.Fatalf("%s: expected %s, got %s instead", testCase.desc, testCase.expected, defaultBranch) + } + } +} + +func TestSuccessfulFindDefaultBranchName(t *testing.T) { + server := runRefServer(t) + defer server.Stop() + + client := newRefClient(t) + repo := &pb.Repository{Path: path.Join(testRepoRoot, testRepo)} + rpcRequest := &pb.FindDefaultBranchNameRequest{Repository: repo} + + r, err := client.FindDefaultBranchName(context.Background(), rpcRequest) + if err != nil { + t.Fatal(err) + } + + if name := r.GetName(); string(name) != "refs/heads/master" { + t.Fatal("Expected HEAD reference to be 'ref/heads/master', got '", string(name), "'") + } +} + +func TestEmptyFindDefaultBranchNameRequest(t *testing.T) { + server := runRefServer(t) + defer server.Stop() + + client := newRefClient(t) + rpcRequest := &pb.FindDefaultBranchNameRequest{} + + _, err := client.FindDefaultBranchName(context.Background(), rpcRequest) + + if grpc.Code(err) != codes.InvalidArgument { + t.Fatal(err) + } +} + +func runRefServer(t *testing.T) *grpc.Server { + grpcServer := grpc.NewServer() + listener, err := net.Listen("unix", serverSocketPath) + if err != nil { + t.Fatal(err) + } + + // Use 100 bytes as the maximum message size to test that fragmenting the ref list works correctly + pb.RegisterRefServer(grpcServer, &server{MaxMsgSize: 100}) + reflection.Register(grpcServer) + + go grpcServer.Serve(listener) + + return grpcServer +} + +func newRefClient(t *testing.T) pb.RefClient { + connOpts := []grpc.DialOption{ + grpc.WithInsecure(), + grpc.WithDialer(func(addr string, _ time.Duration) (net.Conn, error) { + return net.Dial("unix", addr) + }), + } + conn, err := grpc.Dial(serverSocketPath, connOpts...) + if err != nil { + t.Fatal(err) + } + + return pb.NewRefClient(conn) +} diff --git a/internal/service/ref/server.go b/internal/service/ref/server.go new file mode 100644 index 000000000..22e377890 --- /dev/null +++ b/internal/service/ref/server.go @@ -0,0 +1,14 @@ +package ref + +import pb "gitlab.com/gitlab-org/gitaly-proto/go" + +const maxMsgSize = 1024 + +type server struct { + MaxMsgSize int +} + +// NewServer creates a new instance of a grpc RefServer +func NewServer() pb.RefServer { + return &server{MaxMsgSize: maxMsgSize} +} diff --git a/internal/service/ref/util.go b/internal/service/ref/util.go new file mode 100644 index 000000000..74d7527e6 --- /dev/null +++ b/internal/service/ref/util.go @@ -0,0 +1,80 @@ +package ref + +import ( + pb "gitlab.com/gitlab-org/gitaly-proto/go" +) + +func appendRef(refs [][]byte, p []byte) ([][]byte, int) { + ref := make([]byte, len(p)) + size := copy(ref, p) + return append(refs, ref), size +} + +type refNamesSender interface { + sendRefs([][]byte) error +} + +type refNamesWriter struct { + refNamesSender + MaxMsgSize int + refsSize int + refs [][]byte +} + +func (w *refNamesWriter) Flush() error { + if len(w.refs) == 0 { // No message to send, just return + return nil + } + + if err := w.refNamesSender.sendRefs(w.refs); err != nil { + return err + } + + // Reset the message + w.refs = nil + w.refsSize = 0 + + return nil +} + +func (w *refNamesWriter) AddRef(p []byte) error { + refs, size := appendRef(w.refs, p) + w.refsSize += size + w.refs = refs + + if w.refsSize > w.MaxMsgSize { + return w.Flush() + } + + return nil +} + +type branchesSender struct { + stream pb.Ref_FindAllBranchNamesServer +} + +func (w branchesSender) sendRefs(refs [][]byte) error { + return w.stream.Send(&pb.FindAllBranchNamesResponse{Names: refs}) +} + +type tagsSender struct { + stream pb.Ref_FindAllTagNamesServer +} + +func (w tagsSender) sendRefs(refs [][]byte) error { + return w.stream.Send(&pb.FindAllTagNamesResponse{Names: refs}) +} + +func newFindAllBranchNamesWriter(stream pb.Ref_FindAllBranchNamesServer, maxMsgSize int) refNamesWriter { + return refNamesWriter{ + refNamesSender: branchesSender{stream}, + MaxMsgSize: maxMsgSize, + } +} + +func newFindAllTagNamesWriter(stream pb.Ref_FindAllTagNamesServer, maxMsgSize int) refNamesWriter { + return refNamesWriter{ + refNamesSender: tagsSender{stream}, + MaxMsgSize: maxMsgSize, + } +} diff --git a/internal/service/register.go b/internal/service/register.go index eab18803f..a6388e1ff 100644 --- a/internal/service/register.go +++ b/internal/service/register.go @@ -3,6 +3,7 @@ package service import ( pb "gitlab.com/gitlab-org/gitaly-proto/go" "gitlab.com/gitlab-org/gitaly/internal/service/notifications" + "gitlab.com/gitlab-org/gitaly/internal/service/ref" "gitlab.com/gitlab-org/gitaly/internal/service/smarthttp" "google.golang.org/grpc" @@ -12,5 +13,6 @@ import ( // the specified grpc service instance func RegisterAll(grpcServer *grpc.Server) { pb.RegisterNotificationsServer(grpcServer, notifications.NewServer()) + pb.RegisterRefServer(grpcServer, ref.NewServer()) pb.RegisterSmartHTTPServer(grpcServer, smarthttp.NewServer()) } |