diff options
author | Jaime Martinez <jmartinez@gitlab.com> | 2022-04-08 04:03:59 +0300 |
---|---|---|
committer | Jaime Martinez <jmartinez@gitlab.com> | 2022-04-08 04:12:11 +0300 |
commit | 4b70ccd75479dea8cf14a755e3a2ecba7797a449 (patch) | |
tree | bbc0aaa1112156e77829945516e814b118ebad35 | |
parent | 4b1afecbb6ae1886bfd3a31f256909ca2770bce4 (diff) |
fix: handle context canceled gracefully for auth and artifacts
Changelog: changed
-rw-r--r-- | internal/artifact/artifact.go | 6 | ||||
-rw-r--r-- | internal/artifact/artifact_test.go | 22 | ||||
-rw-r--r-- | internal/auth/auth.go | 9 | ||||
-rw-r--r-- | internal/auth/auth_test.go | 20 |
4 files changed, 57 insertions, 0 deletions
diff --git a/internal/artifact/artifact.go b/internal/artifact/artifact.go index 0a4f0ab4..dab1fb91 100644 --- a/internal/artifact/artifact.go +++ b/internal/artifact/artifact.go @@ -1,6 +1,7 @@ package artifact import ( + "context" "errors" "fmt" "io" @@ -89,6 +90,11 @@ func (a *Artifact) makeRequest(w http.ResponseWriter, r *http.Request, reqURL *u } resp, err := a.client.Do(req) if err != nil { + if errors.Is(err, context.Canceled) { + httperrors.Serve404(w) + return + } + logging.LogRequest(r).WithError(err).Error(artifactRequestErrMsg) errortracking.CaptureErrWithReqAndStackTrace(err, r) httperrors.Serve502(w) diff --git a/internal/artifact/artifact_test.go b/internal/artifact/artifact_test.go index ab25d16f..db7a2a84 100644 --- a/internal/artifact/artifact_test.go +++ b/internal/artifact/artifact_test.go @@ -1,6 +1,7 @@ package artifact_test import ( + "context" "fmt" "net/http" "net/http/httptest" @@ -275,3 +276,24 @@ func TestBuildURL(t *testing.T) { }) } } + +func TestContextCanceled(t *testing.T) { + content := "<!DOCTYPE html><html><head><title>Title of the document</title></head><body></body></html>" + contentType := "text/html; charset=utf-8" + testServer := makeArtifactServerStub(t, content, contentType) + t.Cleanup(testServer.Close) + + result := httptest.NewRecorder() + reqURL, err := url.Parse("/-/subgroup/project/-/jobs/1/artifacts/200.html") + require.NoError(t, err) + r := &http.Request{URL: reqURL} + ctx, cancel := context.WithCancel(context.Background()) + r = r.WithContext(ctx) + // cancel context explictly + cancel() + art := artifact.New(testServer.URL, 1, "gitlab-example.io") + + require.True(t, art.TryMakeRequest("group.gitlab-example.io", result, r, "", func(resp *http.Response) bool { return false })) + require.Equal(t, http.StatusNotFound, result.Code) + +} diff --git a/internal/auth/auth.go b/internal/auth/auth.go index 21398e56..681e3199 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -359,6 +359,10 @@ func (a *Auth) fetchAccessToken(ctx context.Context, code string) (tokenResponse // Request token resp, err := a.apiClient.Do(req) if err != nil { + if errors.Is(err, context.Canceled) { + return token, nil + } + return token, err } @@ -476,6 +480,11 @@ func (a *Auth) checkAuthentication(w http.ResponseWriter, r *http.Request, domai req.Header.Add("Authorization", "Bearer "+session.Values["access_token"].(string)) resp, err := a.apiClient.Do(req) if err != nil { + if errors.Is(err, context.Canceled) { + httperrors.Serve404(w) + return true + } + logRequest(r).WithError(err).Error("Failed to retrieve info with token") errortracking.CaptureErrWithReqAndStackTrace(err, r) // call serve404 handler when auth fails diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go index 586c3839..4236d695 100644 --- a/internal/auth/auth_test.go +++ b/internal/auth/auth_test.go @@ -270,6 +270,26 @@ func TestCheckAuthenticationWhenAccess(t *testing.T) { require.Equal(t, http.StatusOK, result.Code) } +func TestCheckAuthenticationWhenContextCanceled(t *testing.T) { + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) + t.Cleanup(apiServer.Close) + + auth := createTestAuth(t, apiServer.URL, "") + + result := httptest.NewRecorder() + r, err := http.NewRequest("Get", "https://example.com/", nil) + require.NoError(t, err) + ctx, cancel := context.WithCancel(r.Context()) + r = r.WithContext(ctx) + setSessionValues(t, r, auth, map[interface{}]interface{}{"access_token": "abc"}) + + // cancel context explicitly and expect not found + cancel() + contentServed := auth.CheckAuthentication(result, r, &domainMock{projectID: 1000}) + require.True(t, contentServed) + require.Equal(t, http.StatusNotFound, result.Code) +} + func TestCheckAuthenticationWhenNoAccess(t *testing.T) { apiServer := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { |