Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitlab-pages.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJaime Martinez <jmartinez@gitlab.com>2022-04-08 04:03:59 +0300
committerJaime Martinez <jmartinez@gitlab.com>2022-04-08 04:12:11 +0300
commit4b70ccd75479dea8cf14a755e3a2ecba7797a449 (patch)
treebbc0aaa1112156e77829945516e814b118ebad35
parent4b1afecbb6ae1886bfd3a31f256909ca2770bce4 (diff)
fix: handle context canceled gracefully for auth and artifacts
Changelog: changed
-rw-r--r--internal/artifact/artifact.go6
-rw-r--r--internal/artifact/artifact_test.go22
-rw-r--r--internal/auth/auth.go9
-rw-r--r--internal/auth/auth_test.go20
4 files changed, 57 insertions, 0 deletions
diff --git a/internal/artifact/artifact.go b/internal/artifact/artifact.go
index 0a4f0ab4..dab1fb91 100644
--- a/internal/artifact/artifact.go
+++ b/internal/artifact/artifact.go
@@ -1,6 +1,7 @@
package artifact
import (
+ "context"
"errors"
"fmt"
"io"
@@ -89,6 +90,11 @@ func (a *Artifact) makeRequest(w http.ResponseWriter, r *http.Request, reqURL *u
}
resp, err := a.client.Do(req)
if err != nil {
+ if errors.Is(err, context.Canceled) {
+ httperrors.Serve404(w)
+ return
+ }
+
logging.LogRequest(r).WithError(err).Error(artifactRequestErrMsg)
errortracking.CaptureErrWithReqAndStackTrace(err, r)
httperrors.Serve502(w)
diff --git a/internal/artifact/artifact_test.go b/internal/artifact/artifact_test.go
index ab25d16f..db7a2a84 100644
--- a/internal/artifact/artifact_test.go
+++ b/internal/artifact/artifact_test.go
@@ -1,6 +1,7 @@
package artifact_test
import (
+ "context"
"fmt"
"net/http"
"net/http/httptest"
@@ -275,3 +276,24 @@ func TestBuildURL(t *testing.T) {
})
}
}
+
+func TestContextCanceled(t *testing.T) {
+ content := "<!DOCTYPE html><html><head><title>Title of the document</title></head><body></body></html>"
+ contentType := "text/html; charset=utf-8"
+ testServer := makeArtifactServerStub(t, content, contentType)
+ t.Cleanup(testServer.Close)
+
+ result := httptest.NewRecorder()
+ reqURL, err := url.Parse("/-/subgroup/project/-/jobs/1/artifacts/200.html")
+ require.NoError(t, err)
+ r := &http.Request{URL: reqURL}
+ ctx, cancel := context.WithCancel(context.Background())
+ r = r.WithContext(ctx)
+ // cancel context explictly
+ cancel()
+ art := artifact.New(testServer.URL, 1, "gitlab-example.io")
+
+ require.True(t, art.TryMakeRequest("group.gitlab-example.io", result, r, "", func(resp *http.Response) bool { return false }))
+ require.Equal(t, http.StatusNotFound, result.Code)
+
+}
diff --git a/internal/auth/auth.go b/internal/auth/auth.go
index 21398e56..681e3199 100644
--- a/internal/auth/auth.go
+++ b/internal/auth/auth.go
@@ -359,6 +359,10 @@ func (a *Auth) fetchAccessToken(ctx context.Context, code string) (tokenResponse
// Request token
resp, err := a.apiClient.Do(req)
if err != nil {
+ if errors.Is(err, context.Canceled) {
+ return token, nil
+ }
+
return token, err
}
@@ -476,6 +480,11 @@ func (a *Auth) checkAuthentication(w http.ResponseWriter, r *http.Request, domai
req.Header.Add("Authorization", "Bearer "+session.Values["access_token"].(string))
resp, err := a.apiClient.Do(req)
if err != nil {
+ if errors.Is(err, context.Canceled) {
+ httperrors.Serve404(w)
+ return true
+ }
+
logRequest(r).WithError(err).Error("Failed to retrieve info with token")
errortracking.CaptureErrWithReqAndStackTrace(err, r)
// call serve404 handler when auth fails
diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go
index 586c3839..4236d695 100644
--- a/internal/auth/auth_test.go
+++ b/internal/auth/auth_test.go
@@ -270,6 +270,26 @@ func TestCheckAuthenticationWhenAccess(t *testing.T) {
require.Equal(t, http.StatusOK, result.Code)
}
+func TestCheckAuthenticationWhenContextCanceled(t *testing.T) {
+ apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }))
+ t.Cleanup(apiServer.Close)
+
+ auth := createTestAuth(t, apiServer.URL, "")
+
+ result := httptest.NewRecorder()
+ r, err := http.NewRequest("Get", "https://example.com/", nil)
+ require.NoError(t, err)
+ ctx, cancel := context.WithCancel(r.Context())
+ r = r.WithContext(ctx)
+ setSessionValues(t, r, auth, map[interface{}]interface{}{"access_token": "abc"})
+
+ // cancel context explicitly and expect not found
+ cancel()
+ contentServed := auth.CheckAuthentication(result, r, &domainMock{projectID: 1000})
+ require.True(t, contentServed)
+ require.Equal(t, http.StatusNotFound, result.Code)
+}
+
func TestCheckAuthenticationWhenNoAccess(t *testing.T) {
apiServer := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {