Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/MHSanaei/3x-ui.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorVladislav Tupikin <MrRefactoring@yandex.ru>2026-04-19 22:24:24 +0300
committerGitHub <noreply@github.com>2026-04-19 22:24:24 +0300
commit7466916e0206d55826d74f37c251bb5e40182c00 (patch)
tree2c887558c71f34b76e541c7bf7d57a420ed3c9ed /web/service
parent96b568b8389fd5a3ce228d5fb82ec9742d145b15 (diff)
Add custom geosite/geoip URL sources (#3980)
* feat: add custom geosite/geoip URL sources Register DB model, panel API, index/xray UI, and i18n. * fix
Diffstat (limited to 'web/service')
-rw-r--r--web/service/custom_geo.go603
-rw-r--r--web/service/custom_geo_test.go330
2 files changed, 933 insertions, 0 deletions
diff --git a/web/service/custom_geo.go b/web/service/custom_geo.go
new file mode 100644
index 00000000..a8b7456b
--- /dev/null
+++ b/web/service/custom_geo.go
@@ -0,0 +1,603 @@
+package service
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "os"
+ "path/filepath"
+ "regexp"
+ "strings"
+ "time"
+
+ "github.com/mhsanaei/3x-ui/v2/config"
+ "github.com/mhsanaei/3x-ui/v2/database"
+ "github.com/mhsanaei/3x-ui/v2/database/model"
+ "github.com/mhsanaei/3x-ui/v2/logger"
+)
+
+const (
+ customGeoTypeGeosite = "geosite"
+ customGeoTypeGeoip = "geoip"
+ minDatBytes = 64
+ customGeoProbeTimeout = 12 * time.Second
+)
+
+var (
+ customGeoAliasPattern = regexp.MustCompile(`^[a-z0-9_-]+$`)
+ reservedCustomAliases = map[string]struct{}{
+ "geoip": {}, "geosite": {},
+ "geoip_ir": {}, "geosite_ir": {},
+ "geoip_ru": {}, "geosite_ru": {},
+ }
+ ErrCustomGeoInvalidType = errors.New("custom_geo_invalid_type")
+ ErrCustomGeoAliasRequired = errors.New("custom_geo_alias_required")
+ ErrCustomGeoAliasPattern = errors.New("custom_geo_alias_pattern")
+ ErrCustomGeoAliasReserved = errors.New("custom_geo_alias_reserved")
+ ErrCustomGeoURLRequired = errors.New("custom_geo_url_required")
+ ErrCustomGeoInvalidURL = errors.New("custom_geo_invalid_url")
+ ErrCustomGeoURLScheme = errors.New("custom_geo_url_scheme")
+ ErrCustomGeoURLHost = errors.New("custom_geo_url_host")
+ ErrCustomGeoDuplicateAlias = errors.New("custom_geo_duplicate_alias")
+ ErrCustomGeoNotFound = errors.New("custom_geo_not_found")
+ ErrCustomGeoDownload = errors.New("custom_geo_download")
+)
+
+type CustomGeoUpdateAllItem struct {
+ Id int `json:"id"`
+ Alias string `json:"alias"`
+ FileName string `json:"fileName"`
+}
+
+type CustomGeoUpdateAllFailure struct {
+ Id int `json:"id"`
+ Alias string `json:"alias"`
+ FileName string `json:"fileName"`
+ Err string `json:"error"`
+}
+
+type CustomGeoUpdateAllResult struct {
+ Succeeded []CustomGeoUpdateAllItem `json:"succeeded"`
+ Failed []CustomGeoUpdateAllFailure `json:"failed"`
+}
+
+type CustomGeoService struct {
+ serverService *ServerService
+ updateAllGetAll func() ([]model.CustomGeoResource, error)
+ updateAllApply func(id int, onStartup bool) (string, error)
+ updateAllRestart func() error
+}
+
+func NewCustomGeoService() *CustomGeoService {
+ s := &CustomGeoService{
+ serverService: &ServerService{},
+ }
+ s.updateAllGetAll = s.GetAll
+ s.updateAllApply = s.applyDownloadAndPersist
+ s.updateAllRestart = func() error { return s.serverService.RestartXrayService() }
+ return s
+}
+
+func NormalizeAliasKey(alias string) string {
+ return strings.ToLower(strings.ReplaceAll(alias, "-", "_"))
+}
+
+func (s *CustomGeoService) fileNameFor(typ, alias string) string {
+ if typ == customGeoTypeGeoip {
+ return fmt.Sprintf("geoip_%s.dat", alias)
+ }
+ return fmt.Sprintf("geosite_%s.dat", alias)
+}
+
+func (s *CustomGeoService) validateType(typ string) error {
+ if typ != customGeoTypeGeosite && typ != customGeoTypeGeoip {
+ return ErrCustomGeoInvalidType
+ }
+ return nil
+}
+
+func (s *CustomGeoService) validateAlias(alias string) error {
+ if alias == "" {
+ return ErrCustomGeoAliasRequired
+ }
+ if !customGeoAliasPattern.MatchString(alias) {
+ return ErrCustomGeoAliasPattern
+ }
+ if _, ok := reservedCustomAliases[NormalizeAliasKey(alias)]; ok {
+ return ErrCustomGeoAliasReserved
+ }
+ return nil
+}
+
+func (s *CustomGeoService) validateURL(raw string) error {
+ if raw == "" {
+ return ErrCustomGeoURLRequired
+ }
+ u, err := url.Parse(raw)
+ if err != nil {
+ return ErrCustomGeoInvalidURL
+ }
+ if u.Scheme != "http" && u.Scheme != "https" {
+ return ErrCustomGeoURLScheme
+ }
+ if u.Host == "" {
+ return ErrCustomGeoURLHost
+ }
+ return nil
+}
+
+func localDatFileNeedsRepair(path string) bool {
+ fi, err := os.Stat(path)
+ if err != nil {
+ return true
+ }
+ if fi.IsDir() {
+ return true
+ }
+ return fi.Size() < int64(minDatBytes)
+}
+
+func CustomGeoLocalFileNeedsRepair(path string) bool {
+ return localDatFileNeedsRepair(path)
+}
+
+func probeCustomGeoURLWithGET(rawURL string) error {
+ client := &http.Client{Timeout: customGeoProbeTimeout}
+ req, err := http.NewRequest(http.MethodGet, rawURL, nil)
+ if err != nil {
+ return err
+ }
+ req.Header.Set("Range", "bytes=0-0")
+ resp, err := client.Do(req)
+ if err != nil {
+ return err
+ }
+ defer resp.Body.Close()
+ _, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 256))
+ switch resp.StatusCode {
+ case http.StatusOK, http.StatusPartialContent:
+ return nil
+ default:
+ return fmt.Errorf("get range status %d", resp.StatusCode)
+ }
+}
+
+func probeCustomGeoURL(rawURL string) error {
+ client := &http.Client{Timeout: customGeoProbeTimeout}
+ req, err := http.NewRequest(http.MethodHead, rawURL, nil)
+ if err != nil {
+ return err
+ }
+ resp, err := client.Do(req)
+ if err != nil {
+ return err
+ }
+ _ = resp.Body.Close()
+ sc := resp.StatusCode
+ if sc >= 200 && sc < 300 {
+ return nil
+ }
+ if sc == http.StatusMethodNotAllowed || sc == http.StatusNotImplemented {
+ return probeCustomGeoURLWithGET(rawURL)
+ }
+ return fmt.Errorf("head status %d", sc)
+}
+
+func (s *CustomGeoService) EnsureOnStartup() {
+ list, err := s.GetAll()
+ if err != nil {
+ logger.Warning("custom geo startup: load list:", err)
+ return
+ }
+ n := len(list)
+ if n == 0 {
+ logger.Info("custom geo startup: no custom geofiles configured")
+ return
+ }
+ logger.Infof("custom geo startup: checking %d custom geofile(s)", n)
+ for i := range list {
+ r := &list[i]
+ if err := s.validateURL(r.Url); err != nil {
+ logger.Warningf("custom geo startup id=%d: invalid url: %v", r.Id, err)
+ continue
+ }
+ s.syncLocalPath(r)
+ localPath := r.LocalPath
+ if !localDatFileNeedsRepair(localPath) {
+ logger.Infof("custom geo startup id=%d alias=%s path=%s: present", r.Id, r.Alias, localPath)
+ continue
+ }
+ logger.Infof("custom geo startup id=%d alias=%s path=%s: missing or needs repair, probing source", r.Id, r.Alias, localPath)
+ if err := probeCustomGeoURL(r.Url); err != nil {
+ logger.Warningf("custom geo startup id=%d alias=%s url=%s: probe: %v (attempting download anyway)", r.Id, r.Alias, r.Url, err)
+ }
+ _, _ = s.applyDownloadAndPersist(r.Id, true)
+ }
+}
+
+func (s *CustomGeoService) downloadToPath(resourceURL, destPath string, lastModifiedHeader string) (skipped bool, newLastModified string, err error) {
+ skipped, lm, err := s.downloadToPathOnce(resourceURL, destPath, lastModifiedHeader, false)
+ if err != nil {
+ return false, "", err
+ }
+ if skipped {
+ if _, statErr := os.Stat(destPath); statErr == nil && !localDatFileNeedsRepair(destPath) {
+ return true, lm, nil
+ }
+ return s.downloadToPathOnce(resourceURL, destPath, lastModifiedHeader, true)
+ }
+ return false, lm, nil
+}
+
+func (s *CustomGeoService) downloadToPathOnce(resourceURL, destPath string, lastModifiedHeader string, forceFull bool) (skipped bool, newLastModified string, err error) {
+ var req *http.Request
+ req, err = http.NewRequest(http.MethodGet, resourceURL, nil)
+ if err != nil {
+ return false, "", fmt.Errorf("%w: %v", ErrCustomGeoDownload, err)
+ }
+
+ if !forceFull {
+ if fi, statErr := os.Stat(destPath); statErr == nil && !localDatFileNeedsRepair(destPath) {
+ if !fi.ModTime().IsZero() {
+ req.Header.Set("If-Modified-Since", fi.ModTime().UTC().Format(http.TimeFormat))
+ } else if lastModifiedHeader != "" {
+ if t, perr := time.Parse(http.TimeFormat, lastModifiedHeader); perr == nil {
+ req.Header.Set("If-Modified-Since", t.UTC().Format(http.TimeFormat))
+ }
+ }
+ }
+ }
+
+ client := &http.Client{Timeout: 10 * time.Minute}
+ resp, err := client.Do(req)
+ if err != nil {
+ return false, "", fmt.Errorf("%w: %v", ErrCustomGeoDownload, err)
+ }
+ defer resp.Body.Close()
+
+ var serverModTime time.Time
+ if lm := resp.Header.Get("Last-Modified"); lm != "" {
+ if parsed, perr := time.Parse(http.TimeFormat, lm); perr == nil {
+ serverModTime = parsed
+ newLastModified = lm
+ }
+ }
+
+ updateModTime := func() {
+ if !serverModTime.IsZero() {
+ _ = os.Chtimes(destPath, serverModTime, serverModTime)
+ }
+ }
+
+ if resp.StatusCode == http.StatusNotModified {
+ if forceFull {
+ return false, "", fmt.Errorf("%w: unexpected 304 on unconditional get", ErrCustomGeoDownload)
+ }
+ updateModTime()
+ return true, newLastModified, nil
+ }
+ if resp.StatusCode != http.StatusOK {
+ return false, "", fmt.Errorf("%w: unexpected status %d", ErrCustomGeoDownload, resp.StatusCode)
+ }
+
+ binDir := filepath.Dir(destPath)
+ if err = os.MkdirAll(binDir, 0o755); err != nil {
+ return false, "", fmt.Errorf("%w: %v", ErrCustomGeoDownload, err)
+ }
+
+ tmpPath := destPath + ".tmp"
+ out, err := os.Create(tmpPath)
+ if err != nil {
+ return false, "", fmt.Errorf("%w: %v", ErrCustomGeoDownload, err)
+ }
+ n, err := io.Copy(out, resp.Body)
+ closeErr := out.Close()
+ if err != nil {
+ _ = os.Remove(tmpPath)
+ return false, "", fmt.Errorf("%w: %v", ErrCustomGeoDownload, err)
+ }
+ if closeErr != nil {
+ _ = os.Remove(tmpPath)
+ return false, "", fmt.Errorf("%w: %v", ErrCustomGeoDownload, closeErr)
+ }
+ if n < minDatBytes {
+ _ = os.Remove(tmpPath)
+ return false, "", fmt.Errorf("%w: file too small", ErrCustomGeoDownload)
+ }
+
+ if err = os.Rename(tmpPath, destPath); err != nil {
+ _ = os.Remove(tmpPath)
+ return false, "", fmt.Errorf("%w: %v", ErrCustomGeoDownload, err)
+ }
+
+ updateModTime()
+ if newLastModified == "" && resp.Header.Get("Last-Modified") != "" {
+ newLastModified = resp.Header.Get("Last-Modified")
+ }
+ return false, newLastModified, nil
+}
+
+func (s *CustomGeoService) resolveDestPath(r *model.CustomGeoResource) string {
+ if r.LocalPath != "" {
+ return r.LocalPath
+ }
+ return filepath.Join(config.GetBinFolderPath(), s.fileNameFor(r.Type, r.Alias))
+}
+
+func (s *CustomGeoService) syncLocalPath(r *model.CustomGeoResource) {
+ p := filepath.Join(config.GetBinFolderPath(), s.fileNameFor(r.Type, r.Alias))
+ r.LocalPath = p
+}
+
+func (s *CustomGeoService) Create(r *model.CustomGeoResource) error {
+ if err := s.validateType(r.Type); err != nil {
+ return err
+ }
+ if err := s.validateAlias(r.Alias); err != nil {
+ return err
+ }
+ if err := s.validateURL(r.Url); err != nil {
+ return err
+ }
+ var existing int64
+ database.GetDB().Model(&model.CustomGeoResource{}).
+ Where("geo_type = ? AND alias = ?", r.Type, r.Alias).Count(&existing)
+ if existing > 0 {
+ return ErrCustomGeoDuplicateAlias
+ }
+ s.syncLocalPath(r)
+ skipped, lm, err := s.downloadToPath(r.Url, r.LocalPath, r.LastModified)
+ if err != nil {
+ return err
+ }
+ now := time.Now().Unix()
+ r.LastUpdatedAt = now
+ r.LastModified = lm
+ if err = database.GetDB().Create(r).Error; err != nil {
+ _ = os.Remove(r.LocalPath)
+ return err
+ }
+ logger.Infof("custom geo created id=%d type=%s alias=%s skipped=%v", r.Id, r.Type, r.Alias, skipped)
+ if err = s.serverService.RestartXrayService(); err != nil {
+ logger.Warning("custom geo create: restart xray:", err)
+ }
+ return nil
+}
+
+func (s *CustomGeoService) Update(id int, r *model.CustomGeoResource) error {
+ var cur model.CustomGeoResource
+ if err := database.GetDB().First(&cur, id).Error; err != nil {
+ if database.IsNotFound(err) {
+ return ErrCustomGeoNotFound
+ }
+ return err
+ }
+ if err := s.validateType(r.Type); err != nil {
+ return err
+ }
+ if err := s.validateAlias(r.Alias); err != nil {
+ return err
+ }
+ if err := s.validateURL(r.Url); err != nil {
+ return err
+ }
+ if cur.Type != r.Type || cur.Alias != r.Alias {
+ var cnt int64
+ database.GetDB().Model(&model.CustomGeoResource{}).
+ Where("geo_type = ? AND alias = ? AND id <> ?", r.Type, r.Alias, id).
+ Count(&cnt)
+ if cnt > 0 {
+ return ErrCustomGeoDuplicateAlias
+ }
+ }
+ oldPath := s.resolveDestPath(&cur)
+ s.syncLocalPath(r)
+ r.Id = id
+ r.LocalPath = filepath.Join(config.GetBinFolderPath(), s.fileNameFor(r.Type, r.Alias))
+ if oldPath != r.LocalPath && oldPath != "" {
+ if _, err := os.Stat(oldPath); err == nil {
+ _ = os.Remove(oldPath)
+ }
+ }
+ _, lm, err := s.downloadToPath(r.Url, r.LocalPath, cur.LastModified)
+ if err != nil {
+ return err
+ }
+ r.LastUpdatedAt = time.Now().Unix()
+ r.LastModified = lm
+ err = database.GetDB().Model(&model.CustomGeoResource{}).Where("id = ?", id).Updates(map[string]any{
+ "geo_type": r.Type,
+ "alias": r.Alias,
+ "url": r.Url,
+ "local_path": r.LocalPath,
+ "last_updated_at": r.LastUpdatedAt,
+ "last_modified": r.LastModified,
+ }).Error
+ if err != nil {
+ return err
+ }
+ logger.Infof("custom geo updated id=%d", id)
+ if err = s.serverService.RestartXrayService(); err != nil {
+ logger.Warning("custom geo update: restart xray:", err)
+ }
+ return nil
+}
+
+func (s *CustomGeoService) Delete(id int) (displayName string, err error) {
+ var r model.CustomGeoResource
+ if err := database.GetDB().First(&r, id).Error; err != nil {
+ if database.IsNotFound(err) {
+ return "", ErrCustomGeoNotFound
+ }
+ return "", err
+ }
+ displayName = s.fileNameFor(r.Type, r.Alias)
+ p := s.resolveDestPath(&r)
+ if err := database.GetDB().Delete(&model.CustomGeoResource{}, id).Error; err != nil {
+ return displayName, err
+ }
+ if p != "" {
+ if _, err := os.Stat(p); err == nil {
+ if rmErr := os.Remove(p); rmErr != nil {
+ logger.Warningf("custom geo delete file %s: %v", p, rmErr)
+ }
+ }
+ }
+ logger.Infof("custom geo deleted id=%d", id)
+ if err := s.serverService.RestartXrayService(); err != nil {
+ logger.Warning("custom geo delete: restart xray:", err)
+ }
+ return displayName, nil
+}
+
+func (s *CustomGeoService) GetAll() ([]model.CustomGeoResource, error) {
+ var list []model.CustomGeoResource
+ err := database.GetDB().Order("id asc").Find(&list).Error
+ return list, err
+}
+
+func (s *CustomGeoService) applyDownloadAndPersist(id int, onStartup bool) (displayName string, err error) {
+ var r model.CustomGeoResource
+ if err := database.GetDB().First(&r, id).Error; err != nil {
+ if database.IsNotFound(err) {
+ return "", ErrCustomGeoNotFound
+ }
+ return "", err
+ }
+ displayName = s.fileNameFor(r.Type, r.Alias)
+ s.syncLocalPath(&r)
+ skipped, lm, err := s.downloadToPath(r.Url, r.LocalPath, r.LastModified)
+ if err != nil {
+ if onStartup {
+ logger.Warningf("custom geo startup download id=%d: %v", id, err)
+ } else {
+ logger.Warningf("custom geo manual update id=%d: %v", id, err)
+ }
+ return displayName, err
+ }
+ now := time.Now().Unix()
+ updates := map[string]any{
+ "last_modified": lm,
+ "local_path": r.LocalPath,
+ "last_updated_at": now,
+ }
+ if err = database.GetDB().Model(&model.CustomGeoResource{}).Where("id = ?", id).Updates(updates).Error; err != nil {
+ if onStartup {
+ logger.Warningf("custom geo startup id=%d: persist metadata: %v", id, err)
+ } else {
+ logger.Warningf("custom geo manual update id=%d: persist metadata: %v", id, err)
+ }
+ return displayName, err
+ }
+ if skipped {
+ if onStartup {
+ logger.Infof("custom geo startup download skipped (not modified) id=%d", id)
+ } else {
+ logger.Infof("custom geo manual update skipped (not modified) id=%d", id)
+ }
+ } else {
+ if onStartup {
+ logger.Infof("custom geo startup download ok id=%d", id)
+ } else {
+ logger.Infof("custom geo manual update ok id=%d", id)
+ }
+ }
+ return displayName, nil
+}
+
+func (s *CustomGeoService) TriggerUpdate(id int) (string, error) {
+ displayName, err := s.applyDownloadAndPersist(id, false)
+ if err != nil {
+ return displayName, err
+ }
+ if err = s.serverService.RestartXrayService(); err != nil {
+ logger.Warning("custom geo manual update: restart xray:", err)
+ }
+ return displayName, nil
+}
+
+func (s *CustomGeoService) TriggerUpdateAll() (*CustomGeoUpdateAllResult, error) {
+ var list []model.CustomGeoResource
+ var err error
+ if s.updateAllGetAll != nil {
+ list, err = s.updateAllGetAll()
+ } else {
+ list, err = s.GetAll()
+ }
+ if err != nil {
+ return nil, err
+ }
+ res := &CustomGeoUpdateAllResult{}
+ if len(list) == 0 {
+ return res, nil
+ }
+ for _, r := range list {
+ var name string
+ var applyErr error
+ if s.updateAllApply != nil {
+ name, applyErr = s.updateAllApply(r.Id, false)
+ } else {
+ name, applyErr = s.applyDownloadAndPersist(r.Id, false)
+ }
+ if applyErr != nil {
+ res.Failed = append(res.Failed, CustomGeoUpdateAllFailure{
+ Id: r.Id, Alias: r.Alias, FileName: name, Err: applyErr.Error(),
+ })
+ continue
+ }
+ res.Succeeded = append(res.Succeeded, CustomGeoUpdateAllItem{
+ Id: r.Id, Alias: r.Alias, FileName: name,
+ })
+ }
+ if len(res.Succeeded) > 0 {
+ var restartErr error
+ if s.updateAllRestart != nil {
+ restartErr = s.updateAllRestart()
+ } else {
+ restartErr = s.serverService.RestartXrayService()
+ }
+ if restartErr != nil {
+ logger.Warning("custom geo update all: restart xray:", restartErr)
+ }
+ }
+ return res, nil
+}
+
+type CustomGeoAliasItem struct {
+ Alias string `json:"alias"`
+ Type string `json:"type"`
+ FileName string `json:"fileName"`
+ ExtExample string `json:"extExample"`
+}
+
+type CustomGeoAliasesResponse struct {
+ Geosite []CustomGeoAliasItem `json:"geosite"`
+ Geoip []CustomGeoAliasItem `json:"geoip"`
+}
+
+func (s *CustomGeoService) GetAliasesForUI() (CustomGeoAliasesResponse, error) {
+ list, err := s.GetAll()
+ if err != nil {
+ logger.Warning("custom geo GetAliasesForUI:", err)
+ return CustomGeoAliasesResponse{}, err
+ }
+ var out CustomGeoAliasesResponse
+ for _, r := range list {
+ fn := s.fileNameFor(r.Type, r.Alias)
+ ex := fmt.Sprintf("ext:%s:tag", fn)
+ item := CustomGeoAliasItem{
+ Alias: r.Alias,
+ Type: r.Type,
+ FileName: fn,
+ ExtExample: ex,
+ }
+ if r.Type == customGeoTypeGeoip {
+ out.Geoip = append(out.Geoip, item)
+ } else {
+ out.Geosite = append(out.Geosite, item)
+ }
+ }
+ return out, nil
+}
diff --git a/web/service/custom_geo_test.go b/web/service/custom_geo_test.go
new file mode 100644
index 00000000..811a0f62
--- /dev/null
+++ b/web/service/custom_geo_test.go
@@ -0,0 +1,330 @@
+package service
+
+import (
+ "errors"
+ "fmt"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "path/filepath"
+ "testing"
+
+ "github.com/mhsanaei/3x-ui/v2/database/model"
+)
+
+func TestNormalizeAliasKey(t *testing.T) {
+ if got := NormalizeAliasKey("GeoIP-IR"); got != "geoip_ir" {
+ t.Fatalf("got %q", got)
+ }
+ if got := NormalizeAliasKey("a-b_c"); got != "a_b_c" {
+ t.Fatalf("got %q", got)
+ }
+}
+
+func TestNewCustomGeoService(t *testing.T) {
+ s := NewCustomGeoService()
+ if err := s.validateAlias("ok_alias-1"); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestTriggerUpdateAllAllSuccess(t *testing.T) {
+ s := CustomGeoService{}
+ s.updateAllGetAll = func() ([]model.CustomGeoResource, error) {
+ return []model.CustomGeoResource{
+ {Id: 1, Alias: "a"},
+ {Id: 2, Alias: "b"},
+ }, nil
+ }
+ s.updateAllApply = func(id int, onStartup bool) (string, error) {
+ return fmt.Sprintf("geo_%d.dat", id), nil
+ }
+ restartCalls := 0
+ s.updateAllRestart = func() error {
+ restartCalls++
+ return nil
+ }
+
+ res, err := s.TriggerUpdateAll()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(res.Succeeded) != 2 || len(res.Failed) != 0 {
+ t.Fatalf("unexpected result: %+v", res)
+ }
+ if restartCalls != 1 {
+ t.Fatalf("expected 1 restart, got %d", restartCalls)
+ }
+}
+
+func TestTriggerUpdateAllPartialSuccess(t *testing.T) {
+ s := CustomGeoService{}
+ s.updateAllGetAll = func() ([]model.CustomGeoResource, error) {
+ return []model.CustomGeoResource{
+ {Id: 1, Alias: "ok"},
+ {Id: 2, Alias: "bad"},
+ }, nil
+ }
+ s.updateAllApply = func(id int, onStartup bool) (string, error) {
+ if id == 2 {
+ return "geo_2.dat", ErrCustomGeoDownload
+ }
+ return "geo_1.dat", nil
+ }
+ restartCalls := 0
+ s.updateAllRestart = func() error {
+ restartCalls++
+ return nil
+ }
+
+ res, err := s.TriggerUpdateAll()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(res.Succeeded) != 1 || len(res.Failed) != 1 {
+ t.Fatalf("unexpected result: %+v", res)
+ }
+ if restartCalls != 1 {
+ t.Fatalf("expected 1 restart, got %d", restartCalls)
+ }
+}
+
+func TestTriggerUpdateAllAllFailure(t *testing.T) {
+ s := CustomGeoService{}
+ s.updateAllGetAll = func() ([]model.CustomGeoResource, error) {
+ return []model.CustomGeoResource{
+ {Id: 1, Alias: "a"},
+ {Id: 2, Alias: "b"},
+ }, nil
+ }
+ s.updateAllApply = func(id int, onStartup bool) (string, error) {
+ return fmt.Sprintf("geo_%d.dat", id), ErrCustomGeoDownload
+ }
+ restartCalls := 0
+ s.updateAllRestart = func() error {
+ restartCalls++
+ return nil
+ }
+
+ res, err := s.TriggerUpdateAll()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(res.Succeeded) != 0 || len(res.Failed) != 2 {
+ t.Fatalf("unexpected result: %+v", res)
+ }
+ if restartCalls != 0 {
+ t.Fatalf("expected 0 restart, got %d", restartCalls)
+ }
+}
+
+func TestCustomGeoValidateAlias(t *testing.T) {
+ s := CustomGeoService{}
+ if err := s.validateAlias(""); !errors.Is(err, ErrCustomGeoAliasRequired) {
+ t.Fatal("empty alias")
+ }
+ if err := s.validateAlias("Bad"); !errors.Is(err, ErrCustomGeoAliasPattern) {
+ t.Fatal("uppercase")
+ }
+ if err := s.validateAlias("a b"); !errors.Is(err, ErrCustomGeoAliasPattern) {
+ t.Fatal("space")
+ }
+ if err := s.validateAlias("ok_alias-1"); err != nil {
+ t.Fatal(err)
+ }
+ if err := s.validateAlias("geoip"); !errors.Is(err, ErrCustomGeoAliasReserved) {
+ t.Fatal("reserved")
+ }
+}
+
+func TestCustomGeoValidateURL(t *testing.T) {
+ s := CustomGeoService{}
+ if err := s.validateURL(""); !errors.Is(err, ErrCustomGeoURLRequired) {
+ t.Fatal("empty")
+ }
+ if err := s.validateURL("ftp://x"); !errors.Is(err, ErrCustomGeoURLScheme) {
+ t.Fatal("ftp")
+ }
+ if err := s.validateURL("https://example.com/a.dat"); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestCustomGeoValidateType(t *testing.T) {
+ s := CustomGeoService{}
+ if err := s.validateType("geosite"); err != nil {
+ t.Fatal(err)
+ }
+ if err := s.validateType("x"); !errors.Is(err, ErrCustomGeoInvalidType) {
+ t.Fatal("bad type")
+ }
+}
+
+func TestCustomGeoDownloadToPath(t *testing.T) {
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("X-Test", "1")
+ if r.Header.Get("If-Modified-Since") != "" {
+ w.WriteHeader(http.StatusNotModified)
+ return
+ }
+ w.WriteHeader(http.StatusOK)
+ _, _ = w.Write(make([]byte, minDatBytes+1))
+ }))
+ defer ts.Close()
+ dir := t.TempDir()
+ t.Setenv("XUI_BIN_FOLDER", dir)
+ dest := filepath.Join(dir, "geoip_t.dat")
+ s := CustomGeoService{}
+ skipped, _, err := s.downloadToPath(ts.URL, dest, "")
+ if err != nil {
+ t.Fatal(err)
+ }
+ if skipped {
+ t.Fatal("expected download")
+ }
+ st, err := os.Stat(dest)
+ if err != nil || st.Size() < minDatBytes {
+ t.Fatalf("file %v", err)
+ }
+ skipped2, _, err2 := s.downloadToPath(ts.URL, dest, "")
+ if err2 != nil || !skipped2 {
+ t.Fatalf("304 expected skipped=%v err=%v", skipped2, err2)
+ }
+}
+
+func TestCustomGeoDownloadToPath_missingLocalSendsNoIMSFromDB(t *testing.T) {
+ lm := "Wed, 21 Oct 2015 07:28:00 GMT"
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.Header.Get("If-Modified-Since") != "" {
+ w.WriteHeader(http.StatusNotModified)
+ return
+ }
+ w.Header().Set("Last-Modified", lm)
+ w.WriteHeader(http.StatusOK)
+ _, _ = w.Write(make([]byte, minDatBytes+1))
+ }))
+ defer ts.Close()
+ dir := t.TempDir()
+ t.Setenv("XUI_BIN_FOLDER", dir)
+ dest := filepath.Join(dir, "geoip_rebuild.dat")
+ s := CustomGeoService{}
+ skipped, _, err := s.downloadToPath(ts.URL, dest, lm)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if skipped {
+ t.Fatal("must not treat as not-modified when local file is missing")
+ }
+ if _, err := os.Stat(dest); err != nil {
+ t.Fatal("file should exist after container-style rebuild")
+ }
+}
+
+func TestCustomGeoDownloadToPath_repairSkipsConditional(t *testing.T) {
+ lm := "Wed, 21 Oct 2015 07:28:00 GMT"
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.Header.Get("If-Modified-Since") != "" {
+ w.WriteHeader(http.StatusNotModified)
+ return
+ }
+ w.Header().Set("Last-Modified", lm)
+ w.WriteHeader(http.StatusOK)
+ _, _ = w.Write(make([]byte, minDatBytes+1))
+ }))
+ defer ts.Close()
+ dir := t.TempDir()
+ t.Setenv("XUI_BIN_FOLDER", dir)
+ dest := filepath.Join(dir, "geoip_bad.dat")
+ if err := os.WriteFile(dest, make([]byte, minDatBytes-1), 0o644); err != nil {
+ t.Fatal(err)
+ }
+ s := CustomGeoService{}
+ skipped, _, err := s.downloadToPath(ts.URL, dest, lm)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if skipped {
+ t.Fatal("corrupt local file must be re-downloaded, not 304")
+ }
+ st, err := os.Stat(dest)
+ if err != nil || st.Size() < minDatBytes {
+ t.Fatalf("file repaired: %v", err)
+ }
+}
+
+func TestCustomGeoFileNameFor(t *testing.T) {
+ s := CustomGeoService{}
+ if s.fileNameFor("geoip", "a") != "geoip_a.dat" {
+ t.Fatal("geoip name")
+ }
+ if s.fileNameFor("geosite", "b") != "geosite_b.dat" {
+ t.Fatal("geosite name")
+ }
+}
+
+func TestLocalDatFileNeedsRepair(t *testing.T) {
+ dir := t.TempDir()
+ if !localDatFileNeedsRepair(filepath.Join(dir, "missing.dat")) {
+ t.Fatal("missing")
+ }
+ smallPath := filepath.Join(dir, "small.dat")
+ if err := os.WriteFile(smallPath, make([]byte, minDatBytes-1), 0o644); err != nil {
+ t.Fatal(err)
+ }
+ if !localDatFileNeedsRepair(smallPath) {
+ t.Fatal("small")
+ }
+ okPath := filepath.Join(dir, "ok.dat")
+ if err := os.WriteFile(okPath, make([]byte, minDatBytes), 0o644); err != nil {
+ t.Fatal(err)
+ }
+ if localDatFileNeedsRepair(okPath) {
+ t.Fatal("ok size")
+ }
+ dirPath := filepath.Join(dir, "isdir.dat")
+ if err := os.Mkdir(dirPath, 0o755); err != nil {
+ t.Fatal(err)
+ }
+ if !localDatFileNeedsRepair(dirPath) {
+ t.Fatal("dir should need repair")
+ }
+ if !CustomGeoLocalFileNeedsRepair(dirPath) {
+ t.Fatal("exported wrapper dir")
+ }
+ if CustomGeoLocalFileNeedsRepair(okPath) {
+ t.Fatal("exported wrapper ok file")
+ }
+}
+
+func TestProbeCustomGeoURL_HEADOK(t *testing.T) {
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.Method == http.MethodHead {
+ w.WriteHeader(http.StatusOK)
+ return
+ }
+ w.WriteHeader(http.StatusOK)
+ }))
+ defer ts.Close()
+ if err := probeCustomGeoURL(ts.URL); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestProbeCustomGeoURL_HEAD405GETRange(t *testing.T) {
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.Method == http.MethodHead {
+ w.WriteHeader(http.StatusMethodNotAllowed)
+ return
+ }
+ if r.Method == http.MethodGet && r.Header.Get("Range") != "" {
+ w.WriteHeader(http.StatusPartialContent)
+ _, _ = w.Write([]byte{0})
+ return
+ }
+ w.WriteHeader(http.StatusBadRequest)
+ }))
+ defer ts.Close()
+ if err := probeCustomGeoURL(ts.URL); err != nil {
+ t.Fatal(err)
+ }
+}