Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitlab-pages.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJaime Martinez <jmartinez@gitlab.com>2022-04-08 04:03:59 +0300
committerJaime Martinez <jmartinez@gitlab.com>2022-04-08 04:12:11 +0300
commit4b70ccd75479dea8cf14a755e3a2ecba7797a449 (patch)
treebbc0aaa1112156e77829945516e814b118ebad35 /internal/auth
parent4b1afecbb6ae1886bfd3a31f256909ca2770bce4 (diff)
fix: handle context canceled gracefully for auth and artifacts
Changelog: changed
Diffstat (limited to 'internal/auth')
-rw-r--r--internal/auth/auth.go9
-rw-r--r--internal/auth/auth_test.go20
2 files changed, 29 insertions, 0 deletions
diff --git a/internal/auth/auth.go b/internal/auth/auth.go
index 21398e56..681e3199 100644
--- a/internal/auth/auth.go
+++ b/internal/auth/auth.go
@@ -359,6 +359,10 @@ func (a *Auth) fetchAccessToken(ctx context.Context, code string) (tokenResponse
// Request token
resp, err := a.apiClient.Do(req)
if err != nil {
+ if errors.Is(err, context.Canceled) {
+ return token, nil
+ }
+
return token, err
}
@@ -476,6 +480,11 @@ func (a *Auth) checkAuthentication(w http.ResponseWriter, r *http.Request, domai
req.Header.Add("Authorization", "Bearer "+session.Values["access_token"].(string))
resp, err := a.apiClient.Do(req)
if err != nil {
+ if errors.Is(err, context.Canceled) {
+ httperrors.Serve404(w)
+ return true
+ }
+
logRequest(r).WithError(err).Error("Failed to retrieve info with token")
errortracking.CaptureErrWithReqAndStackTrace(err, r)
// call serve404 handler when auth fails
diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go
index 586c3839..4236d695 100644
--- a/internal/auth/auth_test.go
+++ b/internal/auth/auth_test.go
@@ -270,6 +270,26 @@ func TestCheckAuthenticationWhenAccess(t *testing.T) {
require.Equal(t, http.StatusOK, result.Code)
}
+func TestCheckAuthenticationWhenContextCanceled(t *testing.T) {
+ apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }))
+ t.Cleanup(apiServer.Close)
+
+ auth := createTestAuth(t, apiServer.URL, "")
+
+ result := httptest.NewRecorder()
+ r, err := http.NewRequest("Get", "https://example.com/", nil)
+ require.NoError(t, err)
+ ctx, cancel := context.WithCancel(r.Context())
+ r = r.WithContext(ctx)
+ setSessionValues(t, r, auth, map[interface{}]interface{}{"access_token": "abc"})
+
+ // cancel context explicitly and expect not found
+ cancel()
+ contentServed := auth.CheckAuthentication(result, r, &domainMock{projectID: 1000})
+ require.True(t, contentServed)
+ require.Equal(t, http.StatusNotFound, result.Code)
+}
+
func TestCheckAuthenticationWhenNoAccess(t *testing.T) {
apiServer := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {