diff options
author | Nick Thomas <nick@gitlab.com> | 2018-09-03 21:39:24 +0300 |
---|---|---|
committer | Nick Thomas <nick@gitlab.com> | 2018-09-03 21:39:24 +0300 |
commit | a96f9a78f30df440efa682bb84e0c96b247fa138 (patch) | |
tree | e86a8f95e7773d60273bc732e4acbb608ba3755a /internal | |
parent | 036f5bd5f519d54a502ae44e966e6c5dbcefc315 (diff) | |
parent | 5cffa83537890540d74664a43e828cd81a164980 (diff) |
Merge branch 'master' into auth
Diffstat (limited to 'internal')
-rw-r--r-- | internal/domain/domain.go | 20 | ||||
-rw-r--r-- | internal/domain/domain_test.go | 96 | ||||
-rw-r--r-- | internal/domain/map.go | 34 | ||||
-rw-r--r-- | internal/domain/map_test.go | 22 |
4 files changed, 153 insertions, 19 deletions
diff --git a/internal/domain/domain.go b/internal/domain/domain.go index 77429372..1ea008c4 100644 --- a/internal/domain/domain.go +++ b/internal/domain/domain.go @@ -100,10 +100,18 @@ func setContentType(w http.ResponseWriter, fullPath string) { } func (d *D) getProject(r *http.Request) *project { + // Check default domain config (e.g. http://mydomain.gitlab.io) + if groupProject := d.projects[strings.ToLower(r.Host)]; groupProject != nil { + return groupProject + } + + // Check URLs with multiple projects for a group + // (e.g. http://group.gitlab.io/projectA and http://group.gitlab.io/projectB) split := strings.SplitN(r.URL.Path, "/", 3) if len(split) < 2 { return nil } + return d.projects[split[1]] } @@ -114,13 +122,13 @@ func (d *D) IsHTTPSOnly(r *http.Request) bool { return false } + // Check custom domain config (e.g. http://example.com) if d.config != nil { return d.config.HTTPSOnly } - project := d.getProject(r) - - if project != nil { + // Check projects served under the group domain, including the default one + if project := d.getProject(r); project != nil { return project.HTTPSOnly } @@ -133,13 +141,13 @@ func (d *D) IsAccessControlEnabled(r *http.Request) bool { return false } + // Check custom domain config (e.g. http://example.com) if d.config != nil { return d.config.AccessControl } - project := d.getProject(r) - - if project != nil { + // Check projects served under the group domain, including the default one + if project := d.getProject(r); project != nil { return project.AccessControl } diff --git a/internal/domain/domain_test.go b/internal/domain/domain_test.go index 4cbc6cc2..64d11a29 100644 --- a/internal/domain/domain_test.go +++ b/internal/domain/domain_test.go @@ -70,6 +70,102 @@ func TestDomainServeHTTP(t *testing.T) { assert.HTTPError(t, testDomain.ServeHTTP, "GET", "/not-existing-file", nil) } +func TestIsHTTPSOnly(t *testing.T) { + tests := []struct { + name string + domain *D + url string + expected bool + }{ + { + name: "Custom domain with HTTPS-only enabled", + domain: &D{ + group: "group", + projectName: "project", + config: &domainConfig{HTTPSOnly: true}, + }, + url: "http://custom-domain", + expected: true, + }, + { + name: "Custom domain with HTTPS-only disabled", + domain: &D{ + group: "group", + projectName: "project", + config: &domainConfig{HTTPSOnly: false}, + }, + url: "http://custom-domain", + expected: false, + }, + { + name: "Default group domain with HTTPS-only enabled", + domain: &D{ + group: "group", + projectName: "project", + projects: projects{"test-domain": &project{HTTPSOnly: true}}, + }, + url: "http://test-domain", + expected: true, + }, + { + name: "Default group domain with HTTPS-only disabled", + domain: &D{ + group: "group", + projectName: "project", + projects: projects{"test-domain": &project{HTTPSOnly: false}}, + }, + url: "http://test-domain", + expected: false, + }, + { + name: "Case-insensitive default group domain with HTTPS-only enabled", + domain: &D{ + group: "group", + projectName: "project", + projects: projects{"test-domain": &project{HTTPSOnly: true}}, + }, + url: "http://Test-domain", + expected: true, + }, + { + name: "Other group domain with HTTPS-only enabled", + domain: &D{ + group: "group", + projectName: "project", + projects: projects{"project": &project{HTTPSOnly: true}}, + }, + url: "http://test-domain/project", + expected: true, + }, + { + name: "Other group domain with HTTPS-only disabled", + domain: &D{ + group: "group", + projectName: "project", + projects: projects{"project": &project{HTTPSOnly: false}}, + }, + url: "http://test-domain/project", + expected: false, + }, + { + name: "Unknown project", + domain: &D{ + group: "group", + projectName: "project", + }, + url: "http://test-domain/project", + expected: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + req, _ := http.NewRequest(http.MethodGet, test.url, nil) + assert.Equal(t, test.domain.IsHTTPSOnly(req), test.expected) + }) + } +} + func testHTTPGzip(t *testing.T, handler http.HandlerFunc, mode, url string, values url.Values, acceptEncoding string, str interface{}, ungzip bool) { w := httptest.NewRecorder() req, err := http.NewRequest(mode, url+"?"+values.Encode(), nil) diff --git a/internal/domain/map.go b/internal/domain/map.go index 5c6d5f5d..b832ffb3 100644 --- a/internal/domain/map.go +++ b/internal/domain/map.go @@ -20,6 +20,20 @@ type Map map[string]*D type domainsUpdater func(Map) +func (dm Map) updateDomainMap(domainName string, domain *D) { + if old, ok := dm[domainName]; ok { + log.WithFields(log.Fields{ + "domain_name": domainName, + "new_group": domain.group, + "new_project_name": domain.projectName, + "old_group": old.group, + "old_project_name": old.projectName, + }).Error("Duplicate domain") + } + + dm[domainName] = domain +} + func (dm Map) addDomain(rootDomain, group, projectName string, config *domainConfig) { newDomain := &D{ group: group, @@ -29,7 +43,7 @@ func (dm Map) addDomain(rootDomain, group, projectName string, config *domainCon var domainName string domainName = strings.ToLower(config.Domain) - dm[domainName] = newDomain + dm.updateDomainMap(domainName, newDomain) } func (dm Map) updateGroupDomain(rootDomain, group, projectName string, httpsOnly bool, accessControl bool, id uint64) { @@ -43,7 +57,7 @@ func (dm Map) updateGroupDomain(rootDomain, group, projectName string, httpsOnly } } - groupDomain.projects[projectName] = &project{ + groupDomain.projects[strings.ToLower(projectName)] = &project{ HTTPSOnly: httpsOnly, AccessControl: accessControl, ID: id, @@ -121,12 +135,7 @@ type jobResult struct { } // ReadGroups walks the pages directory and populates dm with all the domains it finds. -func (dm Map) ReadGroups(rootDomain string) error { - fis, err := godirwalk.ReadDirents(".", nil) - if err != nil { - return err - } - +func (dm Map) ReadGroups(rootDomain string, fis godirwalk.Dirents) { fanOutGroups := make(chan string) fanIn := make(chan jobResult) wg := &sync.WaitGroup{} @@ -177,7 +186,6 @@ func (dm Map) ReadGroups(rootDomain string) error { close(fanOutGroups) <-done - return nil } const ( @@ -206,9 +214,15 @@ func Watch(rootDomain string, updater domainsUpdater, interval time.Duration) { started := time.Now() dm := make(Map) - if err := dm.ReadGroups(rootDomain); err != nil { + + fis, err := godirwalk.ReadDirents(".", nil) + if err != nil { log.WithError(err).Warn("domain scan failed") + metrics.FailedDomainUpdates.Inc() + continue } + + dm.ReadGroups(rootDomain, fis) duration := time.Since(started).Seconds() var hash string diff --git a/internal/domain/map_test.go b/internal/domain/map_test.go index 45658e95..88b406bf 100644 --- a/internal/domain/map_test.go +++ b/internal/domain/map_test.go @@ -8,16 +8,32 @@ import ( "testing" "time" + "github.com/karrick/godirwalk" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +func getEntries(t *testing.T) godirwalk.Dirents { + fis, err := godirwalk.ReadDirents(".", nil) + + require.NoError(t, err) + + return fis +} + +func getEntriesForBenchmark(t *testing.B) godirwalk.Dirents { + fis, err := godirwalk.ReadDirents(".", nil) + + require.NoError(t, err) + + return fis +} + func TestReadProjects(t *testing.T) { setUpTests() dm := make(Map) - err := dm.ReadGroups("test.io") - require.NoError(t, err) + dm.ReadGroups("test.io", getEntries(t)) var domains []string for d := range dm { @@ -142,7 +158,7 @@ func BenchmarkReadGroups(b *testing.B) { var dm Map for i := 0; i < 2; i++ { dm = make(Map) - require.NoError(b, dm.ReadGroups("example.com")) + dm.ReadGroups("example.com", getEntriesForBenchmark(b)) } b.Logf("found %d domains", len(dm)) }) |