Welcome to mirror list, hosted at ThFree Co, Russian Federation.

extension.go « protoutil « internal - gitlab.com/gitlab-org/gitaly.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 5a2b9cd2dce73b65d1da1113ee4d02a3e087dc07 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
package protoutil

import (
	"errors"
	"fmt"

	"github.com/golang/protobuf/protoc-gen-go/descriptor"
	"gitlab.com/gitlab-org/gitaly/v14/proto/go/gitalypb"
	"google.golang.org/protobuf/proto"
	"google.golang.org/protobuf/reflect/protoregistry"
	"google.golang.org/protobuf/runtime/protoimpl"
)

// GetOpExtension gets the OperationMsg from a method descriptor
func GetOpExtension(m *descriptor.MethodDescriptorProto) (*gitalypb.OperationMsg, error) {
	ext, err := getExtension(m.GetOptions(), gitalypb.E_OpType)
	if err != nil {
		return nil, err
	}

	return ext.(*gitalypb.OperationMsg), nil
}

// IsInterceptedService returns whether the serivce is intercepted by Praefect.
func IsInterceptedService(s *descriptor.ServiceDescriptorProto) (bool, error) {
	return getBoolExtension(s.GetOptions(), gitalypb.E_Intercepted)
}

// GetRepositoryExtension gets the repository extension from a field descriptor
func GetRepositoryExtension(m *descriptor.FieldDescriptorProto) (bool, error) {
	return getBoolExtension(m.GetOptions(), gitalypb.E_Repository)
}

// GetStorageExtension gets the storage extension from a field descriptor
func GetStorageExtension(m *descriptor.FieldDescriptorProto) (bool, error) {
	return getBoolExtension(m.GetOptions(), gitalypb.E_Storage)
}

// GetTargetRepositoryExtension gets the target_repository extension from a field descriptor
func GetTargetRepositoryExtension(m *descriptor.FieldDescriptorProto) (bool, error) {
	return getBoolExtension(m.GetOptions(), gitalypb.E_TargetRepository)
}

// GetAdditionalRepositoryExtension gets the target_repository extension from a field descriptor
func GetAdditionalRepositoryExtension(m *descriptor.FieldDescriptorProto) (bool, error) {
	return getBoolExtension(m.GetOptions(), gitalypb.E_AdditionalRepository)
}

func getBoolExtension(options proto.Message, extension *protoimpl.ExtensionInfo) (bool, error) {
	val, err := getExtension(options, extension)
	if err != nil {
		if errors.Is(err, protoregistry.NotFound) {
			return false, nil
		}

		return false, err
	}

	return val.(bool), nil
}

func getExtension(options proto.Message, extension *protoimpl.ExtensionInfo) (interface{}, error) {
	if !proto.HasExtension(options, extension) {
		return nil, fmt.Errorf("protoutil.getExtension %q: %w", extension.TypeDescriptor().FullName(), protoregistry.NotFound)
	}

	return proto.GetExtension(options, extension), nil
}