diff options
author | Kim Carlbäcker <kim.carlbacker@gmail.com> | 2018-01-18 22:58:54 +0300 |
---|---|---|
committer | Kim Carlbäcker <kim.carlbacker@gmail.com> | 2018-01-18 22:58:54 +0300 |
commit | 864a60b7a58201466623bab5fcb7eaf755fb3738 (patch) | |
tree | ee6aed5d8b2b244b17801f38fc48718cdce1fbc6 | |
parent | a613c884192059321f28d17aec5aa811ecad2ffc (diff) | |
parent | 76c862ba55eab6047367122665a69562b064a816 (diff) |
Use grpc-go 1.9.1
See merge request gitlab-org/gitaly!547
107 files changed, 2191 insertions, 1430 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md index b0964e7be..668f07f53 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,10 @@ # Gitaly changelog +UNRELEASED + +- Use grpc-go 1.9.1 + https://gitlab.com/gitlab-org/gitaly/merge_requests/547 + v0.71.0 - Implement GetLfsPointers RPC diff --git a/internal/git/catfile/catfile.go b/internal/git/catfile/catfile.go index 7a41acb80..db925fe5f 100644 --- a/internal/git/catfile/catfile.go +++ b/internal/git/catfile/catfile.go @@ -13,8 +13,8 @@ import ( "gitlab.com/gitlab-org/gitaly/internal/command" "gitlab.com/gitlab-org/gitaly/internal/git/alternates" - "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) // ObjectInfo represents a header returned by `git cat-file --batch` @@ -41,7 +41,7 @@ func CatFile(ctx context.Context, repo *pb.Repository, handler Handler) error { cmdArgs := []string{"--git-dir", repoPath, "cat-file", "--batch"} cmd, err := command.New(ctx, exec.Command(command.GitPath(), cmdArgs...), stdinReader, nil, nil, env...) if err != nil { - return grpc.Errorf(codes.Internal, "CatFile: cmd: %v", err) + return status.Errorf(codes.Internal, "CatFile: cmd: %v", err) } defer stdinWriter.Close() defer stdinReader.Close() diff --git a/internal/helper/error.go b/internal/helper/error.go index 604be7764..4966b788d 100644 --- a/internal/helper/error.go +++ b/internal/helper/error.go @@ -1,18 +1,32 @@ package helper import ( - "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) // Unimplemented is a Go error with gRPC error code 'Unimplemented' -var Unimplemented = grpc.Errorf(codes.Unimplemented, "this rpc is not implemented") +var Unimplemented = status.Errorf(codes.Unimplemented, "this rpc is not implemented") // DecorateError unless it's already a grpc error. // If given nil it will return nil. func DecorateError(code codes.Code, err error) error { - if err != nil && grpc.Code(err) == codes.Unknown { - return grpc.Errorf(code, "%v", err) + if err != nil && GrpcCode(err) == codes.Unknown { + return status.Errorf(code, "%v", err) } return err } + +// GrpcCode emulates the old grpc.Code function: it translates errors into codes.Code values. +func GrpcCode(err error) codes.Code { + if err == nil { + return codes.OK + } + + st, ok := status.FromError(err) + if !ok { + return codes.Unknown + } + + return st.Code() +} diff --git a/internal/helper/repo.go b/internal/helper/repo.go index 326088814..a7348075a 100644 --- a/internal/helper/repo.go +++ b/internal/helper/repo.go @@ -9,8 +9,8 @@ import ( pb "gitlab.com/gitlab-org/gitaly-proto/go" - "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) // GetRepoPath returns the full path of the repository referenced by an @@ -24,14 +24,14 @@ func GetRepoPath(repo *pb.Repository) (string, error) { } if repoPath == "" { - return "", grpc.Errorf(codes.InvalidArgument, "GetRepoPath: empty repo") + return "", status.Errorf(codes.InvalidArgument, "GetRepoPath: empty repo") } if IsGitDirectory(repoPath) { return repoPath, nil } - return "", grpc.Errorf(codes.NotFound, "GetRepoPath: not a git repository '%s'", repoPath) + return "", status.Errorf(codes.NotFound, "GetRepoPath: not a git repository '%s'", repoPath) } // GetPath returns the path of the repo passed as first argument. An error is @@ -44,12 +44,12 @@ func GetPath(repo *pb.Repository) (string, error) { } if _, err := os.Stat(storagePath); err != nil { - return "", grpc.Errorf(codes.Internal, "GetPath: storage path: %v", err) + return "", status.Errorf(codes.Internal, "GetPath: storage path: %v", err) } relativePath := repo.GetRelativePath() if len(relativePath) == 0 { - err := grpc.Errorf(codes.InvalidArgument, "GetPath: relative path missing from %+v", repo) + err := status.Errorf(codes.InvalidArgument, "GetPath: relative path missing from %+v", repo) return "", err } @@ -58,7 +58,7 @@ func GetPath(repo *pb.Repository) (string, error) { if strings.HasPrefix(relativePath, ".."+separator) || strings.Contains(relativePath, separator+".."+separator) || strings.HasSuffix(relativePath, separator+"..") { - return "", grpc.Errorf(codes.InvalidArgument, "GetRepoPath: relative path can't contain directory traversal") + return "", status.Errorf(codes.InvalidArgument, "GetRepoPath: relative path can't contain directory traversal") } return path.Join(storagePath, relativePath), nil @@ -69,7 +69,7 @@ func GetPath(repo *pb.Repository) (string, error) { func GetStorageByName(storageName string) (string, error) { storagePath, ok := config.StoragePath(storageName) if !ok { - return "", grpc.Errorf(codes.InvalidArgument, "Storage can not be found by name '%s'", storageName) + return "", status.Errorf(codes.InvalidArgument, "Storage can not be found by name '%s'", storageName) } return storagePath, nil diff --git a/internal/middleware/cancelhandler/cancelhandler.go b/internal/middleware/cancelhandler/cancelhandler.go index a1e8d3a09..4592059d5 100644 --- a/internal/middleware/cancelhandler/cancelhandler.go +++ b/internal/middleware/cancelhandler/cancelhandler.go @@ -4,6 +4,7 @@ import ( "golang.org/x/net/context" "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) // Unary is a unary server interceptor that puts cancel codes on errors @@ -28,5 +29,5 @@ func wrapErr(ctx context.Context, err error) error { if ctx.Err() == context.DeadlineExceeded { code = codes.DeadlineExceeded } - return grpc.Errorf(code, "%v", err) + return status.Errorf(code, "%v", err) } diff --git a/internal/middleware/metadatahandler/metadatahandler.go b/internal/middleware/metadatahandler/metadatahandler.go index 3f70f24bd..3c95e550c 100644 --- a/internal/middleware/metadatahandler/metadatahandler.go +++ b/internal/middleware/metadatahandler/metadatahandler.go @@ -3,6 +3,7 @@ package metadatahandler import ( "github.com/grpc-ecosystem/go-grpc-middleware/tags" "github.com/prometheus/client_golang/prometheus" + "gitlab.com/gitlab-org/gitaly/internal/helper" "golang.org/x/net/context" "google.golang.org/grpc" "google.golang.org/grpc/metadata" @@ -77,7 +78,7 @@ func UnaryInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServ res, err := handler(ctx, req) - grpcCode := grpc.Code(err) + grpcCode := helper.GrpcCode(err) requests.WithLabelValues(clientName, callSite, grpcCode.String()).Inc() return res, err @@ -90,7 +91,7 @@ func StreamInterceptor(srv interface{}, stream grpc.ServerStream, info *grpc.Str err := handler(srv, stream) - grpcCode := grpc.Code(err) + grpcCode := helper.GrpcCode(err) requests.WithLabelValues(clientName, callSite, grpcCode.String()).Inc() return err diff --git a/internal/middleware/panichandler/panic_handler.go b/internal/middleware/panichandler/panic_handler.go index dd68f22c6..637a854e8 100644 --- a/internal/middleware/panichandler/panic_handler.go +++ b/internal/middleware/panichandler/panic_handler.go @@ -6,6 +6,7 @@ import ( "golang.org/x/net/context" "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) var _ grpc.UnaryServerInterceptor = UnaryPanicHandler @@ -15,7 +16,7 @@ var _ grpc.StreamServerInterceptor = StreamPanicHandler type PanicHandler func(methodName string, error interface{}) func toPanicError(grpcMethodName string, r interface{}) error { - return grpc.Errorf(codes.Internal, "panic: %v", r) + return status.Errorf(codes.Internal, "panic: %v", r) } // UnaryPanicHandler handles request-response panics diff --git a/internal/middleware/sentryhandler/sentryhandler.go b/internal/middleware/sentryhandler/sentryhandler.go index 2557c670f..ee79ea494 100644 --- a/internal/middleware/sentryhandler/sentryhandler.go +++ b/internal/middleware/sentryhandler/sentryhandler.go @@ -6,6 +6,7 @@ import ( raven "github.com/getsentry/raven-go" "github.com/grpc-ecosystem/go-grpc-middleware/tags" + "gitlab.com/gitlab-org/gitaly/internal/helper" "fmt" @@ -53,7 +54,7 @@ func methodToCulprit(methodName string) string { } func logErrorToSentry(err error) (code codes.Code, bypass bool) { - code = grpc.Code(err) + code = helper.GrpcCode(err) bypass = code == codes.OK || code == codes.Canceled return code, bypass diff --git a/internal/middleware/sentryhandler/sentryhandler_test.go b/internal/middleware/sentryhandler/sentryhandler_test.go index c8cf199ea..b156de474 100644 --- a/internal/middleware/sentryhandler/sentryhandler_test.go +++ b/internal/middleware/sentryhandler/sentryhandler_test.go @@ -6,10 +6,10 @@ import ( "time" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "github.com/stretchr/testify/assert" "golang.org/x/net/context" - "google.golang.org/grpc" ) func Test_generateRavenPacket(t *testing.T) { @@ -36,7 +36,7 @@ func Test_generateRavenPacket(t *testing.T) { name: "GRPC error", method: "/gitaly.RepoService/RepoExists", sinceStart: 500 * time.Millisecond, - err: grpc.Errorf(codes.NotFound, "Something failed"), + err: status.Errorf(codes.NotFound, "Something failed"), wantCode: codes.NotFound, wantMessage: "rpc error: code = NotFound desc = Something failed", wantCulprit: "RepoService::RepoExists", diff --git a/internal/server/auth/auth.go b/internal/server/auth/auth.go index 0498756b5..da24914bd 100644 --- a/internal/server/auth/auth.go +++ b/internal/server/auth/auth.go @@ -10,6 +10,7 @@ import ( "golang.org/x/net/context" "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) var ( @@ -45,20 +46,20 @@ func check(ctx context.Context) (context.Context, error) { encodedToken, err := grpc_auth.AuthFromMD(ctx, "bearer") if err != nil { countStatus("unauthenticated").Inc() - err = grpc.Errorf(codes.Unauthenticated, "authentication required") + err = status.Errorf(codes.Unauthenticated, "authentication required") return ctx, ifEnforced(err) } token, err := base64.StdEncoding.DecodeString(encodedToken) if err != nil { countStatus("invalid").Inc() - err = grpc.Errorf(codes.Unauthenticated, "authentication required") + err = status.Errorf(codes.Unauthenticated, "authentication required") return ctx, ifEnforced(err) } if !config.Config.Auth.Token.Equal(string(token)) { countStatus("denied").Inc() - err = grpc.Errorf(codes.PermissionDenied, "permission denied") + err = status.Errorf(codes.PermissionDenied, "permission denied") return ctx, ifEnforced(err) } diff --git a/internal/service/blob/get_blob.go b/internal/service/blob/get_blob.go index 573292e66..f34a3be01 100644 --- a/internal/service/blob/get_blob.go +++ b/internal/service/blob/get_blob.go @@ -13,13 +13,13 @@ import ( pb "gitlab.com/gitlab-org/gitaly-proto/go" "gitlab.com/gitlab-org/gitaly/streamio" - "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) func (s *server) GetBlob(in *pb.GetBlobRequest, stream pb.BlobService_GetBlobServer) error { if err := validateRequest(in); err != nil { - return grpc.Errorf(codes.InvalidArgument, "GetBlob: %v", err) + return status.Errorf(codes.InvalidArgument, "GetBlob: %v", err) } repoPath, err := helper.GetRepoPath(in.Repository) @@ -32,13 +32,13 @@ func (s *server) GetBlob(in *pb.GetBlobRequest, stream pb.BlobService_GetBlobSer cmdArgs := []string{"--git-dir", repoPath, "cat-file", "--batch"} cmd, err := command.New(stream.Context(), exec.Command(command.GitPath(), cmdArgs...), stdinReader, nil, nil) if err != nil { - return grpc.Errorf(codes.Internal, "GetBlob: cmd: %v", err) + return status.Errorf(codes.Internal, "GetBlob: cmd: %v", err) } defer stdinWriter.Close() defer stdinReader.Close() if _, err := fmt.Fprintln(stdinWriter, in.Oid); err != nil { - return grpc.Errorf(codes.Internal, "GetBlob: stdin write: %v", err) + return status.Errorf(codes.Internal, "GetBlob: stdin write: %v", err) } stdinWriter.Close() @@ -46,7 +46,7 @@ func (s *server) GetBlob(in *pb.GetBlobRequest, stream pb.BlobService_GetBlobSer objectInfo, err := catfile.ParseObjectInfo(stdout) if err != nil { - return grpc.Errorf(codes.Internal, "GetBlob: %v", err) + return status.Errorf(codes.Internal, "GetBlob: %v", err) } if objectInfo.Type != "blob" { return helper.DecorateError(codes.Unavailable, stream.Send(&pb.GetBlobResponse{})) @@ -77,10 +77,10 @@ func (s *server) GetBlob(in *pb.GetBlobRequest, stream pb.BlobService_GetBlobSer n, err := io.Copy(sw, io.LimitReader(stdout, readLimit)) if err != nil { - return grpc.Errorf(codes.Unavailable, "GetBlob: send: %v", err) + return status.Errorf(codes.Unavailable, "GetBlob: send: %v", err) } if n != readLimit { - return grpc.Errorf(codes.Unavailable, "GetBlob: short send: %d/%d bytes", n, objectInfo.Size) + return status.Errorf(codes.Unavailable, "GetBlob: short send: %d/%d bytes", n, objectInfo.Size) } return nil diff --git a/internal/service/blob/lfs_pointers.go b/internal/service/blob/lfs_pointers.go index bc8f0dd56..297a83e25 100644 --- a/internal/service/blob/lfs_pointers.go +++ b/internal/service/blob/lfs_pointers.go @@ -7,15 +7,15 @@ import ( pb "gitlab.com/gitlab-org/gitaly-proto/go" - "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) func (s *server) GetLFSPointers(req *pb.GetLFSPointersRequest, stream pb.BlobService_GetLFSPointersServer) error { ctx := stream.Context() if err := validateGetLFSPointersRequest(req); err != nil { - return grpc.Errorf(codes.InvalidArgument, "GetLFSPointers: %v", err) + return status.Errorf(codes.InvalidArgument, "GetLFSPointers: %v", err) } client, err := s.BlobServiceClient(ctx) diff --git a/internal/service/commit/between.go b/internal/service/commit/between.go index f2d9d5ccf..6fc186828 100644 --- a/internal/service/commit/between.go +++ b/internal/service/commit/between.go @@ -5,8 +5,8 @@ import ( pb "gitlab.com/gitlab-org/gitaly-proto/go" "gitlab.com/gitlab-org/gitaly/internal/git" - "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) type commitsBetweenSender struct { @@ -15,10 +15,10 @@ type commitsBetweenSender struct { func (s *server) CommitsBetween(in *pb.CommitsBetweenRequest, stream pb.CommitService_CommitsBetweenServer) error { if err := git.ValidateRevision(in.GetFrom()); err != nil { - return grpc.Errorf(codes.InvalidArgument, "CommitsBetween: from: %v", err) + return status.Errorf(codes.InvalidArgument, "CommitsBetween: from: %v", err) } if err := git.ValidateRevision(in.GetTo()); err != nil { - return grpc.Errorf(codes.InvalidArgument, "CommitsBetween: to: %v", err) + return status.Errorf(codes.InvalidArgument, "CommitsBetween: to: %v", err) } sender := &commitsBetweenSender{stream} diff --git a/internal/service/commit/commits_by_message.go b/internal/service/commit/commits_by_message.go index 0178b768b..2c7e1fc63 100644 --- a/internal/service/commit/commits_by_message.go +++ b/internal/service/commit/commits_by_message.go @@ -5,7 +5,6 @@ import ( pb "gitlab.com/gitlab-org/gitaly-proto/go" - "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -16,7 +15,7 @@ type commitsByMessageSender struct { func (s *server) CommitsByMessage(in *pb.CommitsByMessageRequest, stream pb.CommitService_CommitsByMessageServer) error { if err := validateCommitsByMessageRequest(in); err != nil { - return grpc.Errorf(codes.InvalidArgument, "CommitsByMessage: %v", err) + return status.Errorf(codes.InvalidArgument, "CommitsByMessage: %v", err) } ctx := stream.Context() @@ -42,7 +41,7 @@ func (s *server) CommitsByMessage(in *pb.CommitsByMessageRequest, stream pb.Comm if _, ok := status.FromError(err); ok { return err } - return grpc.Errorf(codes.Internal, "CommitsByMessage: defaultBranchName: %v", err) + return status.Errorf(codes.Internal, "CommitsByMessage: defaultBranchName: %v", err) } } diff --git a/internal/service/commit/count_commits.go b/internal/service/commit/count_commits.go index 061083646..f2aff4663 100644 --- a/internal/service/commit/count_commits.go +++ b/internal/service/commit/count_commits.go @@ -13,14 +13,13 @@ import ( "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus" "golang.org/x/net/context" - "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) func (s *server) CountCommits(ctx context.Context, in *pb.CountCommitsRequest) (*pb.CountCommitsResponse, error) { if err := validateCountCommitsRequest(in); err != nil { - return nil, grpc.Errorf(codes.InvalidArgument, "CountCommits: %v", err) + return nil, status.Errorf(codes.InvalidArgument, "CountCommits: %v", err) } cmdArgs := []string{"rev-list", "--count", string(in.GetRevision())} @@ -43,7 +42,7 @@ func (s *server) CountCommits(ctx context.Context, in *pb.CountCommitsRequest) ( if _, ok := status.FromError(err); ok { return nil, err } - return nil, grpc.Errorf(codes.Internal, "CountCommits: cmd: %v", err) + return nil, status.Errorf(codes.Internal, "CountCommits: cmd: %v", err) } var count int64 @@ -61,7 +60,7 @@ func (s *server) CountCommits(ctx context.Context, in *pb.CountCommitsRequest) ( count, err = strconv.ParseInt(string(countStr), 10, 0) if err != nil { - return nil, grpc.Errorf(codes.Internal, "CountCommits: parse count: %v", err) + return nil, status.Errorf(codes.Internal, "CountCommits: parse count: %v", err) } } diff --git a/internal/service/commit/filter_shas_with_signatures.go b/internal/service/commit/filter_shas_with_signatures.go index f3ed01181..c64453fa5 100644 --- a/internal/service/commit/filter_shas_with_signatures.go +++ b/internal/service/commit/filter_shas_with_signatures.go @@ -3,8 +3,8 @@ package commit import ( pb "gitlab.com/gitlab-org/gitaly-proto/go" "gitlab.com/gitlab-org/gitaly/internal/rubyserver" - "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) func (s *server) FilterShasWithSignatures(bidi pb.CommitService_FilterShasWithSignaturesServer) error { @@ -60,7 +60,7 @@ func (s *server) FilterShasWithSignatures(bidi pb.CommitService_FilterShasWithSi func verifyFirstFilterShasWithSignaturesRequest(in *pb.FilterShasWithSignaturesRequest) error { if in.Repository == nil { - return grpc.Errorf(codes.InvalidArgument, "no repository given") + return status.Errorf(codes.InvalidArgument, "no repository given") } return nil } diff --git a/internal/service/commit/find_all_commits.go b/internal/service/commit/find_all_commits.go index a3c44471b..4197f8992 100644 --- a/internal/service/commit/find_all_commits.go +++ b/internal/service/commit/find_all_commits.go @@ -7,8 +7,8 @@ import ( pb "gitlab.com/gitlab-org/gitaly-proto/go" - "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) // We declare this function in variables so that we can override them in our tests @@ -41,7 +41,7 @@ func (s *server) FindAllCommits(in *pb.FindAllCommitsRequest, stream pb.CommitSe if len(in.GetRevision()) == 0 { branchNames, err := _findBranchNamesFunc(stream.Context(), in.Repository) if err != nil { - return grpc.Errorf(codes.InvalidArgument, "FindAllCommits: %v", err) + return status.Errorf(codes.InvalidArgument, "FindAllCommits: %v", err) } for _, branch := range branchNames { diff --git a/internal/service/commit/find_commit.go b/internal/service/commit/find_commit.go index 905557c7f..d56f3eb13 100644 --- a/internal/service/commit/find_commit.go +++ b/internal/service/commit/find_commit.go @@ -6,13 +6,13 @@ import ( pb "gitlab.com/gitlab-org/gitaly-proto/go" "gitlab.com/gitlab-org/gitaly/internal/git" "golang.org/x/net/context" - "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) func (s *server) FindCommit(ctx context.Context, in *pb.FindCommitRequest) (*pb.FindCommitResponse, error) { if err := git.ValidateRevision(in.GetRevision()); err != nil { - return nil, grpc.Errorf(codes.InvalidArgument, "FindCommit: revision: %v", err) + return nil, status.Errorf(codes.InvalidArgument, "FindCommit: revision: %v", err) } commit, err := log.GetCommit(ctx, in.GetRepository(), string(in.GetRevision()), "") diff --git a/internal/service/commit/find_commits.go b/internal/service/commit/find_commits.go index effcdd822..3060aaf52 100644 --- a/internal/service/commit/find_commits.go +++ b/internal/service/commit/find_commits.go @@ -5,8 +5,8 @@ import ( pb "gitlab.com/gitlab-org/gitaly-proto/go" - "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) func (s *server) FindCommits(req *pb.FindCommitsRequest, stream pb.CommitService_FindCommitsServer) error { @@ -18,14 +18,14 @@ func (s *server) FindCommits(req *pb.FindCommitsRequest, stream pb.CommitService var err error req.Revision, err = defaultBranchName(ctx, req.Repository) if err != nil { - return grpc.Errorf(codes.Internal, "defaultBranchName: %v", err) + return status.Errorf(codes.Internal, "defaultBranchName: %v", err) } } // Clients might send empty paths. That is an error for _, path := range req.Paths { if len(path) == 0 { - return grpc.Errorf(codes.InvalidArgument, "path is empty string") + return status.Errorf(codes.InvalidArgument, "path is empty string") } } diff --git a/internal/service/commit/isancestor.go b/internal/service/commit/isancestor.go index 4292b5c08..120508858 100644 --- a/internal/service/commit/isancestor.go +++ b/internal/service/commit/isancestor.go @@ -2,7 +2,6 @@ package commit import ( "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus" - "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -16,10 +15,10 @@ import ( func (s *server) CommitIsAncestor(ctx context.Context, in *pb.CommitIsAncestorRequest) (*pb.CommitIsAncestorResponse, error) { if in.AncestorId == "" { - return nil, grpc.Errorf(codes.InvalidArgument, "Bad Request (empty ancestor sha)") + return nil, status.Errorf(codes.InvalidArgument, "Bad Request (empty ancestor sha)") } if in.ChildId == "" { - return nil, grpc.Errorf(codes.InvalidArgument, "Bad Request (empty child sha)") + return nil, status.Errorf(codes.InvalidArgument, "Bad Request (empty child sha)") } ret, err := commitIsAncestorName(ctx, in.Repository, in.AncestorId, in.ChildId) @@ -38,7 +37,7 @@ func commitIsAncestorName(ctx context.Context, repo *pb.Repository, ancestorID, if _, ok := status.FromError(err); ok { return false, err } - return false, grpc.Errorf(codes.Internal, err.Error()) + return false, status.Errorf(codes.Internal, err.Error()) } return cmd.Wait() == nil, nil diff --git a/internal/service/commit/isancestor_test.go b/internal/service/commit/isancestor_test.go index 3c7999aa1..1bd7c6c81 100644 --- a/internal/service/commit/isancestor_test.go +++ b/internal/service/commit/isancestor_test.go @@ -7,11 +7,11 @@ import ( "path" "testing" + "gitlab.com/gitlab-org/gitaly/internal/helper" "gitlab.com/gitlab-org/gitaly/internal/testhelper" "github.com/stretchr/testify/require" "golang.org/x/net/context" - "google.golang.org/grpc" "google.golang.org/grpc/codes" pb "gitlab.com/gitlab-org/gitaly-proto/go" @@ -76,7 +76,7 @@ func TestCommitIsAncestorFailure(t *testing.T) { defer cancel() if _, err := client.CommitIsAncestor(ctx, v.Request); err == nil { t.Error("Expected to throw an error") - } else if grpc.Code(err) != v.ErrorCode { + } else if helper.GrpcCode(err) != v.ErrorCode { t.Errorf(v.ErrMsg, err) } }) diff --git a/internal/service/commit/languages.go b/internal/service/commit/languages.go index adcbcb3fd..5d8227a6a 100644 --- a/internal/service/commit/languages.go +++ b/internal/service/commit/languages.go @@ -5,8 +5,8 @@ import ( "sort" "strings" - "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "gitlab.com/gitlab-org/gitaly/internal/git" "gitlab.com/gitlab-org/gitaly/internal/helper" @@ -55,7 +55,7 @@ func (*server) CommitLanguages(ctx context.Context, req *pb.CommitLanguagesReque } if total == 0 { - return nil, grpc.Errorf(codes.Internal, "linguist stats added up to zero: %v", stats) + return nil, status.Errorf(codes.Internal, "linguist stats added up to zero: %v", stats) } for lang, count := range stats { diff --git a/internal/service/commit/last_commit_for_path.go b/internal/service/commit/last_commit_for_path.go index d9d465c21..df4c4b133 100644 --- a/internal/service/commit/last_commit_for_path.go +++ b/internal/service/commit/last_commit_for_path.go @@ -8,13 +8,13 @@ import ( pb "gitlab.com/gitlab-org/gitaly-proto/go" "golang.org/x/net/context" - "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) func (s *server) LastCommitForPath(ctx context.Context, in *pb.LastCommitForPathRequest) (*pb.LastCommitForPathResponse, error) { if err := validateLastCommitForPathRequest(in); err != nil { - return nil, grpc.Errorf(codes.InvalidArgument, "LastCommitForPath: %v", err) + return nil, status.Errorf(codes.InvalidArgument, "LastCommitForPath: %v", err) } path := string(in.GetPath()) diff --git a/internal/service/commit/list_files.go b/internal/service/commit/list_files.go index 42c5ac8ce..c764e71dd 100644 --- a/internal/service/commit/list_files.go +++ b/internal/service/commit/list_files.go @@ -5,7 +5,6 @@ import ( "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus" log "github.com/sirupsen/logrus" - "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -34,7 +33,7 @@ func (s *server) ListFiles(in *pb.ListFilesRequest, stream pb.CommitService_List if _, ok := status.FromError(err); ok { return err } - return grpc.Errorf(codes.NotFound, "Revision not found %q", in.GetRevision()) + return status.Errorf(codes.NotFound, "Revision not found %q", in.GetRevision()) } } if !git.IsValidRef(stream.Context(), repo, string(revision)) { @@ -46,7 +45,7 @@ func (s *server) ListFiles(in *pb.ListFilesRequest, stream pb.CommitService_List if _, ok := status.FromError(err); ok { return err } - return grpc.Errorf(codes.Internal, err.Error()) + return status.Errorf(codes.Internal, err.Error()) } return lines.Send(cmd, listFilesWriter(stream), []byte{'\x00'}) @@ -58,12 +57,12 @@ func listFilesWriter(stream pb.CommitService_ListFilesServer) lines.Sender { for _, obj := range objs { data := bytes.SplitN(obj, []byte{'\t'}, 2) if len(data) != 2 { - return grpc.Errorf(codes.Internal, "ListFiles: failed parsing line") + return status.Errorf(codes.Internal, "ListFiles: failed parsing line") } meta := bytes.SplitN(data[0], []byte{' '}, 3) if len(meta) != 3 { - return grpc.Errorf(codes.Internal, "ListFiles: failed parsing meta") + return status.Errorf(codes.Internal, "ListFiles: failed parsing meta") } if bytes.Equal(meta[1], []byte("blob")) { diff --git a/internal/service/commit/raw_blame.go b/internal/service/commit/raw_blame.go index 948953396..59b8a083f 100644 --- a/internal/service/commit/raw_blame.go +++ b/internal/service/commit/raw_blame.go @@ -10,14 +10,13 @@ import ( pb "gitlab.com/gitlab-org/gitaly-proto/go" "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus" - "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) func (s *server) RawBlame(in *pb.RawBlameRequest, stream pb.CommitService_RawBlameServer) error { if err := validateRawBlameRequest(in); err != nil { - return grpc.Errorf(codes.InvalidArgument, "RawBlame: %v", err) + return status.Errorf(codes.InvalidArgument, "RawBlame: %v", err) } ctx := stream.Context() @@ -29,7 +28,7 @@ func (s *server) RawBlame(in *pb.RawBlameRequest, stream pb.CommitService_RawBla if _, ok := status.FromError(err); ok { return err } - return grpc.Errorf(codes.Internal, "RawBlame: cmd: %v", err) + return status.Errorf(codes.Internal, "RawBlame: cmd: %v", err) } sw := streamio.NewWriter(func(p []byte) error { @@ -38,7 +37,7 @@ func (s *server) RawBlame(in *pb.RawBlameRequest, stream pb.CommitService_RawBla _, err = io.Copy(sw, cmd) if err != nil { - return grpc.Errorf(codes.Unavailable, "RawBlame: send: %v", err) + return status.Errorf(codes.Unavailable, "RawBlame: send: %v", err) } if err := cmd.Wait(); err != nil { diff --git a/internal/service/commit/tree_entries.go b/internal/service/commit/tree_entries.go index 0b834261e..660732499 100644 --- a/internal/service/commit/tree_entries.go +++ b/internal/service/commit/tree_entries.go @@ -10,8 +10,8 @@ import ( "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus" pb "gitlab.com/gitlab-org/gitaly-proto/go" "gitlab.com/gitlab-org/gitaly/internal/git/catfile" - "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) var maxTreeEntries = 1000 @@ -89,7 +89,7 @@ func (s *server) GetTreeEntries(in *pb.GetTreeEntriesRequest, stream pb.CommitSe }).Debug("GetTreeEntries") if err := validateGetTreeEntriesRequest(in); err != nil { - return grpc.Errorf(codes.InvalidArgument, "TreeEntry: %v", err) + return status.Errorf(codes.InvalidArgument, "TreeEntry: %v", err) } revision := string(in.GetRevision()) diff --git a/internal/service/commit/tree_entries_helper.go b/internal/service/commit/tree_entries_helper.go index f6509aab4..4036f81fe 100644 --- a/internal/service/commit/tree_entries_helper.go +++ b/internal/service/commit/tree_entries_helper.go @@ -5,8 +5,8 @@ import ( "fmt" "io" - "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" pb "gitlab.com/gitlab-org/gitaly-proto/go" "gitlab.com/gitlab-org/gitaly/internal/git/catfile" @@ -14,12 +14,12 @@ import ( func getTreeInfo(revision, path string, stdin io.Writer, stdout *bufio.Reader) (*catfile.ObjectInfo, error) { if _, err := fmt.Fprintf(stdin, "%s^{tree}:%s\n", revision, path); err != nil { - return nil, grpc.Errorf(codes.Internal, "TreeEntry: stdin write: %v", err) + return nil, status.Errorf(codes.Internal, "TreeEntry: stdin write: %v", err) } treeInfo, err := catfile.ParseObjectInfo(stdout) if err != nil { - return nil, grpc.Errorf(codes.Internal, "TreeEntry: %v", err) + return nil, status.Errorf(codes.Internal, "TreeEntry: %v", err) } return treeInfo, nil } diff --git a/internal/service/commit/tree_entry.go b/internal/service/commit/tree_entry.go index b73b3cba8..2a22a3d1b 100644 --- a/internal/service/commit/tree_entry.go +++ b/internal/service/commit/tree_entry.go @@ -13,8 +13,8 @@ import ( pb "gitlab.com/gitlab-org/gitaly-proto/go" - "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) func treeEntryHandler(stream pb.Commit_TreeEntryServer, revision, path, baseName string, limit int64) catfile.Handler { @@ -44,7 +44,7 @@ func treeEntryHandler(stream pb.Commit_TreeEntryServer, revision, path, baseName Oid: treeEntry.Oid, } if err := stream.Send(response); err != nil { - return grpc.Errorf(codes.Unavailable, "TreeEntry: send: %v", err) + return status.Errorf(codes.Unavailable, "TreeEntry: send: %v", err) } return nil @@ -54,11 +54,11 @@ func treeEntryHandler(stream pb.Commit_TreeEntryServer, revision, path, baseName objectInfo, err := catfile.ParseObjectInfo(stdout) if err != nil { - return grpc.Errorf(codes.Internal, "TreeEntry: %v", err) + return status.Errorf(codes.Internal, "TreeEntry: %v", err) } if strings.ToLower(treeEntry.Type.String()) != objectInfo.Type { - return grpc.Errorf( + return status.Errorf( codes.Internal, "TreeEntry: mismatched object type: tree-oid=%s object-oid=%s entry-type=%s object-type=%s", treeEntry.Oid, objectInfo.Oid, treeEntry.Type.String(), objectInfo.Type, @@ -94,7 +94,7 @@ func treeEntryHandler(stream pb.Commit_TreeEntryServer, revision, path, baseName response.Data = p if err := stream.Send(response); err != nil { - return grpc.Errorf(codes.Unavailable, "TreeEntry: send: %v", err) + return status.Errorf(codes.Unavailable, "TreeEntry: send: %v", err) } // Use a new response so we don't send other fields (Size, ...) over and over @@ -105,7 +105,7 @@ func treeEntryHandler(stream pb.Commit_TreeEntryServer, revision, path, baseName n, err := io.Copy(sw, io.LimitReader(stdout, dataLength)) if n < dataLength && err == nil { - return grpc.Errorf(codes.Internal, "TreeEntry: Incomplete copy") + return status.Errorf(codes.Internal, "TreeEntry: Incomplete copy") } return err @@ -114,7 +114,7 @@ func treeEntryHandler(stream pb.Commit_TreeEntryServer, revision, path, baseName func (s *server) TreeEntry(in *pb.TreeEntryRequest, stream pb.CommitService_TreeEntryServer) error { if err := validateRequest(in); err != nil { - return grpc.Errorf(codes.InvalidArgument, "TreeEntry: %v", err) + return status.Errorf(codes.InvalidArgument, "TreeEntry: %v", err) } requestPath := string(in.GetPath()) diff --git a/internal/service/conflicts/list_conflict_files.go b/internal/service/conflicts/list_conflict_files.go index 29ac0d7fc..c9b79a7d5 100644 --- a/internal/service/conflicts/list_conflict_files.go +++ b/internal/service/conflicts/list_conflict_files.go @@ -5,15 +5,15 @@ import ( pb "gitlab.com/gitlab-org/gitaly-proto/go" "gitlab.com/gitlab-org/gitaly/internal/rubyserver" - "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) func (s *server) ListConflictFiles(in *pb.ListConflictFilesRequest, stream pb.ConflictsService_ListConflictFilesServer) error { ctx := stream.Context() if err := validateListConflictFilesRequest(in); err != nil { - return grpc.Errorf(codes.InvalidArgument, "ListConflictFiles: %v", err) + return status.Errorf(codes.InvalidArgument, "ListConflictFiles: %v", err) } client, err := s.ConflictsServiceClient(ctx) diff --git a/internal/service/conflicts/resolve_conflicts.go b/internal/service/conflicts/resolve_conflicts.go index 9a121aa37..d54aedf90 100644 --- a/internal/service/conflicts/resolve_conflicts.go +++ b/internal/service/conflicts/resolve_conflicts.go @@ -5,8 +5,8 @@ import ( pb "gitlab.com/gitlab-org/gitaly-proto/go" "gitlab.com/gitlab-org/gitaly/internal/rubyserver" - "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) func (s *server) ResolveConflicts(stream pb.ConflictsService_ResolveConflictsServer) error { @@ -17,11 +17,11 @@ func (s *server) ResolveConflicts(stream pb.ConflictsService_ResolveConflictsSer header := firstRequest.GetHeader() if header == nil { - return grpc.Errorf(codes.InvalidArgument, "ResolveConflicts: empty ResolveConflictsRequestHeader") + return status.Errorf(codes.InvalidArgument, "ResolveConflicts: empty ResolveConflictsRequestHeader") } if err = validateResolveConflictsHeader(header); err != nil { - return grpc.Errorf(codes.InvalidArgument, "ResolveConflicts: %v", err) + return status.Errorf(codes.InvalidArgument, "ResolveConflicts: %v", err) } ctx := stream.Context() diff --git a/internal/service/diff/commit.go b/internal/service/diff/commit.go index 59b478809..7b5ff8d71 100644 --- a/internal/service/diff/commit.go +++ b/internal/service/diff/commit.go @@ -9,7 +9,6 @@ import ( pb "gitlab.com/gitlab-org/gitaly-proto/go" "gitlab.com/gitlab-org/gitaly/internal/diff" "gitlab.com/gitlab-org/gitaly/internal/git" - "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -28,7 +27,7 @@ func (s *server) CommitDiff(in *pb.CommitDiffRequest, stream pb.DiffService_Comm }).Debug("CommitDiff") if err := validateRequest(in); err != nil { - return grpc.Errorf(codes.InvalidArgument, "CommitDiff: %v", err) + return status.Errorf(codes.InvalidArgument, "CommitDiff: %v", err) } leftSha := in.LeftCommitId @@ -85,7 +84,7 @@ func (s *server) CommitDiff(in *pb.CommitDiffRequest, stream pb.DiffService_Comm response.EndOfPatch = true if err := stream.Send(response); err != nil { - return grpc.Errorf(codes.Unavailable, "CommitDiff: send: %v", err) + return status.Errorf(codes.Unavailable, "CommitDiff: send: %v", err) } } else { patch := diff.Patch @@ -101,7 +100,7 @@ func (s *server) CommitDiff(in *pb.CommitDiffRequest, stream pb.DiffService_Comm } if err := stream.Send(response); err != nil { - return grpc.Errorf(codes.Unavailable, "CommitDiff: send: %v", err) + return status.Errorf(codes.Unavailable, "CommitDiff: send: %v", err) } // Use a new response so we don't send other fields (FromPath, ...) over and over @@ -121,7 +120,7 @@ func (s *server) CommitDelta(in *pb.CommitDeltaRequest, stream pb.DiffService_Co }).Debug("CommitDelta") if err := validateRequest(in); err != nil { - return grpc.Errorf(codes.InvalidArgument, "CommitDelta: %v", err) + return status.Errorf(codes.InvalidArgument, "CommitDelta: %v", err) } leftSha := in.LeftCommitId @@ -153,7 +152,7 @@ func (s *server) CommitDelta(in *pb.CommitDeltaRequest, stream pb.DiffService_Co } if err := stream.Send(&pb.CommitDeltaResponse{Deltas: batch}); err != nil { - return grpc.Errorf(codes.Unavailable, "CommitDelta: send: %v", err) + return status.Errorf(codes.Unavailable, "CommitDelta: send: %v", err) } return nil @@ -208,7 +207,7 @@ func eachDiff(ctx context.Context, rpc string, repo *pb.Repository, cmdArgs []st if _, ok := status.FromError(err); ok { return err } - return grpc.Errorf(codes.Internal, "%s: cmd: %v", rpc, err) + return status.Errorf(codes.Internal, "%s: cmd: %v", rpc, err) } diffParser := diff.NewDiffParser(cmd, limits) @@ -220,11 +219,11 @@ func eachDiff(ctx context.Context, rpc string, repo *pb.Repository, cmdArgs []st } if err := diffParser.Err(); err != nil { - return grpc.Errorf(codes.Internal, "%s: parse failure: %v", rpc, err) + return status.Errorf(codes.Internal, "%s: parse failure: %v", rpc, err) } if err := cmd.Wait(); err != nil { - return grpc.Errorf(codes.Unavailable, "%s: %v", rpc, err) + return status.Errorf(codes.Unavailable, "%s: %v", rpc, err) } return nil diff --git a/internal/service/diff/raw.go b/internal/service/diff/raw.go index ce93cd931..3fade9527 100644 --- a/internal/service/diff/raw.go +++ b/internal/service/diff/raw.go @@ -9,14 +9,13 @@ import ( pb "gitlab.com/gitlab-org/gitaly-proto/go" "golang.org/x/net/context" - "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) func (s *server) RawDiff(in *pb.RawDiffRequest, stream pb.DiffService_RawDiffServer) error { if err := validateRequest(in); err != nil { - return grpc.Errorf(codes.InvalidArgument, "RawDiff: %v", err) + return status.Errorf(codes.InvalidArgument, "RawDiff: %v", err) } cmdArgs := []string{"diff", "--full-index", in.LeftCommitId, in.RightCommitId} @@ -30,7 +29,7 @@ func (s *server) RawDiff(in *pb.RawDiffRequest, stream pb.DiffService_RawDiffSer func (s *server) RawPatch(in *pb.RawPatchRequest, stream pb.DiffService_RawPatchServer) error { if err := validateRequest(in); err != nil { - return grpc.Errorf(codes.InvalidArgument, "RawPatch: %v", err) + return status.Errorf(codes.InvalidArgument, "RawPatch: %v", err) } cmdArgs := []string{"format-patch", "--stdout", in.LeftCommitId + ".." + in.RightCommitId} @@ -48,11 +47,11 @@ func sendRawOutput(ctx context.Context, rpc string, repo *pb.Repository, sender if _, ok := status.FromError(err); ok { return err } - return grpc.Errorf(codes.Internal, "%s: cmd: %v", rpc, err) + return status.Errorf(codes.Internal, "%s: cmd: %v", rpc, err) } if _, err := io.Copy(sender, cmd); err != nil { - return grpc.Errorf(codes.Unavailable, "%s: send: %v", rpc, err) + return status.Errorf(codes.Unavailable, "%s: send: %v", rpc, err) } return cmd.Wait() diff --git a/internal/service/namespace/namespace.go b/internal/service/namespace/namespace.go index 4f5aa3658..290397d01 100644 --- a/internal/service/namespace/namespace.go +++ b/internal/service/namespace/namespace.go @@ -7,11 +7,11 @@ import ( pb "gitlab.com/gitlab-org/gitaly-proto/go" "gitlab.com/gitlab-org/gitaly/internal/helper" "golang.org/x/net/context" - "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) -var noNameError = grpc.Errorf(codes.InvalidArgument, "Name: cannot be empty") +var noNameError = status.Errorf(codes.InvalidArgument, "Name: cannot be empty") func (s *server) NamespaceExists(ctx context.Context, in *pb.NamespaceExistsRequest) (*pb.NamespaceExistsResponse, error) { storagePath, err := helper.GetStorageByName(in.GetStorageName()) @@ -28,7 +28,7 @@ func (s *server) NamespaceExists(ctx context.Context, in *pb.NamespaceExistsRequ if fi, err := os.Stat(namespacePath(storagePath, in.GetName())); os.IsNotExist(err) { return &pb.NamespaceExistsResponse{Exists: false}, nil } else if err != nil { - return nil, grpc.Errorf(codes.Internal, "could not stat the directory: %v", err) + return nil, status.Errorf(codes.Internal, "could not stat the directory: %v", err) } else { return &pb.NamespaceExistsResponse{Exists: fi.IsDir()}, nil } @@ -50,7 +50,7 @@ func (s *server) AddNamespace(ctx context.Context, in *pb.AddNamespaceRequest) ( } if err = os.MkdirAll(namespacePath(storagePath, in.GetName()), 0770); err != nil { - return nil, grpc.Errorf(codes.Internal, "create directory: %v", err) + return nil, status.Errorf(codes.Internal, "create directory: %v", err) } return &pb.AddNamespaceResponse{}, nil @@ -63,7 +63,7 @@ func (s *server) RenameNamespace(ctx context.Context, in *pb.RenameNamespaceRequ } if in.GetFrom() == "" || in.GetTo() == "" { - return nil, grpc.Errorf(codes.InvalidArgument, "from and to cannot be empty") + return nil, status.Errorf(codes.InvalidArgument, "from and to cannot be empty") } // No need to check if the from path exists, if it doesn't, we'd later get an @@ -72,14 +72,14 @@ func (s *server) RenameNamespace(ctx context.Context, in *pb.RenameNamespaceRequ if exists, err := s.NamespaceExists(ctx, toExistsCheck); err != nil { return nil, err } else if exists.Exists { - return nil, grpc.Errorf(codes.InvalidArgument, "to directory %s already exists", in.GetTo()) + return nil, status.Errorf(codes.InvalidArgument, "to directory %s already exists", in.GetTo()) } err = os.Rename(namespacePath(storagePath, in.GetFrom()), namespacePath(storagePath, in.GetTo())) if _, ok := err.(*os.LinkError); ok { - return nil, grpc.Errorf(codes.InvalidArgument, "from directory %s not found", in.GetFrom()) + return nil, status.Errorf(codes.InvalidArgument, "from directory %s not found", in.GetFrom()) } else if err != nil { - return nil, grpc.Errorf(codes.Internal, "rename: %v", err) + return nil, status.Errorf(codes.Internal, "rename: %v", err) } return &pb.RenameNamespaceResponse{}, nil @@ -99,7 +99,7 @@ func (s *server) RemoveNamespace(ctx context.Context, in *pb.RemoveNamespaceRequ // os.RemoveAll is idempotent by itself // No need to check if the directory exists, or not if err = os.RemoveAll(namespacePath(storagePath, in.GetName())); err != nil { - return nil, grpc.Errorf(codes.Internal, "removal: %v", err) + return nil, status.Errorf(codes.Internal, "removal: %v", err) } return &pb.RemoveNamespaceResponse{}, nil } diff --git a/internal/service/namespace/namespace_test.go b/internal/service/namespace/namespace_test.go index 435757148..dd71db26c 100644 --- a/internal/service/namespace/namespace_test.go +++ b/internal/service/namespace/namespace_test.go @@ -8,9 +8,9 @@ import ( "github.com/stretchr/testify/require" pb "gitlab.com/gitlab-org/gitaly-proto/go" "gitlab.com/gitlab-org/gitaly/internal/config" + "gitlab.com/gitlab-org/gitaly/internal/helper" "gitlab.com/gitlab-org/gitaly/internal/testhelper" "golang.org/x/net/context" - "google.golang.org/grpc" "google.golang.org/grpc/codes" ) @@ -96,7 +96,7 @@ func TestNamespaceExists(t *testing.T) { defer cancel() response, err := client.NamespaceExists(ctx, tc.request) - require.Equal(t, tc.errorCode, grpc.Code(err)) + require.Equal(t, tc.errorCode, helper.GrpcCode(err)) if tc.errorCode == codes.OK { require.Equal(t, tc.exists, response.Exists) @@ -158,7 +158,7 @@ func TestAddNamespace(t *testing.T) { _, err := client.AddNamespace(ctx, tc.request) - require.Equal(t, tc.errorCode, grpc.Code(err)) + require.Equal(t, tc.errorCode, helper.GrpcCode(err)) // Clean up if tc.errorCode == codes.OK { @@ -218,7 +218,7 @@ func TestRemoveNamespace(t *testing.T) { require.NoError(t, err, "setup failed") _, err = client.RemoveNamespace(ctx, tc.request) - require.Equal(t, tc.errorCode, grpc.Code(err)) + require.Equal(t, tc.errorCode, helper.GrpcCode(err)) }) } } @@ -286,7 +286,7 @@ func TestRenameNamespace(t *testing.T) { t.Run(tc.desc, func(t *testing.T) { _, err := client.RenameNamespace(ctx, tc.request) - require.Equal(t, tc.errorCode, grpc.Code(err)) + require.Equal(t, tc.errorCode, helper.GrpcCode(err)) if tc.errorCode == codes.OK { client.RemoveNamespace(ctx, &pb.RemoveNamespaceRequest{ diff --git a/internal/service/operations/branches.go b/internal/service/operations/branches.go index 3f2c7841f..09548dbf5 100644 --- a/internal/service/operations/branches.go +++ b/internal/service/operations/branches.go @@ -2,8 +2,8 @@ package operations import ( "gitlab.com/gitlab-org/gitaly/internal/rubyserver" - "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" pb "gitlab.com/gitlab-org/gitaly-proto/go" @@ -26,11 +26,11 @@ func (s *server) UserCreateBranch(ctx context.Context, req *pb.UserCreateBranchR func (s *server) UserDeleteBranch(ctx context.Context, req *pb.UserDeleteBranchRequest) (*pb.UserDeleteBranchResponse, error) { if len(req.BranchName) == 0 { - return nil, grpc.Errorf(codes.InvalidArgument, "Bad Request (empty branch name)") + return nil, status.Errorf(codes.InvalidArgument, "Bad Request (empty branch name)") } if req.User == nil { - return nil, grpc.Errorf(codes.InvalidArgument, "Bad Request (empty user)") + return nil, status.Errorf(codes.InvalidArgument, "Bad Request (empty user)") } client, err := s.OperationServiceClient(ctx) diff --git a/internal/service/operations/cherry_pick.go b/internal/service/operations/cherry_pick.go index e73b63467..9459da612 100644 --- a/internal/service/operations/cherry_pick.go +++ b/internal/service/operations/cherry_pick.go @@ -5,13 +5,13 @@ import ( "gitlab.com/gitlab-org/gitaly/internal/rubyserver" "golang.org/x/net/context" - "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) func (s *server) UserCherryPick(ctx context.Context, req *pb.UserCherryPickRequest) (*pb.UserCherryPickResponse, error) { if err := validateCherryPickOrRevertRequest(req); err != nil { - return nil, grpc.Errorf(codes.InvalidArgument, "UserCherryPick: %v", err) + return nil, status.Errorf(codes.InvalidArgument, "UserCherryPick: %v", err) } client, err := s.OperationServiceClient(ctx) diff --git a/internal/service/operations/merge.go b/internal/service/operations/merge.go index 51be90296..eb5dd5d51 100644 --- a/internal/service/operations/merge.go +++ b/internal/service/operations/merge.go @@ -7,8 +7,8 @@ import ( "gitlab.com/gitlab-org/gitaly/internal/rubyserver" "golang.org/x/net/context" - "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) func (s *server) UserMergeBranch(bidi pb.OperationService_UserMergeBranchServer) error { @@ -76,7 +76,7 @@ func validateFFRequest(in *pb.UserFFBranchRequest) error { func (s *server) UserFFBranch(ctx context.Context, in *pb.UserFFBranchRequest) (*pb.UserFFBranchResponse, error) { if err := validateFFRequest(in); err != nil { - return nil, grpc.Errorf(codes.InvalidArgument, "UserFFBranch: %v", err) + return nil, status.Errorf(codes.InvalidArgument, "UserFFBranch: %v", err) } client, err := s.OperationServiceClient(ctx) diff --git a/internal/service/operations/rebase.go b/internal/service/operations/rebase.go index 313e60655..58b04666a 100644 --- a/internal/service/operations/rebase.go +++ b/internal/service/operations/rebase.go @@ -8,13 +8,13 @@ import ( pb "gitlab.com/gitlab-org/gitaly-proto/go" "golang.org/x/net/context" - "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) func (s *server) UserRebase(ctx context.Context, req *pb.UserRebaseRequest) (*pb.UserRebaseResponse, error) { if err := validateUserRebaseRequest(req); err != nil { - return nil, grpc.Errorf(codes.InvalidArgument, "UserRebase: %v", err) + return nil, status.Errorf(codes.InvalidArgument, "UserRebase: %v", err) } client, err := s.OperationServiceClient(ctx) diff --git a/internal/service/operations/revert.go b/internal/service/operations/revert.go index af1435922..c7dff23d2 100644 --- a/internal/service/operations/revert.go +++ b/internal/service/operations/revert.go @@ -5,13 +5,13 @@ import ( "gitlab.com/gitlab-org/gitaly/internal/rubyserver" "golang.org/x/net/context" - "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) func (s *server) UserRevert(ctx context.Context, req *pb.UserRevertRequest) (*pb.UserRevertResponse, error) { if err := validateCherryPickOrRevertRequest(req); err != nil { - return nil, grpc.Errorf(codes.InvalidArgument, "UserRevert: %v", err) + return nil, status.Errorf(codes.InvalidArgument, "UserRevert: %v", err) } client, err := s.OperationServiceClient(ctx) diff --git a/internal/service/ref/delete_refs.go b/internal/service/ref/delete_refs.go index 85ee093d2..6482e1220 100644 --- a/internal/service/ref/delete_refs.go +++ b/internal/service/ref/delete_refs.go @@ -4,18 +4,18 @@ import ( pb "gitlab.com/gitlab-org/gitaly-proto/go" "gitlab.com/gitlab-org/gitaly/internal/rubyserver" "golang.org/x/net/context" - "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) func (s *server) DeleteRefs(ctx context.Context, in *pb.DeleteRefsRequest) (*pb.DeleteRefsResponse, error) { if len(in.ExceptWithPrefix) == 0 { // You can't delete all refs - return nil, grpc.Errorf(codes.InvalidArgument, "DeleteRefs: empty ExceptWithPrefix") + return nil, status.Errorf(codes.InvalidArgument, "DeleteRefs: empty ExceptWithPrefix") } for _, prefix := range in.ExceptWithPrefix { if len(prefix) == 0 { - return nil, grpc.Errorf(codes.InvalidArgument, "DeleteRefs: empty prefix for exclussion") + return nil, status.Errorf(codes.InvalidArgument, "DeleteRefs: empty prefix for exclussion") } } diff --git a/internal/service/ref/refexists.go b/internal/service/ref/refexists.go index 817d627ab..6b3e82a6a 100644 --- a/internal/service/ref/refexists.go +++ b/internal/service/ref/refexists.go @@ -5,7 +5,6 @@ import ( "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus" log "github.com/sirupsen/logrus" - "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -32,7 +31,7 @@ func refExists(ctx context.Context, repo *pb.Repository, ref string) (bool, erro }).Debug("refExists") if !isValidRefName(ref) { - return false, grpc.Errorf(codes.InvalidArgument, "invalid refname") + return false, status.Errorf(codes.InvalidArgument, "invalid refname") } cmd, err := git.Command(ctx, repo, "show-ref", "--verify", "--quiet", ref) @@ -40,7 +39,7 @@ func refExists(ctx context.Context, repo *pb.Repository, ref string) (bool, erro if _, ok := status.FromError(err); ok { return false, err } - return false, grpc.Errorf(codes.Internal, err.Error()) + return false, status.Errorf(codes.Internal, err.Error()) } err = cmd.Wait() @@ -55,7 +54,7 @@ func refExists(ctx context.Context, repo *pb.Repository, ref string) (bool, erro } // This will normally occur when exit code > 1 - return false, grpc.Errorf(codes.Internal, err.Error()) + return false, status.Errorf(codes.Internal, err.Error()) } func isValidRefName(refName string) bool { diff --git a/internal/service/ref/refexists_test.go b/internal/service/ref/refexists_test.go index bfa9328a9..0eb0379d1 100644 --- a/internal/service/ref/refexists_test.go +++ b/internal/service/ref/refexists_test.go @@ -3,10 +3,10 @@ package ref import ( "testing" - "google.golang.org/grpc" "google.golang.org/grpc/codes" pb "gitlab.com/gitlab-org/gitaly-proto/go" + "gitlab.com/gitlab-org/gitaly/internal/helper" "gitlab.com/gitlab-org/gitaly/internal/testhelper" ) @@ -52,7 +52,7 @@ func TestRefExists(t *testing.T) { got, err := client.RefExists(ctx, req) - if grpc.Code(err) != tt.wantErr { + if helper.GrpcCode(err) != tt.wantErr { t.Errorf("server.RefExists() error = %v, wantErr %v", err, tt.wantErr) return } diff --git a/internal/service/ref/refname.go b/internal/service/ref/refname.go index 96145a56e..4beef0da7 100644 --- a/internal/service/ref/refname.go +++ b/internal/service/ref/refname.go @@ -4,7 +4,6 @@ import ( "bufio" "strings" - "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -20,7 +19,7 @@ import ( // returned or that the same one is returned on each call. func (s *server) FindRefName(ctx context.Context, in *pb.FindRefNameRequest) (*pb.FindRefNameResponse, error) { if in.CommitId == "" { - return nil, grpc.Errorf(codes.InvalidArgument, "Bad Request (empty commit sha)") + return nil, status.Errorf(codes.InvalidArgument, "Bad Request (empty commit sha)") } ref, err := findRefName(ctx, in.Repository, in.CommitId, string(in.Prefix)) @@ -28,7 +27,7 @@ func (s *server) FindRefName(ctx context.Context, in *pb.FindRefNameRequest) (*p if _, ok := status.FromError(err); ok { return nil, err } - return nil, grpc.Errorf(codes.Internal, err.Error()) + return nil, status.Errorf(codes.Internal, err.Error()) } return &pb.FindRefNameResponse{Name: []byte(ref)}, nil diff --git a/internal/service/ref/refname_test.go b/internal/service/ref/refname_test.go index a52e6f93d..232f67ae1 100644 --- a/internal/service/ref/refname_test.go +++ b/internal/service/ref/refname_test.go @@ -3,12 +3,12 @@ package ref import ( "testing" - "google.golang.org/grpc" "google.golang.org/grpc/codes" "golang.org/x/net/context" pb "gitlab.com/gitlab-org/gitaly-proto/go" + "gitlab.com/gitlab-org/gitaly/internal/helper" "gitlab.com/gitlab-org/gitaly/internal/testhelper" ) @@ -64,7 +64,7 @@ func TestFindRefNameEmptyCommit(t *testing.T) { if err == nil { t.Fatalf("Expected FindRefName to throw an error") } - if grpc.Code(err) != codes.InvalidArgument { + if helper.GrpcCode(err) != codes.InvalidArgument { t.Errorf("Expected FindRefName to throw InvalidArgument, got %v", err) } @@ -93,7 +93,7 @@ func TestFindRefNameInvalidRepo(t *testing.T) { if err == nil { t.Fatalf("Expected FindRefName to throw an error") } - if grpc.Code(err) != codes.InvalidArgument { + if helper.GrpcCode(err) != codes.InvalidArgument { t.Errorf("Expected FindRefName to throw InvalidArgument, got %v", err) } diff --git a/internal/service/ref/refs.go b/internal/service/ref/refs.go index 075991237..b10b14743 100644 --- a/internal/service/ref/refs.go +++ b/internal/service/ref/refs.go @@ -10,7 +10,6 @@ import ( "gitlab.com/gitlab-org/gitaly/internal/helper" log "github.com/sirupsen/logrus" - "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -202,7 +201,7 @@ func (s *server) FindDefaultBranchName(ctx context.Context, in *pb.FindDefaultBr if _, ok := status.FromError(err); ok { return nil, err } - return nil, grpc.Errorf(codes.Internal, err.Error()) + return nil, status.Errorf(codes.Internal, err.Error()) } return &pb.FindDefaultBranchNameResponse{Name: defaultBranchName}, nil @@ -249,7 +248,7 @@ func (s *server) FindAllBranches(in *pb.FindAllBranchesRequest, stream pb.RefSer if _, ok := status.FromError(err); ok { return err } - return grpc.Errorf(codes.Internal, err.Error()) + return status.Errorf(codes.Internal, err.Error()) } args = append(args, fmt.Sprintf("--merged=%s", string(defaultBranchName))) diff --git a/internal/service/ref/refs_test.go b/internal/service/ref/refs_test.go index be673818b..f86773c9d 100644 --- a/internal/service/ref/refs_test.go +++ b/internal/service/ref/refs_test.go @@ -12,9 +12,9 @@ import ( "github.com/stretchr/testify/require" pb "gitlab.com/gitlab-org/gitaly-proto/go" "gitlab.com/gitlab-org/gitaly/internal/git/log" + "gitlab.com/gitlab-org/gitaly/internal/helper" "gitlab.com/gitlab-org/gitaly/internal/testhelper" "golang.org/x/net/context" - "google.golang.org/grpc" "google.golang.org/grpc/codes" ) @@ -84,7 +84,7 @@ func TestEmptyFindAllBranchNamesRequest(t *testing.T) { _, recvError = c.Recv() } - if grpc.Code(recvError) != codes.InvalidArgument { + if helper.GrpcCode(recvError) != codes.InvalidArgument { t.Fatal(recvError) } } @@ -110,7 +110,7 @@ func TestInvalidRepoFindAllBranchNamesRequest(t *testing.T) { _, recvError = c.Recv() } - if grpc.Code(recvError) != codes.NotFound { + if helper.GrpcCode(recvError) != codes.NotFound { t.Fatal(recvError) } } @@ -173,7 +173,7 @@ func TestEmptyFindAllTagNamesRequest(t *testing.T) { _, recvError = c.Recv() } - if grpc.Code(recvError) != codes.InvalidArgument { + if helper.GrpcCode(recvError) != codes.InvalidArgument { t.Fatal(recvError) } } @@ -199,7 +199,7 @@ func TestInvalidRepoFindAllTagNamesRequest(t *testing.T) { _, recvError = c.Recv() } - if grpc.Code(recvError) != codes.NotFound { + if helper.GrpcCode(recvError) != codes.NotFound { t.Fatal(recvError) } } @@ -350,7 +350,7 @@ func TestEmptyFindDefaultBranchNameRequest(t *testing.T) { defer cancel() _, err := client.FindDefaultBranchName(ctx, rpcRequest) - if grpc.Code(err) != codes.InvalidArgument { + if helper.GrpcCode(err) != codes.InvalidArgument { t.Fatal(err) } } @@ -368,7 +368,7 @@ func TestInvalidRepoFindDefaultBranchNameRequest(t *testing.T) { defer cancel() _, err := client.FindDefaultBranchName(ctx, rpcRequest) - if grpc.Code(err) != codes.NotFound { + if helper.GrpcCode(err) != codes.NotFound { t.Fatal(err) } } @@ -721,7 +721,7 @@ func TestEmptyFindLocalBranchesRequest(t *testing.T) { _, recvError = c.Recv() } - if grpc.Code(recvError) != codes.InvalidArgument { + if helper.GrpcCode(recvError) != codes.InvalidArgument { t.Fatal(recvError) } } diff --git a/internal/service/ref/util.go b/internal/service/ref/util.go index e4b06a4b4..0db913136 100644 --- a/internal/service/ref/util.go +++ b/internal/service/ref/util.go @@ -6,8 +6,8 @@ import ( pb "gitlab.com/gitlab-org/gitaly-proto/go" "gitlab.com/gitlab-org/gitaly/internal/git" "gitlab.com/gitlab-org/gitaly/internal/helper/lines" - "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) var localBranchFormatFields = []string{ @@ -19,7 +19,7 @@ var localBranchFormatFields = []string{ func parseRef(ref []byte) ([][]byte, error) { elements := bytes.Split(ref, []byte("\x00")) if len(elements) != 9 { - return nil, grpc.Errorf(codes.Internal, "error parsing ref %q", ref) + return nil, status.Errorf(codes.Internal, "error parsing ref %q", ref) } return elements, nil } diff --git a/internal/service/remote/fetch_internal_remote.go b/internal/service/remote/fetch_internal_remote.go index ed69697e3..f63b377e1 100644 --- a/internal/service/remote/fetch_internal_remote.go +++ b/internal/service/remote/fetch_internal_remote.go @@ -4,8 +4,8 @@ import ( "fmt" "golang.org/x/net/context" - "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" pb "gitlab.com/gitlab-org/gitaly-proto/go" "gitlab.com/gitlab-org/gitaly/internal/rubyserver" @@ -14,7 +14,7 @@ import ( // FetchInternalRemote fetches another Gitaly repository set as a remote func (s *server) FetchInternalRemote(ctx context.Context, req *pb.FetchInternalRemoteRequest) (*pb.FetchInternalRemoteResponse, error) { if err := validateFetchInternalRemoteRequest(req); err != nil { - return nil, grpc.Errorf(codes.InvalidArgument, "FetchInternalRemote: %v", err) + return nil, status.Errorf(codes.InvalidArgument, "FetchInternalRemote: %v", err) } client, err := s.RemoteServiceClient(ctx) diff --git a/internal/service/remote/remotes.go b/internal/service/remote/remotes.go index 625c7804d..da473d1f9 100644 --- a/internal/service/remote/remotes.go +++ b/internal/service/remote/remotes.go @@ -5,9 +5,9 @@ import ( "strings" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "golang.org/x/net/context" - "google.golang.org/grpc" pb "gitlab.com/gitlab-org/gitaly-proto/go" "gitlab.com/gitlab-org/gitaly/internal/rubyserver" @@ -16,7 +16,7 @@ import ( // AddRemote adds a remote to the repository func (s *server) AddRemote(ctx context.Context, req *pb.AddRemoteRequest) (*pb.AddRemoteResponse, error) { if err := validateAddRemoteRequest(req); err != nil { - return nil, grpc.Errorf(codes.InvalidArgument, "AddRemote: %v", err) + return nil, status.Errorf(codes.InvalidArgument, "AddRemote: %v", err) } client, err := s.RemoteServiceClient(ctx) @@ -46,7 +46,7 @@ func validateAddRemoteRequest(req *pb.AddRemoteRequest) error { // RemoveRemote removes the given remote func (s *server) RemoveRemote(ctx context.Context, req *pb.RemoveRemoteRequest) (*pb.RemoveRemoteResponse, error) { if err := validateRemoveRemoteRequest(req); err != nil { - return nil, grpc.Errorf(codes.InvalidArgument, "AddRemote: %v", err) + return nil, status.Errorf(codes.InvalidArgument, "AddRemote: %v", err) } client, err := s.RemoteServiceClient(ctx) diff --git a/internal/service/repository/apply_gitattributes.go b/internal/service/repository/apply_gitattributes.go index 8233595e4..fee9b25ea 100644 --- a/internal/service/repository/apply_gitattributes.go +++ b/internal/service/repository/apply_gitattributes.go @@ -13,8 +13,8 @@ import ( "gitlab.com/gitlab-org/gitaly/internal/git/catfile" "gitlab.com/gitlab-org/gitaly/internal/helper" "golang.org/x/net/context" - "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) func applyGitattributesHandler(ctx context.Context, repoPath string, revision []byte) catfile.Handler { @@ -30,7 +30,7 @@ func applyGitattributesHandler(ctx context.Context, repoPath string, revision [] return err } if revisionInfo.Oid == "" { - return grpc.Errorf(codes.InvalidArgument, "Revision doesn't exist") + return status.Errorf(codes.InvalidArgument, "Revision doesn't exist") } // Discard revision info if _, err := stdout.Discard(int(revisionInfo.Size) + 1); err != nil { @@ -63,7 +63,7 @@ func applyGitattributesHandler(ctx context.Context, repoPath string, revision [] tempFile, err := ioutil.TempFile(infoPath, "attributes") if err != nil { - return grpc.Errorf(codes.Internal, "ApplyGitAttributes: creating temp file: %v", err) + return status.Errorf(codes.Internal, "ApplyGitAttributes: creating temp file: %v", err) } defer os.Remove(tempFile.Name()) @@ -74,7 +74,7 @@ func applyGitattributesHandler(ctx context.Context, repoPath string, revision [] return err } if n != blobInfo.Size { - return grpc.Errorf(codes.Internal, + return status.Errorf(codes.Internal, "ApplyGitAttributes: copy yielded %v bytes, expected %v", n, blobInfo.Size) } @@ -95,7 +95,7 @@ func (server) ApplyGitattributes(ctx context.Context, in *pb.ApplyGitattributesR } if err := git.ValidateRevision(in.GetRevision()); err != nil { - return nil, grpc.Errorf(codes.InvalidArgument, "ApplyGitAttributes: revision: %v", err) + return nil, status.Errorf(codes.InvalidArgument, "ApplyGitAttributes: revision: %v", err) } handler := applyGitattributesHandler(ctx, repoPath, in.GetRevision()) diff --git a/internal/service/repository/archive.go b/internal/service/repository/archive.go index b7f96c8b6..0540e271b 100644 --- a/internal/service/repository/archive.go +++ b/internal/service/repository/archive.go @@ -5,8 +5,8 @@ import ( "io" "os/exec" - "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "gitlab.com/gitlab-org/gitaly/internal/command" "gitlab.com/gitlab-org/gitaly/streamio" @@ -34,7 +34,7 @@ func handleArchive(ctx context.Context, writer io.Writer, repo *pb.Repository, format pb.GetArchiveRequest_Format, prefix, commitID string) error { compressCmd, formatArg := parseArchiveFormat(format) if len(formatArg) == 0 { - return grpc.Errorf(codes.InvalidArgument, "invalid format") + return status.Errorf(codes.InvalidArgument, "invalid format") } archiveCommand, err := git.Command(ctx, repo, "archive", @@ -61,7 +61,7 @@ func handleArchive(ctx context.Context, writer io.Writer, repo *pb.Repository, func (s *server) GetArchive(in *pb.GetArchiveRequest, stream pb.RepositoryService_GetArchiveServer) error { if err := git.ValidateRevision([]byte(in.CommitId)); err != nil { - return grpc.Errorf(codes.InvalidArgument, "invalid commitId: %v", err) + return status.Errorf(codes.InvalidArgument, "invalid commitId: %v", err) } writer := streamio.NewWriter(func(p []byte) error { diff --git a/internal/service/repository/create_from_url.go b/internal/service/repository/create_from_url.go index 263270b1f..fb2b26328 100644 --- a/internal/service/repository/create_from_url.go +++ b/internal/service/repository/create_from_url.go @@ -11,13 +11,13 @@ import ( pb "gitlab.com/gitlab-org/gitaly-proto/go" "golang.org/x/net/context" - "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) func (s *server) CreateRepositoryFromURL(ctx context.Context, req *pb.CreateRepositoryFromURLRequest) (*pb.CreateRepositoryFromURLResponse, error) { if err := validateCreateRepositoryFromURLRequest(req); err != nil { - return nil, grpc.Errorf(codes.InvalidArgument, "CreateRepositoryFromURL: %v", err) + return nil, status.Errorf(codes.InvalidArgument, "CreateRepositoryFromURL: %v", err) } repository := req.Repository @@ -28,7 +28,7 @@ func (s *server) CreateRepositoryFromURL(ctx context.Context, req *pb.CreateRepo } if _, err := os.Stat(repositoryFullPath); !os.IsNotExist(err) { - return nil, grpc.Errorf(codes.InvalidArgument, "CreateRepositoryFromURL: dest dir exists") + return nil, status.Errorf(codes.InvalidArgument, "CreateRepositoryFromURL: dest dir exists") } args := []string{ @@ -40,20 +40,20 @@ func (s *server) CreateRepositoryFromURL(ctx context.Context, req *pb.CreateRepo } cmd, err := command.New(ctx, exec.Command(command.GitPath(), args...), nil, nil, nil) if err != nil { - return nil, grpc.Errorf(codes.Internal, "CreateRepositoryFromURL: clone cmd start: %v", err) + return nil, status.Errorf(codes.Internal, "CreateRepositoryFromURL: clone cmd start: %v", err) } if err := cmd.Wait(); err != nil { os.RemoveAll(repositoryFullPath) - return nil, grpc.Errorf(codes.Internal, "CreateRepositoryFromURL: clone cmd wait: %v", err) + return nil, status.Errorf(codes.Internal, "CreateRepositoryFromURL: clone cmd wait: %v", err) } // CreateRepository is harmless on existing repositories with the side effect that it creates the hook symlink. if _, err := s.CreateRepository(ctx, &pb.CreateRepositoryRequest{Repository: repository}); err != nil { - return nil, grpc.Errorf(codes.Internal, "CreateRepositoryFromURL: create hooks failed: %v", err) + return nil, status.Errorf(codes.Internal, "CreateRepositoryFromURL: create hooks failed: %v", err) } if err := removeOriginInRepo(ctx, repository); err != nil { - return nil, grpc.Errorf(codes.Internal, "CreateRepositoryFromURL: %v", err) + return nil, status.Errorf(codes.Internal, "CreateRepositoryFromURL: %v", err) } return &pb.CreateRepositoryFromURLResponse{}, nil diff --git a/internal/service/repository/fork.go b/internal/service/repository/fork.go index 8a18088bb..e626ba9be 100644 --- a/internal/service/repository/fork.go +++ b/internal/service/repository/fork.go @@ -14,8 +14,8 @@ import ( "github.com/golang/protobuf/jsonpb" "golang.org/x/net/context" - "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) const gitalyInternalURL = "ssh://gitaly/internal.git" @@ -25,10 +25,10 @@ func (s *server) CreateFork(ctx context.Context, req *pb.CreateForkRequest) (*pb sourceRepository := req.SourceRepository if sourceRepository == nil { - return nil, grpc.Errorf(codes.InvalidArgument, "CreateFork: empty SourceRepository") + return nil, status.Errorf(codes.InvalidArgument, "CreateFork: empty SourceRepository") } if targetRepository == nil { - return nil, grpc.Errorf(codes.InvalidArgument, "CreateFork: empty Repository") + return nil, status.Errorf(codes.InvalidArgument, "CreateFork: empty Repository") } targetRepositoryFullPath, err := helper.GetPath(targetRepository) @@ -37,26 +37,26 @@ func (s *server) CreateFork(ctx context.Context, req *pb.CreateForkRequest) (*pb } if _, err := os.Stat(targetRepositoryFullPath); !os.IsNotExist(err) { - return nil, grpc.Errorf(codes.InvalidArgument, "CreateFork: dest dir exists") + return nil, status.Errorf(codes.InvalidArgument, "CreateFork: dest dir exists") } if err := os.MkdirAll(targetRepositoryFullPath, 0770); err != nil { - return nil, grpc.Errorf(codes.Internal, "CreateFork: create dest dir: %v", err) + return nil, status.Errorf(codes.Internal, "CreateFork: create dest dir: %v", err) } gitalyServersInfo, err := helper.ExtractGitalyServers(ctx) if err != nil { - return nil, grpc.Errorf(codes.Internal, "CreateFork: extracting Gitaly servers: %v", err) + return nil, status.Errorf(codes.Internal, "CreateFork: extracting Gitaly servers: %v", err) } sourceRepositoryStorageInfo, ok := gitalyServersInfo[sourceRepository.StorageName] if !ok { - return nil, grpc.Errorf(codes.InvalidArgument, "CreateFork: no storage info for %s", sourceRepository.StorageName) + return nil, status.Errorf(codes.InvalidArgument, "CreateFork: no storage info for %s", sourceRepository.StorageName) } sourceRepositoryGitalyAddress := sourceRepositoryStorageInfo["address"] if sourceRepositoryGitalyAddress == "" { - return nil, grpc.Errorf(codes.InvalidArgument, "CreateFork: empty gitaly address") + return nil, status.Errorf(codes.InvalidArgument, "CreateFork: empty gitaly address") } sourceRepositoryGitalyToken := sourceRepositoryStorageInfo["token"] @@ -65,7 +65,7 @@ func (s *server) CreateFork(ctx context.Context, req *pb.CreateForkRequest) (*pb pbMarshaler := &jsonpb.Marshaler{} payload, err := pbMarshaler.MarshalToString(cloneReq) if err != nil { - return nil, grpc.Errorf(codes.Internal, "CreateFork: marshalling payload failed: %v", err) + return nil, status.Errorf(codes.Internal, "CreateFork: marshalling payload failed: %v", err) } gitalySSHPath := path.Join(config.Config.BinDir, "gitaly-ssh") @@ -86,19 +86,19 @@ func (s *server) CreateFork(ctx context.Context, req *pb.CreateForkRequest) (*pb } cmd, err := command.New(ctx, exec.Command(command.GitPath(), args...), nil, nil, nil, env...) if err != nil { - return nil, grpc.Errorf(codes.Internal, "CreateFork: clone cmd start: %v", err) + return nil, status.Errorf(codes.Internal, "CreateFork: clone cmd start: %v", err) } if err := cmd.Wait(); err != nil { - return nil, grpc.Errorf(codes.Internal, "CreateFork: clone cmd wait: %v", err) + return nil, status.Errorf(codes.Internal, "CreateFork: clone cmd wait: %v", err) } if err := removeOriginInRepo(ctx, targetRepository); err != nil { - return nil, grpc.Errorf(codes.Internal, "CreateFork: %v", err) + return nil, status.Errorf(codes.Internal, "CreateFork: %v", err) } // CreateRepository is harmless on existing repositories with the side effect that it creates the hook symlink. if _, err := s.CreateRepository(ctx, &pb.CreateRepositoryRequest{Repository: targetRepository}); err != nil { - return nil, grpc.Errorf(codes.Internal, "CreateFork: create hooks failed: %v", err) + return nil, status.Errorf(codes.Internal, "CreateFork: create hooks failed: %v", err) } return &pb.CreateForkResponse{}, nil diff --git a/internal/service/repository/gc.go b/internal/service/repository/gc.go index 4296edae5..9f6724eae 100644 --- a/internal/service/repository/gc.go +++ b/internal/service/repository/gc.go @@ -5,7 +5,6 @@ import ( log "github.com/sirupsen/logrus" "gitlab.com/gitlab-org/gitaly/internal/helper/housekeeping" "golang.org/x/net/context" - "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -32,11 +31,11 @@ func (server) GarbageCollect(ctx context.Context, in *pb.GarbageCollectRequest) if _, ok := status.FromError(err); ok { return nil, err } - return nil, grpc.Errorf(codes.Internal, err.Error()) + return nil, status.Errorf(codes.Internal, err.Error()) } if err := cmd.Wait(); err != nil { - return nil, grpc.Errorf(codes.Internal, err.Error()) + return nil, status.Errorf(codes.Internal, err.Error()) } // Perform housekeeping post GC diff --git a/internal/service/repository/merge_base.go b/internal/service/repository/merge_base.go index 4dd156104..192cc3a25 100644 --- a/internal/service/repository/merge_base.go +++ b/internal/service/repository/merge_base.go @@ -6,13 +6,13 @@ import ( pb "gitlab.com/gitlab-org/gitaly-proto/go" "golang.org/x/net/context" - "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) func (s *server) FindMergeBase(ctx context.Context, req *pb.FindMergeBaseRequest) (*pb.FindMergeBaseResponse, error) { if len(req.Revisions) != 2 { - return nil, grpc.Errorf(codes.InvalidArgument, "FindMergeBase: 2 revisions are required") + return nil, status.Errorf(codes.InvalidArgument, "FindMergeBase: 2 revisions are required") } client, err := s.RepositoryServiceClient(ctx) diff --git a/internal/service/repository/rebase_in_progress.go b/internal/service/repository/rebase_in_progress.go index 6b56c78fc..2c894e198 100644 --- a/internal/service/repository/rebase_in_progress.go +++ b/internal/service/repository/rebase_in_progress.go @@ -8,13 +8,13 @@ import ( "gitlab.com/gitlab-org/gitaly/internal/rubyserver" "golang.org/x/net/context" - "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) func (s *server) IsRebaseInProgress(ctx context.Context, req *pb.IsRebaseInProgressRequest) (*pb.IsRebaseInProgressResponse, error) { if err := validateIsRebaseInProgressRequest(req); err != nil { - return nil, grpc.Errorf(codes.InvalidArgument, "IsRebaseInProgress: %v", err) + return nil, status.Errorf(codes.InvalidArgument, "IsRebaseInProgress: %v", err) } client, err := s.RepositoryServiceClient(ctx) diff --git a/internal/service/repository/repack.go b/internal/service/repository/repack.go index e6478f527..14c455bf1 100644 --- a/internal/service/repository/repack.go +++ b/internal/service/repository/repack.go @@ -4,7 +4,6 @@ import ( "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus" log "github.com/sirupsen/logrus" "golang.org/x/net/context" - "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -44,11 +43,11 @@ func repackCommand(ctx context.Context, rpcName string, repo *pb.Repository, bit if _, ok := status.FromError(err); ok { return err } - return grpc.Errorf(codes.Internal, err.Error()) + return status.Errorf(codes.Internal, err.Error()) } if err := cmd.Wait(); err != nil { - return grpc.Errorf(codes.Internal, err.Error()) + return status.Errorf(codes.Internal, err.Error()) } return nil diff --git a/internal/service/repository/repository_test.go b/internal/service/repository/repository_test.go index 7056566b4..5c03ebd26 100644 --- a/internal/service/repository/repository_test.go +++ b/internal/service/repository/repository_test.go @@ -8,11 +8,11 @@ import ( pb "gitlab.com/gitlab-org/gitaly-proto/go" "gitlab.com/gitlab-org/gitaly/internal/config" + "gitlab.com/gitlab-org/gitaly/internal/helper" "gitlab.com/gitlab-org/gitaly/internal/testhelper" "github.com/stretchr/testify/require" "golang.org/x/net/context" - "google.golang.org/grpc" "google.golang.org/grpc/codes" ) @@ -110,7 +110,7 @@ func TestRepositoryExists(t *testing.T) { defer cancel() response, err := client.RepositoryExists(ctx, tc.request) - require.Equal(t, tc.errorCode, grpc.Code(err)) + require.Equal(t, tc.errorCode, helper.GrpcCode(err)) if err != nil { // Ignore the response message if there was an error @@ -167,7 +167,7 @@ func TestSuccessfulHasLocalBranches(t *testing.T) { response, err := client.HasLocalBranches(ctx, tc.request) - require.Equal(t, tc.errorCode, grpc.Code(err)) + require.Equal(t, tc.errorCode, helper.GrpcCode(err)) if err != nil { return } @@ -209,7 +209,7 @@ func TestFailedHasLocalBranches(t *testing.T) { request := &pb.HasLocalBranchesRequest{Repository: tc.repository} _, err := client.HasLocalBranches(ctx, request) - require.Equal(t, tc.errorCode, grpc.Code(err)) + require.Equal(t, tc.errorCode, helper.GrpcCode(err)) }) } } diff --git a/internal/service/repository/write_ref.go b/internal/service/repository/write_ref.go index 86de86038..01b98ac42 100644 --- a/internal/service/repository/write_ref.go +++ b/internal/service/repository/write_ref.go @@ -7,15 +7,15 @@ import ( pb "gitlab.com/gitlab-org/gitaly-proto/go" "gitlab.com/gitlab-org/gitaly/internal/git" "gitlab.com/gitlab-org/gitaly/internal/rubyserver" - "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "golang.org/x/net/context" ) func (s *server) WriteRef(ctx context.Context, req *pb.WriteRefRequest) (*pb.WriteRefResponse, error) { if err := validateWriteRefRequest(req); err != nil { - return nil, grpc.Errorf(codes.InvalidArgument, "WriteRef: %v", err) + return nil, status.Errorf(codes.InvalidArgument, "WriteRef: %v", err) } client, err := s.RepositoryServiceClient(ctx) diff --git a/internal/service/smarthttp/inforefs.go b/internal/service/smarthttp/inforefs.go index 131818c3f..101d18b96 100644 --- a/internal/service/smarthttp/inforefs.go +++ b/internal/service/smarthttp/inforefs.go @@ -12,7 +12,6 @@ import ( "gitlab.com/gitlab-org/gitaly/internal/helper" "gitlab.com/gitlab-org/gitaly/streamio" - "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -53,23 +52,23 @@ func handleInfoRefs(ctx context.Context, service string, req *pb.InfoRefsRequest if _, ok := status.FromError(err); ok { return err } - return grpc.Errorf(codes.Internal, "GetInfoRefs: cmd: %v", err) + return status.Errorf(codes.Internal, "GetInfoRefs: cmd: %v", err) } if err := pktLine(w, fmt.Sprintf("# service=git-%s\n", service)); err != nil { - return grpc.Errorf(codes.Internal, "GetInfoRefs: pktLine: %v", err) + return status.Errorf(codes.Internal, "GetInfoRefs: pktLine: %v", err) } if err := pktFlush(w); err != nil { - return grpc.Errorf(codes.Internal, "GetInfoRefs: pktFlush: %v", err) + return status.Errorf(codes.Internal, "GetInfoRefs: pktFlush: %v", err) } if _, err := io.Copy(w, cmd); err != nil { - return grpc.Errorf(codes.Internal, "GetInfoRefs: %v", err) + return status.Errorf(codes.Internal, "GetInfoRefs: %v", err) } if err := cmd.Wait(); err != nil { - return grpc.Errorf(codes.Internal, "GetInfoRefs: %v", err) + return status.Errorf(codes.Internal, "GetInfoRefs: %v", err) } return nil diff --git a/internal/service/smarthttp/receive_pack.go b/internal/service/smarthttp/receive_pack.go index 62b66b3c7..441793a8c 100644 --- a/internal/service/smarthttp/receive_pack.go +++ b/internal/service/smarthttp/receive_pack.go @@ -11,8 +11,8 @@ import ( pb "gitlab.com/gitlab-org/gitaly-proto/go" "gitlab.com/gitlab-org/gitaly/streamio" - "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) func (s *server) PostReceivePack(stream pb.SmartHTTPService_PostReceivePackServer) error { @@ -57,11 +57,11 @@ func (s *server) PostReceivePack(stream pb.SmartHTTPService_PostReceivePackServe cmd, err := command.New(stream.Context(), osCommand, stdin, stdout, nil, env...) if err != nil { - return grpc.Errorf(codes.Unavailable, "PostReceivePack: %v", err) + return status.Errorf(codes.Unavailable, "PostReceivePack: %v", err) } if err := cmd.Wait(); err != nil { - return grpc.Errorf(codes.Unavailable, "PostReceivePack: %v", err) + return status.Errorf(codes.Unavailable, "PostReceivePack: %v", err) } return nil @@ -69,10 +69,10 @@ func (s *server) PostReceivePack(stream pb.SmartHTTPService_PostReceivePackServe func validateReceivePackRequest(req *pb.PostReceivePackRequest) error { if req.GlId == "" { - return grpc.Errorf(codes.InvalidArgument, "PostReceivePack: empty GlId") + return status.Errorf(codes.InvalidArgument, "PostReceivePack: empty GlId") } if req.Data != nil { - return grpc.Errorf(codes.InvalidArgument, "PostReceivePack: non-empty Data") + return status.Errorf(codes.InvalidArgument, "PostReceivePack: non-empty Data") } return nil diff --git a/internal/service/smarthttp/upload_pack.go b/internal/service/smarthttp/upload_pack.go index 871999c26..9d80b2fa9 100644 --- a/internal/service/smarthttp/upload_pack.go +++ b/internal/service/smarthttp/upload_pack.go @@ -12,8 +12,8 @@ import ( "gitlab.com/gitlab-org/gitaly/streamio" "github.com/prometheus/client_golang/prometheus" - "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) var ( @@ -71,7 +71,7 @@ func (s *server) PostUploadPack(stream pb.SmartHTTPService_PostUploadPackServer) cmd, err := command.New(stream.Context(), osCommand, stdin, stdout, nil) if err != nil { - return grpc.Errorf(codes.Unavailable, "PostUploadPack: cmd: %v", err) + return status.Errorf(codes.Unavailable, "PostUploadPack: cmd: %v", err) } if err := cmd.Wait(); err != nil { @@ -83,7 +83,7 @@ func (s *server) PostUploadPack(stream pb.SmartHTTPService_PostUploadPackServer) deepenCount.Inc() return nil } - return grpc.Errorf(codes.Unavailable, "PostUploadPack: %v", err) + return status.Errorf(codes.Unavailable, "PostUploadPack: %v", err) } return nil @@ -91,7 +91,7 @@ func (s *server) PostUploadPack(stream pb.SmartHTTPService_PostUploadPackServer) func validateUploadPackRequest(req *pb.PostUploadPackRequest) error { if req.Data != nil { - return grpc.Errorf(codes.InvalidArgument, "PostUploadPack: non-empty Data") + return status.Errorf(codes.InvalidArgument, "PostUploadPack: non-empty Data") } return nil diff --git a/internal/service/ssh/receive_pack.go b/internal/service/ssh/receive_pack.go index a0b45c220..f4986f315 100644 --- a/internal/service/ssh/receive_pack.go +++ b/internal/service/ssh/receive_pack.go @@ -11,8 +11,8 @@ import ( pb "gitlab.com/gitlab-org/gitaly-proto/go" "gitlab.com/gitlab-org/gitaly/streamio" - "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) func (s *server) SSHReceivePack(stream pb.SSHService_SSHReceivePackServer) error { @@ -59,7 +59,7 @@ func (s *server) SSHReceivePack(stream pb.SSHService_SSHReceivePackServer) error cmd, err := command.New(stream.Context(), osCommand, stdin, stdout, stderr, env...) if err != nil { - return grpc.Errorf(codes.Unavailable, "SSHReceivePack: cmd: %v", err) + return status.Errorf(codes.Unavailable, "SSHReceivePack: cmd: %v", err) } if err := cmd.Wait(); err != nil { @@ -69,7 +69,7 @@ func (s *server) SSHReceivePack(stream pb.SSHService_SSHReceivePackServer) error stream.Send(&pb.SSHReceivePackResponse{ExitStatus: &pb.ExitStatus{Value: int32(status)}}), ) } - return grpc.Errorf(codes.Unavailable, "SSHReceivePack: %v", err) + return status.Errorf(codes.Unavailable, "SSHReceivePack: %v", err) } return nil @@ -77,10 +77,10 @@ func (s *server) SSHReceivePack(stream pb.SSHService_SSHReceivePackServer) error func validateFirstReceivePackRequest(req *pb.SSHReceivePackRequest) error { if req.GlId == "" { - return grpc.Errorf(codes.InvalidArgument, "SSHReceivePack: empty GlId") + return status.Errorf(codes.InvalidArgument, "SSHReceivePack: empty GlId") } if req.Stdin != nil { - return grpc.Errorf(codes.InvalidArgument, "SSHReceivePack: non-empty data") + return status.Errorf(codes.InvalidArgument, "SSHReceivePack: non-empty data") } return nil diff --git a/internal/service/ssh/upload_pack.go b/internal/service/ssh/upload_pack.go index 1fe56c568..b87d70e34 100644 --- a/internal/service/ssh/upload_pack.go +++ b/internal/service/ssh/upload_pack.go @@ -8,8 +8,8 @@ import ( "gitlab.com/gitlab-org/gitaly/internal/command" "gitlab.com/gitlab-org/gitaly/internal/helper" "gitlab.com/gitlab-org/gitaly/streamio" - "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) func (s *server) SSHUploadPack(stream pb.SSHService_SSHUploadPackServer) error { @@ -51,7 +51,7 @@ func (s *server) SSHUploadPack(stream pb.SSHService_SSHUploadPackServer) error { cmd, err := command.New(stream.Context(), osCommand, stdin, stdout, stderr) if err != nil { - return grpc.Errorf(codes.Unavailable, "SSHUploadPack: cmd: %v", err) + return status.Errorf(codes.Unavailable, "SSHUploadPack: cmd: %v", err) } if err := cmd.Wait(); err != nil { @@ -61,7 +61,7 @@ func (s *server) SSHUploadPack(stream pb.SSHService_SSHUploadPackServer) error { stream.Send(&pb.SSHUploadPackResponse{ExitStatus: &pb.ExitStatus{Value: int32(status)}}), ) } - return grpc.Errorf(codes.Unavailable, "SSHUploadPack: %v", err) + return status.Errorf(codes.Unavailable, "SSHUploadPack: %v", err) } return nil @@ -69,7 +69,7 @@ func (s *server) SSHUploadPack(stream pb.SSHService_SSHUploadPackServer) error { func validateFirstUploadPackRequest(req *pb.SSHUploadPackRequest) error { if req.Stdin != nil { - return grpc.Errorf(codes.InvalidArgument, "SSHUploadPack: non-empty stdin") + return status.Errorf(codes.InvalidArgument, "SSHUploadPack: non-empty stdin") } return nil diff --git a/internal/service/wiki/delete_page.go b/internal/service/wiki/delete_page.go index b92be433e..60f7ef72d 100644 --- a/internal/service/wiki/delete_page.go +++ b/internal/service/wiki/delete_page.go @@ -8,13 +8,13 @@ import ( pb "gitlab.com/gitlab-org/gitaly-proto/go" "golang.org/x/net/context" - "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) func (s *server) WikiDeletePage(ctx context.Context, request *pb.WikiDeletePageRequest) (*pb.WikiDeletePageResponse, error) { if err := validateWikiDeletePageRequest(request); err != nil { - return nil, grpc.Errorf(codes.InvalidArgument, "WikiDeletePage: %v", err) + return nil, status.Errorf(codes.InvalidArgument, "WikiDeletePage: %v", err) } client, err := s.WikiServiceClient(ctx) diff --git a/internal/service/wiki/find_file.go b/internal/service/wiki/find_file.go index e13cf3a71..9450e4225 100644 --- a/internal/service/wiki/find_file.go +++ b/internal/service/wiki/find_file.go @@ -5,15 +5,15 @@ import ( pb "gitlab.com/gitlab-org/gitaly-proto/go" - "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) func (s *server) WikiFindFile(request *pb.WikiFindFileRequest, stream pb.WikiService_WikiFindFileServer) error { ctx := stream.Context() if len(request.GetName()) == 0 { - return grpc.Errorf(codes.InvalidArgument, "WikiFindFile: Empty Name") + return status.Errorf(codes.InvalidArgument, "WikiFindFile: Empty Name") } client, err := s.WikiServiceClient(ctx) diff --git a/internal/service/wiki/find_page.go b/internal/service/wiki/find_page.go index 79eea6900..ae3191b16 100644 --- a/internal/service/wiki/find_page.go +++ b/internal/service/wiki/find_page.go @@ -5,15 +5,15 @@ import ( pb "gitlab.com/gitlab-org/gitaly-proto/go" - "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) func (s *server) WikiFindPage(request *pb.WikiFindPageRequest, stream pb.WikiService_WikiFindPageServer) error { ctx := stream.Context() if len(request.GetTitle()) == 0 { - return grpc.Errorf(codes.InvalidArgument, "WikiFindPage: Empty Title") + return status.Errorf(codes.InvalidArgument, "WikiFindPage: Empty Title") } client, err := s.WikiServiceClient(ctx) diff --git a/internal/service/wiki/get_page_versions.go b/internal/service/wiki/get_page_versions.go index 02a5f19ae..715c279cb 100644 --- a/internal/service/wiki/get_page_versions.go +++ b/internal/service/wiki/get_page_versions.go @@ -5,15 +5,15 @@ import ( pb "gitlab.com/gitlab-org/gitaly-proto/go" - "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) func (s *server) WikiGetPageVersions(request *pb.WikiGetPageVersionsRequest, stream pb.WikiService_WikiGetPageVersionsServer) error { ctx := stream.Context() if len(request.GetPagePath()) == 0 { - return grpc.Errorf(codes.InvalidArgument, "WikiGetPageVersions: Empty Path") + return status.Errorf(codes.InvalidArgument, "WikiGetPageVersions: Empty Path") } client, err := s.WikiServiceClient(ctx) diff --git a/internal/service/wiki/update_page.go b/internal/service/wiki/update_page.go index 450971bd6..73c17684c 100644 --- a/internal/service/wiki/update_page.go +++ b/internal/service/wiki/update_page.go @@ -7,8 +7,8 @@ import ( pb "gitlab.com/gitlab-org/gitaly-proto/go" - "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) func (s *server) WikiUpdatePage(stream pb.WikiService_WikiUpdatePageServer) error { @@ -18,7 +18,7 @@ func (s *server) WikiUpdatePage(stream pb.WikiService_WikiUpdatePageServer) erro } if err := validateWikiUpdatePageRequest(firstRequest); err != nil { - return grpc.Errorf(codes.InvalidArgument, "WikiUpdatePage: %v", err) + return status.Errorf(codes.InvalidArgument, "WikiUpdatePage: %v", err) } ctx := stream.Context() diff --git a/internal/service/wiki/write_page.go b/internal/service/wiki/write_page.go index 5e5a3e874..ec23baf45 100644 --- a/internal/service/wiki/write_page.go +++ b/internal/service/wiki/write_page.go @@ -7,8 +7,8 @@ import ( pb "gitlab.com/gitlab-org/gitaly-proto/go" - "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) func (s *server) WikiWritePage(stream pb.WikiService_WikiWritePageServer) error { @@ -18,7 +18,7 @@ func (s *server) WikiWritePage(stream pb.WikiService_WikiWritePageServer) error } if err := validateWikiWritePageRequest(firstRequest); err != nil { - return grpc.Errorf(codes.InvalidArgument, "WikiWritePage: %v", err) + return status.Errorf(codes.InvalidArgument, "WikiWritePage: %v", err) } ctx := stream.Context() diff --git a/internal/testhelper/testhelper.go b/internal/testhelper/testhelper.go index 3107824e5..6923ab962 100644 --- a/internal/testhelper/testhelper.go +++ b/internal/testhelper/testhelper.go @@ -31,6 +31,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" ) // TestRelativePath is the path inside its storage of the gitlab-test repo @@ -131,7 +132,8 @@ func AssertGrpcError(t *testing.T, err error, expectedCode codes.Code, containsT } // Check that the code matches - if code := grpc.Code(err); code != expectedCode { + status, _ := status.FromError(err) + if code := status.Code(); code != expectedCode { t.Fatalf("Expected an error with code %v, got %v. The error was %v", expectedCode, code, err) } diff --git a/vendor/google.golang.org/grpc/CONTRIBUTING.md b/vendor/google.golang.org/grpc/CONTRIBUTING.md index a5c6e06e2..8ec6c9574 100644 --- a/vendor/google.golang.org/grpc/CONTRIBUTING.md +++ b/vendor/google.golang.org/grpc/CONTRIBUTING.md @@ -7,7 +7,7 @@ If you are new to github, please start by reading [Pull Request howto](https://h ## Legal requirements In order to protect both you and ourselves, you will need to sign the -[Contributor License Agreement](https://cla.developers.google.com/clas). +[Contributor License Agreement](https://identity.linuxfoundation.org/projects/cncf). ## Guidelines for Pull Requests How to get your contributions merged smoothly and quickly. diff --git a/vendor/google.golang.org/grpc/balancer.go b/vendor/google.golang.org/grpc/balancer.go index ab65049dd..300da6c5e 100644 --- a/vendor/google.golang.org/grpc/balancer.go +++ b/vendor/google.golang.org/grpc/balancer.go @@ -28,6 +28,7 @@ import ( "google.golang.org/grpc/credentials" "google.golang.org/grpc/grpclog" "google.golang.org/grpc/naming" + "google.golang.org/grpc/status" ) // Address represents a server the client connects to. @@ -310,7 +311,7 @@ func (rr *roundRobin) Get(ctx context.Context, opts BalancerGetOptions) (addr Ad if !opts.BlockingWait { if len(rr.addrs) == 0 { rr.mu.Unlock() - err = Errorf(codes.Unavailable, "there is no address available") + err = status.Errorf(codes.Unavailable, "there is no address available") return } // Returns the next addr on rr.addrs for failfast RPCs. diff --git a/vendor/google.golang.org/grpc/balancer/balancer.go b/vendor/google.golang.org/grpc/balancer/balancer.go index cd2682f5f..219a2940c 100644 --- a/vendor/google.golang.org/grpc/balancer/balancer.go +++ b/vendor/google.golang.org/grpc/balancer/balancer.go @@ -23,6 +23,7 @@ package balancer import ( "errors" "net" + "strings" "golang.org/x/net/context" "google.golang.org/grpc/connectivity" @@ -36,15 +37,17 @@ var ( ) // Register registers the balancer builder to the balancer map. -// b.Name will be used as the name registered with this builder. +// b.Name (lowercased) will be used as the name registered with +// this builder. func Register(b Builder) { - m[b.Name()] = b + m[strings.ToLower(b.Name())] = b } // Get returns the resolver builder registered with the given name. +// Note that the compare is done in a case-insenstive fashion. // If no builder is register with the name, nil will be returned. func Get(name string) Builder { - if b, ok := m[name]; ok { + if b, ok := m[strings.ToLower(name)]; ok { return b } return nil @@ -63,6 +66,11 @@ func Get(name string) Builder { // When the connection encounters an error, it will reconnect immediately. // When the connection becomes IDLE, it will not reconnect unless Connect is // called. +// +// This interface is to be implemented by gRPC. Users should not need a +// brand new implementation of this interface. For the situations like +// testing, the new implementation should embed this interface. This allows +// gRPC to add new methods to this interface. type SubConn interface { // UpdateAddresses updates the addresses used in this SubConn. // gRPC checks if currently-connected address is still in the new list. @@ -80,6 +88,11 @@ type SubConn interface { type NewSubConnOptions struct{} // ClientConn represents a gRPC ClientConn. +// +// This interface is to be implemented by gRPC. Users should not need a +// brand new implementation of this interface. For the situations like +// testing, the new implementation should embed this interface. This allows +// gRPC to add new methods to this interface. type ClientConn interface { // NewSubConn is called by balancer to create a new SubConn. // It doesn't block and wait for the connections to be established. @@ -96,6 +109,9 @@ type ClientConn interface { // on the new picker to pick new SubConn. UpdateBalancerState(s connectivity.State, p Picker) + // ResolveNow is called by balancer to notify gRPC to do a name resolving. + ResolveNow(resolver.ResolveNowOption) + // Target returns the dial target for this ClientConn. Target() string } @@ -128,6 +144,10 @@ type PickOptions struct{} type DoneInfo struct { // Err is the rpc error the RPC finished with. It could be nil. Err error + // BytesSent indicates if any bytes have been sent to the server. + BytesSent bool + // BytesReceived indicates if any byte has been received from the server. + BytesReceived bool } var ( diff --git a/vendor/google.golang.org/grpc/balancer/base/balancer.go b/vendor/google.golang.org/grpc/balancer/base/balancer.go new file mode 100644 index 000000000..1e962b724 --- /dev/null +++ b/vendor/google.golang.org/grpc/balancer/base/balancer.go @@ -0,0 +1,209 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package base + +import ( + "golang.org/x/net/context" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/resolver" +) + +type baseBuilder struct { + name string + pickerBuilder PickerBuilder +} + +func (bb *baseBuilder) Build(cc balancer.ClientConn, opt balancer.BuildOptions) balancer.Balancer { + return &baseBalancer{ + cc: cc, + pickerBuilder: bb.pickerBuilder, + + subConns: make(map[resolver.Address]balancer.SubConn), + scStates: make(map[balancer.SubConn]connectivity.State), + csEvltr: &connectivityStateEvaluator{}, + // Initialize picker to a picker that always return + // ErrNoSubConnAvailable, because when state of a SubConn changes, we + // may call UpdateBalancerState with this picker. + picker: NewErrPicker(balancer.ErrNoSubConnAvailable), + } +} + +func (bb *baseBuilder) Name() string { + return bb.name +} + +type baseBalancer struct { + cc balancer.ClientConn + pickerBuilder PickerBuilder + + csEvltr *connectivityStateEvaluator + state connectivity.State + + subConns map[resolver.Address]balancer.SubConn + scStates map[balancer.SubConn]connectivity.State + picker balancer.Picker +} + +func (b *baseBalancer) HandleResolvedAddrs(addrs []resolver.Address, err error) { + if err != nil { + grpclog.Infof("base.baseBalancer: HandleResolvedAddrs called with error %v", err) + return + } + grpclog.Infoln("base.baseBalancer: got new resolved addresses: ", addrs) + // addrsSet is the set converted from addrs, it's used for quick lookup of an address. + addrsSet := make(map[resolver.Address]struct{}) + for _, a := range addrs { + addrsSet[a] = struct{}{} + if _, ok := b.subConns[a]; !ok { + // a is a new address (not existing in b.subConns). + sc, err := b.cc.NewSubConn([]resolver.Address{a}, balancer.NewSubConnOptions{}) + if err != nil { + grpclog.Warningf("base.baseBalancer: failed to create new SubConn: %v", err) + continue + } + b.subConns[a] = sc + b.scStates[sc] = connectivity.Idle + sc.Connect() + } + } + for a, sc := range b.subConns { + // a was removed by resolver. + if _, ok := addrsSet[a]; !ok { + b.cc.RemoveSubConn(sc) + delete(b.subConns, a) + // Keep the state of this sc in b.scStates until sc's state becomes Shutdown. + // The entry will be deleted in HandleSubConnStateChange. + } + } +} + +// regeneratePicker takes a snapshot of the balancer, and generates a picker +// from it. The picker is +// - errPicker with ErrTransientFailure if the balancer is in TransientFailure, +// - built by the pickerBuilder with all READY SubConns otherwise. +func (b *baseBalancer) regeneratePicker() { + if b.state == connectivity.TransientFailure { + b.picker = NewErrPicker(balancer.ErrTransientFailure) + return + } + readySCs := make(map[resolver.Address]balancer.SubConn) + + // Filter out all ready SCs from full subConn map. + for addr, sc := range b.subConns { + if st, ok := b.scStates[sc]; ok && st == connectivity.Ready { + readySCs[addr] = sc + } + } + b.picker = b.pickerBuilder.Build(readySCs) +} + +func (b *baseBalancer) HandleSubConnStateChange(sc balancer.SubConn, s connectivity.State) { + grpclog.Infof("base.baseBalancer: handle SubConn state change: %p, %v", sc, s) + oldS, ok := b.scStates[sc] + if !ok { + grpclog.Infof("base.baseBalancer: got state changes for an unknown SubConn: %p, %v", sc, s) + return + } + b.scStates[sc] = s + switch s { + case connectivity.Idle: + sc.Connect() + case connectivity.Shutdown: + // When an address was removed by resolver, b called RemoveSubConn but + // kept the sc's state in scStates. Remove state for this sc here. + delete(b.scStates, sc) + } + + oldAggrState := b.state + b.state = b.csEvltr.recordTransition(oldS, s) + + // Regenerate picker when one of the following happens: + // - this sc became ready from not-ready + // - this sc became not-ready from ready + // - the aggregated state of balancer became TransientFailure from non-TransientFailure + // - the aggregated state of balancer became non-TransientFailure from TransientFailure + if (s == connectivity.Ready) != (oldS == connectivity.Ready) || + (b.state == connectivity.TransientFailure) != (oldAggrState == connectivity.TransientFailure) { + b.regeneratePicker() + } + + b.cc.UpdateBalancerState(b.state, b.picker) + return +} + +// Close is a nop because base balancer doesn't have internal state to clean up, +// and it doesn't need to call RemoveSubConn for the SubConns. +func (b *baseBalancer) Close() { +} + +// NewErrPicker returns a picker that always returns err on Pick(). +func NewErrPicker(err error) balancer.Picker { + return &errPicker{err: err} +} + +type errPicker struct { + err error // Pick() always returns this err. +} + +func (p *errPicker) Pick(ctx context.Context, opts balancer.PickOptions) (balancer.SubConn, func(balancer.DoneInfo), error) { + return nil, nil, p.err +} + +// connectivityStateEvaluator gets updated by addrConns when their +// states transition, based on which it evaluates the state of +// ClientConn. +type connectivityStateEvaluator struct { + numReady uint64 // Number of addrConns in ready state. + numConnecting uint64 // Number of addrConns in connecting state. + numTransientFailure uint64 // Number of addrConns in transientFailure. +} + +// recordTransition records state change happening in every subConn and based on +// that it evaluates what aggregated state should be. +// It can only transition between Ready, Connecting and TransientFailure. Other states, +// Idle and Shutdown are transitioned into by ClientConn; in the beginning of the connection +// before any subConn is created ClientConn is in idle state. In the end when ClientConn +// closes it is in Shutdown state. +// +// recordTransition should only be called synchronously from the same goroutine. +func (cse *connectivityStateEvaluator) recordTransition(oldState, newState connectivity.State) connectivity.State { + // Update counters. + for idx, state := range []connectivity.State{oldState, newState} { + updateVal := 2*uint64(idx) - 1 // -1 for oldState and +1 for new. + switch state { + case connectivity.Ready: + cse.numReady += updateVal + case connectivity.Connecting: + cse.numConnecting += updateVal + case connectivity.TransientFailure: + cse.numTransientFailure += updateVal + } + } + + // Evaluate. + if cse.numReady > 0 { + return connectivity.Ready + } + if cse.numConnecting > 0 { + return connectivity.Connecting + } + return connectivity.TransientFailure +} diff --git a/vendor/google.golang.org/grpc/balancer/base/base.go b/vendor/google.golang.org/grpc/balancer/base/base.go new file mode 100644 index 000000000..012ace2f2 --- /dev/null +++ b/vendor/google.golang.org/grpc/balancer/base/base.go @@ -0,0 +1,52 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package base defines a balancer base that can be used to build balancers with +// different picking algorithms. +// +// The base balancer creates a new SubConn for each resolved address. The +// provided picker will only be notified about READY SubConns. +// +// This package is the base of round_robin balancer, its purpose is to be used +// to build round_robin like balancers with complex picking algorithms. +// Balancers with more complicated logic should try to implement a balancer +// builder from scratch. +// +// All APIs in this package are experimental. +package base + +import ( + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/resolver" +) + +// PickerBuilder creates balancer.Picker. +type PickerBuilder interface { + // Build takes a slice of ready SubConns, and returns a picker that will be + // used by gRPC to pick a SubConn. + Build(readySCs map[resolver.Address]balancer.SubConn) balancer.Picker +} + +// NewBalancerBuilder returns a balancer builder. The balancers +// built by this builder will use the picker builder to build pickers. +func NewBalancerBuilder(name string, pb PickerBuilder) balancer.Builder { + return &baseBuilder{ + name: name, + pickerBuilder: pb, + } +} diff --git a/vendor/google.golang.org/grpc/balancer/roundrobin/roundrobin.go b/vendor/google.golang.org/grpc/balancer/roundrobin/roundrobin.go index 9d2fbcd84..2eda0a1c2 100644 --- a/vendor/google.golang.org/grpc/balancer/roundrobin/roundrobin.go +++ b/vendor/google.golang.org/grpc/balancer/roundrobin/roundrobin.go @@ -26,145 +26,37 @@ import ( "golang.org/x/net/context" "google.golang.org/grpc/balancer" - "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/balancer/base" "google.golang.org/grpc/grpclog" "google.golang.org/grpc/resolver" ) +// Name is the name of round_robin balancer. +const Name = "round_robin" + // newBuilder creates a new roundrobin balancer builder. func newBuilder() balancer.Builder { - return &rrBuilder{} + return base.NewBalancerBuilder(Name, &rrPickerBuilder{}) } func init() { balancer.Register(newBuilder()) } -type rrBuilder struct{} - -func (*rrBuilder) Build(cc balancer.ClientConn, opt balancer.BuildOptions) balancer.Balancer { - return &rrBalancer{ - cc: cc, - subConns: make(map[resolver.Address]balancer.SubConn), - scStates: make(map[balancer.SubConn]connectivity.State), - csEvltr: &connectivityStateEvaluator{}, - // Initialize picker to a picker that always return - // ErrNoSubConnAvailable, because when state of a SubConn changes, we - // may call UpdateBalancerState with this picker. - picker: newPicker([]balancer.SubConn{}, nil), - } -} - -func (*rrBuilder) Name() string { - return "round_robin" -} +type rrPickerBuilder struct{} -type rrBalancer struct { - cc balancer.ClientConn - - csEvltr *connectivityStateEvaluator - state connectivity.State - - subConns map[resolver.Address]balancer.SubConn - scStates map[balancer.SubConn]connectivity.State - picker *picker -} - -func (b *rrBalancer) HandleResolvedAddrs(addrs []resolver.Address, err error) { - if err != nil { - grpclog.Infof("roundrobin.rrBalancer: HandleResolvedAddrs called with error %v", err) - return - } - grpclog.Infoln("roundrobin.rrBalancer: got new resolved addresses: ", addrs) - // addrsSet is the set converted from addrs, it's used for quick lookup of an address. - addrsSet := make(map[resolver.Address]struct{}) - for _, a := range addrs { - addrsSet[a] = struct{}{} - if _, ok := b.subConns[a]; !ok { - // a is a new address (not existing in b.subConns). - sc, err := b.cc.NewSubConn([]resolver.Address{a}, balancer.NewSubConnOptions{}) - if err != nil { - grpclog.Warningf("roundrobin.rrBalancer: failed to create new SubConn: %v", err) - continue - } - b.subConns[a] = sc - b.scStates[sc] = connectivity.Idle - sc.Connect() - } +func (*rrPickerBuilder) Build(readySCs map[resolver.Address]balancer.SubConn) balancer.Picker { + grpclog.Infof("roundrobinPicker: newPicker called with readySCs: %v", readySCs) + var scs []balancer.SubConn + for _, sc := range readySCs { + scs = append(scs, sc) } - for a, sc := range b.subConns { - // a was removed by resolver. - if _, ok := addrsSet[a]; !ok { - b.cc.RemoveSubConn(sc) - delete(b.subConns, a) - // Keep the state of this sc in b.scStates until sc's state becomes Shutdown. - // The entry will be deleted in HandleSubConnStateChange. - } - } -} - -// regeneratePicker takes a snapshot of the balancer, and generates a picker -// from it. The picker -// - always returns ErrTransientFailure if the balancer is in TransientFailure, -// - or does round robin selection of all READY SubConns otherwise. -func (b *rrBalancer) regeneratePicker() { - if b.state == connectivity.TransientFailure { - b.picker = newPicker(nil, balancer.ErrTransientFailure) - return - } - var readySCs []balancer.SubConn - for sc, st := range b.scStates { - if st == connectivity.Ready { - readySCs = append(readySCs, sc) - } - } - b.picker = newPicker(readySCs, nil) -} - -func (b *rrBalancer) HandleSubConnStateChange(sc balancer.SubConn, s connectivity.State) { - grpclog.Infof("roundrobin.rrBalancer: handle SubConn state change: %p, %v", sc, s) - oldS, ok := b.scStates[sc] - if !ok { - grpclog.Infof("roundrobin.rrBalancer: got state changes for an unknown SubConn: %p, %v", sc, s) - return - } - b.scStates[sc] = s - switch s { - case connectivity.Idle: - sc.Connect() - case connectivity.Shutdown: - // When an address was removed by resolver, b called RemoveSubConn but - // kept the sc's state in scStates. Remove state for this sc here. - delete(b.scStates, sc) - } - - oldAggrState := b.state - b.state = b.csEvltr.recordTransition(oldS, s) - - // Regenerate picker when one of the following happens: - // - this sc became ready from not-ready - // - this sc became not-ready from ready - // - the aggregated state of balancer became TransientFailure from non-TransientFailure - // - the aggregated state of balancer became non-TransientFailure from TransientFailure - if (s == connectivity.Ready) != (oldS == connectivity.Ready) || - (b.state == connectivity.TransientFailure) != (oldAggrState == connectivity.TransientFailure) { - b.regeneratePicker() + return &rrPicker{ + subConns: scs, } - - b.cc.UpdateBalancerState(b.state, b.picker) - return -} - -// Close is a nop because roundrobin balancer doesn't internal state to clean -// up, and it doesn't need to call RemoveSubConn for the SubConns. -func (b *rrBalancer) Close() { } -type picker struct { - // If err is not nil, Pick always returns this err. It's immutable after - // picker is created. - err error - +type rrPicker struct { // subConns is the snapshot of the roundrobin balancer when this picker was // created. The slice is immutable. Each Get() will do a round robin // selection from it and return the selected SubConn. @@ -174,20 +66,7 @@ type picker struct { next int } -func newPicker(scs []balancer.SubConn, err error) *picker { - grpclog.Infof("roundrobinPicker: newPicker called with scs: %v, %v", scs, err) - if err != nil { - return &picker{err: err} - } - return &picker{ - subConns: scs, - } -} - -func (p *picker) Pick(ctx context.Context, opts balancer.PickOptions) (balancer.SubConn, func(balancer.DoneInfo), error) { - if p.err != nil { - return nil, nil, p.err - } +func (p *rrPicker) Pick(ctx context.Context, opts balancer.PickOptions) (balancer.SubConn, func(balancer.DoneInfo), error) { if len(p.subConns) <= 0 { return nil, nil, balancer.ErrNoSubConnAvailable } @@ -198,44 +77,3 @@ func (p *picker) Pick(ctx context.Context, opts balancer.PickOptions) (balancer. p.mu.Unlock() return sc, nil, nil } - -// connectivityStateEvaluator gets updated by addrConns when their -// states transition, based on which it evaluates the state of -// ClientConn. -type connectivityStateEvaluator struct { - numReady uint64 // Number of addrConns in ready state. - numConnecting uint64 // Number of addrConns in connecting state. - numTransientFailure uint64 // Number of addrConns in transientFailure. -} - -// recordTransition records state change happening in every subConn and based on -// that it evaluates what aggregated state should be. -// It can only transition between Ready, Connecting and TransientFailure. Other states, -// Idle and Shutdown are transitioned into by ClientConn; in the beginning of the connection -// before any subConn is created ClientConn is in idle state. In the end when ClientConn -// closes it is in Shutdown state. -// -// recordTransition should only be called synchronously from the same goroutine. -func (cse *connectivityStateEvaluator) recordTransition(oldState, newState connectivity.State) connectivity.State { - // Update counters. - for idx, state := range []connectivity.State{oldState, newState} { - updateVal := 2*uint64(idx) - 1 // -1 for oldState and +1 for new. - switch state { - case connectivity.Ready: - cse.numReady += updateVal - case connectivity.Connecting: - cse.numConnecting += updateVal - case connectivity.TransientFailure: - cse.numTransientFailure += updateVal - } - } - - // Evaluate. - if cse.numReady > 0 { - return connectivity.Ready - } - if cse.numConnecting > 0 { - return connectivity.Connecting - } - return connectivity.TransientFailure -} diff --git a/vendor/google.golang.org/grpc/balancer_conn_wrappers.go b/vendor/google.golang.org/grpc/balancer_conn_wrappers.go index ebfee4a88..db6f0ae3f 100644 --- a/vendor/google.golang.org/grpc/balancer_conn_wrappers.go +++ b/vendor/google.golang.org/grpc/balancer_conn_wrappers.go @@ -19,6 +19,7 @@ package grpc import ( + "fmt" "sync" "google.golang.org/grpc/balancer" @@ -97,6 +98,7 @@ type ccBalancerWrapper struct { resolverUpdateCh chan *resolverUpdate done chan struct{} + mu sync.Mutex subConns map[*acBalancerWrapper]struct{} } @@ -141,7 +143,11 @@ func (ccb *ccBalancerWrapper) watcher() { select { case <-ccb.done: ccb.balancer.Close() - for acbw := range ccb.subConns { + ccb.mu.Lock() + scs := ccb.subConns + ccb.subConns = nil + ccb.mu.Unlock() + for acbw := range scs { ccb.cc.removeAddrConn(acbw.getAddrConn(), errConnDrain) } return @@ -183,6 +189,14 @@ func (ccb *ccBalancerWrapper) handleResolvedAddrs(addrs []resolver.Address, err } func (ccb *ccBalancerWrapper) NewSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) { + if len(addrs) <= 0 { + return nil, fmt.Errorf("grpc: cannot create SubConn with empty address list") + } + ccb.mu.Lock() + defer ccb.mu.Unlock() + if ccb.subConns == nil { + return nil, fmt.Errorf("grpc: ClientConn balancer wrapper was closed") + } ac, err := ccb.cc.newAddrConn(addrs) if err != nil { return nil, err @@ -200,15 +214,29 @@ func (ccb *ccBalancerWrapper) RemoveSubConn(sc balancer.SubConn) { if !ok { return } + ccb.mu.Lock() + defer ccb.mu.Unlock() + if ccb.subConns == nil { + return + } delete(ccb.subConns, acbw) ccb.cc.removeAddrConn(acbw.getAddrConn(), errConnDrain) } func (ccb *ccBalancerWrapper) UpdateBalancerState(s connectivity.State, p balancer.Picker) { + ccb.mu.Lock() + defer ccb.mu.Unlock() + if ccb.subConns == nil { + return + } ccb.cc.csMgr.updateState(s) ccb.cc.blockingpicker.updatePicker(p) } +func (ccb *ccBalancerWrapper) ResolveNow(o resolver.ResolveNowOption) { + ccb.cc.resolveNow(o) +} + func (ccb *ccBalancerWrapper) Target() string { return ccb.cc.target } @@ -223,6 +251,10 @@ type acBalancerWrapper struct { func (acbw *acBalancerWrapper) UpdateAddresses(addrs []resolver.Address) { acbw.mu.Lock() defer acbw.mu.Unlock() + if len(addrs) <= 0 { + acbw.ac.tearDown(errConnDrain) + return + } if !acbw.ac.tryUpdateAddrs(addrs) { cc := acbw.ac.cc acbw.ac.mu.Lock() diff --git a/vendor/google.golang.org/grpc/balancer_v1_wrapper.go b/vendor/google.golang.org/grpc/balancer_v1_wrapper.go index 6cb39071c..faabf87d0 100644 --- a/vendor/google.golang.org/grpc/balancer_v1_wrapper.go +++ b/vendor/google.golang.org/grpc/balancer_v1_wrapper.go @@ -28,6 +28,7 @@ import ( "google.golang.org/grpc/connectivity" "google.golang.org/grpc/grpclog" "google.golang.org/grpc/resolver" + "google.golang.org/grpc/status" ) type balancerWrapperBuilder struct { @@ -173,10 +174,10 @@ func (bw *balancerWrapper) lbWatcher() { sc.Connect() } } else { - oldSC.UpdateAddresses(newAddrs) bw.mu.Lock() bw.connSt[oldSC].addr = addrs[0] bw.mu.Unlock() + oldSC.UpdateAddresses(newAddrs) } } else { var ( @@ -317,12 +318,12 @@ func (bw *balancerWrapper) Pick(ctx context.Context, opts balancer.PickOptions) Metadata: a.Metadata, }] if !ok && failfast { - return nil, nil, Errorf(codes.Unavailable, "there is no connection available") + return nil, nil, status.Errorf(codes.Unavailable, "there is no connection available") } if s, ok := bw.connSt[sc]; failfast && (!ok || s.s != connectivity.Ready) { // If the returned sc is not ready and RPC is failfast, // return error, and this RPC will fail. - return nil, nil, Errorf(codes.Unavailable, "there is no connection available") + return nil, nil, status.Errorf(codes.Unavailable, "there is no connection available") } } diff --git a/vendor/google.golang.org/grpc/call.go b/vendor/google.golang.org/grpc/call.go index 0854f84b9..13cf8b13b 100644 --- a/vendor/google.golang.org/grpc/call.go +++ b/vendor/google.golang.org/grpc/call.go @@ -29,6 +29,7 @@ import ( "google.golang.org/grpc/encoding" "google.golang.org/grpc/peer" "google.golang.org/grpc/stats" + "google.golang.org/grpc/status" "google.golang.org/grpc/transport" ) @@ -59,7 +60,7 @@ func recvResponse(ctx context.Context, dopts dialOptions, t transport.ClientTran } for { if c.maxReceiveMessageSize == nil { - return Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)") + return status.Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)") } // Set dc if it exists and matches the message compression type used, @@ -113,7 +114,7 @@ func sendRequest(ctx context.Context, dopts dialOptions, compressor Compressor, compressor = nil // Disable the legacy compressor. comp = encoding.GetCompressor(ct) if comp == nil { - return Errorf(codes.Internal, "grpc: Compressor is not installed for grpc-encoding %q", ct) + return status.Errorf(codes.Internal, "grpc: Compressor is not installed for grpc-encoding %q", ct) } } hdr, data, err := encode(dopts.codec, args, compressor, outPayload, comp) @@ -121,10 +122,10 @@ func sendRequest(ctx context.Context, dopts dialOptions, compressor Compressor, return err } if c.maxSendMessageSize == nil { - return Errorf(codes.Internal, "callInfo maxSendMessageSize field uninitialized(nil)") + return status.Errorf(codes.Internal, "callInfo maxSendMessageSize field uninitialized(nil)") } if len(data) > *c.maxSendMessageSize { - return Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(data), *c.maxSendMessageSize) + return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(data), *c.maxSendMessageSize) } err = t.Write(stream, hdr, data, opts) if err == nil && outPayload != nil { @@ -277,11 +278,11 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli err = sendRequest(ctx, cc.dopts, cc.dopts.cp, c, callHdr, stream, t, args, topts) if err != nil { if done != nil { - updateRPCInfoInContext(ctx, rpcInfo{ - bytesSent: true, - bytesReceived: stream.BytesReceived(), + done(balancer.DoneInfo{ + Err: err, + BytesSent: true, + BytesReceived: stream.BytesReceived(), }) - done(balancer.DoneInfo{Err: err}) } // Retry a non-failfast RPC when // i) the server started to drain before this RPC was initiated. @@ -301,11 +302,11 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli err = recvResponse(ctx, cc.dopts, t, c, stream, reply) if err != nil { if done != nil { - updateRPCInfoInContext(ctx, rpcInfo{ - bytesSent: true, - bytesReceived: stream.BytesReceived(), + done(balancer.DoneInfo{ + Err: err, + BytesSent: true, + BytesReceived: stream.BytesReceived(), }) - done(balancer.DoneInfo{Err: err}) } if !c.failFast && stream.Unprocessed() { // In these cases, the server did not receive the data, but we still @@ -323,12 +324,13 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli c.traceInfo.tr.LazyLog(&payload{sent: false, msg: reply}, true) } t.CloseStream(stream, nil) + err = stream.Status().Err() if done != nil { - updateRPCInfoInContext(ctx, rpcInfo{ - bytesSent: true, - bytesReceived: stream.BytesReceived(), + done(balancer.DoneInfo{ + Err: err, + BytesSent: true, + BytesReceived: stream.BytesReceived(), }) - done(balancer.DoneInfo{Err: err}) } if !c.failFast && stream.Unprocessed() { // In these cases, the server did not receive the data, but we still @@ -339,6 +341,6 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli continue } } - return stream.Status().Err() + return err } } diff --git a/vendor/google.golang.org/grpc/clientconn.go b/vendor/google.golang.org/grpc/clientconn.go index ae605bc32..bfbef3621 100644 --- a/vendor/google.golang.org/grpc/clientconn.go +++ b/vendor/google.golang.org/grpc/clientconn.go @@ -95,8 +95,14 @@ type dialOptions struct { scChan <-chan ServiceConfig copts transport.ConnectOptions callOptions []CallOption - // This is to support v1 balancer. + // This is used by v1 balancer dial option WithBalancer to support v1 + // balancer, and also by WithBalancerName dial option. balancerBuilder balancer.Builder + // This is to support grpclb. + resolverBuilder resolver.Builder + // Custom user options for resolver.Build. + resolverBuildUserOptions interface{} + waitForHandshake bool } const ( @@ -107,6 +113,15 @@ const ( // DialOption configures how we set up the connection. type DialOption func(*dialOptions) +// WithWaitForHandshake blocks until the initial settings frame is received from the +// server before assigning RPCs to the connection. +// Experimental API. +func WithWaitForHandshake() DialOption { + return func(o *dialOptions) { + o.waitForHandshake = true + } +} + // WithWriteBufferSize lets you set the size of write buffer, this determines how much data can be batched // before doing a write on the wire. func WithWriteBufferSize(s int) DialOption { @@ -186,7 +201,8 @@ func WithDecompressor(dc Decompressor) DialOption { // WithBalancer returns a DialOption which sets a load balancer with the v1 API. // Name resolver will be ignored if this DialOption is specified. -// Deprecated: use the new balancer APIs in balancer package instead. +// +// Deprecated: use the new balancer APIs in balancer package and WithBalancerName. func WithBalancer(b Balancer) DialOption { return func(o *dialOptions) { o.balancerBuilder = &balancerWrapperBuilder{ @@ -195,12 +211,36 @@ func WithBalancer(b Balancer) DialOption { } } -// WithBalancerBuilder is for testing only. Users using custom balancers should -// register their balancer and use service config to choose the balancer to use. -func WithBalancerBuilder(b balancer.Builder) DialOption { - // TODO(bar) remove this when switching balancer is done. +// WithBalancerName sets the balancer that the ClientConn will be initialized +// with. Balancer registered with balancerName will be used. This function +// panics if no balancer was registered by balancerName. +// +// The balancer cannot be overridden by balancer option specified by service +// config. +// +// This is an EXPERIMENTAL API. +func WithBalancerName(balancerName string) DialOption { + builder := balancer.Get(balancerName) + if builder == nil { + panic(fmt.Sprintf("grpc.WithBalancerName: no balancer is registered for name %v", balancerName)) + } + return func(o *dialOptions) { + o.balancerBuilder = builder + } +} + +// withResolverBuilder is only for grpclb. +func withResolverBuilder(b resolver.Builder) DialOption { return func(o *dialOptions) { - o.balancerBuilder = b + o.resolverBuilder = b + } +} + +// WithResolverUserOptions returns a DialOption which sets the UserOptions +// field of resolver's BuildOption. +func WithResolverUserOptions(userOpt interface{}) DialOption { + return func(o *dialOptions) { + o.resolverBuildUserOptions = userOpt } } @@ -231,7 +271,7 @@ func WithBackoffConfig(b BackoffConfig) DialOption { return withBackoff(b) } -// withBackoff sets the backoff strategy used for retries after a +// withBackoff sets the backoff strategy used for connectRetryNum after a // failed connection attempt. // // This can be exported if arbitrary backoff strategies are allowed by gRPC. @@ -283,18 +323,23 @@ func WithTimeout(d time.Duration) DialOption { } } +func withContextDialer(f func(context.Context, string) (net.Conn, error)) DialOption { + return func(o *dialOptions) { + o.copts.Dialer = f + } +} + // WithDialer returns a DialOption that specifies a function to use for dialing network addresses. // If FailOnNonTempDialError() is set to true, and an error is returned by f, gRPC checks the error's // Temporary() method to decide if it should try to reconnect to the network address. func WithDialer(f func(string, time.Duration) (net.Conn, error)) DialOption { - return func(o *dialOptions) { - o.copts.Dialer = func(ctx context.Context, addr string) (net.Conn, error) { + return withContextDialer( + func(ctx context.Context, addr string) (net.Conn, error) { if deadline, ok := ctx.Deadline(); ok { return f(addr, deadline.Sub(time.Now())) } return f(addr, 0) - } - } + }) } // WithStatsHandler returns a DialOption that specifies the stats handler @@ -480,17 +525,19 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * Dialer: cc.dopts.copts.Dialer, } - if cc.dopts.balancerBuilder != nil { - cc.customBalancer = true - // Build should not take long time. So it's ok to not have a goroutine for it. - cc.balancerWrapper = newCCBalancerWrapper(cc, cc.dopts.balancerBuilder, cc.balancerBuildOpts) - } - // Build the resolver. cc.resolverWrapper, err = newCCResolverWrapper(cc) if err != nil { return nil, fmt.Errorf("failed to build resolver: %v", err) } + // Start the resolver wrapper goroutine after resolverWrapper is created. + // + // If the goroutine is started before resolverWrapper is ready, the + // following may happen: The goroutine sends updates to cc. cc forwards + // those to balancer. Balancer creates new addrConn. addrConn fails to + // connect, and calls resolveNow(). resolveNow() tries to use the non-ready + // resolverWrapper. + cc.resolverWrapper.start() // A blocking dial blocks until the clientConn is ready. if cc.dopts.block { @@ -563,7 +610,6 @@ type ClientConn struct { dopts dialOptions csMgr *connectivityStateManager - customBalancer bool // If this is true, switching balancer will be disabled. balancerBuildOpts balancer.BuildOptions resolverWrapper *ccResolverWrapper blockingpicker *pickerWrapper @@ -575,6 +621,7 @@ type ClientConn struct { // Keepalive parameter can be updated if a GoAway is received. mkp keepalive.ClientParameters curBalancerName string + preBalancerName string // previous balancer name. curAddresses []resolver.Address balancerWrapper *ccBalancerWrapper } @@ -624,51 +671,92 @@ func (cc *ClientConn) handleResolvedAddrs(addrs []resolver.Address, err error) { cc.mu.Lock() defer cc.mu.Unlock() if cc.conns == nil { + // cc was closed. return } - // TODO(bar switching) when grpclb is submitted, check address type and start grpclb. - if !cc.customBalancer && cc.balancerWrapper == nil { - // No customBalancer was specified by DialOption, and this is the first - // time handling resolved addresses, create a pickfirst balancer. - builder := newPickfirstBuilder() - cc.curBalancerName = builder.Name() - cc.balancerWrapper = newCCBalancerWrapper(cc, builder, cc.balancerBuildOpts) + if reflect.DeepEqual(cc.curAddresses, addrs) { + return } - // TODO(bar switching) compare addresses, if there's no update, don't notify balancer. cc.curAddresses = addrs + + if cc.dopts.balancerBuilder == nil { + // Only look at balancer types and switch balancer if balancer dial + // option is not set. + var isGRPCLB bool + for _, a := range addrs { + if a.Type == resolver.GRPCLB { + isGRPCLB = true + break + } + } + var newBalancerName string + if isGRPCLB { + newBalancerName = grpclbName + } else { + // Address list doesn't contain grpclb address. Try to pick a + // non-grpclb balancer. + newBalancerName = cc.curBalancerName + // If current balancer is grpclb, switch to the previous one. + if newBalancerName == grpclbName { + newBalancerName = cc.preBalancerName + } + // The following could be true in two cases: + // - the first time handling resolved addresses + // (curBalancerName="") + // - the first time handling non-grpclb addresses + // (curBalancerName="grpclb", preBalancerName="") + if newBalancerName == "" { + newBalancerName = PickFirstBalancerName + } + } + cc.switchBalancer(newBalancerName) + } else if cc.balancerWrapper == nil { + // Balancer dial option was set, and this is the first time handling + // resolved addresses. Build a balancer with dopts.balancerBuilder. + cc.balancerWrapper = newCCBalancerWrapper(cc, cc.dopts.balancerBuilder, cc.balancerBuildOpts) + } + cc.balancerWrapper.handleResolvedAddrs(addrs, nil) } -// switchBalancer starts the switching from current balancer to the balancer with name. +// switchBalancer starts the switching from current balancer to the balancer +// with the given name. +// +// It will NOT send the current address list to the new balancer. If needed, +// caller of this function should send address list to the new balancer after +// this function returns. +// +// Caller must hold cc.mu. func (cc *ClientConn) switchBalancer(name string) { if cc.conns == nil { return } - grpclog.Infof("ClientConn switching balancer to %q", name) - if cc.customBalancer { - grpclog.Infoln("ignoring service config balancer configuration: WithBalancer DialOption used instead") + if strings.ToLower(cc.curBalancerName) == strings.ToLower(name) { return } - if cc.curBalancerName == name { + grpclog.Infof("ClientConn switching balancer to %q", name) + if cc.dopts.balancerBuilder != nil { + grpclog.Infoln("ignoring balancer switching: Balancer DialOption used instead") return } - // TODO(bar switching) change this to two steps: drain and close. // Keep track of sc in wrapper. - cc.balancerWrapper.close() + if cc.balancerWrapper != nil { + cc.balancerWrapper.close() + } builder := balancer.Get(name) if builder == nil { - grpclog.Infof("failed to get balancer builder for: %v (this should never happen...)", name) + grpclog.Infof("failed to get balancer builder for: %v, using pick_first instead", name) builder = newPickfirstBuilder() } + cc.preBalancerName = cc.curBalancerName cc.curBalancerName = builder.Name() cc.balancerWrapper = newCCBalancerWrapper(cc, builder, cc.balancerBuildOpts) - cc.balancerWrapper.handleResolvedAddrs(cc.curAddresses, nil) } func (cc *ClientConn) handleSubConnStateChange(sc balancer.SubConn, s connectivity.State) { @@ -684,6 +772,8 @@ func (cc *ClientConn) handleSubConnStateChange(sc balancer.SubConn, s connectivi } // newAddrConn creates an addrConn for addrs and adds it to cc.conns. +// +// Caller needs to make sure len(addrs) > 0. func (cc *ClientConn) newAddrConn(addrs []resolver.Address) (*addrConn, error) { ac := &addrConn{ cc: cc, @@ -774,6 +864,7 @@ func (ac *addrConn) tryUpdateAddrs(addrs []resolver.Address) bool { grpclog.Infof("addrConn: tryUpdateAddrs curAddrFound: %v", curAddrFound) if curAddrFound { ac.addrs = addrs + ac.reconnectIdx = 0 // Start reconnecting from beginning in the new list. } return curAddrFound @@ -816,13 +907,33 @@ func (cc *ClientConn) handleServiceConfig(js string) error { cc.mu.Lock() cc.scRaw = js cc.sc = sc - if sc.LB != nil { - cc.switchBalancer(*sc.LB) + if sc.LB != nil && *sc.LB != grpclbName { // "grpclb" is not a valid balancer option in service config. + if cc.curBalancerName == grpclbName { + // If current balancer is grpclb, there's at least one grpclb + // balancer address in the resolved list. Don't switch the balancer, + // but change the previous balancer name, so if a new resolved + // address list doesn't contain grpclb address, balancer will be + // switched to *sc.LB. + cc.preBalancerName = *sc.LB + } else { + cc.switchBalancer(*sc.LB) + cc.balancerWrapper.handleResolvedAddrs(cc.curAddresses, nil) + } } cc.mu.Unlock() return nil } +func (cc *ClientConn) resolveNow(o resolver.ResolveNowOption) { + cc.mu.Lock() + r := cc.resolverWrapper + cc.mu.Unlock() + if r == nil { + return + } + go r.resolveNow(o) +} + // Close tears down the ClientConn and all underlying connections. func (cc *ClientConn) Close() error { cc.cancel() @@ -859,15 +970,16 @@ type addrConn struct { ctx context.Context cancel context.CancelFunc - cc *ClientConn - curAddr resolver.Address - addrs []resolver.Address - dopts dialOptions - events trace.EventLog - acbw balancer.SubConn + cc *ClientConn + addrs []resolver.Address + dopts dialOptions + events trace.EventLog + acbw balancer.SubConn - mu sync.Mutex - state connectivity.State + mu sync.Mutex + curAddr resolver.Address + reconnectIdx int // The index in addrs list to start reconnecting from. + state connectivity.State // ready is closed and becomes nil when a new transport is up or failed // due to timeout. ready chan struct{} @@ -875,6 +987,14 @@ type addrConn struct { // The reason this addrConn is torn down. tearDownErr error + + connectRetryNum int + // backoffDeadline is the time until which resetTransport needs to + // wait before increasing connectRetryNum count. + backoffDeadline time.Time + // connectDeadline is the time by which all connection + // negotiations must complete. + connectDeadline time.Time } // adjustParams updates parameters used to create transports upon @@ -909,6 +1029,15 @@ func (ac *addrConn) errorf(format string, a ...interface{}) { // resetTransport recreates a transport to the address for ac. The old // transport will close itself on error or when the clientconn is closed. +// The created transport must receive initial settings frame from the server. +// In case that doesnt happen, transportMonitor will kill the newly created +// transport after connectDeadline has expired. +// In case there was an error on the transport before the settings frame was +// received, resetTransport resumes connecting to backends after the one that +// was previously connected to. In case end of the list is reached, resetTransport +// backs off until the original deadline. +// If the DialOption WithWaitForHandshake was set, resetTrasport returns +// successfully only after server settings are received. // // TODO(bar) make sure all state transitions are valid. func (ac *addrConn) resetTransport() error { @@ -922,19 +1051,38 @@ func (ac *addrConn) resetTransport() error { ac.ready = nil } ac.transport = nil - ac.curAddr = resolver.Address{} + ridx := ac.reconnectIdx ac.mu.Unlock() ac.cc.mu.RLock() ac.dopts.copts.KeepaliveParams = ac.cc.mkp ac.cc.mu.RUnlock() - for retries := 0; ; retries++ { - sleepTime := ac.dopts.bs.backoff(retries) - timeout := minConnectTimeout + var backoffDeadline, connectDeadline time.Time + for connectRetryNum := 0; ; connectRetryNum++ { ac.mu.Lock() - if timeout < time.Duration(int(sleepTime)/len(ac.addrs)) { - timeout = time.Duration(int(sleepTime) / len(ac.addrs)) + if ac.backoffDeadline.IsZero() { + // This means either a successful HTTP2 connection was established + // or this is the first time this addrConn is trying to establish a + // connection. + backoffFor := ac.dopts.bs.backoff(connectRetryNum) // time.Duration. + // This will be the duration that dial gets to finish. + dialDuration := minConnectTimeout + if backoffFor > dialDuration { + // Give dial more time as we keep failing to connect. + dialDuration = backoffFor + } + start := time.Now() + backoffDeadline = start.Add(backoffFor) + connectDeadline = start.Add(dialDuration) + ridx = 0 // Start connecting from the beginning. + } else { + // Continue trying to conect with the same deadlines. + connectRetryNum = ac.connectRetryNum + backoffDeadline = ac.backoffDeadline + connectDeadline = ac.connectDeadline + ac.backoffDeadline = time.Time{} + ac.connectDeadline = time.Time{} + ac.connectRetryNum = 0 } - connectTime := time.Now() if ac.state == connectivity.Shutdown { ac.mu.Unlock() return errConnClosing @@ -949,93 +1097,159 @@ func (ac *addrConn) resetTransport() error { copy(addrsIter, ac.addrs) copts := ac.dopts.copts ac.mu.Unlock() - for _, addr := range addrsIter { + connected, err := ac.createTransport(connectRetryNum, ridx, backoffDeadline, connectDeadline, addrsIter, copts) + if err != nil { + return err + } + if connected { + return nil + } + } +} + +// createTransport creates a connection to one of the backends in addrs. +// It returns true if a connection was established. +func (ac *addrConn) createTransport(connectRetryNum, ridx int, backoffDeadline, connectDeadline time.Time, addrs []resolver.Address, copts transport.ConnectOptions) (bool, error) { + for i := ridx; i < len(addrs); i++ { + addr := addrs[i] + target := transport.TargetInfo{ + Addr: addr.Addr, + Metadata: addr.Metadata, + Authority: ac.cc.authority, + } + done := make(chan struct{}) + onPrefaceReceipt := func() { ac.mu.Lock() - if ac.state == connectivity.Shutdown { - // ac.tearDown(...) has been invoked. - ac.mu.Unlock() - return errConnClosing + close(done) + if !ac.backoffDeadline.IsZero() { + // If we haven't already started reconnecting to + // other backends. + // Note, this can happen when writer notices an error + // and triggers resetTransport while at the same time + // reader receives the preface and invokes this closure. + ac.backoffDeadline = time.Time{} + ac.connectDeadline = time.Time{} + ac.connectRetryNum = 0 } ac.mu.Unlock() - sinfo := transport.TargetInfo{ - Addr: addr.Addr, - Metadata: addr.Metadata, - Authority: ac.cc.authority, - } - newTransport, err := transport.NewClientTransport(ac.cc.ctx, sinfo, copts, timeout) - if err != nil { - if e, ok := err.(transport.ConnectionError); ok && !e.Temporary() { - ac.mu.Lock() - if ac.state != connectivity.Shutdown { - ac.state = connectivity.TransientFailure - ac.cc.handleSubConnStateChange(ac.acbw, ac.state) - } - ac.mu.Unlock() - return err - } - grpclog.Warningf("grpc: addrConn.resetTransport failed to create client transport: %v; Reconnecting to %v", err, addr) + } + // Do not cancel in the success path because of + // this issue in Go1.6: https://github.com/golang/go/issues/15078. + connectCtx, cancel := context.WithDeadline(ac.ctx, connectDeadline) + newTr, err := transport.NewClientTransport(connectCtx, ac.cc.ctx, target, copts, onPrefaceReceipt) + if err != nil { + cancel() + if e, ok := err.(transport.ConnectionError); ok && !e.Temporary() { ac.mu.Lock() - if ac.state == connectivity.Shutdown { - // ac.tearDown(...) has been invoked. - ac.mu.Unlock() - return errConnClosing + if ac.state != connectivity.Shutdown { + ac.state = connectivity.TransientFailure + ac.cc.handleSubConnStateChange(ac.acbw, ac.state) } ac.mu.Unlock() - continue + return false, err } ac.mu.Lock() - ac.printf("ready") if ac.state == connectivity.Shutdown { // ac.tearDown(...) has been invoked. ac.mu.Unlock() - newTransport.Close() - return errConnClosing - } - ac.state = connectivity.Ready - ac.cc.handleSubConnStateChange(ac.acbw, ac.state) - t := ac.transport - ac.transport = newTransport - if t != nil { - t.Close() - } - ac.curAddr = addr - if ac.ready != nil { - close(ac.ready) - ac.ready = nil + return false, errConnClosing } ac.mu.Unlock() - return nil + grpclog.Warningf("grpc: addrConn.createTransport failed to connect to %v. Err :%v. Reconnecting...", addr, err) + continue + } + if ac.dopts.waitForHandshake { + select { + case <-done: + case <-connectCtx.Done(): + // Didn't receive server preface, must kill this new transport now. + grpclog.Warningf("grpc: addrConn.createTransport failed to receive server preface before deadline.") + newTr.Close() + break + case <-ac.ctx.Done(): + } } ac.mu.Lock() - ac.state = connectivity.TransientFailure + if ac.state == connectivity.Shutdown { + ac.mu.Unlock() + // ac.tearDonn(...) has been invoked. + newTr.Close() + return false, errConnClosing + } + ac.printf("ready") + ac.state = connectivity.Ready ac.cc.handleSubConnStateChange(ac.acbw, ac.state) + ac.transport = newTr + ac.curAddr = addr if ac.ready != nil { close(ac.ready) ac.ready = nil } - ac.mu.Unlock() - timer := time.NewTimer(sleepTime - time.Since(connectTime)) select { - case <-timer.C: - case <-ac.ctx.Done(): - timer.Stop() - return ac.ctx.Err() + case <-done: + // If the server has responded back with preface already, + // don't set the reconnect parameters. + default: + ac.connectRetryNum = connectRetryNum + ac.backoffDeadline = backoffDeadline + ac.connectDeadline = connectDeadline + ac.reconnectIdx = i + 1 // Start reconnecting from the next backend in the list. } + ac.mu.Unlock() + return true, nil + } + ac.mu.Lock() + ac.state = connectivity.TransientFailure + ac.cc.handleSubConnStateChange(ac.acbw, ac.state) + ac.cc.resolveNow(resolver.ResolveNowOption{}) + if ac.ready != nil { + close(ac.ready) + ac.ready = nil + } + ac.mu.Unlock() + timer := time.NewTimer(backoffDeadline.Sub(time.Now())) + select { + case <-timer.C: + case <-ac.ctx.Done(): timer.Stop() + return false, ac.ctx.Err() } + return false, nil } // Run in a goroutine to track the error in transport and create the // new transport if an error happens. It returns when the channel is closing. func (ac *addrConn) transportMonitor() { for { + var timer *time.Timer + var cdeadline <-chan time.Time ac.mu.Lock() t := ac.transport + if !ac.connectDeadline.IsZero() { + timer = time.NewTimer(ac.connectDeadline.Sub(time.Now())) + cdeadline = timer.C + } ac.mu.Unlock() // Block until we receive a goaway or an error occurs. select { case <-t.GoAway(): case <-t.Error(): + case <-cdeadline: + ac.mu.Lock() + // This implies that client received server preface. + if ac.backoffDeadline.IsZero() { + ac.mu.Unlock() + continue + } + ac.mu.Unlock() + timer = nil + // No server preface received until deadline. + // Kill the connection. + grpclog.Warningf("grpc: addrConn.transportMonitor didn't get server preface after waiting. Closing the new transport now.") + t.Close() + } + if timer != nil { + timer.Stop() } // If a GoAway happened, regardless of error, adjust our keepalive // parameters as appropriate. @@ -1053,6 +1267,7 @@ func (ac *addrConn) transportMonitor() { // resetTransport. Transition READY->CONNECTING is not valid. ac.state = connectivity.TransientFailure ac.cc.handleSubConnStateChange(ac.acbw, ac.state) + ac.cc.resolveNow(resolver.ResolveNowOption{}) ac.curAddr = resolver.Address{} ac.mu.Unlock() if err := ac.resetTransport(); err != nil { @@ -1140,6 +1355,9 @@ func (ac *addrConn) tearDown(err error) { ac.cancel() ac.mu.Lock() defer ac.mu.Unlock() + if ac.state == connectivity.Shutdown { + return + } ac.curAddr = resolver.Address{} if err == errConnDrain && ac.transport != nil { // GracefulClose(...) may be executed multiple times when @@ -1148,9 +1366,6 @@ func (ac *addrConn) tearDown(err error) { // address removal and GoAway. ac.transport.GracefulClose() } - if ac.state == connectivity.Shutdown { - return - } ac.state = connectivity.Shutdown ac.tearDownErr = err ac.cc.handleSubConnStateChange(ac.acbw, ac.state) diff --git a/vendor/google.golang.org/grpc/codec.go b/vendor/google.golang.org/grpc/codec.go index b452a4ae8..43d81ed2a 100644 --- a/vendor/google.golang.org/grpc/codec.go +++ b/vendor/google.golang.org/grpc/codec.go @@ -69,6 +69,11 @@ func (p protoCodec) marshal(v interface{}, cb *cachedProtoBuffer) ([]byte, error } func (p protoCodec) Marshal(v interface{}) ([]byte, error) { + if pm, ok := v.(proto.Marshaler); ok { + // object can marshal itself, no need for buffer + return pm.Marshal() + } + cb := protoBufferPool.Get().(*cachedProtoBuffer) out, err := p.marshal(v, cb) @@ -79,10 +84,17 @@ func (p protoCodec) Marshal(v interface{}) ([]byte, error) { } func (p protoCodec) Unmarshal(data []byte, v interface{}) error { + protoMsg := v.(proto.Message) + protoMsg.Reset() + + if pu, ok := protoMsg.(proto.Unmarshaler); ok { + // object can unmarshal itself, no need for buffer + return pu.Unmarshal(data) + } + cb := protoBufferPool.Get().(*cachedProtoBuffer) cb.SetBuf(data) - v.(proto.Message).Reset() - err := cb.Unmarshal(v.(proto.Message)) + err := cb.Unmarshal(protoMsg) cb.SetBuf(nil) protoBufferPool.Put(cb) return err diff --git a/vendor/google.golang.org/grpc/codes/code_string.go b/vendor/google.golang.org/grpc/codes/code_string.go index d9cf9675b..0b206a578 100644 --- a/vendor/google.golang.org/grpc/codes/code_string.go +++ b/vendor/google.golang.org/grpc/codes/code_string.go @@ -1,16 +1,62 @@ -// Code generated by "stringer -type=Code"; DO NOT EDIT. +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ package codes import "strconv" -const _Code_name = "OKCanceledUnknownInvalidArgumentDeadlineExceededNotFoundAlreadyExistsPermissionDeniedResourceExhaustedFailedPreconditionAbortedOutOfRangeUnimplementedInternalUnavailableDataLossUnauthenticated" - -var _Code_index = [...]uint8{0, 2, 10, 17, 32, 48, 56, 69, 85, 102, 120, 127, 137, 150, 158, 169, 177, 192} - -func (i Code) String() string { - if i >= Code(len(_Code_index)-1) { - return "Code(" + strconv.FormatInt(int64(i), 10) + ")" +func (c Code) String() string { + switch c { + case OK: + return "OK" + case Canceled: + return "Canceled" + case Unknown: + return "Unknown" + case InvalidArgument: + return "InvalidArgument" + case DeadlineExceeded: + return "DeadlineExceeded" + case NotFound: + return "NotFound" + case AlreadyExists: + return "AlreadyExists" + case PermissionDenied: + return "PermissionDenied" + case ResourceExhausted: + return "ResourceExhausted" + case FailedPrecondition: + return "FailedPrecondition" + case Aborted: + return "Aborted" + case OutOfRange: + return "OutOfRange" + case Unimplemented: + return "Unimplemented" + case Internal: + return "Internal" + case Unavailable: + return "Unavailable" + case DataLoss: + return "DataLoss" + case Unauthenticated: + return "Unauthenticated" + default: + return "Code(" + strconv.FormatInt(int64(c), 10) + ")" } - return _Code_name[_Code_index[i]:_Code_index[i+1]] } diff --git a/vendor/google.golang.org/grpc/codes/codes.go b/vendor/google.golang.org/grpc/codes/codes.go index 21e7733a5..f3719d562 100644 --- a/vendor/google.golang.org/grpc/codes/codes.go +++ b/vendor/google.golang.org/grpc/codes/codes.go @@ -19,12 +19,13 @@ // Package codes defines the canonical error codes used by gRPC. It is // consistent across various languages. package codes // import "google.golang.org/grpc/codes" +import ( + "fmt" +) // A Code is an unsigned 32-bit error code as defined in the gRPC spec. type Code uint32 -//go:generate stringer -type=Code - const ( // OK is returned on success. OK Code = 0 @@ -142,3 +143,41 @@ const ( // DataLoss indicates unrecoverable data loss or corruption. DataLoss Code = 15 ) + +var strToCode = map[string]Code{ + `"OK"`: OK, + `"CANCELLED"`:/* [sic] */ Canceled, + `"UNKNOWN"`: Unknown, + `"INVALID_ARGUMENT"`: InvalidArgument, + `"DEADLINE_EXCEEDED"`: DeadlineExceeded, + `"NOT_FOUND"`: NotFound, + `"ALREADY_EXISTS"`: AlreadyExists, + `"PERMISSION_DENIED"`: PermissionDenied, + `"RESOURCE_EXHAUSTED"`: ResourceExhausted, + `"FAILED_PRECONDITION"`: FailedPrecondition, + `"ABORTED"`: Aborted, + `"OUT_OF_RANGE"`: OutOfRange, + `"UNIMPLEMENTED"`: Unimplemented, + `"INTERNAL"`: Internal, + `"UNAVAILABLE"`: Unavailable, + `"DATA_LOSS"`: DataLoss, + `"UNAUTHENTICATED"`: Unauthenticated, +} + +// UnmarshalJSON unmarshals b into the Code. +func (c *Code) UnmarshalJSON(b []byte) error { + // From json.Unmarshaler: By convention, to approximate the behavior of + // Unmarshal itself, Unmarshalers implement UnmarshalJSON([]byte("null")) as + // a no-op. + if string(b) == "null" { + return nil + } + if c == nil { + return fmt.Errorf("nil receiver passed to UnmarshalJSON") + } + if jc, ok := strToCode[string(b)]; ok { + *c = jc + return nil + } + return fmt.Errorf("invalid code: %q", string(b)) +} diff --git a/vendor/google.golang.org/grpc/grpclb.go b/vendor/google.golang.org/grpc/grpclb.go index db56ff362..d14a5d409 100644 --- a/vendor/google.golang.org/grpc/grpclb.go +++ b/vendor/google.golang.org/grpc/grpclb.go @@ -19,21 +19,32 @@ package grpc import ( - "errors" - "fmt" - "math/rand" - "net" + "strconv" + "strings" "sync" "time" "golang.org/x/net/context" - "google.golang.org/grpc/codes" - lbmpb "google.golang.org/grpc/grpclb/grpc_lb_v1/messages" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/connectivity" + lbpb "google.golang.org/grpc/grpclb/grpc_lb_v1/messages" "google.golang.org/grpc/grpclog" - "google.golang.org/grpc/metadata" - "google.golang.org/grpc/naming" + "google.golang.org/grpc/resolver" ) +const ( + lbTokeyKey = "lb-token" + defaultFallbackTimeout = 10 * time.Second + grpclbName = "grpclb" +) + +func convertDuration(d *lbpb.Duration) time.Duration { + if d == nil { + return 0 + } + return time.Duration(d.Seconds)*time.Second + time.Duration(d.Nanos)*time.Nanosecond +} + // Client API for LoadBalancer service. // Mostly copied from generated pb.go file. // To avoid circular dependency. @@ -59,646 +70,273 @@ type balanceLoadClientStream struct { ClientStream } -func (x *balanceLoadClientStream) Send(m *lbmpb.LoadBalanceRequest) error { +func (x *balanceLoadClientStream) Send(m *lbpb.LoadBalanceRequest) error { return x.ClientStream.SendMsg(m) } -func (x *balanceLoadClientStream) Recv() (*lbmpb.LoadBalanceResponse, error) { - m := new(lbmpb.LoadBalanceResponse) +func (x *balanceLoadClientStream) Recv() (*lbpb.LoadBalanceResponse, error) { + m := new(lbpb.LoadBalanceResponse) if err := x.ClientStream.RecvMsg(m); err != nil { return nil, err } return m, nil } -// NewGRPCLBBalancer creates a grpclb load balancer. -func NewGRPCLBBalancer(r naming.Resolver) Balancer { - return &grpclbBalancer{ - r: r, - } +func init() { + balancer.Register(newLBBuilder()) } -type remoteBalancerInfo struct { - addr string - // the server name used for authentication with the remote LB server. - name string +// newLBBuilder creates a builder for grpclb. +func newLBBuilder() balancer.Builder { + return NewLBBuilderWithFallbackTimeout(defaultFallbackTimeout) } -// grpclbAddrInfo consists of the information of a backend server. -type grpclbAddrInfo struct { - addr Address - connected bool - // dropForRateLimiting indicates whether this particular request should be - // dropped by the client for rate limiting. - dropForRateLimiting bool - // dropForLoadBalancing indicates whether this particular request should be - // dropped by the client for load balancing. - dropForLoadBalancing bool +// NewLBBuilderWithFallbackTimeout creates a grpclb builder with the given +// fallbackTimeout. If no response is received from the remote balancer within +// fallbackTimeout, the backend addresses from the resolved address list will be +// used. +// +// Only call this function when a non-default fallback timeout is needed. +func NewLBBuilderWithFallbackTimeout(fallbackTimeout time.Duration) balancer.Builder { + return &lbBuilder{ + fallbackTimeout: fallbackTimeout, + } } -type grpclbBalancer struct { - r naming.Resolver - target string - mu sync.Mutex - seq int // a sequence number to make sure addrCh does not get stale addresses. - w naming.Watcher - addrCh chan []Address - rbs []remoteBalancerInfo - addrs []*grpclbAddrInfo - next int - waitCh chan struct{} - done bool - rand *rand.Rand - - clientStats lbmpb.ClientStats +type lbBuilder struct { + fallbackTimeout time.Duration } -func (b *grpclbBalancer) watchAddrUpdates(w naming.Watcher, ch chan []remoteBalancerInfo) error { - updates, err := w.Next() - if err != nil { - grpclog.Warningf("grpclb: failed to get next addr update from watcher: %v", err) - return err - } - b.mu.Lock() - defer b.mu.Unlock() - if b.done { - return ErrClientConnClosing - } - for _, update := range updates { - switch update.Op { - case naming.Add: - var exist bool - for _, v := range b.rbs { - // TODO: Is the same addr with different server name a different balancer? - if update.Addr == v.addr { - exist = true - break - } - } - if exist { - continue - } - md, ok := update.Metadata.(*naming.AddrMetadataGRPCLB) - if !ok { - // TODO: Revisit the handling here and may introduce some fallback mechanism. - grpclog.Errorf("The name resolution contains unexpected metadata %v", update.Metadata) - continue - } - switch md.AddrType { - case naming.Backend: - // TODO: Revisit the handling here and may introduce some fallback mechanism. - grpclog.Errorf("The name resolution does not give grpclb addresses") - continue - case naming.GRPCLB: - b.rbs = append(b.rbs, remoteBalancerInfo{ - addr: update.Addr, - name: md.ServerName, - }) - default: - grpclog.Errorf("Received unknow address type %d", md.AddrType) - continue - } - case naming.Delete: - for i, v := range b.rbs { - if update.Addr == v.addr { - copy(b.rbs[i:], b.rbs[i+1:]) - b.rbs = b.rbs[:len(b.rbs)-1] - break - } - } - default: - grpclog.Errorf("Unknown update.Op %v", update.Op) - } +func (b *lbBuilder) Name() string { + return grpclbName +} + +func (b *lbBuilder) Build(cc balancer.ClientConn, opt balancer.BuildOptions) balancer.Balancer { + // This generates a manual resolver builder with a random scheme. This + // scheme will be used to dial to remote LB, so we can send filtered address + // updates to remote LB ClientConn using this manual resolver. + scheme := "grpclb_internal_" + strconv.FormatInt(time.Now().UnixNano(), 36) + r := &lbManualResolver{scheme: scheme, ccb: cc} + + var target string + targetSplitted := strings.Split(cc.Target(), ":///") + if len(targetSplitted) < 2 { + target = cc.Target() + } else { + target = targetSplitted[1] } - // TODO: Fall back to the basic round-robin load balancing if the resulting address is - // not a load balancer. - select { - case <-ch: - default: + + lb := &lbBalancer{ + cc: cc, + target: target, + opt: opt, + fallbackTimeout: b.fallbackTimeout, + doneCh: make(chan struct{}), + + manualResolver: r, + csEvltr: &connectivityStateEvaluator{}, + subConns: make(map[resolver.Address]balancer.SubConn), + scStates: make(map[balancer.SubConn]connectivity.State), + picker: &errPicker{err: balancer.ErrNoSubConnAvailable}, + clientStats: &rpcStats{}, } - ch <- b.rbs - return nil + + return lb } -func convertDuration(d *lbmpb.Duration) time.Duration { - if d == nil { - return 0 - } - return time.Duration(d.Seconds)*time.Second + time.Duration(d.Nanos)*time.Nanosecond +type lbBalancer struct { + cc balancer.ClientConn + target string + opt balancer.BuildOptions + fallbackTimeout time.Duration + doneCh chan struct{} + + // manualResolver is used in the remote LB ClientConn inside grpclb. When + // resolved address updates are received by grpclb, filtered updates will be + // send to remote LB ClientConn through this resolver. + manualResolver *lbManualResolver + // The ClientConn to talk to the remote balancer. + ccRemoteLB *ClientConn + + // Support client side load reporting. Each picker gets a reference to this, + // and will update its content. + clientStats *rpcStats + + mu sync.Mutex // guards everything following. + // The full server list including drops, used to check if the newly received + // serverList contains anything new. Each generate picker will also have + // reference to this list to do the first layer pick. + fullServerList []*lbpb.Server + // All backends addresses, with metadata set to nil. This list contains all + // backend addresses in the same order and with the same duplicates as in + // serverlist. When generating picker, a SubConn slice with the same order + // but with only READY SCs will be gerenated. + backendAddrs []resolver.Address + // Roundrobin functionalities. + csEvltr *connectivityStateEvaluator + state connectivity.State + subConns map[resolver.Address]balancer.SubConn // Used to new/remove SubConn. + scStates map[balancer.SubConn]connectivity.State // Used to filter READY SubConns. + picker balancer.Picker + // Support fallback to resolved backend addresses if there's no response + // from remote balancer within fallbackTimeout. + fallbackTimerExpired bool + serverListReceived bool + // resolvedBackendAddrs is resolvedAddrs minus remote balancers. It's set + // when resolved address updates are received, and read in the goroutine + // handling fallback. + resolvedBackendAddrs []resolver.Address } -func (b *grpclbBalancer) processServerList(l *lbmpb.ServerList, seq int) { - if l == nil { +// regeneratePicker takes a snapshot of the balancer, and generates a picker from +// it. The picker +// - always returns ErrTransientFailure if the balancer is in TransientFailure, +// - does two layer roundrobin pick otherwise. +// Caller must hold lb.mu. +func (lb *lbBalancer) regeneratePicker() { + if lb.state == connectivity.TransientFailure { + lb.picker = &errPicker{err: balancer.ErrTransientFailure} return } - servers := l.GetServers() - var ( - sl []*grpclbAddrInfo - addrs []Address - ) - for _, s := range servers { - md := metadata.Pairs("lb-token", s.LoadBalanceToken) - ip := net.IP(s.IpAddress) - ipStr := ip.String() - if ip.To4() == nil { - // Add square brackets to ipv6 addresses, otherwise net.Dial() and - // net.SplitHostPort() will return too many colons error. - ipStr = fmt.Sprintf("[%s]", ipStr) - } - addr := Address{ - Addr: fmt.Sprintf("%s:%d", ipStr, s.Port), - Metadata: &md, + var readySCs []balancer.SubConn + for _, a := range lb.backendAddrs { + if sc, ok := lb.subConns[a]; ok { + if st, ok := lb.scStates[sc]; ok && st == connectivity.Ready { + readySCs = append(readySCs, sc) + } } - sl = append(sl, &grpclbAddrInfo{ - addr: addr, - dropForRateLimiting: s.DropForRateLimiting, - dropForLoadBalancing: s.DropForLoadBalancing, - }) - addrs = append(addrs, addr) } - b.mu.Lock() - defer b.mu.Unlock() - if b.done || seq < b.seq { - return - } - if len(sl) > 0 { - // reset b.next to 0 when replacing the server list. - b.next = 0 - b.addrs = sl - b.addrCh <- addrs - } - return -} -func (b *grpclbBalancer) sendLoadReport(s *balanceLoadClientStream, interval time.Duration, done <-chan struct{}) { - ticker := time.NewTicker(interval) - defer ticker.Stop() - for { - select { - case <-ticker.C: - case <-done: - return - } - b.mu.Lock() - stats := b.clientStats - b.clientStats = lbmpb.ClientStats{} // Clear the stats. - b.mu.Unlock() - t := time.Now() - stats.Timestamp = &lbmpb.Timestamp{ - Seconds: t.Unix(), - Nanos: int32(t.Nanosecond()), - } - if err := s.Send(&lbmpb.LoadBalanceRequest{ - LoadBalanceRequestType: &lbmpb.LoadBalanceRequest_ClientStats{ - ClientStats: &stats, - }, - }); err != nil { - grpclog.Errorf("grpclb: failed to send load report: %v", err) + if len(lb.fullServerList) <= 0 { + if len(readySCs) <= 0 { + lb.picker = &errPicker{err: balancer.ErrNoSubConnAvailable} return } - } -} - -func (b *grpclbBalancer) callRemoteBalancer(lbc *loadBalancerClient, seq int) (retry bool) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - stream, err := lbc.BalanceLoad(ctx) - if err != nil { - grpclog.Errorf("grpclb: failed to perform RPC to the remote balancer %v", err) + lb.picker = &rrPicker{subConns: readySCs} return } - b.mu.Lock() - if b.done { - b.mu.Unlock() - return - } - b.mu.Unlock() - initReq := &lbmpb.LoadBalanceRequest{ - LoadBalanceRequestType: &lbmpb.LoadBalanceRequest_InitialRequest{ - InitialRequest: &lbmpb.InitialLoadBalanceRequest{ - Name: b.target, - }, - }, + lb.picker = &lbPicker{ + serverList: lb.fullServerList, + subConns: readySCs, + stats: lb.clientStats, } - if err := stream.Send(initReq); err != nil { - grpclog.Errorf("grpclb: failed to send init request: %v", err) - // TODO: backoff on retry? - return true - } - reply, err := stream.Recv() - if err != nil { - grpclog.Errorf("grpclb: failed to recv init response: %v", err) - // TODO: backoff on retry? - return true - } - initResp := reply.GetInitialResponse() - if initResp == nil { - grpclog.Errorf("grpclb: reply from remote balancer did not include initial response.") - return - } - // TODO: Support delegation. - if initResp.LoadBalancerDelegate != "" { - // delegation - grpclog.Errorf("TODO: Delegation is not supported yet.") - return - } - streamDone := make(chan struct{}) - defer close(streamDone) - b.mu.Lock() - b.clientStats = lbmpb.ClientStats{} // Clear client stats. - b.mu.Unlock() - if d := convertDuration(initResp.ClientStatsReportInterval); d > 0 { - go b.sendLoadReport(stream, d, streamDone) - } - // Retrieve the server list. - for { - reply, err := stream.Recv() - if err != nil { - grpclog.Errorf("grpclb: failed to recv server list: %v", err) - break - } - b.mu.Lock() - if b.done || seq < b.seq { - b.mu.Unlock() - return - } - b.seq++ // tick when receiving a new list of servers. - seq = b.seq - b.mu.Unlock() - if serverList := reply.GetServerList(); serverList != nil { - b.processServerList(serverList, seq) - } - } - return true + return } -func (b *grpclbBalancer) Start(target string, config BalancerConfig) error { - b.rand = rand.New(rand.NewSource(time.Now().Unix())) - // TODO: Fall back to the basic direct connection if there is no name resolver. - if b.r == nil { - return errors.New("there is no name resolver installed") +func (lb *lbBalancer) HandleSubConnStateChange(sc balancer.SubConn, s connectivity.State) { + grpclog.Infof("lbBalancer: handle SubConn state change: %p, %v", sc, s) + lb.mu.Lock() + defer lb.mu.Unlock() + + oldS, ok := lb.scStates[sc] + if !ok { + grpclog.Infof("lbBalancer: got state changes for an unknown SubConn: %p, %v", sc, s) + return } - b.target = target - b.mu.Lock() - if b.done { - b.mu.Unlock() - return ErrClientConnClosing + lb.scStates[sc] = s + switch s { + case connectivity.Idle: + sc.Connect() + case connectivity.Shutdown: + // When an address was removed by resolver, b called RemoveSubConn but + // kept the sc's state in scStates. Remove state for this sc here. + delete(lb.scStates, sc) } - b.addrCh = make(chan []Address) - w, err := b.r.Resolve(target) - if err != nil { - b.mu.Unlock() - grpclog.Errorf("grpclb: failed to resolve address: %v, err: %v", target, err) - return err - } - b.w = w - b.mu.Unlock() - balancerAddrsCh := make(chan []remoteBalancerInfo, 1) - // Spawn a goroutine to monitor the name resolution of remote load balancer. - go func() { - for { - if err := b.watchAddrUpdates(w, balancerAddrsCh); err != nil { - grpclog.Warningf("grpclb: the naming watcher stops working due to %v.\n", err) - close(balancerAddrsCh) - return - } - } - }() - // Spawn a goroutine to talk to the remote load balancer. - go func() { - var ( - cc *ClientConn - // ccError is closed when there is an error in the current cc. - // A new rb should be picked from rbs and connected. - ccError chan struct{} - rb *remoteBalancerInfo - rbs []remoteBalancerInfo - rbIdx int - ) - - defer func() { - if ccError != nil { - select { - case <-ccError: - default: - close(ccError) - } - } - if cc != nil { - cc.Close() - } - }() - - for { - var ok bool - select { - case rbs, ok = <-balancerAddrsCh: - if !ok { - return - } - foundIdx := -1 - if rb != nil { - for i, trb := range rbs { - if trb == *rb { - foundIdx = i - break - } - } - } - if foundIdx >= 0 { - if foundIdx >= 1 { - // Move the address in use to the beginning of the list. - b.rbs[0], b.rbs[foundIdx] = b.rbs[foundIdx], b.rbs[0] - rbIdx = 0 - } - continue // If found, don't dial new cc. - } else if len(rbs) > 0 { - // Pick a random one from the list, instead of always using the first one. - if l := len(rbs); l > 1 && rb != nil { - tmpIdx := b.rand.Intn(l - 1) - b.rbs[0], b.rbs[tmpIdx] = b.rbs[tmpIdx], b.rbs[0] - } - rbIdx = 0 - rb = &rbs[0] - } else { - // foundIdx < 0 && len(rbs) <= 0. - rb = nil - } - case <-ccError: - ccError = nil - if rbIdx < len(rbs)-1 { - rbIdx++ - rb = &rbs[rbIdx] - } else { - rb = nil - } - } - - if rb == nil { - continue - } - if cc != nil { - cc.Close() - } - // Talk to the remote load balancer to get the server list. - var ( - err error - dopts []DialOption - ) - if creds := config.DialCreds; creds != nil { - if rb.name != "" { - if err := creds.OverrideServerName(rb.name); err != nil { - grpclog.Warningf("grpclb: failed to override the server name in the credentials: %v", err) - continue - } - } - dopts = append(dopts, WithTransportCredentials(creds)) - } else { - dopts = append(dopts, WithInsecure()) - } - if dialer := config.Dialer; dialer != nil { - // WithDialer takes a different type of function, so we instead use a special DialOption here. - dopts = append(dopts, func(o *dialOptions) { o.copts.Dialer = dialer }) - } - dopts = append(dopts, WithBlock()) - ccError = make(chan struct{}) - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - cc, err = DialContext(ctx, rb.addr, dopts...) - cancel() - if err != nil { - grpclog.Warningf("grpclb: failed to setup a connection to the remote balancer %v: %v", rb.addr, err) - close(ccError) - continue - } - b.mu.Lock() - b.seq++ // tick when getting a new balancer address - seq := b.seq - b.next = 0 - b.mu.Unlock() - go func(cc *ClientConn, ccError chan struct{}) { - lbc := &loadBalancerClient{cc} - b.callRemoteBalancer(lbc, seq) - cc.Close() - select { - case <-ccError: - default: - close(ccError) - } - }(cc, ccError) - } - }() - return nil -} + oldAggrState := lb.state + lb.state = lb.csEvltr.recordTransition(oldS, s) -func (b *grpclbBalancer) down(addr Address, err error) { - b.mu.Lock() - defer b.mu.Unlock() - for _, a := range b.addrs { - if addr == a.addr { - a.connected = false - break - } + // Regenerate picker when one of the following happens: + // - this sc became ready from not-ready + // - this sc became not-ready from ready + // - the aggregated state of balancer became TransientFailure from non-TransientFailure + // - the aggregated state of balancer became non-TransientFailure from TransientFailure + if (oldS == connectivity.Ready) != (s == connectivity.Ready) || + (lb.state == connectivity.TransientFailure) != (oldAggrState == connectivity.TransientFailure) { + lb.regeneratePicker() } + + lb.cc.UpdateBalancerState(lb.state, lb.picker) + return } -func (b *grpclbBalancer) Up(addr Address) func(error) { - b.mu.Lock() - defer b.mu.Unlock() - if b.done { - return nil - } - var cnt int - for _, a := range b.addrs { - if a.addr == addr { - if a.connected { - return nil - } - a.connected = true - } - if a.connected && !a.dropForRateLimiting && !a.dropForLoadBalancing { - cnt++ - } - } - // addr is the only one which is connected. Notify the Get() callers who are blocking. - if cnt == 1 && b.waitCh != nil { - close(b.waitCh) - b.waitCh = nil +// fallbackToBackendsAfter blocks for fallbackTimeout and falls back to use +// resolved backends (backends received from resolver, not from remote balancer) +// if no connection to remote balancers was successful. +func (lb *lbBalancer) fallbackToBackendsAfter(fallbackTimeout time.Duration) { + timer := time.NewTimer(fallbackTimeout) + defer timer.Stop() + select { + case <-timer.C: + case <-lb.doneCh: + return } - return func(err error) { - b.down(addr, err) + lb.mu.Lock() + if lb.serverListReceived { + lb.mu.Unlock() + return } + lb.fallbackTimerExpired = true + lb.refreshSubConns(lb.resolvedBackendAddrs) + lb.mu.Unlock() } -func (b *grpclbBalancer) Get(ctx context.Context, opts BalancerGetOptions) (addr Address, put func(), err error) { - var ch chan struct{} - b.mu.Lock() - if b.done { - b.mu.Unlock() - err = ErrClientConnClosing +// HandleResolvedAddrs sends the updated remoteLB addresses to remoteLB +// clientConn. The remoteLB clientConn will handle creating/removing remoteLB +// connections. +func (lb *lbBalancer) HandleResolvedAddrs(addrs []resolver.Address, err error) { + grpclog.Infof("lbBalancer: handleResolvedResult: %+v", addrs) + if len(addrs) <= 0 { return } - seq := b.seq - defer func() { - if err != nil { - return - } - put = func() { - s, ok := rpcInfoFromContext(ctx) - if !ok { - return - } - b.mu.Lock() - defer b.mu.Unlock() - if b.done || seq < b.seq { - return - } - b.clientStats.NumCallsFinished++ - if !s.bytesSent { - b.clientStats.NumCallsFinishedWithClientFailedToSend++ - } else if s.bytesReceived { - b.clientStats.NumCallsFinishedKnownReceived++ - } + var remoteBalancerAddrs, backendAddrs []resolver.Address + for _, a := range addrs { + if a.Type == resolver.GRPCLB { + remoteBalancerAddrs = append(remoteBalancerAddrs, a) + } else { + backendAddrs = append(backendAddrs, a) } - }() - - b.clientStats.NumCallsStarted++ - if len(b.addrs) > 0 { - if b.next >= len(b.addrs) { - b.next = 0 - } - next := b.next - for { - a := b.addrs[next] - next = (next + 1) % len(b.addrs) - if a.connected { - if !a.dropForRateLimiting && !a.dropForLoadBalancing { - addr = a.addr - b.next = next - b.mu.Unlock() - return - } - if !opts.BlockingWait { - b.next = next - if a.dropForLoadBalancing { - b.clientStats.NumCallsFinished++ - b.clientStats.NumCallsFinishedWithDropForLoadBalancing++ - } else if a.dropForRateLimiting { - b.clientStats.NumCallsFinished++ - b.clientStats.NumCallsFinishedWithDropForRateLimiting++ - } - b.mu.Unlock() - err = Errorf(codes.Unavailable, "%s drops requests", a.addr.Addr) - return - } - } - if next == b.next { - // Has iterated all the possible address but none is connected. - break - } - } - } - if !opts.BlockingWait { - b.clientStats.NumCallsFinished++ - b.clientStats.NumCallsFinishedWithClientFailedToSend++ - b.mu.Unlock() - err = Errorf(codes.Unavailable, "there is no address available") - return } - // Wait on b.waitCh for non-failfast RPCs. - if b.waitCh == nil { - ch = make(chan struct{}) - b.waitCh = ch - } else { - ch = b.waitCh - } - b.mu.Unlock() - for { - select { - case <-ctx.Done(): - b.mu.Lock() - b.clientStats.NumCallsFinished++ - b.clientStats.NumCallsFinishedWithClientFailedToSend++ - b.mu.Unlock() - err = ctx.Err() - return - case <-ch: - b.mu.Lock() - if b.done { - b.clientStats.NumCallsFinished++ - b.clientStats.NumCallsFinishedWithClientFailedToSend++ - b.mu.Unlock() - err = ErrClientConnClosing - return - } - if len(b.addrs) > 0 { - if b.next >= len(b.addrs) { - b.next = 0 - } - next := b.next - for { - a := b.addrs[next] - next = (next + 1) % len(b.addrs) - if a.connected { - if !a.dropForRateLimiting && !a.dropForLoadBalancing { - addr = a.addr - b.next = next - b.mu.Unlock() - return - } - if !opts.BlockingWait { - b.next = next - if a.dropForLoadBalancing { - b.clientStats.NumCallsFinished++ - b.clientStats.NumCallsFinishedWithDropForLoadBalancing++ - } else if a.dropForRateLimiting { - b.clientStats.NumCallsFinished++ - b.clientStats.NumCallsFinishedWithDropForRateLimiting++ - } - b.mu.Unlock() - err = Errorf(codes.Unavailable, "drop requests for the addreess %s", a.addr.Addr) - return - } - } - if next == b.next { - // Has iterated all the possible address but none is connected. - break - } - } - } - // The newly added addr got removed by Down() again. - if b.waitCh == nil { - ch = make(chan struct{}) - b.waitCh = ch - } else { - ch = b.waitCh - } - b.mu.Unlock() + if lb.ccRemoteLB == nil { + if len(remoteBalancerAddrs) <= 0 { + grpclog.Errorf("grpclb: no remote balancer address is available, should never happen") + return } + // First time receiving resolved addresses, create a cc to remote + // balancers. + lb.dialRemoteLB(remoteBalancerAddrs[0].ServerName) + // Start the fallback goroutine. + go lb.fallbackToBackendsAfter(lb.fallbackTimeout) } -} -func (b *grpclbBalancer) Notify() <-chan []Address { - return b.addrCh + // cc to remote balancers uses lb.manualResolver. Send the updated remote + // balancer addresses to it through manualResolver. + lb.manualResolver.NewAddress(remoteBalancerAddrs) + + lb.mu.Lock() + lb.resolvedBackendAddrs = backendAddrs + // If serverListReceived is true, connection to remote balancer was + // successful and there's no need to do fallback anymore. + // If fallbackTimerExpired is false, fallback hasn't happened yet. + if !lb.serverListReceived && lb.fallbackTimerExpired { + // This means we received a new list of resolved backends, and we are + // still in fallback mode. Need to update the list of backends we are + // using to the new list of backends. + lb.refreshSubConns(lb.resolvedBackendAddrs) + } + lb.mu.Unlock() } -func (b *grpclbBalancer) Close() error { - b.mu.Lock() - defer b.mu.Unlock() - if b.done { - return errBalancerClosed - } - b.done = true - if b.waitCh != nil { - close(b.waitCh) - } - if b.addrCh != nil { - close(b.addrCh) +func (lb *lbBalancer) Close() { + select { + case <-lb.doneCh: + return + default: } - if b.w != nil { - b.w.Close() + close(lb.doneCh) + if lb.ccRemoteLB != nil { + lb.ccRemoteLB.Close() } - return nil } diff --git a/vendor/google.golang.org/grpc/grpclb_picker.go b/vendor/google.golang.org/grpc/grpclb_picker.go new file mode 100644 index 000000000..872c7ccea --- /dev/null +++ b/vendor/google.golang.org/grpc/grpclb_picker.go @@ -0,0 +1,159 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package grpc + +import ( + "sync" + "sync/atomic" + + "golang.org/x/net/context" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/codes" + lbpb "google.golang.org/grpc/grpclb/grpc_lb_v1/messages" + "google.golang.org/grpc/status" +) + +type rpcStats struct { + NumCallsStarted int64 + NumCallsFinished int64 + NumCallsFinishedWithDropForRateLimiting int64 + NumCallsFinishedWithDropForLoadBalancing int64 + NumCallsFinishedWithClientFailedToSend int64 + NumCallsFinishedKnownReceived int64 +} + +// toClientStats converts rpcStats to lbpb.ClientStats, and clears rpcStats. +func (s *rpcStats) toClientStats() *lbpb.ClientStats { + stats := &lbpb.ClientStats{ + NumCallsStarted: atomic.SwapInt64(&s.NumCallsStarted, 0), + NumCallsFinished: atomic.SwapInt64(&s.NumCallsFinished, 0), + NumCallsFinishedWithDropForRateLimiting: atomic.SwapInt64(&s.NumCallsFinishedWithDropForRateLimiting, 0), + NumCallsFinishedWithDropForLoadBalancing: atomic.SwapInt64(&s.NumCallsFinishedWithDropForLoadBalancing, 0), + NumCallsFinishedWithClientFailedToSend: atomic.SwapInt64(&s.NumCallsFinishedWithClientFailedToSend, 0), + NumCallsFinishedKnownReceived: atomic.SwapInt64(&s.NumCallsFinishedKnownReceived, 0), + } + return stats +} + +func (s *rpcStats) dropForRateLimiting() { + atomic.AddInt64(&s.NumCallsStarted, 1) + atomic.AddInt64(&s.NumCallsFinishedWithDropForRateLimiting, 1) + atomic.AddInt64(&s.NumCallsFinished, 1) +} + +func (s *rpcStats) dropForLoadBalancing() { + atomic.AddInt64(&s.NumCallsStarted, 1) + atomic.AddInt64(&s.NumCallsFinishedWithDropForLoadBalancing, 1) + atomic.AddInt64(&s.NumCallsFinished, 1) +} + +func (s *rpcStats) failedToSend() { + atomic.AddInt64(&s.NumCallsStarted, 1) + atomic.AddInt64(&s.NumCallsFinishedWithClientFailedToSend, 1) + atomic.AddInt64(&s.NumCallsFinished, 1) +} + +func (s *rpcStats) knownReceived() { + atomic.AddInt64(&s.NumCallsStarted, 1) + atomic.AddInt64(&s.NumCallsFinishedKnownReceived, 1) + atomic.AddInt64(&s.NumCallsFinished, 1) +} + +type errPicker struct { + // Pick always returns this err. + err error +} + +func (p *errPicker) Pick(ctx context.Context, opts balancer.PickOptions) (balancer.SubConn, func(balancer.DoneInfo), error) { + return nil, nil, p.err +} + +// rrPicker does roundrobin on subConns. It's typically used when there's no +// response from remote balancer, and grpclb falls back to the resolved +// backends. +// +// It guaranteed that len(subConns) > 0. +type rrPicker struct { + mu sync.Mutex + subConns []balancer.SubConn // The subConns that were READY when taking the snapshot. + subConnsNext int +} + +func (p *rrPicker) Pick(ctx context.Context, opts balancer.PickOptions) (balancer.SubConn, func(balancer.DoneInfo), error) { + p.mu.Lock() + defer p.mu.Unlock() + sc := p.subConns[p.subConnsNext] + p.subConnsNext = (p.subConnsNext + 1) % len(p.subConns) + return sc, nil, nil +} + +// lbPicker does two layers of picks: +// +// First layer: roundrobin on all servers in serverList, including drops and backends. +// - If it picks a drop, the RPC will fail as being dropped. +// - If it picks a backend, do a second layer pick to pick the real backend. +// +// Second layer: roundrobin on all READY backends. +// +// It's guaranteed that len(serverList) > 0. +type lbPicker struct { + mu sync.Mutex + serverList []*lbpb.Server + serverListNext int + subConns []balancer.SubConn // The subConns that were READY when taking the snapshot. + subConnsNext int + + stats *rpcStats +} + +func (p *lbPicker) Pick(ctx context.Context, opts balancer.PickOptions) (balancer.SubConn, func(balancer.DoneInfo), error) { + p.mu.Lock() + defer p.mu.Unlock() + + // Layer one roundrobin on serverList. + s := p.serverList[p.serverListNext] + p.serverListNext = (p.serverListNext + 1) % len(p.serverList) + + // If it's a drop, return an error and fail the RPC. + if s.DropForRateLimiting { + p.stats.dropForRateLimiting() + return nil, nil, status.Errorf(codes.Unavailable, "request dropped by grpclb") + } + if s.DropForLoadBalancing { + p.stats.dropForLoadBalancing() + return nil, nil, status.Errorf(codes.Unavailable, "request dropped by grpclb") + } + + // If not a drop but there's no ready subConns. + if len(p.subConns) <= 0 { + return nil, nil, balancer.ErrNoSubConnAvailable + } + + // Return the next ready subConn in the list, also collect rpc stats. + sc := p.subConns[p.subConnsNext] + p.subConnsNext = (p.subConnsNext + 1) % len(p.subConns) + done := func(info balancer.DoneInfo) { + if !info.BytesSent { + p.stats.failedToSend() + } else if info.BytesReceived { + p.stats.knownReceived() + } + } + return sc, done, nil +} diff --git a/vendor/google.golang.org/grpc/grpclb_remote_balancer.go b/vendor/google.golang.org/grpc/grpclb_remote_balancer.go new file mode 100644 index 000000000..1b580df26 --- /dev/null +++ b/vendor/google.golang.org/grpc/grpclb_remote_balancer.go @@ -0,0 +1,254 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package grpc + +import ( + "fmt" + "net" + "reflect" + "time" + + "golang.org/x/net/context" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/connectivity" + lbpb "google.golang.org/grpc/grpclb/grpc_lb_v1/messages" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/resolver" +) + +// processServerList updates balaner's internal state, create/remove SubConns +// and regenerates picker using the received serverList. +func (lb *lbBalancer) processServerList(l *lbpb.ServerList) { + grpclog.Infof("lbBalancer: processing server list: %+v", l) + lb.mu.Lock() + defer lb.mu.Unlock() + + // Set serverListReceived to true so fallback will not take effect if it has + // not hit timeout. + lb.serverListReceived = true + + // If the new server list == old server list, do nothing. + if reflect.DeepEqual(lb.fullServerList, l.Servers) { + grpclog.Infof("lbBalancer: new serverlist same as the previous one, ignoring") + return + } + lb.fullServerList = l.Servers + + var backendAddrs []resolver.Address + for _, s := range l.Servers { + if s.DropForLoadBalancing || s.DropForRateLimiting { + continue + } + + md := metadata.Pairs(lbTokeyKey, s.LoadBalanceToken) + ip := net.IP(s.IpAddress) + ipStr := ip.String() + if ip.To4() == nil { + // Add square brackets to ipv6 addresses, otherwise net.Dial() and + // net.SplitHostPort() will return too many colons error. + ipStr = fmt.Sprintf("[%s]", ipStr) + } + addr := resolver.Address{ + Addr: fmt.Sprintf("%s:%d", ipStr, s.Port), + Metadata: &md, + } + + backendAddrs = append(backendAddrs, addr) + } + + // Call refreshSubConns to create/remove SubConns. + backendsUpdated := lb.refreshSubConns(backendAddrs) + // If no backend was updated, no SubConn will be newed/removed. But since + // the full serverList was different, there might be updates in drops or + // pick weights(different number of duplicates). We need to update picker + // with the fulllist. + if !backendsUpdated { + lb.regeneratePicker() + lb.cc.UpdateBalancerState(lb.state, lb.picker) + } +} + +// refreshSubConns creates/removes SubConns with backendAddrs. It returns a bool +// indicating whether the backendAddrs are different from the cached +// backendAddrs (whether any SubConn was newed/removed). +// Caller must hold lb.mu. +func (lb *lbBalancer) refreshSubConns(backendAddrs []resolver.Address) bool { + lb.backendAddrs = nil + var backendsUpdated bool + // addrsSet is the set converted from backendAddrs, it's used to quick + // lookup for an address. + addrsSet := make(map[resolver.Address]struct{}) + // Create new SubConns. + for _, addr := range backendAddrs { + addrWithoutMD := addr + addrWithoutMD.Metadata = nil + addrsSet[addrWithoutMD] = struct{}{} + lb.backendAddrs = append(lb.backendAddrs, addrWithoutMD) + + if _, ok := lb.subConns[addrWithoutMD]; !ok { + backendsUpdated = true + + // Use addrWithMD to create the SubConn. + sc, err := lb.cc.NewSubConn([]resolver.Address{addr}, balancer.NewSubConnOptions{}) + if err != nil { + grpclog.Warningf("roundrobinBalancer: failed to create new SubConn: %v", err) + continue + } + lb.subConns[addrWithoutMD] = sc // Use the addr without MD as key for the map. + lb.scStates[sc] = connectivity.Idle + sc.Connect() + } + } + + for a, sc := range lb.subConns { + // a was removed by resolver. + if _, ok := addrsSet[a]; !ok { + backendsUpdated = true + + lb.cc.RemoveSubConn(sc) + delete(lb.subConns, a) + // Keep the state of this sc in b.scStates until sc's state becomes Shutdown. + // The entry will be deleted in HandleSubConnStateChange. + } + } + + return backendsUpdated +} + +func (lb *lbBalancer) readServerList(s *balanceLoadClientStream) error { + for { + reply, err := s.Recv() + if err != nil { + return fmt.Errorf("grpclb: failed to recv server list: %v", err) + } + if serverList := reply.GetServerList(); serverList != nil { + lb.processServerList(serverList) + } + } +} + +func (lb *lbBalancer) sendLoadReport(s *balanceLoadClientStream, interval time.Duration) { + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + case <-s.Context().Done(): + return + } + stats := lb.clientStats.toClientStats() + t := time.Now() + stats.Timestamp = &lbpb.Timestamp{ + Seconds: t.Unix(), + Nanos: int32(t.Nanosecond()), + } + if err := s.Send(&lbpb.LoadBalanceRequest{ + LoadBalanceRequestType: &lbpb.LoadBalanceRequest_ClientStats{ + ClientStats: stats, + }, + }); err != nil { + return + } + } +} +func (lb *lbBalancer) callRemoteBalancer() error { + lbClient := &loadBalancerClient{cc: lb.ccRemoteLB} + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + stream, err := lbClient.BalanceLoad(ctx, FailFast(false)) + if err != nil { + return fmt.Errorf("grpclb: failed to perform RPC to the remote balancer %v", err) + } + + // grpclb handshake on the stream. + initReq := &lbpb.LoadBalanceRequest{ + LoadBalanceRequestType: &lbpb.LoadBalanceRequest_InitialRequest{ + InitialRequest: &lbpb.InitialLoadBalanceRequest{ + Name: lb.target, + }, + }, + } + if err := stream.Send(initReq); err != nil { + return fmt.Errorf("grpclb: failed to send init request: %v", err) + } + reply, err := stream.Recv() + if err != nil { + return fmt.Errorf("grpclb: failed to recv init response: %v", err) + } + initResp := reply.GetInitialResponse() + if initResp == nil { + return fmt.Errorf("grpclb: reply from remote balancer did not include initial response") + } + if initResp.LoadBalancerDelegate != "" { + return fmt.Errorf("grpclb: Delegation is not supported") + } + + go func() { + if d := convertDuration(initResp.ClientStatsReportInterval); d > 0 { + lb.sendLoadReport(stream, d) + } + }() + return lb.readServerList(stream) +} + +func (lb *lbBalancer) watchRemoteBalancer() { + for { + err := lb.callRemoteBalancer() + select { + case <-lb.doneCh: + return + default: + if err != nil { + grpclog.Error(err) + } + } + + } +} + +func (lb *lbBalancer) dialRemoteLB(remoteLBName string) { + var dopts []DialOption + if creds := lb.opt.DialCreds; creds != nil { + if err := creds.OverrideServerName(remoteLBName); err == nil { + dopts = append(dopts, WithTransportCredentials(creds)) + } else { + grpclog.Warningf("grpclb: failed to override the server name in the credentials: %v, using Insecure", err) + dopts = append(dopts, WithInsecure()) + } + } else { + dopts = append(dopts, WithInsecure()) + } + if lb.opt.Dialer != nil { + // WithDialer takes a different type of function, so we instead use a + // special DialOption here. + dopts = append(dopts, withContextDialer(lb.opt.Dialer)) + } + // Explicitly set pickfirst as the balancer. + dopts = append(dopts, WithBalancerName(PickFirstBalancerName)) + dopts = append(dopts, withResolverBuilder(lb.manualResolver)) + // Dial using manualResolver.Scheme, which is a random scheme generated + // when init grpclb. The target name is not important. + cc, err := Dial("grpclb:///grpclb.server", dopts...) + if err != nil { + grpclog.Fatalf("failed to dial: %v", err) + } + lb.ccRemoteLB = cc + go lb.watchRemoteBalancer() +} diff --git a/vendor/google.golang.org/grpc/grpclb_util.go b/vendor/google.golang.org/grpc/grpclb_util.go new file mode 100644 index 000000000..93ab2db32 --- /dev/null +++ b/vendor/google.golang.org/grpc/grpclb_util.go @@ -0,0 +1,90 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package grpc + +import ( + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/resolver" +) + +// The parent ClientConn should re-resolve when grpclb loses connection to the +// remote balancer. When the ClientConn inside grpclb gets a TransientFailure, +// it calls lbManualResolver.ResolveNow(), which calls parent ClientConn's +// ResolveNow, and eventually results in re-resolve happening in parent +// ClientConn's resolver (DNS for example). +// +// parent +// ClientConn +// +-----------------------------------------------------------------+ +// | parent +---------------------------------+ | +// | DNS ClientConn | grpclb | | +// | resolver balancerWrapper | | | +// | + + | grpclb grpclb | | +// | | | | ManualResolver ClientConn | | +// | | | | + + | | +// | | | | | | Transient | | +// | | | | | | Failure | | +// | | | | | <--------- | | | +// | | | <--------------- | ResolveNow | | | +// | | <--------- | ResolveNow | | | | | +// | | ResolveNow | | | | | | +// | | | | | | | | +// | + + | + + | | +// | +---------------------------------+ | +// +-----------------------------------------------------------------+ + +// lbManualResolver is used by the ClientConn inside grpclb. It's a manual +// resolver with a special ResolveNow() function. +// +// When ResolveNow() is called, it calls ResolveNow() on the parent ClientConn, +// so when grpclb client lose contact with remote balancers, the parent +// ClientConn's resolver will re-resolve. +type lbManualResolver struct { + scheme string + ccr resolver.ClientConn + + ccb balancer.ClientConn +} + +func (r *lbManualResolver) Build(_ resolver.Target, cc resolver.ClientConn, _ resolver.BuildOption) (resolver.Resolver, error) { + r.ccr = cc + return r, nil +} + +func (r *lbManualResolver) Scheme() string { + return r.scheme +} + +// ResolveNow calls resolveNow on the parent ClientConn. +func (r *lbManualResolver) ResolveNow(o resolver.ResolveNowOption) { + r.ccb.ResolveNow(o) +} + +// Close is a noop for Resolver. +func (*lbManualResolver) Close() {} + +// NewAddress calls cc.NewAddress. +func (r *lbManualResolver) NewAddress(addrs []resolver.Address) { + r.ccr.NewAddress(addrs) +} + +// NewServiceConfig calls cc.NewServiceConfig. +func (r *lbManualResolver) NewServiceConfig(sc string) { + r.ccr.NewServiceConfig(sc) +} diff --git a/vendor/google.golang.org/grpc/naming/go17.go b/vendor/google.golang.org/grpc/naming/go17.go index a537b08c6..57b65d7b8 100644 --- a/vendor/google.golang.org/grpc/naming/go17.go +++ b/vendor/google.golang.org/grpc/naming/go17.go @@ -1,4 +1,4 @@ -// +build go1.6, !go1.8 +// +build go1.6,!go1.8 /* * diff --git a/vendor/google.golang.org/grpc/picker_wrapper.go b/vendor/google.golang.org/grpc/picker_wrapper.go index 9085dbc9c..db82bfb3a 100644 --- a/vendor/google.golang.org/grpc/picker_wrapper.go +++ b/vendor/google.golang.org/grpc/picker_wrapper.go @@ -97,7 +97,7 @@ func (bp *pickerWrapper) pick(ctx context.Context, failfast bool, opts balancer. p = bp.picker bp.mu.Unlock() - subConn, put, err := p.Pick(ctx, opts) + subConn, done, err := p.Pick(ctx, opts) if err != nil { switch err { @@ -120,7 +120,7 @@ func (bp *pickerWrapper) pick(ctx context.Context, failfast bool, opts balancer. continue } if t, ok := acw.getAddrConn().getReadyTransport(); ok { - return t, put, nil + return t, done, nil } grpclog.Infof("blockingPicker: the picked transport is not ready, loop back to repick") // If ok == false, ac.state is not READY. diff --git a/vendor/google.golang.org/grpc/pickfirst.go b/vendor/google.golang.org/grpc/pickfirst.go index e83ca2b0d..bf659d49d 100644 --- a/vendor/google.golang.org/grpc/pickfirst.go +++ b/vendor/google.golang.org/grpc/pickfirst.go @@ -26,6 +26,9 @@ import ( "google.golang.org/grpc/resolver" ) +// PickFirstBalancerName is the name of the pick_first balancer. +const PickFirstBalancerName = "pick_first" + func newPickfirstBuilder() balancer.Builder { return &pickfirstBuilder{} } @@ -37,7 +40,7 @@ func (*pickfirstBuilder) Build(cc balancer.ClientConn, opt balancer.BuildOptions } func (*pickfirstBuilder) Name() string { - return "pick_first" + return PickFirstBalancerName } type pickfirstBalancer struct { diff --git a/vendor/google.golang.org/grpc/resolver/resolver.go b/vendor/google.golang.org/grpc/resolver/resolver.go index 0dd887fa5..df097eedf 100644 --- a/vendor/google.golang.org/grpc/resolver/resolver.go +++ b/vendor/google.golang.org/grpc/resolver/resolver.go @@ -38,7 +38,7 @@ func Register(b Builder) { // Get returns the resolver builder registered with the given scheme. // If no builder is register with the scheme, the default scheme will // be used. -// If the default scheme is not modified, "dns" will be the default +// If the default scheme is not modified, "passthrough" will be the default // scheme, and the preinstalled dns resolver will be used. // If the default scheme is modified, and a resolver is registered with // the scheme, that resolver will be returned. @@ -55,7 +55,7 @@ func Get(scheme string) Builder { } // SetDefaultScheme sets the default scheme that will be used. -// The default default scheme is "dns". +// The default default scheme is "passthrough". func SetDefaultScheme(scheme string) { defaultScheme = scheme } @@ -78,7 +78,9 @@ type Address struct { // Type is the type of this address. Type AddressType // ServerName is the name of this address. - // It's the name of the grpc load balancer, which will be used for authentication. + // + // e.g. if Type is GRPCLB, ServerName should be the name of the remote load + // balancer, not the name of the backend. ServerName string // Metadata is the information associated with Addr, which may be used // to make load balancing decision. @@ -88,10 +90,18 @@ type Address struct { // BuildOption includes additional information for the builder to create // the resolver. type BuildOption struct { + // UserOptions can be used to pass configuration between DialOptions and the + // resolver. + UserOptions interface{} } // ClientConn contains the callbacks for resolver to notify any updates // to the gRPC ClientConn. +// +// This interface is to be implemented by gRPC. Users should not need a +// brand new implementation of this interface. For the situations like +// testing, the new implementation should embed this interface. This allows +// gRPC to add new methods to this interface. type ClientConn interface { // NewAddress is called by resolver to notify ClientConn a new list // of resolved addresses. @@ -128,8 +138,10 @@ type ResolveNowOption struct{} // Resolver watches for the updates on the specified target. // Updates include address updates and service config updates. type Resolver interface { - // ResolveNow will be called by gRPC to try to resolve the target name again. - // It's just a hint, resolver can ignore this if it's not necessary. + // ResolveNow will be called by gRPC to try to resolve the target name + // again. It's just a hint, resolver can ignore this if it's not necessary. + // + // It could be called multiple times concurrently. ResolveNow(ResolveNowOption) // Close closes the resolver. Close() diff --git a/vendor/google.golang.org/grpc/resolver_conn_wrapper.go b/vendor/google.golang.org/grpc/resolver_conn_wrapper.go index c07e174a8..ef5d4c286 100644 --- a/vendor/google.golang.org/grpc/resolver_conn_wrapper.go +++ b/vendor/google.golang.org/grpc/resolver_conn_wrapper.go @@ -61,12 +61,18 @@ func parseTarget(target string) (ret resolver.Target) { // newCCResolverWrapper parses cc.target for scheme and gets the resolver // builder for this scheme. It then builds the resolver and starts the // monitoring goroutine for it. +// +// If withResolverBuilder dial option is set, the specified resolver will be +// used instead. func newCCResolverWrapper(cc *ClientConn) (*ccResolverWrapper, error) { grpclog.Infof("dialing to target with scheme: %q", cc.parsedTarget.Scheme) - rb := resolver.Get(cc.parsedTarget.Scheme) + rb := cc.dopts.resolverBuilder if rb == nil { - return nil, fmt.Errorf("could not get resolver for scheme: %q", cc.parsedTarget.Scheme) + rb = resolver.Get(cc.parsedTarget.Scheme) + if rb == nil { + return nil, fmt.Errorf("could not get resolver for scheme: %q", cc.parsedTarget.Scheme) + } } ccr := &ccResolverWrapper{ @@ -77,14 +83,19 @@ func newCCResolverWrapper(cc *ClientConn) (*ccResolverWrapper, error) { } var err error - ccr.resolver, err = rb.Build(cc.parsedTarget, ccr, resolver.BuildOption{}) + ccr.resolver, err = rb.Build(cc.parsedTarget, ccr, resolver.BuildOption{ + UserOptions: cc.dopts.resolverBuildUserOptions, + }) if err != nil { return nil, err } - go ccr.watcher() return ccr, nil } +func (ccr *ccResolverWrapper) start() { + go ccr.watcher() +} + // watcher processes address updates and service config updates sequencially. // Otherwise, we need to resolve possible races between address and service // config (e.g. they specify different balancer types). @@ -119,6 +130,10 @@ func (ccr *ccResolverWrapper) watcher() { } } +func (ccr *ccResolverWrapper) resolveNow(o resolver.ResolveNowOption) { + ccr.resolver.ResolveNow(o) +} + func (ccr *ccResolverWrapper) close() { ccr.resolver.Close() close(ccr.done) diff --git a/vendor/google.golang.org/grpc/rpc_util.go b/vendor/google.golang.org/grpc/rpc_util.go index 1b11ae05e..bf384b644 100644 --- a/vendor/google.golang.org/grpc/rpc_util.go +++ b/vendor/google.golang.org/grpc/rpc_util.go @@ -293,10 +293,10 @@ func (p *parser) recvMsg(maxReceiveMessageSize int) (pf payloadFormat, msg []byt return pf, nil, nil } if int64(length) > int64(maxInt) { - return 0, nil, Errorf(codes.ResourceExhausted, "grpc: received message larger than max length allowed on current machine (%d vs. %d)", length, maxInt) + return 0, nil, status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max length allowed on current machine (%d vs. %d)", length, maxInt) } if int(length) > maxReceiveMessageSize { - return 0, nil, Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", length, maxReceiveMessageSize) + return 0, nil, status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", length, maxReceiveMessageSize) } // TODO(bradfitz,zhaoq): garbage. reuse buffer after proto decoding instead // of making it for each message: @@ -326,7 +326,7 @@ func encode(c Codec, msg interface{}, cp Compressor, outPayload *stats.OutPayloa var err error b, err = c.Marshal(msg) if err != nil { - return nil, nil, Errorf(codes.Internal, "grpc: error while marshaling: %v", err.Error()) + return nil, nil, status.Errorf(codes.Internal, "grpc: error while marshaling: %v", err.Error()) } if outPayload != nil { outPayload.Payload = msg @@ -340,20 +340,20 @@ func encode(c Codec, msg interface{}, cp Compressor, outPayload *stats.OutPayloa if compressor != nil { z, _ := compressor.Compress(cbuf) if _, err := z.Write(b); err != nil { - return nil, nil, Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error()) + return nil, nil, status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error()) } z.Close() } else { // If Compressor is not set by UseCompressor, use default Compressor if err := cp.Do(cbuf, b); err != nil { - return nil, nil, Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error()) + return nil, nil, status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error()) } } b = cbuf.Bytes() } } if uint(len(b)) > math.MaxUint32 { - return nil, nil, Errorf(codes.ResourceExhausted, "grpc: message too large (%d bytes)", len(b)) + return nil, nil, status.Errorf(codes.ResourceExhausted, "grpc: message too large (%d bytes)", len(b)) } bufHeader := make([]byte, payloadLen+sizeLen) @@ -409,26 +409,26 @@ func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{ if dc != nil { d, err = dc.Do(bytes.NewReader(d)) if err != nil { - return Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err) + return status.Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err) } } else { dcReader, err := compressor.Decompress(bytes.NewReader(d)) if err != nil { - return Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err) + return status.Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err) } d, err = ioutil.ReadAll(dcReader) if err != nil { - return Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err) + return status.Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err) } } } if len(d) > maxReceiveMessageSize { // TODO: Revisit the error code. Currently keep it consistent with java // implementation. - return Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", len(d), maxReceiveMessageSize) + return status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", len(d), maxReceiveMessageSize) } if err := c.Unmarshal(d, m); err != nil { - return Errorf(codes.Internal, "grpc: failed to unmarshal the received message %v", err) + return status.Errorf(codes.Internal, "grpc: failed to unmarshal the received message %v", err) } if inPayload != nil { inPayload.RecvTime = time.Now() @@ -441,9 +441,7 @@ func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{ } type rpcInfo struct { - failfast bool - bytesSent bool - bytesReceived bool + failfast bool } type rpcInfoContextKey struct{} @@ -457,18 +455,10 @@ func rpcInfoFromContext(ctx context.Context) (s *rpcInfo, ok bool) { return } -func updateRPCInfoInContext(ctx context.Context, s rpcInfo) { - if ss, ok := rpcInfoFromContext(ctx); ok { - ss.bytesReceived = s.bytesReceived - ss.bytesSent = s.bytesSent - } - return -} - // Code returns the error code for err if it was produced by the rpc system. // Otherwise, it returns codes.Unknown. // -// Deprecated; use status.FromError and Code method instead. +// Deprecated: use status.FromError and Code method instead. func Code(err error) codes.Code { if s, ok := status.FromError(err); ok { return s.Code() @@ -479,7 +469,7 @@ func Code(err error) codes.Code { // ErrorDesc returns the error description of err if it was produced by the rpc system. // Otherwise, it returns err.Error() or empty string when err is nil. // -// Deprecated; use status.FromError and Message method instead. +// Deprecated: use status.FromError and Message method instead. func ErrorDesc(err error) string { if s, ok := status.FromError(err); ok { return s.Message() @@ -490,7 +480,7 @@ func ErrorDesc(err error) string { // Errorf returns an error containing an error code and a description; // Errorf returns nil if c is OK. // -// Deprecated; use status.Errorf instead. +// Deprecated: use status.Errorf instead. func Errorf(c codes.Code, format string, a ...interface{}) error { return status.Errorf(c, format, a...) } @@ -510,6 +500,6 @@ const ( ) // Version is the current grpc version. -const Version = "1.8.0" +const Version = "1.9.1" const grpcUA = "grpc-go/" + Version diff --git a/vendor/google.golang.org/grpc/server.go b/vendor/google.golang.org/grpc/server.go index e9737fc49..f65162168 100644 --- a/vendor/google.golang.org/grpc/server.go +++ b/vendor/google.golang.org/grpc/server.go @@ -92,11 +92,7 @@ type Server struct { conns map[io.Closer]bool serve bool drain bool - ctx context.Context - cancel context.CancelFunc - // A CondVar to let GracefulStop() blocks until all the pending RPCs are finished - // and all the transport goes away. - cv *sync.Cond + cv *sync.Cond // signaled when connections close for GracefulStop m map[string]*service // service name -> service info events trace.EventLog @@ -104,6 +100,7 @@ type Server struct { done chan struct{} quitOnce sync.Once doneOnce sync.Once + serveWG sync.WaitGroup // counts active Serve goroutines for GracefulStop } type options struct { @@ -343,7 +340,6 @@ func NewServer(opt ...ServerOption) *Server { done: make(chan struct{}), } s.cv = sync.NewCond(&s.mu) - s.ctx, s.cancel = context.WithCancel(context.Background()) if EnableTracing { _, file, line, _ := runtime.Caller(1) s.events = trace.NewEventLog("grpc.Server", fmt.Sprintf("%s:%d", file, line)) @@ -474,10 +470,23 @@ func (s *Server) Serve(lis net.Listener) error { s.printf("serving") s.serve = true if s.lis == nil { + // Serve called after Stop or GracefulStop. s.mu.Unlock() lis.Close() return ErrServerStopped } + + s.serveWG.Add(1) + defer func() { + s.serveWG.Done() + select { + // Stop or GracefulStop called; block until done and return nil. + case <-s.quit: + <-s.done + default: + } + }() + s.lis[lis] = true s.mu.Unlock() defer func() { @@ -511,33 +520,39 @@ func (s *Server) Serve(lis net.Listener) error { timer := time.NewTimer(tempDelay) select { case <-timer.C: - case <-s.ctx.Done(): + case <-s.quit: + timer.Stop() + return nil } - timer.Stop() continue } s.mu.Lock() s.printf("done serving; Accept = %v", err) s.mu.Unlock() - // If Stop or GracefulStop is called, block until they are done and return nil select { case <-s.quit: - <-s.done return nil default: } return err } tempDelay = 0 - // Start a new goroutine to deal with rawConn - // so we don't stall this Accept loop goroutine. - go s.handleRawConn(rawConn) + // Start a new goroutine to deal with rawConn so we don't stall this Accept + // loop goroutine. + // + // Make sure we account for the goroutine so GracefulStop doesn't nil out + // s.conns before this conn can be added. + s.serveWG.Add(1) + go func() { + s.handleRawConn(rawConn) + s.serveWG.Done() + }() } } -// handleRawConn is run in its own goroutine and handles a just-accepted -// connection that has not had any I/O performed on it yet. +// handleRawConn forks a goroutine to handle a just-accepted connection that +// has not had any I/O performed on it yet. func (s *Server) handleRawConn(rawConn net.Conn) { rawConn.SetDeadline(time.Now().Add(s.opts.connectionTimeout)) conn, authInfo, err := s.useTransportAuthenticator(rawConn) @@ -562,17 +577,28 @@ func (s *Server) handleRawConn(rawConn net.Conn) { } s.mu.Unlock() + var serve func() + c := conn.(io.Closer) if s.opts.useHandlerImpl { - rawConn.SetDeadline(time.Time{}) - s.serveUsingHandler(conn) + serve = func() { s.serveUsingHandler(conn) } } else { + // Finish handshaking (HTTP2) st := s.newHTTP2Transport(conn, authInfo) if st == nil { return } - rawConn.SetDeadline(time.Time{}) - s.serveStreams(st) + c = st + serve = func() { s.serveStreams(st) } + } + + rawConn.SetDeadline(time.Time{}) + if !s.addConn(c) { + return } + go func() { + serve() + s.removeConn(c) + }() } // newHTTP2Transport sets up a http/2 transport (using the @@ -599,15 +625,10 @@ func (s *Server) newHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) tr grpclog.Warningln("grpc: Server.Serve failed to create ServerTransport: ", err) return nil } - if !s.addConn(st) { - st.Close() - return nil - } return st } func (s *Server) serveStreams(st transport.ServerTransport) { - defer s.removeConn(st) defer st.Close() var wg sync.WaitGroup st.HandleStreams(func(stream *transport.Stream) { @@ -641,11 +662,6 @@ var _ http.Handler = (*Server)(nil) // // conn is the *tls.Conn that's already been authenticated. func (s *Server) serveUsingHandler(conn net.Conn) { - if !s.addConn(conn) { - conn.Close() - return - } - defer s.removeConn(conn) h2s := &http2.Server{ MaxConcurrentStreams: s.opts.maxConcurrentStreams, } @@ -685,7 +701,6 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } if !s.addConn(st) { - st.Close() return } defer s.removeConn(st) @@ -715,9 +730,15 @@ func (s *Server) traceInfo(st transport.ServerTransport, stream *transport.Strea func (s *Server) addConn(c io.Closer) bool { s.mu.Lock() defer s.mu.Unlock() - if s.conns == nil || s.drain { + if s.conns == nil { + c.Close() return false } + if s.drain { + // Transport added after we drained our existing conns: drain it + // immediately. + c.(transport.ServerTransport).Drain() + } s.conns[c] = true return true } @@ -826,7 +847,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. return err } if err == io.ErrUnexpectedEOF { - err = Errorf(codes.Internal, io.ErrUnexpectedEOF.Error()) + err = status.Errorf(codes.Internal, io.ErrUnexpectedEOF.Error()) } if err != nil { if st, ok := status.FromError(err); ok { @@ -868,13 +889,13 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. if dc != nil { req, err = dc.Do(bytes.NewReader(req)) if err != nil { - return Errorf(codes.Internal, err.Error()) + return status.Errorf(codes.Internal, err.Error()) } } else { tmp, _ := decomp.Decompress(bytes.NewReader(req)) req, err = ioutil.ReadAll(tmp) if err != nil { - return Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err) + return status.Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err) } } } @@ -1158,6 +1179,7 @@ func (s *Server) Stop() { }) defer func() { + s.serveWG.Wait() s.doneOnce.Do(func() { close(s.done) }) @@ -1180,7 +1202,6 @@ func (s *Server) Stop() { } s.mu.Lock() - s.cancel() if s.events != nil { s.events.Finish() s.events = nil @@ -1203,21 +1224,27 @@ func (s *Server) GracefulStop() { }() s.mu.Lock() - defer s.mu.Unlock() if s.conns == nil { + s.mu.Unlock() return } for lis := range s.lis { lis.Close() } s.lis = nil - s.cancel() if !s.drain { for c := range s.conns { c.(transport.ServerTransport).Drain() } s.drain = true } + + // Wait for serving threads to be ready to exit. Only then can we be sure no + // new conns will be created. + s.mu.Unlock() + s.serveWG.Wait() + s.mu.Lock() + for len(s.conns) != 0 { s.cv.Wait() } @@ -1226,6 +1253,7 @@ func (s *Server) GracefulStop() { s.events.Finish() s.events = nil } + s.mu.Unlock() } func init() { @@ -1246,7 +1274,7 @@ func SetHeader(ctx context.Context, md metadata.MD) error { } stream, ok := transport.StreamFromContext(ctx) if !ok { - return Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx) + return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx) } return stream.SetHeader(md) } @@ -1256,7 +1284,7 @@ func SetHeader(ctx context.Context, md metadata.MD) error { func SendHeader(ctx context.Context, md metadata.MD) error { stream, ok := transport.StreamFromContext(ctx) if !ok { - return Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx) + return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx) } t := stream.ServerTransport() if t == nil { @@ -1276,7 +1304,7 @@ func SetTrailer(ctx context.Context, md metadata.MD) error { } stream, ok := transport.StreamFromContext(ctx) if !ok { - return Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx) + return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx) } return stream.SetTrailer(md) } diff --git a/vendor/google.golang.org/grpc/service_config.go b/vendor/google.golang.org/grpc/service_config.go index cde648334..53fa88f37 100644 --- a/vendor/google.golang.org/grpc/service_config.go +++ b/vendor/google.golang.org/grpc/service_config.go @@ -20,6 +20,9 @@ package grpc import ( "encoding/json" + "fmt" + "strconv" + "strings" "time" "google.golang.org/grpc/grpclog" @@ -70,12 +73,48 @@ type ServiceConfig struct { Methods map[string]MethodConfig } -func parseTimeout(t *string) (*time.Duration, error) { - if t == nil { +func parseDuration(s *string) (*time.Duration, error) { + if s == nil { return nil, nil } - d, err := time.ParseDuration(*t) - return &d, err + if !strings.HasSuffix(*s, "s") { + return nil, fmt.Errorf("malformed duration %q", *s) + } + ss := strings.SplitN((*s)[:len(*s)-1], ".", 3) + if len(ss) > 2 { + return nil, fmt.Errorf("malformed duration %q", *s) + } + // hasDigits is set if either the whole or fractional part of the number is + // present, since both are optional but one is required. + hasDigits := false + var d time.Duration + if len(ss[0]) > 0 { + i, err := strconv.ParseInt(ss[0], 10, 32) + if err != nil { + return nil, fmt.Errorf("malformed duration %q: %v", *s, err) + } + d = time.Duration(i) * time.Second + hasDigits = true + } + if len(ss) == 2 && len(ss[1]) > 0 { + if len(ss[1]) > 9 { + return nil, fmt.Errorf("malformed duration %q", *s) + } + f, err := strconv.ParseInt(ss[1], 10, 64) + if err != nil { + return nil, fmt.Errorf("malformed duration %q: %v", *s, err) + } + for i := 9; i > len(ss[1]); i-- { + f *= 10 + } + d += time.Duration(f) + hasDigits = true + } + if !hasDigits { + return nil, fmt.Errorf("malformed duration %q", *s) + } + + return &d, nil } type jsonName struct { @@ -128,7 +167,7 @@ func parseServiceConfig(js string) (ServiceConfig, error) { if m.Name == nil { continue } - d, err := parseTimeout(m.Timeout) + d, err := parseDuration(m.Timeout) if err != nil { grpclog.Warningf("grpc: parseServiceConfig error unmarshaling %s due to %v", js, err) return ServiceConfig{}, err @@ -182,18 +221,6 @@ func getMaxSize(mcMax, doptMax *int, defaultVal int) *int { return doptMax } -func newBool(b bool) *bool { - return &b -} - func newInt(b int) *int { return &b } - -func newDuration(b time.Duration) *time.Duration { - return &b -} - -func newString(b string) *string { - return &b -} diff --git a/vendor/google.golang.org/grpc/status/status.go b/vendor/google.golang.org/grpc/status/status.go index 871dc4b31..d9defaebc 100644 --- a/vendor/google.golang.org/grpc/status/status.go +++ b/vendor/google.golang.org/grpc/status/status.go @@ -125,8 +125,8 @@ func FromError(err error) (s *Status, ok bool) { if err == nil { return &Status{s: &spb.Status{Code: int32(codes.OK)}}, true } - if s, ok := err.(*statusError); ok { - return s.status(), true + if se, ok := err.(*statusError); ok { + return se.status(), true } return nil, false } @@ -166,3 +166,16 @@ func (s *Status) Details() []interface{} { } return details } + +// Code returns the Code of the error if it is a Status error, codes.OK if err +// is nil, or codes.Unknown otherwise. +func Code(err error) codes.Code { + // Don't use FromError to avoid allocation of OK status. + if err == nil { + return codes.OK + } + if se, ok := err.(*statusError); ok { + return se.status().Code() + } + return codes.Unknown +} diff --git a/vendor/google.golang.org/grpc/stream.go b/vendor/google.golang.org/grpc/stream.go index 9eeaafef8..f91381995 100644 --- a/vendor/google.golang.org/grpc/stream.go +++ b/vendor/google.golang.org/grpc/stream.go @@ -163,7 +163,7 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth if ct != encoding.Identity { comp = encoding.GetCompressor(ct) if comp == nil { - return nil, Errorf(codes.Internal, "grpc: Compressor is not installed for requested grpc-encoding %q", ct) + return nil, status.Errorf(codes.Internal, "grpc: Compressor is not installed for requested grpc-encoding %q", ct) } } } else if cc.dopts.cp != nil { @@ -232,7 +232,14 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth s, err = t.NewStream(ctx, callHdr) if err != nil { if done != nil { - done(balancer.DoneInfo{Err: err}) + doneInfo := balancer.DoneInfo{Err: err} + if _, ok := err.(transport.ConnectionError); ok { + // If error is connection error, transport was sending data on wire, + // and we are not sure if anything has been sent on wire. + // If error is not connection error, we are sure nothing has been sent. + doneInfo.BytesSent = true + } + done(doneInfo) done = nil } // In the event of any error from NewStream, we never attempted to write @@ -393,10 +400,10 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) { return err } if cs.c.maxSendMessageSize == nil { - return Errorf(codes.Internal, "callInfo maxSendMessageSize field uninitialized(nil)") + return status.Errorf(codes.Internal, "callInfo maxSendMessageSize field uninitialized(nil)") } if len(data) > *cs.c.maxSendMessageSize { - return Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(data), *cs.c.maxSendMessageSize) + return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(data), *cs.c.maxSendMessageSize) } err = cs.t.Write(cs.s, hdr, data, &transport.Options{Last: false}) if err == nil && outPayload != nil { @@ -414,7 +421,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) { } } if cs.c.maxReceiveMessageSize == nil { - return Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)") + return status.Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)") } if !cs.decompSet { // Block until we receive headers containing received message encoding. @@ -456,7 +463,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) { // Special handling for client streaming rpc. // This recv expects EOF or errors, so we don't collect inPayload. if cs.c.maxReceiveMessageSize == nil { - return Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)") + return status.Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)") } err = recv(cs.p, cs.codec, cs.s, cs.dc, m, *cs.c.maxReceiveMessageSize, nil, cs.decomp) cs.closeTransportStream(err) @@ -529,11 +536,11 @@ func (cs *clientStream) finish(err error) { o.after(cs.c) } if cs.done != nil { - updateRPCInfoInContext(cs.s.Context(), rpcInfo{ - bytesSent: true, - bytesReceived: cs.s.BytesReceived(), + cs.done(balancer.DoneInfo{ + Err: err, + BytesSent: true, + BytesReceived: cs.s.BytesReceived(), }) - cs.done(balancer.DoneInfo{Err: err}) cs.done = nil } if cs.statsHandler != nil { @@ -653,7 +660,7 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) { return err } if len(data) > ss.maxSendMessageSize { - return Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(data), ss.maxSendMessageSize) + return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(data), ss.maxSendMessageSize) } if err := ss.t.Write(ss.s, hdr, data, &transport.Options{Last: false}); err != nil { return toRPCErr(err) @@ -693,7 +700,7 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) { return err } if err == io.ErrUnexpectedEOF { - err = Errorf(codes.Internal, io.ErrUnexpectedEOF.Error()) + err = status.Errorf(codes.Internal, io.ErrUnexpectedEOF.Error()) } return toRPCErr(err) } diff --git a/vendor/google.golang.org/grpc/transport/control.go b/vendor/google.golang.org/grpc/transport/control.go index 63194830d..0474b0907 100644 --- a/vendor/google.golang.org/grpc/transport/control.go +++ b/vendor/google.golang.org/grpc/transport/control.go @@ -116,6 +116,7 @@ type goAway struct { func (*goAway) item() {} type flushIO struct { + closeTr bool } func (*flushIO) item() {} diff --git a/vendor/google.golang.org/grpc/transport/go16.go b/vendor/google.golang.org/grpc/transport/go16.go index 7cffee11e..5babcf9b8 100644 --- a/vendor/google.golang.org/grpc/transport/go16.go +++ b/vendor/google.golang.org/grpc/transport/go16.go @@ -22,6 +22,7 @@ package transport import ( "net" + "net/http" "google.golang.org/grpc/codes" @@ -43,3 +44,8 @@ func ContextErr(err error) StreamError { } return streamErrorf(codes.Internal, "Unexpected error from context packet: %v", err) } + +// contextFromRequest returns a background context. +func contextFromRequest(r *http.Request) context.Context { + return context.Background() +} diff --git a/vendor/google.golang.org/grpc/transport/go17.go b/vendor/google.golang.org/grpc/transport/go17.go index 2464e69fa..b7fa6bdb9 100644 --- a/vendor/google.golang.org/grpc/transport/go17.go +++ b/vendor/google.golang.org/grpc/transport/go17.go @@ -23,6 +23,7 @@ package transport import ( "context" "net" + "net/http" "google.golang.org/grpc/codes" @@ -44,3 +45,8 @@ func ContextErr(err error) StreamError { } return streamErrorf(codes.Internal, "Unexpected error from context packet: %v", err) } + +// contextFromRequest returns a context from the HTTP Request. +func contextFromRequest(r *http.Request) context.Context { + return r.Context() +} diff --git a/vendor/google.golang.org/grpc/transport/handler_server.go b/vendor/google.golang.org/grpc/transport/handler_server.go index f1f6caf89..27c4ebb5f 100644 --- a/vendor/google.golang.org/grpc/transport/handler_server.go +++ b/vendor/google.golang.org/grpc/transport/handler_server.go @@ -123,10 +123,9 @@ type serverHandlerTransport struct { // when WriteStatus is called. writes chan func() - mu sync.Mutex - // streamDone indicates whether WriteStatus has been called and writes channel - // has been closed. - streamDone bool + // block concurrent WriteStatus calls + // e.g. grpc/(*serverStream).SendMsg/RecvMsg + writeStatusMu sync.Mutex } func (ht *serverHandlerTransport) Close() error { @@ -177,13 +176,9 @@ func (ht *serverHandlerTransport) do(fn func()) error { } func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) error { - ht.mu.Lock() - if ht.streamDone { - ht.mu.Unlock() - return nil - } - ht.streamDone = true - ht.mu.Unlock() + ht.writeStatusMu.Lock() + defer ht.writeStatusMu.Unlock() + err := ht.do(func() { ht.writeCommonHeaders(s) @@ -222,7 +217,11 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) erro } } }) - close(ht.writes) + + if err == nil { // transport has not been closed + ht.Close() + close(ht.writes) + } return err } @@ -285,12 +284,12 @@ func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error { func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), traceCtx func(context.Context, string) context.Context) { // With this transport type there will be exactly 1 stream: this HTTP request. - var ctx context.Context + ctx := contextFromRequest(ht.req) var cancel context.CancelFunc if ht.timeoutSet { - ctx, cancel = context.WithTimeout(context.Background(), ht.timeout) + ctx, cancel = context.WithTimeout(ctx, ht.timeout) } else { - ctx, cancel = context.WithCancel(context.Background()) + ctx, cancel = context.WithCancel(ctx) } // requestOver is closed when either the request's context is done diff --git a/vendor/google.golang.org/grpc/transport/http2_client.go b/vendor/google.golang.org/grpc/transport/http2_client.go index 0f58a390a..4a122692a 100644 --- a/vendor/google.golang.org/grpc/transport/http2_client.go +++ b/vendor/google.golang.org/grpc/transport/http2_client.go @@ -20,6 +20,7 @@ package transport import ( "bytes" + "fmt" "io" "math" "net" @@ -93,6 +94,11 @@ type http2Client struct { bdpEst *bdpEstimator outQuotaVersion uint32 + // onSuccess is a callback that client transport calls upon + // receiving server preface to signal that a succefull HTTP2 + // connection was established. + onSuccess func() + mu sync.Mutex // guard the following variables state transportState // the state of underlying connection activeStreams map[uint32]*Stream @@ -145,16 +151,12 @@ func isTemporary(err error) bool { // newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2 // and starts to receive messages on it. Non-nil error returns if construction // fails. -func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions, timeout time.Duration) (_ ClientTransport, err error) { +func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts ConnectOptions, onSuccess func()) (_ ClientTransport, err error) { scheme := "http" ctx, cancel := context.WithCancel(ctx) - connectCtx, connectCancel := context.WithTimeout(ctx, timeout) defer func() { if err != nil { cancel() - // Don't call connectCancel in success path due to a race in Go 1.6: - // https://github.com/golang/go/issues/15078. - connectCancel() } }() @@ -240,6 +242,7 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions, t kp: kp, statsHandler: opts.StatsHandler, initialWindowSize: initialWindowSize, + onSuccess: onSuccess, } if opts.InitialWindowSize >= defaultWindowSize { t.initialWindowSize = opts.InitialWindowSize @@ -300,7 +303,7 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions, t t.framer.writer.Flush() go func() { loopyWriter(t.ctx, t.controlBuf, t.itemHandler) - t.Close() + t.conn.Close() }() if t.kp.Time != infinity { go t.keepalive() @@ -1122,7 +1125,6 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { s.mu.Unlock() return } - if len(state.mdata) > 0 { s.trailer = state.mdata } @@ -1160,6 +1162,7 @@ func (t *http2Client) reader() { t.Close() return } + t.onSuccess() t.handleSettings(sf, true) // loop to keep reading incoming messages on this transport. @@ -1234,8 +1237,7 @@ func (t *http2Client) applySettings(ss []http2.Setting) { // TODO(mmukhi): A lot of this code(and code in other places in the tranpsort layer) // is duplicated between the client and the server. // The transport layer needs to be refactored to take care of this. -func (t *http2Client) itemHandler(i item) error { - var err error +func (t *http2Client) itemHandler(i item) (err error) { defer func() { if err != nil { errorf(" error in itemHandler: %v", err) @@ -1243,10 +1245,11 @@ func (t *http2Client) itemHandler(i item) error { }() switch i := i.(type) { case *dataFrame: - err = t.framer.fr.WriteData(i.streamID, i.endStream, i.d) - if err == nil { - i.f() + if err := t.framer.fr.WriteData(i.streamID, i.endStream, i.d); err != nil { + return err } + i.f() + return nil case *headerFrame: t.hBuf.Reset() for _, f := range i.hf { @@ -1280,31 +1283,33 @@ func (t *http2Client) itemHandler(i item) error { return err } } + return nil case *windowUpdate: - err = t.framer.fr.WriteWindowUpdate(i.streamID, i.increment) + return t.framer.fr.WriteWindowUpdate(i.streamID, i.increment) case *settings: - err = t.framer.fr.WriteSettings(i.ss...) + return t.framer.fr.WriteSettings(i.ss...) case *settingsAck: - err = t.framer.fr.WriteSettingsAck() + return t.framer.fr.WriteSettingsAck() case *resetStream: // If the server needs to be to intimated about stream closing, // then we need to make sure the RST_STREAM frame is written to // the wire before the headers of the next stream waiting on // streamQuota. We ensure this by adding to the streamsQuota pool // only after having acquired the writableChan to send RST_STREAM. - err = t.framer.fr.WriteRSTStream(i.streamID, i.code) + err := t.framer.fr.WriteRSTStream(i.streamID, i.code) t.streamsQuota.add(1) + return err case *flushIO: - err = t.framer.writer.Flush() + return t.framer.writer.Flush() case *ping: if !i.ack { t.bdpEst.timesnap(i.data) } - err = t.framer.fr.WritePing(i.ack, i.data) + return t.framer.fr.WritePing(i.ack, i.data) default: errorf("transport: http2Client.controller got unexpected item type %v", i) + return fmt.Errorf("transport: http2Client.controller got unexpected item type %v", i) } - return err } // keepalive running in a separate goroutune makes sure the connection is alive by sending pings. diff --git a/vendor/google.golang.org/grpc/transport/http2_server.go b/vendor/google.golang.org/grpc/transport/http2_server.go index 4a95363cc..6d252c53a 100644 --- a/vendor/google.golang.org/grpc/transport/http2_server.go +++ b/vendor/google.golang.org/grpc/transport/http2_server.go @@ -228,6 +228,12 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err } t.framer.writer.Flush() + defer func() { + if err != nil { + t.Close() + } + }() + // Check the validity of client preface. preface := make([]byte, len(clientPreface)) if _, err := io.ReadFull(t.conn, preface); err != nil { @@ -239,8 +245,7 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err frame, err := t.framer.fr.ReadFrame() if err == io.EOF || err == io.ErrUnexpectedEOF { - t.Close() - return + return nil, err } if err != nil { return nil, connectionErrorf(false, err, "transport: http2Server.HandleStreams failed to read initial settings frame: %v", err) @@ -254,7 +259,7 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err go func() { loopyWriter(t.ctx, t.controlBuf, t.itemHandler) - t.Close() + t.conn.Close() }() go t.keepalive() return t, nil @@ -1069,6 +1074,9 @@ func (t *http2Server) itemHandler(i item) error { if !i.headsUp { // Stop accepting more streams now. t.state = draining + if len(t.activeStreams) == 0 { + i.closeConn = true + } t.mu.Unlock() if err := t.framer.fr.WriteGoAway(sid, i.code, i.debugData); err != nil { return err @@ -1076,8 +1084,7 @@ func (t *http2Server) itemHandler(i item) error { if i.closeConn { // Abruptly close the connection following the GoAway (via // loopywriter). But flush out what's inside the buffer first. - t.framer.writer.Flush() - return fmt.Errorf("transport: Connection closing") + t.controlBuf.put(&flushIO{closeTr: true}) } return nil } @@ -1107,7 +1114,13 @@ func (t *http2Server) itemHandler(i item) error { }() return nil case *flushIO: - return t.framer.writer.Flush() + if err := t.framer.writer.Flush(); err != nil { + return err + } + if i.closeTr { + return ErrConnClosing + } + return nil case *ping: if !i.ack { t.bdpEst.timesnap(i.data) @@ -1155,7 +1168,7 @@ func (t *http2Server) closeStream(s *Stream) { t.idle = time.Now() } if t.state == draining && len(t.activeStreams) == 0 { - defer t.Close() + defer t.controlBuf.put(&flushIO{closeTr: true}) } t.mu.Unlock() // In case stream sending and receiving are invoked in separate diff --git a/vendor/google.golang.org/grpc/transport/transport.go b/vendor/google.golang.org/grpc/transport/transport.go index b7a5dbe42..2e7bcaeaa 100644 --- a/vendor/google.golang.org/grpc/transport/transport.go +++ b/vendor/google.golang.org/grpc/transport/transport.go @@ -26,7 +26,6 @@ import ( "io" "net" "sync" - "time" "golang.org/x/net/context" "golang.org/x/net/http2" @@ -249,11 +248,28 @@ type Stream struct { unprocessed bool // set if the server sends a refused stream or GOAWAY including this stream } +func (s *Stream) waitOnHeader() error { + if s.headerChan == nil { + // On the server headerChan is always nil since a stream originates + // only after having received headers. + return nil + } + wc := s.waiters + select { + case <-wc.ctx.Done(): + return ContextErr(wc.ctx.Err()) + case <-wc.goAway: + return errStreamDrain + case <-s.headerChan: + return nil + } +} + // RecvCompress returns the compression algorithm applied to the inbound // message. It is empty string if there is no compression applied. func (s *Stream) RecvCompress() string { - if s.headerChan != nil { - <-s.headerChan + if err := s.waitOnHeader(); err != nil { + return "" } return s.recvCompress } @@ -279,15 +295,7 @@ func (s *Stream) GoAway() <-chan struct{} { // is available. It blocks until i) the metadata is ready or ii) there is no // header metadata or iii) the stream is canceled/expired. func (s *Stream) Header() (metadata.MD, error) { - var err error - select { - case <-s.ctx.Done(): - err = ContextErr(s.ctx.Err()) - case <-s.goAway: - err = errStreamDrain - case <-s.headerChan: - return s.header.Copy(), nil - } + err := s.waitOnHeader() // Even if the stream is closed, header is returned if available. select { case <-s.headerChan: @@ -506,8 +514,8 @@ type TargetInfo struct { // NewClientTransport establishes the transport with the required ConnectOptions // and returns it to the caller. -func NewClientTransport(ctx context.Context, target TargetInfo, opts ConnectOptions, timeout time.Duration) (ClientTransport, error) { - return newHTTP2Client(ctx, target, opts, timeout) +func NewClientTransport(connectCtx, ctx context.Context, target TargetInfo, opts ConnectOptions, onSuccess func()) (ClientTransport, error) { + return newHTTP2Client(connectCtx, ctx, target, opts, onSuccess) } // Options provides additional hints and information for message diff --git a/vendor/google.golang.org/grpc/vet.sh b/vendor/google.golang.org/grpc/vet.sh index 02d4bae39..2ad94fed9 100755 --- a/vendor/google.golang.org/grpc/vet.sh +++ b/vendor/google.golang.org/grpc/vet.sh @@ -23,8 +23,7 @@ if [ "$1" = "-install" ]; then golang.org/x/tools/cmd/goimports \ honnef.co/go/tools/cmd/staticcheck \ github.com/client9/misspell/cmd/misspell \ - github.com/golang/protobuf/protoc-gen-go \ - golang.org/x/tools/cmd/stringer + github.com/golang/protobuf/protoc-gen-go if [[ "$check_proto" = "true" ]]; then if [[ "$TRAVIS" = "true" ]]; then PROTOBUF_VERSION=3.3.0 @@ -52,7 +51,7 @@ fi git ls-files "*.go" | xargs grep -L "\(Copyright [0-9]\{4,\} gRPC authors\)\|DO NOT EDIT" 2>&1 | tee /dev/stderr | (! read) gofmt -s -d -l . 2>&1 | tee /dev/stderr | (! read) goimports -l . 2>&1 | tee /dev/stderr | (! read) -golint ./... 2>&1 | (grep -vE "(_mock|_string|\.pb)\.go:" || true) | tee /dev/stderr | (! read) +golint ./... 2>&1 | (grep -vE "(_mock|\.pb)\.go:" || true) | tee /dev/stderr | (! read) # Undo any edits made by this script. cleanup() { @@ -65,7 +64,7 @@ trap cleanup EXIT git ls-files "*.go" | xargs sed -i 's:"golang.org/x/net/context":"context":' set +o pipefail # TODO: Stop filtering pb.go files once golang/protobuf#214 is fixed. -go tool vet -all . 2>&1 | grep -vF '.pb.go:' | tee /dev/stderr | (! read) +go tool vet -all . 2>&1 | grep -vE '(clientconn|transport\/transport_test).go:.*cancel (function|var)' | grep -vF '.pb.go:' | tee /dev/stderr | (! read) set -o pipefail git reset --hard HEAD diff --git a/vendor/vendor.json b/vendor/vendor.json index a2765c99b..cf042c26e 100644 --- a/vendor/vendor.json +++ b/vendor/vendor.json @@ -335,204 +335,212 @@ "revisionTime": "2017-10-02T23:26:14Z" }, { - "checksumSHA1": "+m79YSIlNtryIAT7xmuAQUz1pN8=", + "checksumSHA1": "LXTQppZOmpZb8/zNBzfXmq3GDEg=", "path": "google.golang.org/grpc", - "revision": "5a9f7b402fe85096d2e1d0383435ee1876e863d0", - "revisionTime": "2017-11-21T19:13:43Z", - "version": "v1.8.0", - "versionExact": "v1.8.0" + "revision": "7cea4cc846bcf00cbb27595b07da5de875ef7de9", + "revisionTime": "2018-01-08T22:01:35Z", + "version": "v1.9.1", + "versionExact": "v1.9.1" }, { - "checksumSHA1": "rbTWD4bqUpDwyuLPqzHwJZYlBbQ=", + "checksumSHA1": "xBhmO0Vn4kzbmySioX+2gBImrkk=", "path": "google.golang.org/grpc/balancer", - "revision": "5a9f7b402fe85096d2e1d0383435ee1876e863d0", - "revisionTime": "2017-11-21T19:13:43Z", - "version": "v1.8.0", - "versionExact": "v1.8.0" + "revision": "7cea4cc846bcf00cbb27595b07da5de875ef7de9", + "revisionTime": "2018-01-08T22:01:35Z", + "version": "v1.9.1", + "versionExact": "v1.9.1" }, { - "checksumSHA1": "L/Rj4jU8ZLOcLvMOD6XuwCoDeFY=", + "checksumSHA1": "CPWX/IgaQSR3+78j4sPrvHNkW+U=", + "path": "google.golang.org/grpc/balancer/base", + "revision": "7cea4cc846bcf00cbb27595b07da5de875ef7de9", + "revisionTime": "2018-01-08T22:01:35Z", + "version": "v1.9.1", + "versionExact": "v1.9.1" + }, + { + "checksumSHA1": "DJ1AtOk4Pu7bqtUMob95Hw8HPNw=", "path": "google.golang.org/grpc/balancer/roundrobin", - "revision": "5a9f7b402fe85096d2e1d0383435ee1876e863d0", - "revisionTime": "2017-11-21T19:13:43Z", - "version": "v1.8.0", - "versionExact": "v1.8.0" + "revision": "7cea4cc846bcf00cbb27595b07da5de875ef7de9", + "revisionTime": "2018-01-08T22:01:35Z", + "version": "v1.9.1", + "versionExact": "v1.9.1" }, { - "checksumSHA1": "m5QNRsnKMZ/3p4V/LDLknFInkGs=", + "checksumSHA1": "bfmh2m3qW8bb6qpfS/D4Wcl4hZE=", "path": "google.golang.org/grpc/codes", - "revision": "5a9f7b402fe85096d2e1d0383435ee1876e863d0", - "revisionTime": "2017-11-21T19:13:43Z", - "version": "v1.8.0", - "versionExact": "v1.8.0" + "revision": "7cea4cc846bcf00cbb27595b07da5de875ef7de9", + "revisionTime": "2018-01-08T22:01:35Z", + "version": "v1.9.1", + "versionExact": "v1.9.1" }, { "checksumSHA1": "XH2WYcDNwVO47zYShREJjcYXm0Y=", "path": "google.golang.org/grpc/connectivity", - "revision": "5a9f7b402fe85096d2e1d0383435ee1876e863d0", - "revisionTime": "2017-11-21T19:13:43Z", - "version": "v1.8.0", - "versionExact": "v1.8.0" + "revision": "7cea4cc846bcf00cbb27595b07da5de875ef7de9", + "revisionTime": "2018-01-08T22:01:35Z", + "version": "v1.9.1", + "versionExact": "v1.9.1" }, { "checksumSHA1": "4DnDX81AOSyVP3UJ5tQmlNcG1MI=", "path": "google.golang.org/grpc/credentials", - "revision": "5a9f7b402fe85096d2e1d0383435ee1876e863d0", - "revisionTime": "2017-11-21T19:13:43Z", - "version": "v1.8.0", - "versionExact": "v1.8.0" + "revision": "7cea4cc846bcf00cbb27595b07da5de875ef7de9", + "revisionTime": "2018-01-08T22:01:35Z", + "version": "v1.9.1", + "versionExact": "v1.9.1" }, { "checksumSHA1": "9DImIDqmAMPO24loHJ77UVJTDxQ=", "path": "google.golang.org/grpc/encoding", - "revision": "5a9f7b402fe85096d2e1d0383435ee1876e863d0", - "revisionTime": "2017-11-21T19:13:43Z", - "version": "v1.8.0", - "versionExact": "v1.8.0" + "revision": "7cea4cc846bcf00cbb27595b07da5de875ef7de9", + "revisionTime": "2018-01-08T22:01:35Z", + "version": "v1.9.1", + "versionExact": "v1.9.1" }, { "checksumSHA1": "H7SuPUqbPcdbNqgl+k3ohuwMAwE=", "path": "google.golang.org/grpc/grpclb/grpc_lb_v1/messages", - "revision": "5a9f7b402fe85096d2e1d0383435ee1876e863d0", - "revisionTime": "2017-11-21T19:13:43Z", - "version": "v1.8.0", - "versionExact": "v1.8.0" + "revision": "7cea4cc846bcf00cbb27595b07da5de875ef7de9", + "revisionTime": "2018-01-08T22:01:35Z", + "version": "v1.9.1", + "versionExact": "v1.9.1" }, { "checksumSHA1": "ntHev01vgZgeIh5VFRmbLx/BSTo=", "path": "google.golang.org/grpc/grpclog", - "revision": "5a9f7b402fe85096d2e1d0383435ee1876e863d0", - "revisionTime": "2017-11-21T19:13:43Z", - "version": "v1.8.0", - "versionExact": "v1.8.0" + "revision": "7cea4cc846bcf00cbb27595b07da5de875ef7de9", + "revisionTime": "2018-01-08T22:01:35Z", + "version": "v1.9.1", + "versionExact": "v1.9.1" }, { "checksumSHA1": "DyM0uqLtknaI4THSc3spn9XlL+g=", "path": "google.golang.org/grpc/health", - "revision": "5a9f7b402fe85096d2e1d0383435ee1876e863d0", - "revisionTime": "2017-11-21T19:13:43Z", - "version": "v1.8.0", - "versionExact": "v1.8.0" + "revision": "7cea4cc846bcf00cbb27595b07da5de875ef7de9", + "revisionTime": "2018-01-08T22:01:35Z", + "version": "v1.9.1", + "versionExact": "v1.9.1" }, { "checksumSHA1": "6vY7tYjV84pnr3sDctzx53Bs8b0=", "path": "google.golang.org/grpc/health/grpc_health_v1", - "revision": "5a9f7b402fe85096d2e1d0383435ee1876e863d0", - "revisionTime": "2017-11-21T19:13:43Z", - "version": "v1.8.0", - "versionExact": "v1.8.0" + "revision": "7cea4cc846bcf00cbb27595b07da5de875ef7de9", + "revisionTime": "2018-01-08T22:01:35Z", + "version": "v1.9.1", + "versionExact": "v1.9.1" }, { "checksumSHA1": "Qvf3zdmRCSsiM/VoBv0qB/naHtU=", "path": "google.golang.org/grpc/internal", - "revision": "5a9f7b402fe85096d2e1d0383435ee1876e863d0", - "revisionTime": "2017-11-21T19:13:43Z", - "version": "v1.8.0", - "versionExact": "v1.8.0" + "revision": "7cea4cc846bcf00cbb27595b07da5de875ef7de9", + "revisionTime": "2018-01-08T22:01:35Z", + "version": "v1.9.1", + "versionExact": "v1.9.1" }, { "checksumSHA1": "hcuHgKp8W0wIzoCnNfKI8NUss5o=", "path": "google.golang.org/grpc/keepalive", - "revision": "5a9f7b402fe85096d2e1d0383435ee1876e863d0", - "revisionTime": "2017-11-21T19:13:43Z", - "version": "v1.8.0", - "versionExact": "v1.8.0" + "revision": "7cea4cc846bcf00cbb27595b07da5de875ef7de9", + "revisionTime": "2018-01-08T22:01:35Z", + "version": "v1.9.1", + "versionExact": "v1.9.1" }, { "checksumSHA1": "KeUmTZV+2X46C49cKyjp+xM7fvw=", "path": "google.golang.org/grpc/metadata", - "revision": "5a9f7b402fe85096d2e1d0383435ee1876e863d0", - "revisionTime": "2017-11-21T19:13:43Z", - "version": "v1.8.0", - "versionExact": "v1.8.0" + "revision": "7cea4cc846bcf00cbb27595b07da5de875ef7de9", + "revisionTime": "2018-01-08T22:01:35Z", + "version": "v1.9.1", + "versionExact": "v1.9.1" }, { - "checksumSHA1": "dgwdT20kXe4ZbXBOFbTwVQt8rmA=", + "checksumSHA1": "5dwF592DPvhF2Wcex3m7iV6aGRQ=", "path": "google.golang.org/grpc/naming", - "revision": "5a9f7b402fe85096d2e1d0383435ee1876e863d0", - "revisionTime": "2017-11-21T19:13:43Z", - "version": "v1.8.0", - "versionExact": "v1.8.0" + "revision": "7cea4cc846bcf00cbb27595b07da5de875ef7de9", + "revisionTime": "2018-01-08T22:01:35Z", + "version": "v1.9.1", + "versionExact": "v1.9.1" }, { "checksumSHA1": "n5EgDdBqFMa2KQFhtl+FF/4gIFo=", "path": "google.golang.org/grpc/peer", - "revision": "5a9f7b402fe85096d2e1d0383435ee1876e863d0", - "revisionTime": "2017-11-21T19:13:43Z", - "version": "v1.8.0", - "versionExact": "v1.8.0" + "revision": "7cea4cc846bcf00cbb27595b07da5de875ef7de9", + "revisionTime": "2018-01-08T22:01:35Z", + "version": "v1.9.1", + "versionExact": "v1.9.1" }, { "checksumSHA1": "JF/KBFCo5JwVtXfrZ2kJnFRC6W8=", "path": "google.golang.org/grpc/reflection", - "revision": "5a9f7b402fe85096d2e1d0383435ee1876e863d0", - "revisionTime": "2017-11-21T19:13:43Z", - "version": "v1.8.0", - "versionExact": "v1.8.0" + "revision": "7cea4cc846bcf00cbb27595b07da5de875ef7de9", + "revisionTime": "2018-01-08T22:01:35Z", + "version": "v1.9.1", + "versionExact": "v1.9.1" }, { "checksumSHA1": "7Ax2K0St9CIi1rkA9Ju+2ERfe9E=", "path": "google.golang.org/grpc/reflection/grpc_reflection_v1alpha", - "revision": "5a9f7b402fe85096d2e1d0383435ee1876e863d0", - "revisionTime": "2017-11-21T19:13:43Z", - "version": "v1.8.0", - "versionExact": "v1.8.0" + "revision": "7cea4cc846bcf00cbb27595b07da5de875ef7de9", + "revisionTime": "2018-01-08T22:01:35Z", + "version": "v1.9.1", + "versionExact": "v1.9.1" }, { - "checksumSHA1": "H7VyP18nJ9MmoB5r9+I7EKVEeVM=", + "checksumSHA1": "y8Ta+ctMP9CUTiPyPyxiD154d8w=", "path": "google.golang.org/grpc/resolver", - "revision": "5a9f7b402fe85096d2e1d0383435ee1876e863d0", - "revisionTime": "2017-11-21T19:13:43Z", - "version": "v1.8.0", - "versionExact": "v1.8.0" + "revision": "7cea4cc846bcf00cbb27595b07da5de875ef7de9", + "revisionTime": "2018-01-08T22:01:35Z", + "version": "v1.9.1", + "versionExact": "v1.9.1" }, { "checksumSHA1": "WpWF+bDzObsHf+bjoGpb/abeFxo=", "path": "google.golang.org/grpc/resolver/dns", - "revision": "5a9f7b402fe85096d2e1d0383435ee1876e863d0", - "revisionTime": "2017-11-21T19:13:43Z", - "version": "v1.8.0", - "versionExact": "v1.8.0" + "revision": "7cea4cc846bcf00cbb27595b07da5de875ef7de9", + "revisionTime": "2018-01-08T22:01:35Z", + "version": "v1.9.1", + "versionExact": "v1.9.1" }, { "checksumSHA1": "zs9M4xE8Lyg4wvuYvR00XoBxmuw=", "path": "google.golang.org/grpc/resolver/passthrough", - "revision": "5a9f7b402fe85096d2e1d0383435ee1876e863d0", - "revisionTime": "2017-11-21T19:13:43Z", - "version": "v1.8.0", - "versionExact": "v1.8.0" + "revision": "7cea4cc846bcf00cbb27595b07da5de875ef7de9", + "revisionTime": "2018-01-08T22:01:35Z", + "version": "v1.9.1", + "versionExact": "v1.9.1" }, { "checksumSHA1": "G9lgXNi7qClo5sM2s6TbTHLFR3g=", "path": "google.golang.org/grpc/stats", - "revision": "5a9f7b402fe85096d2e1d0383435ee1876e863d0", - "revisionTime": "2017-11-21T19:13:43Z", - "version": "v1.8.0", - "versionExact": "v1.8.0" + "revision": "7cea4cc846bcf00cbb27595b07da5de875ef7de9", + "revisionTime": "2018-01-08T22:01:35Z", + "version": "v1.9.1", + "versionExact": "v1.9.1" }, { - "checksumSHA1": "3Dwz4RLstDHMPyDA7BUsYe+JP4w=", + "checksumSHA1": "tUo+M0Cb0W9ZEIt5BH30wJz/Kjc=", "path": "google.golang.org/grpc/status", - "revision": "5a9f7b402fe85096d2e1d0383435ee1876e863d0", - "revisionTime": "2017-11-21T19:13:43Z", - "version": "v1.8.0", - "versionExact": "v1.8.0" + "revision": "7cea4cc846bcf00cbb27595b07da5de875ef7de9", + "revisionTime": "2018-01-08T22:01:35Z", + "version": "v1.9.1", + "versionExact": "v1.9.1" }, { "checksumSHA1": "qvArRhlrww5WvRmbyMF2mUfbJew=", "path": "google.golang.org/grpc/tap", - "revision": "5a9f7b402fe85096d2e1d0383435ee1876e863d0", - "revisionTime": "2017-11-21T19:13:43Z", - "version": "v1.8.0", - "versionExact": "v1.8.0" + "revision": "7cea4cc846bcf00cbb27595b07da5de875ef7de9", + "revisionTime": "2018-01-08T22:01:35Z", + "version": "v1.9.1", + "versionExact": "v1.9.1" }, { - "checksumSHA1": "cp2boGt5b7B2G7mIkI4my6r6JdE=", + "checksumSHA1": "4PldZ/0JjX6SpJYaMByY1ozywnY=", "path": "google.golang.org/grpc/transport", - "revision": "5a9f7b402fe85096d2e1d0383435ee1876e863d0", - "revisionTime": "2017-11-21T19:13:43Z", - "version": "v1.8.0", - "versionExact": "v1.8.0" + "revision": "7cea4cc846bcf00cbb27595b07da5de875ef7de9", + "revisionTime": "2018-01-08T22:01:35Z", + "version": "v1.9.1", + "versionExact": "v1.9.1" } ], "rootPath": "gitlab.com/gitlab-org/gitaly" |