Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitlab-foss.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'workhorse/internal/upload/saved_file_tracker.go')
-rw-r--r--workhorse/internal/upload/saved_file_tracker.go55
1 files changed, 55 insertions, 0 deletions
diff --git a/workhorse/internal/upload/saved_file_tracker.go b/workhorse/internal/upload/saved_file_tracker.go
new file mode 100644
index 00000000000..7b6cade4faa
--- /dev/null
+++ b/workhorse/internal/upload/saved_file_tracker.go
@@ -0,0 +1,55 @@
+package upload
+
+import (
+ "context"
+ "fmt"
+ "mime/multipart"
+ "net/http"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/filestore"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/secret"
+)
+
+type SavedFileTracker struct {
+ Request *http.Request
+ rewrittenFields map[string]string
+}
+
+func (s *SavedFileTracker) Track(fieldName string, localPath string) {
+ if s.rewrittenFields == nil {
+ s.rewrittenFields = make(map[string]string)
+ }
+ s.rewrittenFields[fieldName] = localPath
+}
+
+func (s *SavedFileTracker) Count() int {
+ return len(s.rewrittenFields)
+}
+
+func (s *SavedFileTracker) ProcessFile(_ context.Context, fieldName string, file *filestore.FileHandler, _ *multipart.Writer) error {
+ s.Track(fieldName, file.LocalPath)
+ return nil
+}
+
+func (s *SavedFileTracker) ProcessField(_ context.Context, _ string, _ *multipart.Writer) error {
+ return nil
+}
+
+func (s *SavedFileTracker) Finalize(_ context.Context) error {
+ if s.rewrittenFields == nil {
+ return nil
+ }
+
+ claims := MultipartClaims{RewrittenFields: s.rewrittenFields, StandardClaims: secret.DefaultClaims}
+ tokenString, err := secret.JWTTokenString(claims)
+ if err != nil {
+ return fmt.Errorf("savedFileTracker.Finalize: %v", err)
+ }
+
+ s.Request.Header.Set(RewrittenFieldsHeader, tokenString)
+ return nil
+}
+
+func (s *SavedFileTracker) Name() string {
+ return "accelerate"
+}