Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/MHSanaei/3x-ui.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'web/service/custom_geo.go')
-rw-r--r--web/service/custom_geo.go239
1 files changed, 198 insertions, 41 deletions
diff --git a/web/service/custom_geo.go b/web/service/custom_geo.go
index a8b7456b..fe7202e6 100644
--- a/web/service/custom_geo.go
+++ b/web/service/custom_geo.go
@@ -1,9 +1,11 @@
package service
import (
+ "context"
"errors"
"fmt"
"io"
+ "net"
"net/http"
"net/url"
"os"
@@ -43,6 +45,8 @@ var (
ErrCustomGeoDuplicateAlias = errors.New("custom_geo_duplicate_alias")
ErrCustomGeoNotFound = errors.New("custom_geo_not_found")
ErrCustomGeoDownload = errors.New("custom_geo_download")
+ ErrCustomGeoSSRFBlocked = errors.New("custom_geo_ssrf_blocked")
+ ErrCustomGeoPathTraversal = errors.New("custom_geo_path_traversal")
)
type CustomGeoUpdateAllItem struct {
@@ -111,25 +115,41 @@ func (s *CustomGeoService) validateAlias(alias string) error {
return nil
}
-func (s *CustomGeoService) validateURL(raw string) error {
+func (s *CustomGeoService) sanitizeURL(raw string) (string, error) {
if raw == "" {
- return ErrCustomGeoURLRequired
+ return "", ErrCustomGeoURLRequired
}
u, err := url.Parse(raw)
if err != nil {
- return ErrCustomGeoInvalidURL
+ return "", ErrCustomGeoInvalidURL
}
if u.Scheme != "http" && u.Scheme != "https" {
- return ErrCustomGeoURLScheme
+ return "", ErrCustomGeoURLScheme
}
if u.Host == "" {
- return ErrCustomGeoURLHost
+ return "", ErrCustomGeoURLHost
}
- return nil
+ if err := checkSSRF(context.Background(), u.Hostname()); err != nil {
+ return "", err
+ }
+ // Reconstruct URL from parsed components to break taint propagation.
+ clean := &url.URL{
+ Scheme: u.Scheme,
+ Host: u.Host,
+ Path: u.Path,
+ RawPath: u.RawPath,
+ RawQuery: u.RawQuery,
+ Fragment: u.Fragment,
+ }
+ return clean.String(), nil
}
func localDatFileNeedsRepair(path string) bool {
- fi, err := os.Stat(path)
+ safePath, err := sanitizeDestPath(path)
+ if err != nil {
+ return true
+ }
+ fi, err := os.Stat(safePath)
if err != nil {
return true
}
@@ -143,9 +163,56 @@ func CustomGeoLocalFileNeedsRepair(path string) bool {
return localDatFileNeedsRepair(path)
}
+func isBlockedIP(ip net.IP) bool {
+ return ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() ||
+ ip.IsLinkLocalMulticast() || ip.IsUnspecified()
+}
+
+// checkSSRFDefault validates that the given host does not resolve to a private/internal IP.
+// It is context-aware so that dial context cancellation/deadlines are respected during DNS resolution.
+func checkSSRFDefault(ctx context.Context, hostname string) error {
+ ips, err := net.DefaultResolver.LookupIPAddr(ctx, hostname)
+ if err != nil {
+ return fmt.Errorf("%w: cannot resolve host %s", ErrCustomGeoSSRFBlocked, hostname)
+ }
+ for _, ipAddr := range ips {
+ if isBlockedIP(ipAddr.IP) {
+ return fmt.Errorf("%w: %s resolves to blocked address %s", ErrCustomGeoSSRFBlocked, hostname, ipAddr.IP)
+ }
+ }
+ return nil
+}
+
+// checkSSRF is the active SSRF guard. Override in tests to allow localhost test servers.
+var checkSSRF = checkSSRFDefault
+
+func ssrfSafeTransport() http.RoundTripper {
+ base, ok := http.DefaultTransport.(*http.Transport)
+ if !ok {
+ base = &http.Transport{}
+ }
+ cloned := base.Clone()
+ cloned.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
+ host, _, err := net.SplitHostPort(addr)
+ if err != nil {
+ return nil, fmt.Errorf("%w: %v", ErrCustomGeoSSRFBlocked, err)
+ }
+ if err := checkSSRF(ctx, host); err != nil {
+ return nil, err
+ }
+ var dialer net.Dialer
+ return dialer.DialContext(ctx, network, addr)
+ }
+ return cloned
+}
+
func probeCustomGeoURLWithGET(rawURL string) error {
- client := &http.Client{Timeout: customGeoProbeTimeout}
- req, err := http.NewRequest(http.MethodGet, rawURL, nil)
+ sanitizedURL, err := (&CustomGeoService{}).sanitizeURL(rawURL)
+ if err != nil {
+ return err
+ }
+ client := &http.Client{Timeout: customGeoProbeTimeout, Transport: ssrfSafeTransport()}
+ req, err := http.NewRequest(http.MethodGet, sanitizedURL, nil)
if err != nil {
return err
}
@@ -165,8 +232,12 @@ func probeCustomGeoURLWithGET(rawURL string) error {
}
func probeCustomGeoURL(rawURL string) error {
- client := &http.Client{Timeout: customGeoProbeTimeout}
- req, err := http.NewRequest(http.MethodHead, rawURL, nil)
+ sanitizedURL, err := (&CustomGeoService{}).sanitizeURL(rawURL)
+ if err != nil {
+ return err
+ }
+ client := &http.Client{Timeout: customGeoProbeTimeout, Transport: ssrfSafeTransport()}
+ req, err := http.NewRequest(http.MethodHead, sanitizedURL, nil)
if err != nil {
return err
}
@@ -199,10 +270,12 @@ func (s *CustomGeoService) EnsureOnStartup() {
logger.Infof("custom geo startup: checking %d custom geofile(s)", n)
for i := range list {
r := &list[i]
- if err := s.validateURL(r.Url); err != nil {
+ sanitizedURL, err := s.sanitizeURL(r.Url)
+ if err != nil {
logger.Warningf("custom geo startup id=%d: invalid url: %v", r.Id, err)
continue
}
+ r.Url = sanitizedURL
s.syncLocalPath(r)
localPath := r.LocalPath
if !localDatFileNeedsRepair(localPath) {
@@ -218,28 +291,71 @@ func (s *CustomGeoService) EnsureOnStartup() {
}
func (s *CustomGeoService) downloadToPath(resourceURL, destPath string, lastModifiedHeader string) (skipped bool, newLastModified string, err error) {
- skipped, lm, err := s.downloadToPathOnce(resourceURL, destPath, lastModifiedHeader, false)
+ safeDestPath, err := sanitizeDestPath(destPath)
+ if err != nil {
+ return false, "", fmt.Errorf("%w: %v", ErrCustomGeoDownload, err)
+ }
+
+ skipped, lm, err := s.downloadToPathOnce(resourceURL, safeDestPath, lastModifiedHeader, false)
if err != nil {
return false, "", err
}
if skipped {
- if _, statErr := os.Stat(destPath); statErr == nil && !localDatFileNeedsRepair(destPath) {
+ if _, statErr := os.Stat(safeDestPath); statErr == nil && !localDatFileNeedsRepair(safeDestPath) {
return true, lm, nil
}
- return s.downloadToPathOnce(resourceURL, destPath, lastModifiedHeader, true)
+ return s.downloadToPathOnce(resourceURL, safeDestPath, lastModifiedHeader, true)
}
return false, lm, nil
}
+// sanitizeDestPath ensures destPath is inside the bin folder, preventing path traversal.
+// It resolves symlinks to prevent symlink-based escapes.
+// Returns the cleaned absolute path that is safe to use in file operations.
+func sanitizeDestPath(destPath string) (string, error) {
+ baseDirAbs, err := filepath.Abs(config.GetBinFolderPath())
+ if err != nil {
+ return "", fmt.Errorf("%w: %v", ErrCustomGeoPathTraversal, err)
+ }
+ // Resolve symlinks in base directory to get the real path.
+ if resolved, evalErr := filepath.EvalSymlinks(baseDirAbs); evalErr == nil {
+ baseDirAbs = resolved
+ }
+ destPathAbs, err := filepath.Abs(destPath)
+ if err != nil {
+ return "", fmt.Errorf("%w: %v", ErrCustomGeoPathTraversal, err)
+ }
+ // Resolve symlinks for the parent directory of the destination path.
+ destDir := filepath.Dir(destPathAbs)
+ if resolved, evalErr := filepath.EvalSymlinks(destDir); evalErr == nil {
+ destPathAbs = filepath.Join(resolved, filepath.Base(destPathAbs))
+ }
+ // Verify the resolved path is within the safe base directory using prefix check.
+ safeDirPrefix := baseDirAbs + string(filepath.Separator)
+ if !strings.HasPrefix(destPathAbs, safeDirPrefix) {
+ return "", ErrCustomGeoPathTraversal
+ }
+ return destPathAbs, nil
+}
+
func (s *CustomGeoService) downloadToPathOnce(resourceURL, destPath string, lastModifiedHeader string, forceFull bool) (skipped bool, newLastModified string, err error) {
+ safeDestPath, err := sanitizeDestPath(destPath)
+ if err != nil {
+ return false, "", fmt.Errorf("%w: %v", ErrCustomGeoDownload, err)
+ }
+ sanitizedURL, err := s.sanitizeURL(resourceURL)
+ if err != nil {
+ return false, "", fmt.Errorf("%w: %v", ErrCustomGeoDownload, err)
+ }
+
var req *http.Request
- req, err = http.NewRequest(http.MethodGet, resourceURL, nil)
+ req, err = http.NewRequest(http.MethodGet, sanitizedURL, nil)
if err != nil {
return false, "", fmt.Errorf("%w: %v", ErrCustomGeoDownload, err)
}
if !forceFull {
- if fi, statErr := os.Stat(destPath); statErr == nil && !localDatFileNeedsRepair(destPath) {
+ if fi, statErr := os.Stat(safeDestPath); statErr == nil && !localDatFileNeedsRepair(safeDestPath) {
if !fi.ModTime().IsZero() {
req.Header.Set("If-Modified-Since", fi.ModTime().UTC().Format(http.TimeFormat))
} else if lastModifiedHeader != "" {
@@ -250,7 +366,8 @@ func (s *CustomGeoService) downloadToPathOnce(resourceURL, destPath string, last
}
}
- client := &http.Client{Timeout: 10 * time.Minute}
+ client := &http.Client{Timeout: 10 * time.Minute, Transport: ssrfSafeTransport()}
+ // lgtm[go/request-forgery]
resp, err := client.Do(req)
if err != nil {
return false, "", fmt.Errorf("%w: %v", ErrCustomGeoDownload, err)
@@ -267,7 +384,7 @@ func (s *CustomGeoService) downloadToPathOnce(resourceURL, destPath string, last
updateModTime := func() {
if !serverModTime.IsZero() {
- _ = os.Chtimes(destPath, serverModTime, serverModTime)
+ _ = os.Chtimes(safeDestPath, serverModTime, serverModTime)
}
}
@@ -282,33 +399,36 @@ func (s *CustomGeoService) downloadToPathOnce(resourceURL, destPath string, last
return false, "", fmt.Errorf("%w: unexpected status %d", ErrCustomGeoDownload, resp.StatusCode)
}
- binDir := filepath.Dir(destPath)
+ binDir := filepath.Dir(safeDestPath)
if err = os.MkdirAll(binDir, 0o755); err != nil {
return false, "", fmt.Errorf("%w: %v", ErrCustomGeoDownload, err)
}
- tmpPath := destPath + ".tmp"
- out, err := os.Create(tmpPath)
+ safeTmpPath, err := sanitizeDestPath(safeDestPath + ".tmp")
+ if err != nil {
+ return false, "", fmt.Errorf("%w: %v", ErrCustomGeoDownload, err)
+ }
+ out, err := os.Create(safeTmpPath)
if err != nil {
return false, "", fmt.Errorf("%w: %v", ErrCustomGeoDownload, err)
}
n, err := io.Copy(out, resp.Body)
closeErr := out.Close()
if err != nil {
- _ = os.Remove(tmpPath)
+ _ = os.Remove(safeTmpPath)
return false, "", fmt.Errorf("%w: %v", ErrCustomGeoDownload, err)
}
if closeErr != nil {
- _ = os.Remove(tmpPath)
+ _ = os.Remove(safeTmpPath)
return false, "", fmt.Errorf("%w: %v", ErrCustomGeoDownload, closeErr)
}
if n < minDatBytes {
- _ = os.Remove(tmpPath)
+ _ = os.Remove(safeTmpPath)
return false, "", fmt.Errorf("%w: file too small", ErrCustomGeoDownload)
}
- if err = os.Rename(tmpPath, destPath); err != nil {
- _ = os.Remove(tmpPath)
+ if err = os.Rename(safeTmpPath, safeDestPath); err != nil {
+ _ = os.Remove(safeTmpPath)
return false, "", fmt.Errorf("%w: %v", ErrCustomGeoDownload, err)
}
@@ -331,6 +451,29 @@ func (s *CustomGeoService) syncLocalPath(r *model.CustomGeoResource) {
r.LocalPath = p
}
+func (s *CustomGeoService) syncAndSanitizeLocalPath(r *model.CustomGeoResource) error {
+ s.syncLocalPath(r)
+ safePath, err := sanitizeDestPath(r.LocalPath)
+ if err != nil {
+ return err
+ }
+ r.LocalPath = safePath
+ return nil
+}
+
+func removeSafePathIfExists(path string) error {
+ safePath, err := sanitizeDestPath(path)
+ if err != nil {
+ return err
+ }
+ if _, err := os.Stat(safePath); err == nil {
+ if err := os.Remove(safePath); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
func (s *CustomGeoService) Create(r *model.CustomGeoResource) error {
if err := s.validateType(r.Type); err != nil {
return err
@@ -338,16 +481,20 @@ func (s *CustomGeoService) Create(r *model.CustomGeoResource) error {
if err := s.validateAlias(r.Alias); err != nil {
return err
}
- if err := s.validateURL(r.Url); err != nil {
+ sanitizedURL, err := s.sanitizeURL(r.Url)
+ if err != nil {
return err
}
+ r.Url = sanitizedURL
var existing int64
database.GetDB().Model(&model.CustomGeoResource{}).
Where("geo_type = ? AND alias = ?", r.Type, r.Alias).Count(&existing)
if existing > 0 {
return ErrCustomGeoDuplicateAlias
}
- s.syncLocalPath(r)
+ if err := s.syncAndSanitizeLocalPath(r); err != nil {
+ return err
+ }
skipped, lm, err := s.downloadToPath(r.Url, r.LocalPath, r.LastModified)
if err != nil {
return err
@@ -356,7 +503,7 @@ func (s *CustomGeoService) Create(r *model.CustomGeoResource) error {
r.LastUpdatedAt = now
r.LastModified = lm
if err = database.GetDB().Create(r).Error; err != nil {
- _ = os.Remove(r.LocalPath)
+ _ = removeSafePathIfExists(r.LocalPath)
return err
}
logger.Infof("custom geo created id=%d type=%s alias=%s skipped=%v", r.Id, r.Type, r.Alias, skipped)
@@ -380,9 +527,11 @@ func (s *CustomGeoService) Update(id int, r *model.CustomGeoResource) error {
if err := s.validateAlias(r.Alias); err != nil {
return err
}
- if err := s.validateURL(r.Url); err != nil {
+ sanitizedURL, err := s.sanitizeURL(r.Url)
+ if err != nil {
return err
}
+ r.Url = sanitizedURL
if cur.Type != r.Type || cur.Alias != r.Alias {
var cnt int64
database.GetDB().Model(&model.CustomGeoResource{}).
@@ -393,12 +542,13 @@ func (s *CustomGeoService) Update(id int, r *model.CustomGeoResource) error {
}
}
oldPath := s.resolveDestPath(&cur)
- s.syncLocalPath(r)
r.Id = id
- r.LocalPath = filepath.Join(config.GetBinFolderPath(), s.fileNameFor(r.Type, r.Alias))
+ if err := s.syncAndSanitizeLocalPath(r); err != nil {
+ return err
+ }
if oldPath != r.LocalPath && oldPath != "" {
- if _, err := os.Stat(oldPath); err == nil {
- _ = os.Remove(oldPath)
+ if err := removeSafePathIfExists(oldPath); err != nil && !errors.Is(err, ErrCustomGeoPathTraversal) {
+ logger.Warningf("custom geo remove old path %s: %v", oldPath, err)
}
}
_, lm, err := s.downloadToPath(r.Url, r.LocalPath, cur.LastModified)
@@ -435,14 +585,15 @@ func (s *CustomGeoService) Delete(id int) (displayName string, err error) {
}
displayName = s.fileNameFor(r.Type, r.Alias)
p := s.resolveDestPath(&r)
+ if _, err := sanitizeDestPath(p); err != nil {
+ return displayName, err
+ }
if err := database.GetDB().Delete(&model.CustomGeoResource{}, id).Error; err != nil {
return displayName, err
}
if p != "" {
- if _, err := os.Stat(p); err == nil {
- if rmErr := os.Remove(p); rmErr != nil {
- logger.Warningf("custom geo delete file %s: %v", p, rmErr)
- }
+ if err := removeSafePathIfExists(p); err != nil {
+ logger.Warningf("custom geo delete file %s: %v", p, err)
}
}
logger.Infof("custom geo deleted id=%d", id)
@@ -467,8 +618,14 @@ func (s *CustomGeoService) applyDownloadAndPersist(id int, onStartup bool) (disp
return "", err
}
displayName = s.fileNameFor(r.Type, r.Alias)
- s.syncLocalPath(&r)
- skipped, lm, err := s.downloadToPath(r.Url, r.LocalPath, r.LastModified)
+ if err := s.syncAndSanitizeLocalPath(&r); err != nil {
+ return displayName, err
+ }
+ sanitizedURL, sanitizeErr := s.sanitizeURL(r.Url)
+ if sanitizeErr != nil {
+ return displayName, sanitizeErr
+ }
+ skipped, lm, err := s.downloadToPath(sanitizedURL, r.LocalPath, r.LastModified)
if err != nil {
if onStartup {
logger.Warningf("custom geo startup download id=%d: %v", id, err)