Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitaly.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKim "BKC" Carlbäcker <kim.carlbacker@gmail.com>2017-05-09 05:44:25 +0300
committerKim "BKC" Carlbäcker <kim.carlbacker@gmail.com>2017-05-10 18:27:52 +0300
commit2f88d5d9ab699e9ada65fd0ac6b376009ff3ab30 (patch)
treef08fb37dd16c543fb6c0c59fc744e10584f6b5c0
parent2ce9392ae0b73cc7a68b307c96b82fa95ba9f5dc (diff)
Use pbhelper.NewSendWriter
- Use Stream-helpers from pbhelper - Handle GL_REPOSITORY
-rw-r--r--internal/service/ssh/receive_pack.go66
-rw-r--r--internal/service/ssh/receive_pack_test.go188
-rw-r--r--internal/service/ssh/server.go8
-rw-r--r--internal/service/ssh/testhelper_test.go10
-rw-r--r--internal/service/ssh/upload_pack_test.go193
-rw-r--r--internal/service/ssh/uploadpack.go63
6 files changed, 37 insertions, 491 deletions
diff --git a/internal/service/ssh/receive_pack.go b/internal/service/ssh/receive_pack.go
index e6c199052..24cb0ddf9 100644
--- a/internal/service/ssh/receive_pack.go
+++ b/internal/service/ssh/receive_pack.go
@@ -8,47 +8,44 @@ import (
"gitlab.com/gitlab-org/gitaly/internal/helper"
pb "gitlab.com/gitlab-org/gitaly-proto/go"
- pbh "gitlab.com/gitlab-org/gitaly-proto/go/helper"
+ pbhelper "gitlab.com/gitlab-org/gitaly-proto/go/helper"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
)
-type receivePackBytesReader struct {
- pb.SSH_SSHReceivePackServer
-}
-
-type receivePackWriter struct {
- pb.SSH_SSHReceivePackServer
-}
-
-type receivePackErrorWriter struct {
- pb.SSH_SSHReceivePackServer
-}
-
func (s *server) SSHReceivePack(stream pb.SSH_SSHReceivePackServer) error {
req, err := stream.Recv() // First request contains only Repository and GlId
if err != nil {
return err
}
- if err = validateReceivePackRequest(req); err != nil {
+ if err = validateFirstReceivePackRequest(req); err != nil {
return err
}
- streamBytesReader := receivePackBytesReader{stream}
- stdin := pbh.NewReceiveReader(streamBytesReader.ReceiveBytes)
- stdout := receivePackWriter{stream}
- stderr := receivePackErrorWriter{stream}
+ stdin := pbhelper.NewReceiveReader(func() ([]byte, error) {
+ request, err := stream.Recv()
+ return request.GetStdin(), err
+ })
+ stdout := pbhelper.NewSendWriter(func(p []byte) error {
+ return stream.Send(&pb.SSHReceivePackResponse{Stdout: p})
+ })
+ stderr := pbhelper.NewSendWriter(func(p []byte) error {
+ return stream.Send(&pb.SSHReceivePackResponse{Stderr: p})
+ })
env := []string{
fmt.Sprintf("GL_ID=%s", req.GlId),
"GL_PROTOCOL=ssh",
}
+ if req.GlRepository != "" {
+ env = append(env, fmt.Sprintf("GL_REPOSITORY=%s", req.GlRepository))
+ }
repoPath, err := helper.GetRepoPath(req.Repository)
if err != nil {
return err
}
- log.Printf("PostReceivePack: RepoPath=%q GlID=%q", repoPath, req.GlId)
+ log.Printf("PostReceivePack: RepoPath=%q GlID=%q GlRepository=%q", repoPath, req.GlId, req.GlRepository)
osCommand := exec.Command("git-receive-pack", repoPath)
cmd, err := helper.NewCommand(osCommand, stdin, stdout, stderr, env...)
@@ -60,9 +57,7 @@ func (s *server) SSHReceivePack(stream pb.SSH_SSHReceivePackServer) error {
if err := cmd.Wait(); err != nil {
if status, ok := helper.ExitStatus(err); ok {
- log.Printf("Exit Status: %d", status)
- stream.Send(&pb.SSHReceivePackResponse{ExitStatus: &pb.ExitStatus{Value: int32(status)}})
- return nil
+ return stream.Send(&pb.SSHReceivePackResponse{ExitStatus: &pb.ExitStatus{Value: int32(status)}})
}
return grpc.Errorf(codes.Unavailable, "PostReceivePack: cmd wait for %v: %v", cmd.Args, err)
}
@@ -70,7 +65,7 @@ func (s *server) SSHReceivePack(stream pb.SSH_SSHReceivePackServer) error {
return nil
}
-func validateReceivePackRequest(req *pb.SSHReceivePackRequest) error {
+func validateFirstReceivePackRequest(req *pb.SSHReceivePackRequest) error {
if req.GlId == "" {
return grpc.Errorf(codes.InvalidArgument, "PostReceivePack: empty GlId")
}
@@ -80,28 +75,3 @@ func validateReceivePackRequest(req *pb.SSHReceivePackRequest) error {
return nil
}
-
-func (rw receivePackWriter) Write(p []byte) (int, error) {
- resp := &pb.SSHReceivePackResponse{Stdout: p}
- if err := rw.Send(resp); err != nil {
- return 0, err
- }
- return len(p), nil
-}
-
-func (rw receivePackErrorWriter) Write(p []byte) (int, error) {
- resp := &pb.SSHReceivePackResponse{Stderr: p}
- if err := rw.Send(resp); err != nil {
- return 0, err
- }
- return len(p), nil
-}
-
-func (br receivePackBytesReader) ReceiveBytes() ([]byte, error) {
- resp, err := br.Recv()
- if err != nil {
- return nil, err
- }
-
- return resp.GetStdin(), nil
-}
diff --git a/internal/service/ssh/receive_pack_test.go b/internal/service/ssh/receive_pack_test.go
index e0b42cf46..813e292c8 100644
--- a/internal/service/ssh/receive_pack_test.go
+++ b/internal/service/ssh/receive_pack_test.go
@@ -1,14 +1,7 @@
package ssh
import (
- "bytes"
- "fmt"
- "io"
- "os"
- "path"
- "strconv"
"testing"
- "time"
"gitlab.com/gitlab-org/gitaly/internal/testhelper"
@@ -18,118 +11,6 @@ import (
"google.golang.org/grpc/codes"
)
-func TestSuccessfulReceivePackRequest(t *testing.T) {
- server := runSSHServer(t)
- defer server.Stop()
-
- remoteRepoPath := path.Join(testRepoRoot, "gitlab-test-remote")
- localRepoPath := path.Join(testRepoRoot, "gitlab-test-local")
- // Make a non-bare clone of the test repo to act as a local one
- testhelper.MustRunCommand(t, nil, "git", "clone", testhelper.GitlabTestRepoPath(), localRepoPath)
- // Make a bare clone of the test repo to act as a remote one and to leave the original repo intact for other tests
- testhelper.MustRunCommand(t, nil, "git", "clone", "--bare", testhelper.GitlabTestRepoPath(), remoteRepoPath)
- defer os.RemoveAll(remoteRepoPath)
- defer os.RemoveAll(localRepoPath)
-
- commitMsg := fmt.Sprintf("Testing ReceivePack RPC around %d", time.Now().Unix())
- committerName := "Scrooge McDuck"
- committerEmail := "scrooge@mcduck.com"
- clientCapabilities := "report-status side-band-64k agent=git/2.12.0"
-
- // The latest commit ID on the remote repo
- oldHead := bytes.TrimSpace(testhelper.MustRunCommand(t, nil, "git", "-C", localRepoPath, "rev-parse", "master"))
-
- testhelper.MustRunCommand(t, nil, "git", "-C", localRepoPath,
- "-c", fmt.Sprintf("user.name=%s", committerName),
- "-c", fmt.Sprintf("user.email=%s", committerEmail),
- "commit", "--allow-empty", "-m", commitMsg)
-
- // The commit ID we want to push to the remote repo
- newHead := bytes.TrimSpace(testhelper.MustRunCommand(t, nil, "git", "-C", localRepoPath, "rev-parse", "master"))
-
- // ReceivePack request is a packet line followed by a packet flush, then the pack file of the objects we want to push.
- // This is explained a bit in https://git-scm.com/book/en/v2/Git-Internals-Transfer-Protocols#_uploading_data
- // We form the packet line the same way git executable does: https://github.com/git/git/blob/d1a13d3fcb252631361a961cb5e2bf10ed467cba/send-pack.c#L524-L527
- pkt := fmt.Sprintf("%s %s refs/heads/master\x00 %s", oldHead, newHead, clientCapabilities)
- // We need to get a pack file containing the objects we want to push, so we use git pack-objects
- // which expects a list of revisions passed through standard input. The list format means
- // pack the objects needed if I have oldHead but not newHead (think of it from the perspective of the remote repo).
- // For more info, check the man pages of both `git-pack-objects` and `git-rev-list --objects`.
- stdin := bytes.NewBufferString(fmt.Sprintf("^%s\n%s\n", oldHead, newHead))
- // The options passed are the same ones used when doing an actual push.
- pack := testhelper.MustRunCommand(t, stdin, "git", "-C", localRepoPath, "pack-objects", "--stdout", "--revs", "--thin", "--delta-base-offset", "-q")
-
- // We chop the request into multiple small pieces to exercise the server code that handles
- // the stream sent by the client, so we use a buffer to read chunks of data in a nice way.
- requestBuffer := &bytes.Buffer{}
- fmt.Fprintf(requestBuffer, "%04x%s%s", len(pkt)+4, pkt, pktFlushStr)
- requestBuffer.Write(pack)
-
- client := newSSHClient(t)
- repo := &pb.Repository{Path: remoteRepoPath}
- rpcRequest := &pb.SSHReceivePackRequest{Repository: repo, GlId: "user-123"}
- stream, err := client.SSHReceivePack(context.Background())
- if err != nil {
- t.Fatal(err)
- }
-
- if err := stream.Send(rpcRequest); err != nil {
- t.Fatal(err)
- }
-
- data := make([]byte, 16)
- for {
- n, err := requestBuffer.Read(data)
- if err == io.EOF {
- break
- } else if err != nil {
- t.Fatal(err)
- }
-
- rpcRequest = &pb.SSHReceivePackRequest{Stdin: data[:n]}
- if err := stream.Send(rpcRequest); err != nil {
- t.Fatal(err)
- }
- }
- stream.CloseSend()
-
- // Verify everything is going as planned
- responseBuffer := bytes.Buffer{}
- for {
- rpcResponse, err := stream.Recv()
- if err != nil {
- if err == io.EOF {
- if rpcResponse.GetExitStatus().GetValue() != 0 {
- t.Fatalf("Expected ExitStatus to be %d, got %d", 0, rpcResponse.GetExitStatus().GetValue())
- }
- break
- } else {
- t.Fatal(err)
- }
- }
-
- if rpcResponse.Stdout != nil {
- responseBuffer.Write(rpcResponse.GetStdout())
- }
- if rpcResponse.Stderr != nil {
- t.Fatalf("Got something on StdErr: %q", rpcResponse.GetStderr())
- responseBuffer.Write(rpcResponse.GetStderr())
- }
- }
-
- expectedResponse := "0030\x01000eunpack ok\n0019ok refs/heads/master\n00000000"
- extractedResponse, ok := extractUnpackDataFromResponse(t, &responseBuffer)
- if !ok {
- t.Errorf(`Expected response status to be "true", got "false"`)
- }
- if string(extractedResponse) != expectedResponse {
- t.Errorf("Expected response to be %q, got %q", expectedResponse, responseBuffer.String())
- }
-
- // The fact that this command succeeds means that we got the commit correctly, no further checks should be needed.
- testhelper.MustRunCommand(t, nil, "git", "-C", remoteRepoPath, "show", string(newHead))
-}
-
func TestFailedReceivePackRequestDueToValidationError(t *testing.T) {
server := runSSHServer(t)
defer server.Stop()
@@ -167,72 +48,3 @@ func drainPostReceivePackResponse(stream pb.SSH_SSHReceivePackClient) error {
}
return err
}
-
-// The response contains bunch of things; metadata, progress messages, and a pack file. We're only
-// interested in the pack file and its header values.
-func extractUnpackDataFromResponse(t *testing.T, buf *bytes.Buffer) ([]byte, bool) {
- var pack []byte
-
- // The response should have the following format, where <length> is always four hexadecimal digits.
- // <length><data>
- // <length><data>
- // ...
- // 0000
- for {
- pktLenStr := buf.Next(4)
- if len(pktLenStr) != 4 {
- return nil, false
- }
- if string(pktLenStr) == pktFlushStr {
- break
- }
-
- pktLen, err := strconv.ParseUint(string(pktLenStr), 16, 16)
- if err != nil {
- t.Fatal(err)
- }
-
- restPktLen := int(pktLen) - 4
- pkt := buf.Next(restPktLen)
- if len(pkt) != restPktLen {
- t.Fatalf("Incomplete packet read")
- }
- }
-
- t.Logf("resulting buf: %q", buf.String())
-
- // NOTE: This seems like an ugly hack...
- temp := buf.Next(5)
- if len(temp) != 5 {
- t.Fatalf("Could not read unpack magic...")
- }
-
- pack = buf.Bytes()
- t.Logf("pack: %q", pack)
-
- if len(pack) < 4 {
- t.Fatalf("Invalid unpack signature %q", pack)
- }
- pktLenStr := pack[:4]
- pktLen, err := strconv.ParseUint(string(pktLenStr), 16, 16)
- if err != nil {
- t.Fatal(err)
- }
-
- restPktLen := int(pktLen)
- unpkt := pack[4:restPktLen]
- t.Logf("unpkt: %q", unpkt)
-
- // The packet is structured as follows:
- // 4 bytes for signature, here it's "PACK"
- // 4 bytes for header version
- // 4 bytes for header entries
- // The rest is the pack file
- if len(unpkt) < 6 || string(unpkt[:6]) != "unpack" {
- t.Fatalf("Invalid packet signature %q", pack)
- }
-
- t.Logf("status: %q", unpkt[6:])
-
- return append(temp, pack...), string(unpkt[6:]) == " ok\n"
-}
diff --git a/internal/service/ssh/server.go b/internal/service/ssh/server.go
index 7082505c4..5012f022f 100644
--- a/internal/service/ssh/server.go
+++ b/internal/service/ssh/server.go
@@ -2,13 +2,9 @@ package ssh
import pb "gitlab.com/gitlab-org/gitaly-proto/go"
-const maxChunkSize = 1024
-
-type server struct {
- ChunkSize int
-}
+type server struct{}
// NewServer creates a new instance of a grpc SSHServer
func NewServer() pb.SSHServer {
- return &server{ChunkSize: maxChunkSize}
+ return &server{}
}
diff --git a/internal/service/ssh/testhelper_test.go b/internal/service/ssh/testhelper_test.go
index 20112cb88..84b037519 100644
--- a/internal/service/ssh/testhelper_test.go
+++ b/internal/service/ssh/testhelper_test.go
@@ -8,8 +8,6 @@ import (
"testing"
"time"
- "gitlab.com/gitlab-org/gitaly/internal/testhelper"
-
pb "gitlab.com/gitlab-org/gitaly-proto/go"
"google.golang.org/grpc"
@@ -17,22 +15,18 @@ import (
)
const (
- scratchDir = "testdata/scratch"
- testRepoRoot = "testdata/data"
- pktFlushStr = "0000"
+ scratchDir = "testdata/scratch"
)
var (
serverSocketPath = path.Join(scratchDir, "gitaly.sock")
- testRepoPath = ""
)
func TestMain(m *testing.M) {
- testRepoPath = testhelper.GitlabTestRepoPath()
-
if err := os.MkdirAll(scratchDir, 0755); err != nil {
log.Fatal(err)
}
+ defer os.RemoveAll(scratchDir)
os.Exit(func() int {
return m.Run()
diff --git a/internal/service/ssh/upload_pack_test.go b/internal/service/ssh/upload_pack_test.go
index f26da2e3c..da223e100 100644
--- a/internal/service/ssh/upload_pack_test.go
+++ b/internal/service/ssh/upload_pack_test.go
@@ -1,15 +1,7 @@
package ssh
import (
- "bytes"
- "encoding/binary"
- "fmt"
- "io"
- "os"
- "path"
- "strconv"
"testing"
- "time"
"gitlab.com/gitlab-org/gitaly/internal/testhelper"
@@ -19,132 +11,6 @@ import (
"google.golang.org/grpc/codes"
)
-func TestSuccessfulUploadPackRequest(t *testing.T) {
- server := runSSHServer(t)
- defer server.Stop()
-
- localRepoPath := path.Join(testRepoRoot, "gitlab-test-local")
- remoteRepoPath := path.Join(testRepoRoot, "gitlab-test-remote")
- // Make a non-bare clone of the test repo to act as a remote one
- testhelper.MustRunCommand(t, nil, "git", "clone", testhelper.GitlabTestRepoPath(), remoteRepoPath)
- // Make a bare clone of the test repo to act as a local one and to leave the original repo intact for other tests
- testhelper.MustRunCommand(t, nil, "git", "clone", "--bare", testhelper.GitlabTestRepoPath(), localRepoPath)
- defer os.RemoveAll(localRepoPath)
- defer os.RemoveAll(remoteRepoPath)
-
- commitMsg := fmt.Sprintf("Testing UploadPack RPC around %d", time.Now().Unix())
- committerName := "Scrooge McDuck"
- committerEmail := "scrooge@mcduck.com"
- clientCapabilities := "multi_ack_detailed no-done side-band-64k thin-pack include-tag ofs-delta deepen-since deepen-not agent=git/2.12.0"
-
- // The latest commit ID on the local repo
- oldHead := bytes.TrimSpace(testhelper.MustRunCommand(t, nil, "git", "-C", remoteRepoPath, "rev-parse", "master"))
-
- testhelper.MustRunCommand(t, nil, "git", "-C", remoteRepoPath,
- "-c", fmt.Sprintf("user.name=%s", committerName),
- "-c", fmt.Sprintf("user.email=%s", committerEmail),
- "commit", "--allow-empty", "-m", commitMsg)
-
- // The commit ID we want to pull from the remote repo
- newHead := bytes.TrimSpace(testhelper.MustRunCommand(t, nil, "git", "-C", remoteRepoPath, "rev-parse", "master"))
-
- // UploadPack request is a "want" packet line followed by a packet flush, then many "have" packets followed by a packet flush.
- // This is explained a bit in https://git-scm.com/book/en/v2/Git-Internals-Transfer-Protocols#_downloading_data
- wantPkt := fmt.Sprintf("want %s %s\n", newHead, clientCapabilities)
- havePkt := fmt.Sprintf("have %s\n", oldHead)
-
- // We don't check for errors because per bytes.Buffer docs, Buffer.Write will always return a nil error.
- requestBuffer := &bytes.Buffer{}
- fmt.Fprintf(requestBuffer, "%04x%s%s", len(wantPkt)+4, wantPkt, pktFlushStr)
- fmt.Fprintf(requestBuffer, "%04x%s%s", len(havePkt)+4, havePkt, pktFlushStr)
-
- client := newSSHClient(t)
- repo := &pb.Repository{Path: path.Join(remoteRepoPath, ".git")}
- rpcRequest := &pb.SSHUploadPackRequest{Repository: repo}
- stream, err := client.SSHUploadPack(context.Background())
- if err != nil {
- t.Fatal(err)
- }
-
- if err = stream.Send(rpcRequest); err != nil {
- t.Fatal(err)
- }
-
- data := make([]byte, 16)
- for {
- n, err := requestBuffer.Read(data)
- if err == io.EOF {
- break
- } else if err != nil {
- t.Fatal(err)
- }
-
- rpcRequest = &pb.SSHUploadPackRequest{Stdin: data[:n]}
- if err := stream.Send(rpcRequest); err != nil {
- t.Fatal(err)
- }
- }
- stream.CloseSend()
-
- responseBuffer := &bytes.Buffer{}
- var chunk int
- for {
- rpcResponse, err := stream.Recv()
- if err != nil {
- if err == io.EOF {
- break
- } else {
- t.Fatal(err)
- }
- }
- chunk++
-
- if rpcResponse.Stdout != nil {
- responseBuffer.Write(rpcResponse.GetStdout())
- }
- if rpcResponse.Stderr != nil {
- responseBuffer.Write(rpcResponse.GetStderr())
- }
- // responseBuffer.Write(rpcResponse.GetStderr())
- t.Logf("Read chunk %d", chunk)
- }
-
- // There's no git command we can pass it this response and do the work for us (extracting pack file, ...),
- // so we have to do it ourselves.
- pack, version, entries := extractPackDataFromResponse(t, responseBuffer)
- if pack == nil {
- t.Errorf("Expected to find a pack file in response, found none")
- return
- }
-
- err = drainUploadStreamAndVerifyExitStatus(t, 0, stream)
- if err != nil {
- t.Fatal(err)
- }
-
- testhelper.MustRunCommand(t, bytes.NewReader(pack), "git", "-C", localRepoPath, "unpack-objects", fmt.Sprintf("--pack_header=%d,%d", version, entries))
-
- // The fact that this command succeeds means that we got the commit correctly, no further checks should be needed.
- testhelper.MustRunCommand(t, nil, "git", "-C", localRepoPath, "show", string(newHead))
-}
-
-func drainUploadStreamAndVerifyExitStatus(t *testing.T, status int32, stream pb.SSH_SSHUploadPackClient) error {
- var (
- err error
- chunk *pb.SSHUploadPackResponse
- )
- for err == nil {
- chunk, err = stream.Recv()
- }
- if chunk.GetExitStatus().GetValue() != status {
- t.Fatalf("Expected ExitStatus to be %d, got %d", status, chunk.GetExitStatus().GetValue())
- }
- if err != io.EOF {
- return err
- }
- return nil
-}
-
func TestFailedUploadPackRequestDueToValidationError(t *testing.T) {
server := runSSHServer(t)
defer server.Stop()
@@ -181,62 +47,3 @@ func drainPostUploadPackResponse(stream pb.SSH_SSHUploadPackClient) error {
}
return err
}
-
-// The response contains bunch of things; metadata, progress messages, and a pack file. We're only
-// interested in the pack file and its header values.
-func extractPackDataFromResponse(t *testing.T, buf *bytes.Buffer) ([]byte, int, int) {
- var pack []byte
- t.Logf("complete buf: %q", buf.String())
-
- // Since this is Smart Protocol we need to do this twice...
- // The first pass is listing all the refs that the server has
- // The second pass is listing the commonly known last ref
- for i := 0; i < 2; i++ {
- // The response should have the following format, where <length> is always four hexadecimal digits.
- // <length><data>
- // <length><data>
- // ...
- // 0000
- for {
- pktLenStr := buf.Next(4)
- if len(pktLenStr) != 4 {
- return nil, 0, 0
- }
- if string(pktLenStr) == pktFlushStr {
- break
- }
-
- pktLen, err := strconv.ParseUint(string(pktLenStr), 16, 16)
- if err != nil {
- t.Fatal(err)
- }
-
- restPktLen := int(pktLen) - 4
- pkt := buf.Next(restPktLen)
- if len(pkt) != restPktLen {
- t.Fatalf("Incomplete packet read")
- }
-
- // The first byte of the packet is the band designator. We only care about data in band 1.
- if pkt[0] == 1 {
- pack = append(pack, pkt[1:]...)
- }
- }
- }
-
- t.Logf("resulting buf: %s", buf.String())
-
- // The packet is structured as follows:
- // 4 bytes for signature, here it's "PACK"
- // 4 bytes for header version
- // 4 bytes for header entries
- // The rest is the pack file
- if len(pack) < 4 || string(pack[:4]) != "PACK" {
- t.Fatalf("Invalid packet signature %q", pack)
- }
- version := int(binary.BigEndian.Uint32(pack[4:8]))
- entries := int(binary.BigEndian.Uint32(pack[8:12]))
- pack = pack[12:]
-
- return pack, version, entries
-}
diff --git a/internal/service/ssh/uploadpack.go b/internal/service/ssh/uploadpack.go
index 1ff0d988f..ad18816ba 100644
--- a/internal/service/ssh/uploadpack.go
+++ b/internal/service/ssh/uploadpack.go
@@ -5,37 +5,31 @@ import (
"os/exec"
pb "gitlab.com/gitlab-org/gitaly-proto/go"
- pbh "gitlab.com/gitlab-org/gitaly-proto/go/helper"
+ pbhelper "gitlab.com/gitlab-org/gitaly-proto/go/helper"
"gitlab.com/gitlab-org/gitaly/internal/helper"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
)
-type uploadPackBytesReader struct {
- pb.SSH_SSHUploadPackServer
-}
-
-type uploadPackWriter struct {
- pb.SSH_SSHUploadPackServer
-}
-
-type uploadPackErrorWriter struct {
- pb.SSH_SSHUploadPackServer
-}
-
func (s *server) SSHUploadPack(stream pb.SSH_SSHUploadPackServer) error {
req, err := stream.Recv() // First request contains Repository only
if err != nil {
return err
}
- if err = validateUploadPackRequest(req); err != nil {
+ if err = validateFirstUploadPackRequest(req); err != nil {
return err
}
- streamBytesReader := uploadPackBytesReader{stream}
- stdin := pbh.NewReceiveReader(streamBytesReader.ReceiveBytes)
- stdout := uploadPackWriter{stream}
- stderr := uploadPackErrorWriter{stream}
+ stdin := pbhelper.NewReceiveReader(func() ([]byte, error) {
+ request, err := stream.Recv()
+ return request.GetStdin(), err
+ })
+ stdout := pbhelper.NewSendWriter(func(p []byte) error {
+ return stream.Send(&pb.SSHUploadPackResponse{Stdout: p})
+ })
+ stderr := pbhelper.NewSendWriter(func(p []byte) error {
+ return stream.Send(&pb.SSHUploadPackResponse{Stderr: p})
+ })
repoPath, err := helper.GetRepoPath(req.Repository)
if err != nil {
return err
@@ -43,7 +37,7 @@ func (s *server) SSHUploadPack(stream pb.SSH_SSHUploadPackServer) error {
log.Printf("PostUploadPack: RepoPath=%q", repoPath)
- osCommand := exec.Command("git", "upload-pack", repoPath)
+ osCommand := exec.Command("git-upload-pack", repoPath)
cmd, err := helper.NewCommand(osCommand, stdin, stdout, stderr)
if err != nil {
@@ -53,9 +47,7 @@ func (s *server) SSHUploadPack(stream pb.SSH_SSHUploadPackServer) error {
if err := cmd.Wait(); err != nil {
if status, ok := helper.ExitStatus(err); ok {
- log.Printf("Exit Status: %d", status)
- stream.Send(&pb.SSHUploadPackResponse{ExitStatus: &pb.ExitStatus{Value: int32(status)}})
- return nil
+ return stream.Send(&pb.SSHUploadPackResponse{ExitStatus: &pb.ExitStatus{Value: int32(status)}})
}
return grpc.Errorf(codes.Unavailable, "PostUploadPack: cmd wait for %v: %v", cmd.Args, err)
}
@@ -63,35 +55,10 @@ func (s *server) SSHUploadPack(stream pb.SSH_SSHUploadPackServer) error {
return nil
}
-func validateUploadPackRequest(req *pb.SSHUploadPackRequest) error {
+func validateFirstUploadPackRequest(req *pb.SSHUploadPackRequest) error {
if req.Stdin != nil {
return grpc.Errorf(codes.InvalidArgument, "PostUploadPack: non-empty stdin")
}
return nil
}
-
-func (rw uploadPackWriter) Write(p []byte) (int, error) {
- resp := &pb.SSHUploadPackResponse{Stdout: p}
- if err := rw.Send(resp); err != nil {
- return 0, err
- }
- return len(p), nil
-}
-
-func (rw uploadPackErrorWriter) Write(p []byte) (int, error) {
- resp := &pb.SSHUploadPackResponse{Stderr: p}
- if err := rw.Send(resp); err != nil {
- return 0, err
- }
- return len(p), nil
-}
-
-func (br uploadPackBytesReader) ReceiveBytes() ([]byte, error) {
- resp, err := br.Recv()
- if err != nil {
- return nil, err
- }
-
- return resp.GetStdin(), nil
-}