Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/mpolden/echoip.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/http
diff options
context:
space:
mode:
authorMartin Polden <mpolden@mpolden.no>2018-07-25 22:05:08 +0300
committerMartin Polden <mpolden@mpolden.no>2018-07-25 22:05:08 +0300
commit91f0c17c94ec935bbc480c9eaabc91b5f131aa33 (patch)
treead1b2b2a75dd366321ae13fd0bc383b55b3d8e70 /http
parente282ac2729811c8e931154027000a488924346c7 (diff)
Add support for multiple trusted headers
Diffstat (limited to 'http')
-rw-r--r--http/http.go18
-rw-r--r--http/http_test.go21
2 files changed, 23 insertions, 16 deletions
diff --git a/http/http.go b/http/http.go
index c04f514..cac1010 100644
--- a/http/http.go
+++ b/http/http.go
@@ -22,7 +22,7 @@ const (
type Server struct {
Template string
- IPHeader string
+ IPHeaders []string
LookupAddr func(net.IP) (string, error)
LookupPort func(net.IP, uint64) error
db database.Client
@@ -47,8 +47,14 @@ func New(db database.Client) *Server {
return &Server{db: db}
}
-func ipFromRequest(header string, r *http.Request) (net.IP, error) {
- remoteIP := r.Header.Get(header)
+func ipFromRequest(headers []string, r *http.Request) (net.IP, error) {
+ remoteIP := ""
+ for _, header := range headers {
+ remoteIP = r.Header.Get(header)
+ if remoteIP != "" {
+ break
+ }
+ }
if remoteIP == "" {
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
@@ -64,7 +70,7 @@ func ipFromRequest(header string, r *http.Request) (net.IP, error) {
}
func (s *Server) newResponse(r *http.Request) (Response, error) {
- ip, err := ipFromRequest(s.IPHeader, r)
+ ip, err := ipFromRequest(s.IPHeaders, r)
if err != nil {
return Response{}, err
}
@@ -91,7 +97,7 @@ func (s *Server) newPortResponse(r *http.Request) (PortResponse, error) {
if err != nil || port < 1 || port > 65355 {
return PortResponse{Port: port}, fmt.Errorf("invalid port: %d", port)
}
- ip, err := ipFromRequest(s.IPHeader, r)
+ ip, err := ipFromRequest(s.IPHeaders, r)
if err != nil {
return PortResponse{Port: port}, err
}
@@ -104,7 +110,7 @@ func (s *Server) newPortResponse(r *http.Request) (PortResponse, error) {
}
func (s *Server) CLIHandler(w http.ResponseWriter, r *http.Request) *appError {
- ip, err := ipFromRequest(s.IPHeader, r)
+ ip, err := ipFromRequest(s.IPHeaders, r)
if err != nil {
return internalServerError(err)
}
diff --git a/http/http_test.go b/http/http_test.go
index 66a3027..206d9ff 100644
--- a/http/http_test.go
+++ b/http/http_test.go
@@ -149,16 +149,17 @@ func TestJSONHandlers(t *testing.T) {
func TestIPFromRequest(t *testing.T) {
var tests = []struct {
- remoteAddr string
- headerKey string
- headerValue string
- trustedHeader string
- out string
+ remoteAddr string
+ headerKey string
+ headerValue string
+ trustedHeaders []string
+ out string
}{
- {"127.0.0.1:9999", "", "", "", "127.0.0.1"}, // No header given
- {"127.0.0.1:9999", "X-Real-IP", "1.3.3.7", "", "127.0.0.1"}, // Trusted header is empty
- {"127.0.0.1:9999", "X-Real-IP", "1.3.3.7", "X-Foo-Bar", "127.0.0.1"}, // Trusted header does not match
- {"127.0.0.1:9999", "X-Real-IP", "1.3.3.7", "X-Real-IP", "1.3.3.7"}, // Trusted header matches
+ {"127.0.0.1:9999", "", "", nil, "127.0.0.1"}, // No header given
+ {"127.0.0.1:9999", "X-Real-IP", "1.3.3.7", nil, "127.0.0.1"}, // Trusted header is empty
+ {"127.0.0.1:9999", "X-Real-IP", "1.3.3.7", []string{"X-Foo-Bar"}, "127.0.0.1"}, // Trusted header does not match
+ {"127.0.0.1:9999", "X-Real-IP", "1.3.3.7", []string{"X-Real-IP", "X-Forwarded-For"}, "1.3.3.7"}, // Trusted header matches
+ {"127.0.0.1:9999", "X-Forwarded-For", "1.3.3.7", []string{"X-Real-IP", "X-Forwarded-For"}, "1.3.3.7"}, // Second trusted header matches
}
for _, tt := range tests {
r := &http.Request{
@@ -166,7 +167,7 @@ func TestIPFromRequest(t *testing.T) {
Header: http.Header{},
}
r.Header.Add(tt.headerKey, tt.headerValue)
- ip, err := ipFromRequest(tt.trustedHeader, r)
+ ip, err := ipFromRequest(tt.trustedHeaders, r)
if err != nil {
t.Fatal(err)
}