diff options
-rw-r--r-- | domain.go | 25 | ||||
-rw-r--r-- | domain_config.go | 28 | ||||
-rw-r--r-- | domains.go | 128 | ||||
-rw-r--r-- | main.go | 23 | ||||
-rw-r--r-- | server.go | 6 |
5 files changed, 142 insertions, 68 deletions
@@ -1,18 +1,18 @@ package main import ( + "crypto/tls" + "errors" "net/http" "os" "path/filepath" "strings" - "crypto/tls" - "errors" ) type domain struct { Group string Project string - CNAME bool + Config *domainConfig certificate *tls.Certificate } @@ -63,7 +63,7 @@ func (d *domain) tryFile(w http.ResponseWriter, r *http.Request, projectName, su return true } -func (d *domain) serverGroup(w http.ResponseWriter, r *http.Request) { +func (d *domain) serveFromGroup(w http.ResponseWriter, r *http.Request) { // The Path always contains "/" at the beggining split := strings.SplitN(r.URL.Path, "/", 3) @@ -84,7 +84,7 @@ func (d *domain) serverGroup(w http.ResponseWriter, r *http.Request) { d.notFound(w, r) } -func (d *domain) serveCNAME(w http.ResponseWriter, r *http.Request) { +func (d *domain) serveFromConfig(w http.ResponseWriter, r *http.Request) { if d.tryFile(w, r, d.Project, r.URL.Path) { return } @@ -93,18 +93,15 @@ func (d *domain) serveCNAME(w http.ResponseWriter, r *http.Request) { } func (d *domain) ensureCertificate() (*tls.Certificate, error) { - if !d.CNAME { - return nil, errors.New("tls certificates can be loaded only for pages with CNAME") + if !d.Config { + return nil, errors.New("tls certificates can be loaded only for pages with configuration") } if d.certificate != nil { return d.certificate, nil } - // Load keypair from shared/pages/group/project/domain.{crt,key} - certificateFile := filepath.Join(*pagesRoot, d.Group, d.Project, "domain.crt") - keyFile := filepath.Join(*pagesRoot, d.Group, d.Project, "domain.key") - tls, err := tls.LoadX509KeyPair(certificateFile, keyFile) + tls, err := tls.X509KeyPair([]byte(d.Config.Certificate), []byte(d.Config.Key)) if err != nil { return nil, err } @@ -114,9 +111,9 @@ func (d *domain) ensureCertificate() (*tls.Certificate, error) { } func (d *domain) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if d.CNAME { - d.serveCNAME(w, r) + if d.Config != nil { + d.serveFromConfig(w, r) } else { - d.serverGroup(w, r) + d.serveFromGroup(w, r) } } diff --git a/domain_config.go b/domain_config.go new file mode 100644 index 00000000..d16ca9ad --- /dev/null +++ b/domain_config.go @@ -0,0 +1,28 @@ +package main + +import ( + "encoding/json" + "os" + "path/filepath" +) + +type domainConfig struct { + Domain string + Certificate string + Key string +} + +type domainsConfig struct { + Domains []domainConfig +} + +func (c *domainsConfig) Read(group, project string) (err error) { + configFile, err := os.Open(filepath.Join(*pagesRoot, project, group, "config.json")) + if err != nil { + return nil + } + defer configFile.Close() + + err = json.NewDecoder(configFile).Decode(c) + return +} @@ -1,6 +1,9 @@ package main import ( + "bytes" + "encoding/json" + "errors" "io/ioutil" "log" "os" @@ -13,76 +16,119 @@ type domains map[string]domain type domainsUpdater func(domains domains) -func readGroups(domains domains) error { - groups, err := filepath.Glob(filepath.Join(*pagesRoot, "*/")) - if err != nil { - return err +func isDomainAllowed(domain string) bool { + if domain == "" { + return false + } + // TODO: better sanitize domain + domain = strings.ToLower(domain) + pagesDomain = "." + strings.ToLower(pagesDomain) + return !strings.HasPrefix(domain, pagesDomain) +} + +func (d domains) addDomain(group, project string, config *domainConfig) error { + newDomain := &domain{ + Group: group, + Project: project, + Config: domainConfig, } - for _, groupDir := range groups { - group := filepath.Base(groupDir) - groupName := strings.ToLower(group) - domains[groupName+"."+*pagesDomain] = domain{ - Group: group, - CNAME: false, + if config != nil { + if !isDomainAllowed(domainConfig.Domain) { + return errors.New("domain name is not allowed") } + + d[config.Domain] = newDomain + } else { + domainName := group + "." + *pagesDomain + d[domainName] = newDomain } - return nil + return } -func readCnames(domains domains) error { - cnames, err := filepath.Glob(filepath.Join(*pagesRoot, "*/*/CNAME")) +func (d domains) readProjects(group string) (count int) { + projects, err := os.Open(filepath.Join(*pagesRoot, group)) if err != nil { - return err + return + } + defer projects.Close() + + fis, err := projects.Readdir(0) + if err != nil { + log.Println("Failed to Readdir for ", *pagesRoot, ":", err) } - for _, cnamePath := range cnames { - cnameData, err := ioutil.ReadFile(cnamePath) + for _, project := range fis { + if !project.IsDir() { + continue + } + if strings.HasPrefix(project.Name(), ".") { + continue + } + + count++ + + var config domainsConfig + err := config.Read(group, project.Name()) if err != nil { continue } - for _, cname := range strings.Fields(string(cnameData)) { - cname := strings.ToLower(cname) - if strings.HasSuffix(cname, "."+*pagesDomain) { - continue - } + for _, domainConfig := range domainsConfig.Domains { + d.addDomain(group, project.Name(), &domainConfig) + } + } + return +} - domains[cname] = domain{ - // TODO: make it nicer - Group: filepath.Base(filepath.Dir(filepath.Dir(cnamePath))), - Project: filepath.Base(filepath.Dir(cnamePath)), - CNAME: true, - } +func (d domains) ReadGroups() error { + groups, err := os.Open(*pagesRoot) + if err != nil { + return err + } + defer groups.Close() + + fis, err := groups.Readdir(0) + if err != nil { + log.Println("Failed to Readdir for ", *pagesRoot, ":", err) + } + + for _, group := range fis { + if !group.IsDir() { + continue + } + if strings.HasPrefix(group.Name(), ".") { + continue + } + + count := d.readProjects(group.Name()) + if count > 0 { + d.addDomain(group, "", &domainConfig) } } return nil } func watchDomains(updater domainsUpdater) { - var lastModified time.Time + lastUpdate := "no-configuration" for { - fi, err := os.Stat(*pagesRoot) - if err != nil || !fi.IsDir() { - log.Println("Failed to read domains from", *pagesRoot, "due to:", err, fi.IsDir()) - time.Sleep(time.Second) - continue - } - - // If directory did not get modified we will reload - if !lastModified.Before(fi.ModTime()) { - time.Sleep(time.Second) + update, err := ioutil.ReadFile(filepath.Join(*pagesRoot, ".update")) + if bytes.Equal(lastUpdate, update) { + if err != nil { + log.Println("Failed to read update timestamp:", err) + time.Sleep(time.Second) + } continue } - lastModified = fi.ModTime() + lastUpdate = update started := time.Now() domains := make(domains) - readGroups(domains) - readCnames(domains) + domains.ReadGroups() duration := time.Since(started) log.Println("Updated", len(domains), "domains in", duration) + if updater != nil { updater(domains) } @@ -8,9 +8,13 @@ import ( "net/http" "strings" "sync" + "syscall" ) +// VERSION stores the information about the semantic version of application var VERSION = "dev" + +// REVISION stores the information about the git revision of application var REVISION = "HEAD" var listenHTTP = flag.String("listen-http", ":80", "The address to listen for HTTP requests") @@ -23,8 +27,8 @@ var serverHTTP = flag.Bool("serve-http", true, "Serve the pages under HTTP") var http2proto = flag.Bool("http2", true, "Enable HTTP2 support") var pagesRoot = flag.String("pages-root", "shared/pages", "The directory where pages are stored") -const XForwardedProto = "X-Forwarded-Proto" -const XForwardedProtoHttps = "https" +const xForwardedProto = "X-Forwarded-Proto" +const xForwardedProtoHTTPS = "https" type theApp struct { domains domains @@ -77,8 +81,8 @@ func (a *theApp) ServeHTTP(ww http.ResponseWriter, r *http.Request) { } func (a *theApp) ServeProxy(ww http.ResponseWriter, r *http.Request) { - forwardedProto := r.Header.Get(XForwardedProto) - https := forwardedProto == XForwardedProtoHttps + forwardedProto := r.Header.Get(xForwardedProto) + https := forwardedProto == xForwardedProtoHTTPS a.serveContent(ww, r, https) } @@ -91,9 +95,8 @@ func main() { var wg sync.WaitGroup var app theApp - fmt.Println("GitLab Pages Daemon %s (%s)", VERSION, REVISION) - fmt.Println("URL: https://gitlab.com/gitlab-org/gitlab-pages") - + fmt.Printf("GitLab Pages Daemon %s (%s)", VERSION, REVISION) + fmt.Printf("URL: https://gitlab.com/gitlab-org/gitlab-pages") flag.Parse() // Listen for HTTP @@ -101,7 +104,7 @@ func main() { wg.Add(1) go func() { defer wg.Done() - err := ListenAndServe(*listenHTTP, app.ServeHTTP) + err := listenAndServe(*listenHTTP, app.ServeHTTP) if err != nil { log.Fatal(err) } @@ -113,7 +116,7 @@ func main() { wg.Add(1) go func() { defer wg.Done() - err := ListenAndServeTLS(*listenHTTPS, *pagesRootCert, *pagesRootKey, app.ServeHTTP, app.ServeTLS) + err := listenAndServeTLS(*listenHTTPS, *pagesRootCert, *pagesRootKey, app.ServeHTTP, app.ServeTLS) if err != nil { log.Fatal(err) } @@ -125,7 +128,7 @@ func main() { wg.Add(1) go func() { defer wg.Done() - err := ListenAndServe(*listenProxy, app.ServeProxy) + err := listenAndServe(*listenProxy, app.ServeProxy) if err != nil { log.Fatal(err) } @@ -6,9 +6,9 @@ import ( "net/http" ) -type TLSHandlerFunc func(*tls.ClientHelloInfo) (*tls.Certificate, error) +type tlsHandlerFunc func(*tls.ClientHelloInfo) (*tls.Certificate, error) -func ListenAndServe(addr string, handler http.HandlerFunc) error { +func listenAndServe(addr string, handler http.HandlerFunc) error { // create server server := &http.Server{Addr: addr, Handler: handler} @@ -22,7 +22,7 @@ func ListenAndServe(addr string, handler http.HandlerFunc) error { return server.ListenAndServe() } -func ListenAndServeTLS(addr string, certFile, keyFile string, handler http.HandlerFunc, tlsHandler TLSHandlerFunc) error { +func listenAndServeTLS(addr string, certFile, keyFile string, handler http.HandlerFunc, tlsHandler tlsHandlerFunc) error { // create server server := &http.Server{Addr: addr, Handler: handler} server.TLSConfig = &tls.Config{} |