diff options
author | Karthik Nayak <knayak@gitlab.com> | 2022-08-15 10:44:34 +0300 |
---|---|---|
committer | Karthik Nayak <knayak@gitlab.com> | 2022-08-21 21:57:38 +0300 |
commit | 20ca1904446185977817bd7d010f50431cc64466 (patch) | |
tree | d8d8192b3f051e97bf2e3388e052aee34cf71f74 /internal/gitlab/client | |
parent | 28b52c3b64fb857de6b1c62af662fe54249fa225 (diff) |
gitlab: Copy client code from 'gitlab-shell'
Currently 'gitlab-shell' contains the code for the gitlab client. We use
this in our 'gitaly' code, but this causes a cyclic dependency since
'gitlab-shell' also imports parts of our code.
To remove this cyclic dependency, we decided to move most of the code
in-house. In this commit we do this by copying over the following files:
1. gitlabnet.go
2. httpclient.go
from 'gitlab-shell' without modification to a new 'client' package under
'internal/gitlab'.
We do not copy over the tests, since our repo has its own helpers for
tests which we'll leverage to rewrite the tests.
Diffstat (limited to 'internal/gitlab/client')
-rw-r--r-- | internal/gitlab/client/gitlabnet.go | 194 | ||||
-rw-r--r-- | internal/gitlab/client/httpclient.go | 191 |
2 files changed, 385 insertions, 0 deletions
diff --git a/internal/gitlab/client/gitlabnet.go b/internal/gitlab/client/gitlabnet.go new file mode 100644 index 000000000..c34f148f3 --- /dev/null +++ b/internal/gitlab/client/gitlabnet.go @@ -0,0 +1,194 @@ +package client + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/golang-jwt/jwt/v4" + + "gitlab.com/gitlab-org/labkit/log" +) + +const ( + internalApiPath = "/api/v4/internal" + secretHeaderName = "Gitlab-Shared-Secret" + apiSecretHeaderName = "Gitlab-Shell-Api-Request" + defaultUserAgent = "GitLab-Shell" + jwtTTL = time.Minute + jwtIssuer = "gitlab-shell" +) + +type ErrorResponse struct { + Message string `json:"message"` +} + +type GitlabNetClient struct { + httpClient *HttpClient + user string + password string + secret string + userAgent string +} + +type ApiError struct { + Msg string +} + +// To use as the key in a Context to set an X-Forwarded-For header in a request +type OriginalRemoteIPContextKey struct{} + +func (e *ApiError) Error() string { + return e.Msg +} + +func NewGitlabNetClient( + user, + password, + secret string, + httpClient *HttpClient, +) (*GitlabNetClient, error) { + + if httpClient == nil { + return nil, fmt.Errorf("Unsupported protocol") + } + + return &GitlabNetClient{ + httpClient: httpClient, + user: user, + password: password, + secret: secret, + userAgent: defaultUserAgent, + }, nil +} + +// SetUserAgent overrides the default user agent for the User-Agent header field +// for subsequent requests for the GitlabNetClient +func (c *GitlabNetClient) SetUserAgent(ua string) { + c.userAgent = ua +} + +func normalizePath(path string) string { + if !strings.HasPrefix(path, "/") { + path = "/" + path + } + + if !strings.HasPrefix(path, internalApiPath) { + path = internalApiPath + path + } + return path +} + +func newRequest(ctx context.Context, method, host, path string, data interface{}) (*http.Request, error) { + var jsonReader io.Reader + if data != nil { + jsonData, err := json.Marshal(data) + if err != nil { + return nil, err + } + + jsonReader = bytes.NewReader(jsonData) + } + + request, err := http.NewRequestWithContext(ctx, method, host+path, jsonReader) + if err != nil { + return nil, err + } + + return request, nil +} + +func parseError(resp *http.Response) error { + if resp.StatusCode >= 200 && resp.StatusCode <= 399 { + return nil + } + defer resp.Body.Close() + parsedResponse := &ErrorResponse{} + + if err := json.NewDecoder(resp.Body).Decode(parsedResponse); err != nil { + return &ApiError{fmt.Sprintf("Internal API error (%v)", resp.StatusCode)} + } else { + return &ApiError{parsedResponse.Message} + } + +} + +func (c *GitlabNetClient) Get(ctx context.Context, path string) (*http.Response, error) { + return c.DoRequest(ctx, http.MethodGet, normalizePath(path), nil) +} + +func (c *GitlabNetClient) Post(ctx context.Context, path string, data interface{}) (*http.Response, error) { + return c.DoRequest(ctx, http.MethodPost, normalizePath(path), data) +} + +func (c *GitlabNetClient) DoRequest(ctx context.Context, method, path string, data interface{}) (*http.Response, error) { + request, err := newRequest(ctx, method, c.httpClient.Host, path, data) + if err != nil { + return nil, err + } + + user, password := c.user, c.password + if user != "" && password != "" { + request.SetBasicAuth(user, password) + } + secretBytes := []byte(c.secret) + + encodedSecret := base64.StdEncoding.EncodeToString(secretBytes) + request.Header.Set(secretHeaderName, encodedSecret) + + claims := jwt.RegisteredClaims{ + Issuer: jwtIssuer, + IssuedAt: jwt.NewNumericDate(time.Now()), + ExpiresAt: jwt.NewNumericDate(time.Now().Add(jwtTTL)), + } + tokenString, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString(secretBytes) + if err != nil { + return nil, err + } + request.Header.Set(apiSecretHeaderName, tokenString) + + originalRemoteIP, ok := ctx.Value(OriginalRemoteIPContextKey{}).(string) + if ok { + request.Header.Add("X-Forwarded-For", originalRemoteIP) + } + + request.Header.Add("Content-Type", "application/json") + request.Header.Add("User-Agent", c.userAgent) + request.Close = true + + start := time.Now() + response, err := c.httpClient.Do(request) + fields := log.Fields{ + "method": method, + "url": request.URL.String(), + "duration_ms": time.Since(start) / time.Millisecond, + } + logger := log.WithContextFields(ctx, fields) + + if err != nil { + logger.WithError(err).Error("Internal API unreachable") + return nil, &ApiError{"Internal API unreachable"} + } + + if response != nil { + logger = logger.WithField("status", response.StatusCode) + } + if err := parseError(response); err != nil { + logger.WithError(err).Error("Internal API error") + return nil, err + } + + if response.ContentLength >= 0 { + logger = logger.WithField("content_length_bytes", response.ContentLength) + } + + logger.Info("Finished HTTP request") + + return response, nil +} diff --git a/internal/gitlab/client/httpclient.go b/internal/gitlab/client/httpclient.go new file mode 100644 index 000000000..bd00b6b80 --- /dev/null +++ b/internal/gitlab/client/httpclient.go @@ -0,0 +1,191 @@ +package client + +import ( + "context" + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "net" + "net/http" + "os" + "path/filepath" + "strings" + "time" + + "gitlab.com/gitlab-org/labkit/correlation" + "gitlab.com/gitlab-org/labkit/tracing" +) + +const ( + socketBaseUrl = "http://unix" + unixSocketProtocol = "http+unix://" + httpProtocol = "http://" + httpsProtocol = "https://" + defaultReadTimeoutSeconds = 300 +) + +var ( + ErrCafileNotFound = errors.New("cafile not found") +) + +type HttpClient struct { + *http.Client + Host string +} + +type httpClientCfg struct { + keyPath, certPath string + caFile, caPath string +} + +func (hcc httpClientCfg) HaveCertAndKey() bool { return hcc.keyPath != "" && hcc.certPath != "" } + +// HTTPClientOpt provides options for configuring an HttpClient +type HTTPClientOpt func(*httpClientCfg) + +// WithClientCert will configure the HttpClient to provide client certificates +// when connecting to a server. +func WithClientCert(certPath, keyPath string) HTTPClientOpt { + return func(hcc *httpClientCfg) { + hcc.keyPath = keyPath + hcc.certPath = certPath + } +} + +func validateCaFile(filename string) error { + if filename == "" { + return nil + } + + if _, err := os.Stat(filename); err != nil { + if os.IsNotExist(err) { + return fmt.Errorf("cannot find cafile '%s': %w", filename, ErrCafileNotFound) + } + + return err + } + + return nil +} + +// NewHTTPClientWithOpts builds an HTTP client using the provided options +func NewHTTPClientWithOpts(gitlabURL, gitlabRelativeURLRoot, caFile, caPath string, readTimeoutSeconds uint64, opts []HTTPClientOpt) (*HttpClient, error) { + var transport *http.Transport + var host string + var err error + if strings.HasPrefix(gitlabURL, unixSocketProtocol) { + transport, host = buildSocketTransport(gitlabURL, gitlabRelativeURLRoot) + } else if strings.HasPrefix(gitlabURL, httpProtocol) { + transport, host = buildHttpTransport(gitlabURL) + } else if strings.HasPrefix(gitlabURL, httpsProtocol) { + err = validateCaFile(caFile) + if err != nil { + return nil, err + } + + hcc := &httpClientCfg{ + caFile: caFile, + caPath: caPath, + } + + for _, opt := range opts { + opt(hcc) + } + + transport, host, err = buildHttpsTransport(*hcc, gitlabURL) + if err != nil { + return nil, err + } + } else { + return nil, errors.New("unknown GitLab URL prefix") + } + + c := &http.Client{ + Transport: correlation.NewInstrumentedRoundTripper(tracing.NewRoundTripper(transport)), + Timeout: readTimeout(readTimeoutSeconds), + } + + client := &HttpClient{Client: c, Host: host} + + return client, nil +} + +func buildSocketTransport(gitlabURL, gitlabRelativeURLRoot string) (*http.Transport, string) { + socketPath := strings.TrimPrefix(gitlabURL, unixSocketProtocol) + + transport := &http.Transport{ + DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) { + dialer := net.Dialer{} + return dialer.DialContext(ctx, "unix", socketPath) + }, + } + + host := socketBaseUrl + gitlabRelativeURLRoot = strings.Trim(gitlabRelativeURLRoot, "/") + if gitlabRelativeURLRoot != "" { + host = host + "/" + gitlabRelativeURLRoot + } + + return transport, host +} + +func buildHttpsTransport(hcc httpClientCfg, gitlabURL string) (*http.Transport, string, error) { + certPool, err := x509.SystemCertPool() + + if err != nil { + certPool = x509.NewCertPool() + } + + if hcc.caFile != "" { + addCertToPool(certPool, hcc.caFile) + } + + if hcc.caPath != "" { + fis, _ := os.ReadDir(hcc.caPath) + for _, fi := range fis { + if fi.IsDir() { + continue + } + + addCertToPool(certPool, filepath.Join(hcc.caPath, fi.Name())) + } + } + tlsConfig := &tls.Config{ + RootCAs: certPool, + MinVersion: tls.VersionTLS12, + } + + if hcc.HaveCertAndKey() { + cert, err := tls.LoadX509KeyPair(hcc.certPath, hcc.keyPath) + if err != nil { + return nil, "", err + } + tlsConfig.Certificates = []tls.Certificate{cert} + } + + transport := &http.Transport{ + TLSClientConfig: tlsConfig, + } + + return transport, gitlabURL, err +} + +func addCertToPool(certPool *x509.CertPool, fileName string) { + cert, err := os.ReadFile(fileName) + if err == nil { + certPool.AppendCertsFromPEM(cert) + } +} + +func buildHttpTransport(gitlabURL string) (*http.Transport, string) { + return &http.Transport{}, gitlabURL +} + +func readTimeout(timeoutSeconds uint64) time.Duration { + if timeoutSeconds == 0 { + timeoutSeconds = defaultReadTimeoutSeconds + } + + return time.Duration(timeoutSeconds) * time.Second +} |