diff options
author | Martin Polden <mpolden@mpolden.no> | 2018-02-10 15:24:32 +0300 |
---|---|---|
committer | Martin Polden <mpolden@mpolden.no> | 2018-02-10 15:24:32 +0300 |
commit | 35061bfe83e5f71f7e4f9e74f8a23e1589d89dee (patch) | |
tree | b3328f90df268fa41a17d786be6311d636db9a2b /http | |
parent | 7362c9043aaa2791a46f8d401cdc15be04d8d878 (diff) |
Restructure
Diffstat (limited to 'http')
-rw-r--r-- | http/error.go | 40 | ||||
-rw-r--r-- | http/http.go | 285 | ||||
-rw-r--r-- | http/http_test.go | 186 | ||||
-rw-r--r-- | http/oracle.go | 163 |
4 files changed, 674 insertions, 0 deletions
diff --git a/http/error.go b/http/error.go new file mode 100644 index 0000000..72c6fce --- /dev/null +++ b/http/error.go @@ -0,0 +1,40 @@ +package http + +import "net/http" + +type appError struct { + Error error + Message string + Code int + ContentType string +} + +func internalServerError(err error) *appError { + return &appError{ + Error: err, + Message: "Internal server error", + Code: http.StatusInternalServerError, + } +} + +func notFound(err error) *appError { + return &appError{Error: err, Code: http.StatusNotFound} +} + +func badRequest(err error) *appError { + return &appError{Error: err, Code: http.StatusBadRequest} +} + +func (e *appError) AsJSON() *appError { + e.ContentType = jsonMediaType + return e +} + +func (e *appError) WithMessage(message string) *appError { + e.Message = message + return e +} + +func (e *appError) IsJSON() bool { + return e.ContentType == jsonMediaType +} diff --git a/http/http.go b/http/http.go new file mode 100644 index 0000000..41f8a65 --- /dev/null +++ b/http/http.go @@ -0,0 +1,285 @@ +package http + +import ( + "encoding/json" + "fmt" + "html/template" + + "github.com/mpolden/ipd/useragent" + "github.com/sirupsen/logrus" + + "math/big" + "net" + "net/http" + "path/filepath" + "strconv" + "strings" + + "github.com/gorilla/mux" +) + +const ( + jsonMediaType = "application/json" + textMediaType = "text/plain" +) + +type Server struct { + Template string + IPHeader string + oracle Oracle + log *logrus.Logger +} + +type Response struct { + IP net.IP `json:"ip"` + IPDecimal *big.Int `json:"ip_decimal"` + Country string `json:"country,omitempty"` + CountryISO string `json:"country_iso,omitempty"` + City string `json:"city,omitempty"` + Hostname string `json:"hostname,omitempty"` +} + +type PortResponse struct { + IP net.IP `json:"ip"` + Port uint64 `json:"port"` + Reachable bool `json:"reachable"` +} + +func New(oracle Oracle, logger *logrus.Logger) *Server { + return &Server{oracle: oracle, log: logger} +} + +func ipToDecimal(ip net.IP) *big.Int { + i := big.NewInt(0) + if to4 := ip.To4(); to4 != nil { + i.SetBytes(to4) + } else { + i.SetBytes(ip) + } + return i +} + +func ipFromRequest(header string, r *http.Request) (net.IP, error) { + remoteIP := r.Header.Get(header) + if remoteIP == "" { + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + return nil, err + } + remoteIP = host + } + ip := net.ParseIP(remoteIP) + if ip == nil { + return nil, fmt.Errorf("could not parse IP: %s", remoteIP) + } + return ip, nil +} + +func (s *Server) newResponse(r *http.Request) (Response, error) { + ip, err := ipFromRequest(s.IPHeader, r) + if err != nil { + return Response{}, err + } + ipDecimal := ipToDecimal(ip) + country, err := s.oracle.LookupCountry(ip) + if err != nil { + s.log.Debug(err) + } + countryISO, err := s.oracle.LookupCountryISO(ip) + if err != nil { + s.log.Debug(err) + } + city, err := s.oracle.LookupCity(ip) + if err != nil { + s.log.Debug(err) + } + hostnames, err := s.oracle.LookupAddr(ip) + if err != nil { + s.log.Debug(err) + } + return Response{ + IP: ip, + IPDecimal: ipDecimal, + Country: country, + CountryISO: countryISO, + City: city, + Hostname: strings.Join(hostnames, " "), + }, nil +} + +func (s *Server) newPortResponse(r *http.Request) (PortResponse, error) { + vars := mux.Vars(r) + port, err := strconv.ParseUint(vars["port"], 10, 16) + if err != nil { + return PortResponse{Port: port}, err + } + if port < 1 || port > 65355 { + return PortResponse{Port: port}, fmt.Errorf("invalid port: %d", port) + } + ip, err := ipFromRequest(s.IPHeader, r) + if err != nil { + return PortResponse{Port: port}, err + } + err = s.oracle.LookupPort(ip, port) + return PortResponse{ + IP: ip, + Port: port, + Reachable: err == nil, + }, nil +} + +func (s *Server) CLIHandler(w http.ResponseWriter, r *http.Request) *appError { + ip, err := ipFromRequest(s.IPHeader, r) + if err != nil { + return internalServerError(err) + } + fmt.Fprintln(w, ip.String()) + return nil +} + +func (s *Server) CLICountryHandler(w http.ResponseWriter, r *http.Request) *appError { + response, err := s.newResponse(r) + if err != nil { + return internalServerError(err) + } + fmt.Fprintln(w, response.Country) + return nil +} + +func (s *Server) CLICountryISOHandler(w http.ResponseWriter, r *http.Request) *appError { + response, err := s.newResponse(r) + if err != nil { + return internalServerError(err) + } + fmt.Fprintln(w, response.CountryISO) + return nil +} + +func (s *Server) CLICityHandler(w http.ResponseWriter, r *http.Request) *appError { + response, err := s.newResponse(r) + if err != nil { + return internalServerError(err) + } + fmt.Fprintln(w, response.City) + return nil +} + +func (s *Server) JSONHandler(w http.ResponseWriter, r *http.Request) *appError { + response, err := s.newResponse(r) + if err != nil { + return internalServerError(err).AsJSON() + } + b, err := json.Marshal(response) + if err != nil { + return internalServerError(err).AsJSON() + } + w.Header().Set("Content-Type", jsonMediaType) + w.Write(b) + return nil +} + +func (s *Server) PortHandler(w http.ResponseWriter, r *http.Request) *appError { + response, err := s.newPortResponse(r) + if err != nil { + return badRequest(err).WithMessage(fmt.Sprintf("Invalid port: %d", response.Port)).AsJSON() + } + b, err := json.Marshal(response) + if err != nil { + return internalServerError(err).AsJSON() + } + w.Header().Set("Content-Type", jsonMediaType) + w.Write(b) + return nil +} + +func (s *Server) DefaultHandler(w http.ResponseWriter, r *http.Request) *appError { + response, err := s.newResponse(r) + if err != nil { + return internalServerError(err) + } + t, err := template.New(filepath.Base(s.Template)).ParseFiles(s.Template) + if err != nil { + return internalServerError(err) + } + var data = struct { + Host string + Response + Oracle + }{r.Host, response, s.oracle} + if err := t.Execute(w, &data); err != nil { + return internalServerError(err) + } + return nil +} + +func (s *Server) NotFoundHandler(w http.ResponseWriter, r *http.Request) *appError { + err := notFound(nil).WithMessage("404 page not found") + if r.Header.Get("accept") == jsonMediaType { + err = err.AsJSON() + } + return err +} + +func cliMatcher(r *http.Request, rm *mux.RouteMatch) bool { + ua := useragent.Parse(r.UserAgent()) + switch ua.Product { + case "curl", "HTTPie", "Wget", "fetch libfetch", "Go", "Go-http-client", "ddclient": + return true + } + return false +} + +type appHandler func(http.ResponseWriter, *http.Request) *appError + +func (fn appHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if e := fn(w, r); e != nil { // e is *appError + // When Content-Type for error is JSON, we need to marshal the response into JSON + if e.IsJSON() { + var data = struct { + Error string `json:"error"` + }{e.Message} + b, err := json.Marshal(data) + if err != nil { + panic(err) + } + e.Message = string(b) + } + // Set Content-Type of response if set in error + if e.ContentType != "" { + w.Header().Set("Content-Type", e.ContentType) + } + w.WriteHeader(e.Code) + fmt.Fprint(w, e.Message) + } +} + +func (s *Server) Handler() http.Handler { + r := mux.NewRouter() + + // JSON + r.Handle("/", appHandler(s.JSONHandler)).Methods("GET").Headers("Accept", jsonMediaType) + r.Handle("/json", appHandler(s.JSONHandler)).Methods("GET") + + // CLI + r.Handle("/", appHandler(s.CLIHandler)).Methods("GET").MatcherFunc(cliMatcher) + r.Handle("/", appHandler(s.CLIHandler)).Methods("GET").Headers("Accept", textMediaType) + r.Handle("/ip", appHandler(s.CLIHandler)).Methods("GET") + r.Handle("/country", appHandler(s.CLICountryHandler)).Methods("GET") + r.Handle("/country-iso", appHandler(s.CLICountryISOHandler)).Methods("GET") + r.Handle("/city", appHandler(s.CLICityHandler)).Methods("GET") + + // Browser + r.Handle("/", appHandler(s.DefaultHandler)).Methods("GET") + + // Port testing + r.Handle("/port/{port:[0-9]+}", appHandler(s.PortHandler)).Methods("GET") + + // Not found handler which returns JSON when appropriate + r.NotFoundHandler = appHandler(s.NotFoundHandler) + + return r +} + +func (s *Server) ListenAndServe(addr string) error { + return http.ListenAndServe(addr, s.Handler()) +} diff --git a/http/http_test.go b/http/http_test.go new file mode 100644 index 0000000..f4ce1a3 --- /dev/null +++ b/http/http_test.go @@ -0,0 +1,186 @@ +package http + +import ( + "io/ioutil" + "log" + "math/big" + "net" + "net/http" + "net/http/httptest" + "testing" +) + +type mockOracle struct{} + +func (r *mockOracle) LookupAddr(net.IP) ([]string, error) { return []string{"localhost"}, nil } +func (r *mockOracle) LookupCountry(net.IP) (string, error) { return "Elbonia", nil } +func (r *mockOracle) LookupCountryISO(net.IP) (string, error) { return "EB", nil } +func (r *mockOracle) LookupCity(net.IP) (string, error) { return "Bornyasherk", nil } +func (r *mockOracle) LookupPort(net.IP, uint64) error { return nil } +func (r *mockOracle) IsLookupAddrEnabled() bool { return true } +func (r *mockOracle) IsLookupCountryEnabled() bool { return true } +func (r *mockOracle) IsLookupCityEnabled() bool { return true } +func (r *mockOracle) IsLookupPortEnabled() bool { return true } + +func newTestAPI() *Server { + return &Server{oracle: &mockOracle{}} +} + +func httpGet(url string, acceptMediaType string, userAgent string) (string, int, error) { + r, err := http.NewRequest("GET", url, nil) + if err != nil { + return "", 0, err + } + if acceptMediaType != "" { + r.Header.Set("Accept", acceptMediaType) + } + r.Header.Set("User-Agent", userAgent) + res, err := http.DefaultClient.Do(r) + if err != nil { + return "", 0, err + } + defer res.Body.Close() + data, err := ioutil.ReadAll(res.Body) + if err != nil { + return "", 0, err + } + return string(data), res.StatusCode, nil +} + +func TestCLIHandlers(t *testing.T) { + log.SetOutput(ioutil.Discard) + s := httptest.NewServer(newTestAPI().Handler()) + + var tests = []struct { + url string + out string + status int + userAgent string + acceptMediaType string + }{ + {s.URL, "127.0.0.1\n", 200, "curl/7.43.0", ""}, + {s.URL, "127.0.0.1\n", 200, "foo/bar", textMediaType}, + {s.URL + "/ip", "127.0.0.1\n", 200, "", ""}, + {s.URL + "/country", "Elbonia\n", 200, "", ""}, + {s.URL + "/country-iso", "EB\n", 200, "", ""}, + {s.URL + "/city", "Bornyasherk\n", 200, "", ""}, + {s.URL + "/foo", "404 page not found", 404, "", ""}, + } + + for _, tt := range tests { + out, status, err := httpGet(tt.url, tt.acceptMediaType, tt.userAgent) + if err != nil { + t.Fatal(err) + } + if status != tt.status { + t.Errorf("Expected %d, got %d", tt.status, status) + } + if out != tt.out { + t.Errorf("Expected %q, got %q", tt.out, out) + } + } +} + +func TestJSONHandlers(t *testing.T) { + log.SetOutput(ioutil.Discard) + s := httptest.NewServer(newTestAPI().Handler()) + + var tests = []struct { + url string + out string + status int + }{ + {s.URL, `{"ip":"127.0.0.1","ip_decimal":2130706433,"country":"Elbonia","country_iso":"EB","city":"Bornyasherk","hostname":"localhost"}`, 200}, + {s.URL + "/port/foo", `{"error":"404 page not found"}`, 404}, + {s.URL + "/port/0", `{"error":"Invalid port: 0"}`, 400}, + {s.URL + "/port/65356", `{"error":"Invalid port: 65356"}`, 400}, + {s.URL + "/port/31337", `{"ip":"127.0.0.1","port":31337,"reachable":true}`, 200}, + {s.URL + "/foo", `{"error":"404 page not found"}`, 404}, + } + + for _, tt := range tests { + out, status, err := httpGet(tt.url, jsonMediaType, "curl/7.2.6.0") + if err != nil { + t.Fatal(err) + } + if status != tt.status { + t.Errorf("Expected %d, got %d", tt.status, status) + } + if out != tt.out { + t.Errorf("Expected %q, got %q", tt.out, out) + } + } +} + +func TestIPFromRequest(t *testing.T) { + var tests = []struct { + remoteAddr string + headerKey string + headerValue string + trustedHeader string + out string + }{ + {"127.0.0.1:9999", "", "", "", "127.0.0.1"}, // No header given + {"127.0.0.1:9999", "X-Real-IP", "1.3.3.7", "", "127.0.0.1"}, // Trusted header is empty + {"127.0.0.1:9999", "X-Real-IP", "1.3.3.7", "X-Foo-Bar", "127.0.0.1"}, // Trusted header does not match + {"127.0.0.1:9999", "X-Real-IP", "1.3.3.7", "X-Real-IP", "1.3.3.7"}, // Trusted header matches + } + for _, tt := range tests { + r := &http.Request{ + RemoteAddr: tt.remoteAddr, + Header: http.Header{}, + } + r.Header.Add(tt.headerKey, tt.headerValue) + ip, err := ipFromRequest(tt.trustedHeader, r) + if err != nil { + t.Fatal(err) + } + out := net.ParseIP(tt.out) + if !ip.Equal(out) { + t.Errorf("Expected %s, got %s", out, ip) + } + } +} + +func TestCLIMatcher(t *testing.T) { + browserUserAgent := "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_4) " + + "AppleWebKit/537.36 (KHTML, like Gecko) Chrome/30.0.1599.28 " + + "Safari/537.36" + var tests = []struct { + in string + out bool + }{ + {"curl/7.26.0", true}, + {"Wget/1.13.4 (linux-gnu)", true}, + {"Wget", true}, + {"fetch libfetch/2.0", true}, + {"HTTPie/0.9.3", true}, + {"Go 1.1 package http", true}, + {"Go-http-client/1.1", true}, + {"Go-http-client/2.0", true}, + {"ddclient/3.8.3", true}, + {browserUserAgent, false}, + } + for _, tt := range tests { + r := &http.Request{Header: http.Header{"User-Agent": []string{tt.in}}} + if got := cliMatcher(r, nil); got != tt.out { + t.Errorf("Expected %t, got %t for %q", tt.out, got, tt.in) + } + } +} + +func TestIPToDecimal(t *testing.T) { + var tests = []struct { + in string + out *big.Int + }{ + {"127.0.0.1", big.NewInt(2130706433)}, + {"::1", big.NewInt(1)}, + } + for _, tt := range tests { + i := ipToDecimal(net.ParseIP(tt.in)) + if i.Cmp(tt.out) != 0 { + t.Errorf("Expected %d, got %d for IP %s", tt.out, i, tt.in) + } + } +} diff --git a/http/oracle.go b/http/oracle.go new file mode 100644 index 0000000..1654758 --- /dev/null +++ b/http/oracle.go @@ -0,0 +1,163 @@ +package http + +import ( + "fmt" + "net" + "strings" + "time" + + "github.com/oschwald/geoip2-golang" +) + +type Oracle interface { + LookupAddr(net.IP) ([]string, error) + LookupCountry(net.IP) (string, error) + LookupCountryISO(net.IP) (string, error) + LookupCity(net.IP) (string, error) + LookupPort(net.IP, uint64) error + IsLookupAddrEnabled() bool + IsLookupCountryEnabled() bool + IsLookupCityEnabled() bool + IsLookupPortEnabled() bool +} + +type DefaultOracle struct { + lookupAddr func(net.IP) ([]string, error) + lookupCountry func(net.IP) (string, error) + lookupCountryISO func(net.IP) (string, error) + lookupCity func(net.IP) (string, error) + lookupPort func(net.IP, uint64) error + lookupAddrEnabled bool + lookupCountryEnabled bool + lookupCityEnabled bool + lookupPortEnabled bool +} + +func NewOracle() *DefaultOracle { + return &DefaultOracle{ + lookupAddr: func(net.IP) ([]string, error) { return nil, nil }, + lookupCountry: func(net.IP) (string, error) { return "", nil }, + lookupCountryISO: func(net.IP) (string, error) { return "", nil }, + lookupCity: func(net.IP) (string, error) { return "", nil }, + lookupPort: func(net.IP, uint64) error { return nil }, + } +} + +func (r *DefaultOracle) LookupAddr(ip net.IP) ([]string, error) { + return r.lookupAddr(ip) +} + +func (r *DefaultOracle) LookupCountry(ip net.IP) (string, error) { + return r.lookupCountry(ip) +} + +func (r *DefaultOracle) LookupCountryISO(ip net.IP) (string, error) { + return r.lookupCountryISO(ip) +} + +func (r *DefaultOracle) LookupCity(ip net.IP) (string, error) { + return r.lookupCity(ip) +} + +func (r *DefaultOracle) LookupPort(ip net.IP, port uint64) error { + return r.lookupPort(ip, port) +} + +func (r *DefaultOracle) EnableLookupAddr() { + r.lookupAddr = lookupAddr + r.lookupAddrEnabled = true +} + +func (r *DefaultOracle) EnableLookupCountry(filepath string) error { + db, err := geoip2.Open(filepath) + if err != nil { + return err + } + r.lookupCountry = func(ip net.IP) (string, error) { + return lookupCountry(db, ip) + } + r.lookupCountryISO = func(ip net.IP) (string, error) { + return lookupCountryISO(db, ip) + } + r.lookupCountryEnabled = true + return nil +} + +func (r *DefaultOracle) EnableLookupCity(filepath string) error { + db, err := geoip2.Open(filepath) + if err != nil { + return err + } + r.lookupCity = func(ip net.IP) (string, error) { + return lookupCity(db, ip) + } + r.lookupCityEnabled = true + return nil +} + +func (r *DefaultOracle) EnableLookupPort() { + r.lookupPort = lookupPort + r.lookupPortEnabled = true +} + +func (r *DefaultOracle) IsLookupAddrEnabled() bool { return r.lookupAddrEnabled } +func (r *DefaultOracle) IsLookupCountryEnabled() bool { return r.lookupCountryEnabled } +func (r *DefaultOracle) IsLookupCityEnabled() bool { return r.lookupCityEnabled } +func (r *DefaultOracle) IsLookupPortEnabled() bool { return r.lookupPortEnabled } + +func lookupAddr(ip net.IP) ([]string, error) { + names, err := net.LookupAddr(ip.String()) + for i, _ := range names { + names[i] = strings.TrimRight(names[i], ".") // Always return unrooted name + } + return names, err +} + +func lookupPort(ip net.IP, port uint64) error { + address := fmt.Sprintf("[%s]:%d", ip, port) + conn, err := net.DialTimeout("tcp", address, 2*time.Second) + if err != nil { + return err + } + defer conn.Close() + return nil +} + +func lookupCountry(db *geoip2.Reader, ip net.IP) (string, error) { + record, err := db.Country(ip) + if err != nil { + return "", err + } + if country, exists := record.Country.Names["en"]; exists { + return country, nil + } + if country, exists := record.RegisteredCountry.Names["en"]; exists { + return country, nil + } + return "Unknown", fmt.Errorf("could not determine country for IP: %s", ip) +} + +func lookupCountryISO(db *geoip2.Reader, ip net.IP) (string, error) { + record, err := db.City(ip) + if err != nil { + return "", err + } + if record.Country.IsoCode != "" { + return record.Country.IsoCode, nil + } + if record.RegisteredCountry.IsoCode != "" { + return record.RegisteredCountry.IsoCode, nil + } + return "Unknown", fmt.Errorf("could not determine country ISO Code for IP: %s", ip) +} + +func lookupCity(db *geoip2.Reader, ip net.IP) (string, error) { + record, err := db.City(ip) + if err != nil { + return "", err + } + if city, exists := record.City.Names["en"]; exists { + return city, nil + } + return "Unknown", fmt.Errorf("could not determine city for IP: %s", ip) +} |