diff options
| author | Vladislav Tupikin <MrRefactoring@yandex.ru> | 2026-04-19 22:24:24 +0300 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2026-04-19 22:24:24 +0300 |
| commit | 7466916e0206d55826d74f37c251bb5e40182c00 (patch) | |
| tree | 2c887558c71f34b76e541c7bf7d57a420ed3c9ed /web/service | |
| parent | 96b568b8389fd5a3ce228d5fb82ec9742d145b15 (diff) | |
Add custom geosite/geoip URL sources (#3980)
* feat: add custom geosite/geoip URL sources
Register DB model, panel API, index/xray UI, and i18n.
* fix
Diffstat (limited to 'web/service')
| -rw-r--r-- | web/service/custom_geo.go | 603 | ||||
| -rw-r--r-- | web/service/custom_geo_test.go | 330 |
2 files changed, 933 insertions, 0 deletions
diff --git a/web/service/custom_geo.go b/web/service/custom_geo.go new file mode 100644 index 00000000..a8b7456b --- /dev/null +++ b/web/service/custom_geo.go @@ -0,0 +1,603 @@ +package service + +import ( + "errors" + "fmt" + "io" + "net/http" + "net/url" + "os" + "path/filepath" + "regexp" + "strings" + "time" + + "github.com/mhsanaei/3x-ui/v2/config" + "github.com/mhsanaei/3x-ui/v2/database" + "github.com/mhsanaei/3x-ui/v2/database/model" + "github.com/mhsanaei/3x-ui/v2/logger" +) + +const ( + customGeoTypeGeosite = "geosite" + customGeoTypeGeoip = "geoip" + minDatBytes = 64 + customGeoProbeTimeout = 12 * time.Second +) + +var ( + customGeoAliasPattern = regexp.MustCompile(`^[a-z0-9_-]+$`) + reservedCustomAliases = map[string]struct{}{ + "geoip": {}, "geosite": {}, + "geoip_ir": {}, "geosite_ir": {}, + "geoip_ru": {}, "geosite_ru": {}, + } + ErrCustomGeoInvalidType = errors.New("custom_geo_invalid_type") + ErrCustomGeoAliasRequired = errors.New("custom_geo_alias_required") + ErrCustomGeoAliasPattern = errors.New("custom_geo_alias_pattern") + ErrCustomGeoAliasReserved = errors.New("custom_geo_alias_reserved") + ErrCustomGeoURLRequired = errors.New("custom_geo_url_required") + ErrCustomGeoInvalidURL = errors.New("custom_geo_invalid_url") + ErrCustomGeoURLScheme = errors.New("custom_geo_url_scheme") + ErrCustomGeoURLHost = errors.New("custom_geo_url_host") + ErrCustomGeoDuplicateAlias = errors.New("custom_geo_duplicate_alias") + ErrCustomGeoNotFound = errors.New("custom_geo_not_found") + ErrCustomGeoDownload = errors.New("custom_geo_download") +) + +type CustomGeoUpdateAllItem struct { + Id int `json:"id"` + Alias string `json:"alias"` + FileName string `json:"fileName"` +} + +type CustomGeoUpdateAllFailure struct { + Id int `json:"id"` + Alias string `json:"alias"` + FileName string `json:"fileName"` + Err string `json:"error"` +} + +type CustomGeoUpdateAllResult struct { + Succeeded []CustomGeoUpdateAllItem `json:"succeeded"` + Failed []CustomGeoUpdateAllFailure `json:"failed"` +} + +type CustomGeoService struct { + serverService *ServerService + updateAllGetAll func() ([]model.CustomGeoResource, error) + updateAllApply func(id int, onStartup bool) (string, error) + updateAllRestart func() error +} + +func NewCustomGeoService() *CustomGeoService { + s := &CustomGeoService{ + serverService: &ServerService{}, + } + s.updateAllGetAll = s.GetAll + s.updateAllApply = s.applyDownloadAndPersist + s.updateAllRestart = func() error { return s.serverService.RestartXrayService() } + return s +} + +func NormalizeAliasKey(alias string) string { + return strings.ToLower(strings.ReplaceAll(alias, "-", "_")) +} + +func (s *CustomGeoService) fileNameFor(typ, alias string) string { + if typ == customGeoTypeGeoip { + return fmt.Sprintf("geoip_%s.dat", alias) + } + return fmt.Sprintf("geosite_%s.dat", alias) +} + +func (s *CustomGeoService) validateType(typ string) error { + if typ != customGeoTypeGeosite && typ != customGeoTypeGeoip { + return ErrCustomGeoInvalidType + } + return nil +} + +func (s *CustomGeoService) validateAlias(alias string) error { + if alias == "" { + return ErrCustomGeoAliasRequired + } + if !customGeoAliasPattern.MatchString(alias) { + return ErrCustomGeoAliasPattern + } + if _, ok := reservedCustomAliases[NormalizeAliasKey(alias)]; ok { + return ErrCustomGeoAliasReserved + } + return nil +} + +func (s *CustomGeoService) validateURL(raw string) error { + if raw == "" { + return ErrCustomGeoURLRequired + } + u, err := url.Parse(raw) + if err != nil { + return ErrCustomGeoInvalidURL + } + if u.Scheme != "http" && u.Scheme != "https" { + return ErrCustomGeoURLScheme + } + if u.Host == "" { + return ErrCustomGeoURLHost + } + return nil +} + +func localDatFileNeedsRepair(path string) bool { + fi, err := os.Stat(path) + if err != nil { + return true + } + if fi.IsDir() { + return true + } + return fi.Size() < int64(minDatBytes) +} + +func CustomGeoLocalFileNeedsRepair(path string) bool { + return localDatFileNeedsRepair(path) +} + +func probeCustomGeoURLWithGET(rawURL string) error { + client := &http.Client{Timeout: customGeoProbeTimeout} + req, err := http.NewRequest(http.MethodGet, rawURL, nil) + if err != nil { + return err + } + req.Header.Set("Range", "bytes=0-0") + resp, err := client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + _, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 256)) + switch resp.StatusCode { + case http.StatusOK, http.StatusPartialContent: + return nil + default: + return fmt.Errorf("get range status %d", resp.StatusCode) + } +} + +func probeCustomGeoURL(rawURL string) error { + client := &http.Client{Timeout: customGeoProbeTimeout} + req, err := http.NewRequest(http.MethodHead, rawURL, nil) + if err != nil { + return err + } + resp, err := client.Do(req) + if err != nil { + return err + } + _ = resp.Body.Close() + sc := resp.StatusCode + if sc >= 200 && sc < 300 { + return nil + } + if sc == http.StatusMethodNotAllowed || sc == http.StatusNotImplemented { + return probeCustomGeoURLWithGET(rawURL) + } + return fmt.Errorf("head status %d", sc) +} + +func (s *CustomGeoService) EnsureOnStartup() { + list, err := s.GetAll() + if err != nil { + logger.Warning("custom geo startup: load list:", err) + return + } + n := len(list) + if n == 0 { + logger.Info("custom geo startup: no custom geofiles configured") + return + } + logger.Infof("custom geo startup: checking %d custom geofile(s)", n) + for i := range list { + r := &list[i] + if err := s.validateURL(r.Url); err != nil { + logger.Warningf("custom geo startup id=%d: invalid url: %v", r.Id, err) + continue + } + s.syncLocalPath(r) + localPath := r.LocalPath + if !localDatFileNeedsRepair(localPath) { + logger.Infof("custom geo startup id=%d alias=%s path=%s: present", r.Id, r.Alias, localPath) + continue + } + logger.Infof("custom geo startup id=%d alias=%s path=%s: missing or needs repair, probing source", r.Id, r.Alias, localPath) + if err := probeCustomGeoURL(r.Url); err != nil { + logger.Warningf("custom geo startup id=%d alias=%s url=%s: probe: %v (attempting download anyway)", r.Id, r.Alias, r.Url, err) + } + _, _ = s.applyDownloadAndPersist(r.Id, true) + } +} + +func (s *CustomGeoService) downloadToPath(resourceURL, destPath string, lastModifiedHeader string) (skipped bool, newLastModified string, err error) { + skipped, lm, err := s.downloadToPathOnce(resourceURL, destPath, lastModifiedHeader, false) + if err != nil { + return false, "", err + } + if skipped { + if _, statErr := os.Stat(destPath); statErr == nil && !localDatFileNeedsRepair(destPath) { + return true, lm, nil + } + return s.downloadToPathOnce(resourceURL, destPath, lastModifiedHeader, true) + } + return false, lm, nil +} + +func (s *CustomGeoService) downloadToPathOnce(resourceURL, destPath string, lastModifiedHeader string, forceFull bool) (skipped bool, newLastModified string, err error) { + var req *http.Request + req, err = http.NewRequest(http.MethodGet, resourceURL, nil) + if err != nil { + return false, "", fmt.Errorf("%w: %v", ErrCustomGeoDownload, err) + } + + if !forceFull { + if fi, statErr := os.Stat(destPath); statErr == nil && !localDatFileNeedsRepair(destPath) { + if !fi.ModTime().IsZero() { + req.Header.Set("If-Modified-Since", fi.ModTime().UTC().Format(http.TimeFormat)) + } else if lastModifiedHeader != "" { + if t, perr := time.Parse(http.TimeFormat, lastModifiedHeader); perr == nil { + req.Header.Set("If-Modified-Since", t.UTC().Format(http.TimeFormat)) + } + } + } + } + + client := &http.Client{Timeout: 10 * time.Minute} + resp, err := client.Do(req) + if err != nil { + return false, "", fmt.Errorf("%w: %v", ErrCustomGeoDownload, err) + } + defer resp.Body.Close() + + var serverModTime time.Time + if lm := resp.Header.Get("Last-Modified"); lm != "" { + if parsed, perr := time.Parse(http.TimeFormat, lm); perr == nil { + serverModTime = parsed + newLastModified = lm + } + } + + updateModTime := func() { + if !serverModTime.IsZero() { + _ = os.Chtimes(destPath, serverModTime, serverModTime) + } + } + + if resp.StatusCode == http.StatusNotModified { + if forceFull { + return false, "", fmt.Errorf("%w: unexpected 304 on unconditional get", ErrCustomGeoDownload) + } + updateModTime() + return true, newLastModified, nil + } + if resp.StatusCode != http.StatusOK { + return false, "", fmt.Errorf("%w: unexpected status %d", ErrCustomGeoDownload, resp.StatusCode) + } + + binDir := filepath.Dir(destPath) + if err = os.MkdirAll(binDir, 0o755); err != nil { + return false, "", fmt.Errorf("%w: %v", ErrCustomGeoDownload, err) + } + + tmpPath := destPath + ".tmp" + out, err := os.Create(tmpPath) + if err != nil { + return false, "", fmt.Errorf("%w: %v", ErrCustomGeoDownload, err) + } + n, err := io.Copy(out, resp.Body) + closeErr := out.Close() + if err != nil { + _ = os.Remove(tmpPath) + return false, "", fmt.Errorf("%w: %v", ErrCustomGeoDownload, err) + } + if closeErr != nil { + _ = os.Remove(tmpPath) + return false, "", fmt.Errorf("%w: %v", ErrCustomGeoDownload, closeErr) + } + if n < minDatBytes { + _ = os.Remove(tmpPath) + return false, "", fmt.Errorf("%w: file too small", ErrCustomGeoDownload) + } + + if err = os.Rename(tmpPath, destPath); err != nil { + _ = os.Remove(tmpPath) + return false, "", fmt.Errorf("%w: %v", ErrCustomGeoDownload, err) + } + + updateModTime() + if newLastModified == "" && resp.Header.Get("Last-Modified") != "" { + newLastModified = resp.Header.Get("Last-Modified") + } + return false, newLastModified, nil +} + +func (s *CustomGeoService) resolveDestPath(r *model.CustomGeoResource) string { + if r.LocalPath != "" { + return r.LocalPath + } + return filepath.Join(config.GetBinFolderPath(), s.fileNameFor(r.Type, r.Alias)) +} + +func (s *CustomGeoService) syncLocalPath(r *model.CustomGeoResource) { + p := filepath.Join(config.GetBinFolderPath(), s.fileNameFor(r.Type, r.Alias)) + r.LocalPath = p +} + +func (s *CustomGeoService) Create(r *model.CustomGeoResource) error { + if err := s.validateType(r.Type); err != nil { + return err + } + if err := s.validateAlias(r.Alias); err != nil { + return err + } + if err := s.validateURL(r.Url); err != nil { + return err + } + var existing int64 + database.GetDB().Model(&model.CustomGeoResource{}). + Where("geo_type = ? AND alias = ?", r.Type, r.Alias).Count(&existing) + if existing > 0 { + return ErrCustomGeoDuplicateAlias + } + s.syncLocalPath(r) + skipped, lm, err := s.downloadToPath(r.Url, r.LocalPath, r.LastModified) + if err != nil { + return err + } + now := time.Now().Unix() + r.LastUpdatedAt = now + r.LastModified = lm + if err = database.GetDB().Create(r).Error; err != nil { + _ = os.Remove(r.LocalPath) + return err + } + logger.Infof("custom geo created id=%d type=%s alias=%s skipped=%v", r.Id, r.Type, r.Alias, skipped) + if err = s.serverService.RestartXrayService(); err != nil { + logger.Warning("custom geo create: restart xray:", err) + } + return nil +} + +func (s *CustomGeoService) Update(id int, r *model.CustomGeoResource) error { + var cur model.CustomGeoResource + if err := database.GetDB().First(&cur, id).Error; err != nil { + if database.IsNotFound(err) { + return ErrCustomGeoNotFound + } + return err + } + if err := s.validateType(r.Type); err != nil { + return err + } + if err := s.validateAlias(r.Alias); err != nil { + return err + } + if err := s.validateURL(r.Url); err != nil { + return err + } + if cur.Type != r.Type || cur.Alias != r.Alias { + var cnt int64 + database.GetDB().Model(&model.CustomGeoResource{}). + Where("geo_type = ? AND alias = ? AND id <> ?", r.Type, r.Alias, id). + Count(&cnt) + if cnt > 0 { + return ErrCustomGeoDuplicateAlias + } + } + oldPath := s.resolveDestPath(&cur) + s.syncLocalPath(r) + r.Id = id + r.LocalPath = filepath.Join(config.GetBinFolderPath(), s.fileNameFor(r.Type, r.Alias)) + if oldPath != r.LocalPath && oldPath != "" { + if _, err := os.Stat(oldPath); err == nil { + _ = os.Remove(oldPath) + } + } + _, lm, err := s.downloadToPath(r.Url, r.LocalPath, cur.LastModified) + if err != nil { + return err + } + r.LastUpdatedAt = time.Now().Unix() + r.LastModified = lm + err = database.GetDB().Model(&model.CustomGeoResource{}).Where("id = ?", id).Updates(map[string]any{ + "geo_type": r.Type, + "alias": r.Alias, + "url": r.Url, + "local_path": r.LocalPath, + "last_updated_at": r.LastUpdatedAt, + "last_modified": r.LastModified, + }).Error + if err != nil { + return err + } + logger.Infof("custom geo updated id=%d", id) + if err = s.serverService.RestartXrayService(); err != nil { + logger.Warning("custom geo update: restart xray:", err) + } + return nil +} + +func (s *CustomGeoService) Delete(id int) (displayName string, err error) { + var r model.CustomGeoResource + if err := database.GetDB().First(&r, id).Error; err != nil { + if database.IsNotFound(err) { + return "", ErrCustomGeoNotFound + } + return "", err + } + displayName = s.fileNameFor(r.Type, r.Alias) + p := s.resolveDestPath(&r) + if err := database.GetDB().Delete(&model.CustomGeoResource{}, id).Error; err != nil { + return displayName, err + } + if p != "" { + if _, err := os.Stat(p); err == nil { + if rmErr := os.Remove(p); rmErr != nil { + logger.Warningf("custom geo delete file %s: %v", p, rmErr) + } + } + } + logger.Infof("custom geo deleted id=%d", id) + if err := s.serverService.RestartXrayService(); err != nil { + logger.Warning("custom geo delete: restart xray:", err) + } + return displayName, nil +} + +func (s *CustomGeoService) GetAll() ([]model.CustomGeoResource, error) { + var list []model.CustomGeoResource + err := database.GetDB().Order("id asc").Find(&list).Error + return list, err +} + +func (s *CustomGeoService) applyDownloadAndPersist(id int, onStartup bool) (displayName string, err error) { + var r model.CustomGeoResource + if err := database.GetDB().First(&r, id).Error; err != nil { + if database.IsNotFound(err) { + return "", ErrCustomGeoNotFound + } + return "", err + } + displayName = s.fileNameFor(r.Type, r.Alias) + s.syncLocalPath(&r) + skipped, lm, err := s.downloadToPath(r.Url, r.LocalPath, r.LastModified) + if err != nil { + if onStartup { + logger.Warningf("custom geo startup download id=%d: %v", id, err) + } else { + logger.Warningf("custom geo manual update id=%d: %v", id, err) + } + return displayName, err + } + now := time.Now().Unix() + updates := map[string]any{ + "last_modified": lm, + "local_path": r.LocalPath, + "last_updated_at": now, + } + if err = database.GetDB().Model(&model.CustomGeoResource{}).Where("id = ?", id).Updates(updates).Error; err != nil { + if onStartup { + logger.Warningf("custom geo startup id=%d: persist metadata: %v", id, err) + } else { + logger.Warningf("custom geo manual update id=%d: persist metadata: %v", id, err) + } + return displayName, err + } + if skipped { + if onStartup { + logger.Infof("custom geo startup download skipped (not modified) id=%d", id) + } else { + logger.Infof("custom geo manual update skipped (not modified) id=%d", id) + } + } else { + if onStartup { + logger.Infof("custom geo startup download ok id=%d", id) + } else { + logger.Infof("custom geo manual update ok id=%d", id) + } + } + return displayName, nil +} + +func (s *CustomGeoService) TriggerUpdate(id int) (string, error) { + displayName, err := s.applyDownloadAndPersist(id, false) + if err != nil { + return displayName, err + } + if err = s.serverService.RestartXrayService(); err != nil { + logger.Warning("custom geo manual update: restart xray:", err) + } + return displayName, nil +} + +func (s *CustomGeoService) TriggerUpdateAll() (*CustomGeoUpdateAllResult, error) { + var list []model.CustomGeoResource + var err error + if s.updateAllGetAll != nil { + list, err = s.updateAllGetAll() + } else { + list, err = s.GetAll() + } + if err != nil { + return nil, err + } + res := &CustomGeoUpdateAllResult{} + if len(list) == 0 { + return res, nil + } + for _, r := range list { + var name string + var applyErr error + if s.updateAllApply != nil { + name, applyErr = s.updateAllApply(r.Id, false) + } else { + name, applyErr = s.applyDownloadAndPersist(r.Id, false) + } + if applyErr != nil { + res.Failed = append(res.Failed, CustomGeoUpdateAllFailure{ + Id: r.Id, Alias: r.Alias, FileName: name, Err: applyErr.Error(), + }) + continue + } + res.Succeeded = append(res.Succeeded, CustomGeoUpdateAllItem{ + Id: r.Id, Alias: r.Alias, FileName: name, + }) + } + if len(res.Succeeded) > 0 { + var restartErr error + if s.updateAllRestart != nil { + restartErr = s.updateAllRestart() + } else { + restartErr = s.serverService.RestartXrayService() + } + if restartErr != nil { + logger.Warning("custom geo update all: restart xray:", restartErr) + } + } + return res, nil +} + +type CustomGeoAliasItem struct { + Alias string `json:"alias"` + Type string `json:"type"` + FileName string `json:"fileName"` + ExtExample string `json:"extExample"` +} + +type CustomGeoAliasesResponse struct { + Geosite []CustomGeoAliasItem `json:"geosite"` + Geoip []CustomGeoAliasItem `json:"geoip"` +} + +func (s *CustomGeoService) GetAliasesForUI() (CustomGeoAliasesResponse, error) { + list, err := s.GetAll() + if err != nil { + logger.Warning("custom geo GetAliasesForUI:", err) + return CustomGeoAliasesResponse{}, err + } + var out CustomGeoAliasesResponse + for _, r := range list { + fn := s.fileNameFor(r.Type, r.Alias) + ex := fmt.Sprintf("ext:%s:tag", fn) + item := CustomGeoAliasItem{ + Alias: r.Alias, + Type: r.Type, + FileName: fn, + ExtExample: ex, + } + if r.Type == customGeoTypeGeoip { + out.Geoip = append(out.Geoip, item) + } else { + out.Geosite = append(out.Geosite, item) + } + } + return out, nil +} diff --git a/web/service/custom_geo_test.go b/web/service/custom_geo_test.go new file mode 100644 index 00000000..811a0f62 --- /dev/null +++ b/web/service/custom_geo_test.go @@ -0,0 +1,330 @@ +package service + +import ( + "errors" + "fmt" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/mhsanaei/3x-ui/v2/database/model" +) + +func TestNormalizeAliasKey(t *testing.T) { + if got := NormalizeAliasKey("GeoIP-IR"); got != "geoip_ir" { + t.Fatalf("got %q", got) + } + if got := NormalizeAliasKey("a-b_c"); got != "a_b_c" { + t.Fatalf("got %q", got) + } +} + +func TestNewCustomGeoService(t *testing.T) { + s := NewCustomGeoService() + if err := s.validateAlias("ok_alias-1"); err != nil { + t.Fatal(err) + } +} + +func TestTriggerUpdateAllAllSuccess(t *testing.T) { + s := CustomGeoService{} + s.updateAllGetAll = func() ([]model.CustomGeoResource, error) { + return []model.CustomGeoResource{ + {Id: 1, Alias: "a"}, + {Id: 2, Alias: "b"}, + }, nil + } + s.updateAllApply = func(id int, onStartup bool) (string, error) { + return fmt.Sprintf("geo_%d.dat", id), nil + } + restartCalls := 0 + s.updateAllRestart = func() error { + restartCalls++ + return nil + } + + res, err := s.TriggerUpdateAll() + if err != nil { + t.Fatal(err) + } + if len(res.Succeeded) != 2 || len(res.Failed) != 0 { + t.Fatalf("unexpected result: %+v", res) + } + if restartCalls != 1 { + t.Fatalf("expected 1 restart, got %d", restartCalls) + } +} + +func TestTriggerUpdateAllPartialSuccess(t *testing.T) { + s := CustomGeoService{} + s.updateAllGetAll = func() ([]model.CustomGeoResource, error) { + return []model.CustomGeoResource{ + {Id: 1, Alias: "ok"}, + {Id: 2, Alias: "bad"}, + }, nil + } + s.updateAllApply = func(id int, onStartup bool) (string, error) { + if id == 2 { + return "geo_2.dat", ErrCustomGeoDownload + } + return "geo_1.dat", nil + } + restartCalls := 0 + s.updateAllRestart = func() error { + restartCalls++ + return nil + } + + res, err := s.TriggerUpdateAll() + if err != nil { + t.Fatal(err) + } + if len(res.Succeeded) != 1 || len(res.Failed) != 1 { + t.Fatalf("unexpected result: %+v", res) + } + if restartCalls != 1 { + t.Fatalf("expected 1 restart, got %d", restartCalls) + } +} + +func TestTriggerUpdateAllAllFailure(t *testing.T) { + s := CustomGeoService{} + s.updateAllGetAll = func() ([]model.CustomGeoResource, error) { + return []model.CustomGeoResource{ + {Id: 1, Alias: "a"}, + {Id: 2, Alias: "b"}, + }, nil + } + s.updateAllApply = func(id int, onStartup bool) (string, error) { + return fmt.Sprintf("geo_%d.dat", id), ErrCustomGeoDownload + } + restartCalls := 0 + s.updateAllRestart = func() error { + restartCalls++ + return nil + } + + res, err := s.TriggerUpdateAll() + if err != nil { + t.Fatal(err) + } + if len(res.Succeeded) != 0 || len(res.Failed) != 2 { + t.Fatalf("unexpected result: %+v", res) + } + if restartCalls != 0 { + t.Fatalf("expected 0 restart, got %d", restartCalls) + } +} + +func TestCustomGeoValidateAlias(t *testing.T) { + s := CustomGeoService{} + if err := s.validateAlias(""); !errors.Is(err, ErrCustomGeoAliasRequired) { + t.Fatal("empty alias") + } + if err := s.validateAlias("Bad"); !errors.Is(err, ErrCustomGeoAliasPattern) { + t.Fatal("uppercase") + } + if err := s.validateAlias("a b"); !errors.Is(err, ErrCustomGeoAliasPattern) { + t.Fatal("space") + } + if err := s.validateAlias("ok_alias-1"); err != nil { + t.Fatal(err) + } + if err := s.validateAlias("geoip"); !errors.Is(err, ErrCustomGeoAliasReserved) { + t.Fatal("reserved") + } +} + +func TestCustomGeoValidateURL(t *testing.T) { + s := CustomGeoService{} + if err := s.validateURL(""); !errors.Is(err, ErrCustomGeoURLRequired) { + t.Fatal("empty") + } + if err := s.validateURL("ftp://x"); !errors.Is(err, ErrCustomGeoURLScheme) { + t.Fatal("ftp") + } + if err := s.validateURL("https://example.com/a.dat"); err != nil { + t.Fatal(err) + } +} + +func TestCustomGeoValidateType(t *testing.T) { + s := CustomGeoService{} + if err := s.validateType("geosite"); err != nil { + t.Fatal(err) + } + if err := s.validateType("x"); !errors.Is(err, ErrCustomGeoInvalidType) { + t.Fatal("bad type") + } +} + +func TestCustomGeoDownloadToPath(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Test", "1") + if r.Header.Get("If-Modified-Since") != "" { + w.WriteHeader(http.StatusNotModified) + return + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write(make([]byte, minDatBytes+1)) + })) + defer ts.Close() + dir := t.TempDir() + t.Setenv("XUI_BIN_FOLDER", dir) + dest := filepath.Join(dir, "geoip_t.dat") + s := CustomGeoService{} + skipped, _, err := s.downloadToPath(ts.URL, dest, "") + if err != nil { + t.Fatal(err) + } + if skipped { + t.Fatal("expected download") + } + st, err := os.Stat(dest) + if err != nil || st.Size() < minDatBytes { + t.Fatalf("file %v", err) + } + skipped2, _, err2 := s.downloadToPath(ts.URL, dest, "") + if err2 != nil || !skipped2 { + t.Fatalf("304 expected skipped=%v err=%v", skipped2, err2) + } +} + +func TestCustomGeoDownloadToPath_missingLocalSendsNoIMSFromDB(t *testing.T) { + lm := "Wed, 21 Oct 2015 07:28:00 GMT" + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("If-Modified-Since") != "" { + w.WriteHeader(http.StatusNotModified) + return + } + w.Header().Set("Last-Modified", lm) + w.WriteHeader(http.StatusOK) + _, _ = w.Write(make([]byte, minDatBytes+1)) + })) + defer ts.Close() + dir := t.TempDir() + t.Setenv("XUI_BIN_FOLDER", dir) + dest := filepath.Join(dir, "geoip_rebuild.dat") + s := CustomGeoService{} + skipped, _, err := s.downloadToPath(ts.URL, dest, lm) + if err != nil { + t.Fatal(err) + } + if skipped { + t.Fatal("must not treat as not-modified when local file is missing") + } + if _, err := os.Stat(dest); err != nil { + t.Fatal("file should exist after container-style rebuild") + } +} + +func TestCustomGeoDownloadToPath_repairSkipsConditional(t *testing.T) { + lm := "Wed, 21 Oct 2015 07:28:00 GMT" + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("If-Modified-Since") != "" { + w.WriteHeader(http.StatusNotModified) + return + } + w.Header().Set("Last-Modified", lm) + w.WriteHeader(http.StatusOK) + _, _ = w.Write(make([]byte, minDatBytes+1)) + })) + defer ts.Close() + dir := t.TempDir() + t.Setenv("XUI_BIN_FOLDER", dir) + dest := filepath.Join(dir, "geoip_bad.dat") + if err := os.WriteFile(dest, make([]byte, minDatBytes-1), 0o644); err != nil { + t.Fatal(err) + } + s := CustomGeoService{} + skipped, _, err := s.downloadToPath(ts.URL, dest, lm) + if err != nil { + t.Fatal(err) + } + if skipped { + t.Fatal("corrupt local file must be re-downloaded, not 304") + } + st, err := os.Stat(dest) + if err != nil || st.Size() < minDatBytes { + t.Fatalf("file repaired: %v", err) + } +} + +func TestCustomGeoFileNameFor(t *testing.T) { + s := CustomGeoService{} + if s.fileNameFor("geoip", "a") != "geoip_a.dat" { + t.Fatal("geoip name") + } + if s.fileNameFor("geosite", "b") != "geosite_b.dat" { + t.Fatal("geosite name") + } +} + +func TestLocalDatFileNeedsRepair(t *testing.T) { + dir := t.TempDir() + if !localDatFileNeedsRepair(filepath.Join(dir, "missing.dat")) { + t.Fatal("missing") + } + smallPath := filepath.Join(dir, "small.dat") + if err := os.WriteFile(smallPath, make([]byte, minDatBytes-1), 0o644); err != nil { + t.Fatal(err) + } + if !localDatFileNeedsRepair(smallPath) { + t.Fatal("small") + } + okPath := filepath.Join(dir, "ok.dat") + if err := os.WriteFile(okPath, make([]byte, minDatBytes), 0o644); err != nil { + t.Fatal(err) + } + if localDatFileNeedsRepair(okPath) { + t.Fatal("ok size") + } + dirPath := filepath.Join(dir, "isdir.dat") + if err := os.Mkdir(dirPath, 0o755); err != nil { + t.Fatal(err) + } + if !localDatFileNeedsRepair(dirPath) { + t.Fatal("dir should need repair") + } + if !CustomGeoLocalFileNeedsRepair(dirPath) { + t.Fatal("exported wrapper dir") + } + if CustomGeoLocalFileNeedsRepair(okPath) { + t.Fatal("exported wrapper ok file") + } +} + +func TestProbeCustomGeoURL_HEADOK(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodHead { + w.WriteHeader(http.StatusOK) + return + } + w.WriteHeader(http.StatusOK) + })) + defer ts.Close() + if err := probeCustomGeoURL(ts.URL); err != nil { + t.Fatal(err) + } +} + +func TestProbeCustomGeoURL_HEAD405GETRange(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodHead { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + if r.Method == http.MethodGet && r.Header.Get("Range") != "" { + w.WriteHeader(http.StatusPartialContent) + _, _ = w.Write([]byte{0}) + return + } + w.WriteHeader(http.StatusBadRequest) + })) + defer ts.Close() + if err := probeCustomGeoURL(ts.URL); err != nil { + t.Fatal(err) + } +} |
