Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/MHSanaei/3x-ui.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'web/service/setting.go')
-rw-r--r--web/service/setting.go303
1 files changed, 303 insertions, 0 deletions
diff --git a/web/service/setting.go b/web/service/setting.go
new file mode 100644
index 00000000..a0592e93
--- /dev/null
+++ b/web/service/setting.go
@@ -0,0 +1,303 @@
+package service
+
+import (
+ _ "embed"
+ "errors"
+ "fmt"
+ "reflect"
+ "strconv"
+ "strings"
+ "time"
+ "x-ui/database"
+ "x-ui/database/model"
+ "x-ui/logger"
+ "x-ui/util/common"
+ "x-ui/util/random"
+ "x-ui/util/reflect_util"
+ "x-ui/web/entity"
+)
+
+//go:embed config.json
+var xrayTemplateConfig string
+
+var defaultValueMap = map[string]string{
+ "xrayTemplateConfig": xrayTemplateConfig,
+ "webListen": "",
+ "webPort": "54321",
+ "webCertFile": "",
+ "webKeyFile": "",
+ "secret": random.Seq(32),
+ "webBasePath": "/",
+ "timeLocation": "Asia/Tehran",
+ "tgBotEnable": "false",
+ "tgBotToken": "",
+ "tgBotChatId": "0",
+ "tgRunTime": "",
+}
+
+type SettingService struct {
+}
+
+func (s *SettingService) GetAllSetting() (*entity.AllSetting, error) {
+ db := database.GetDB()
+ settings := make([]*model.Setting, 0)
+ err := db.Model(model.Setting{}).Find(&settings).Error
+ if err != nil {
+ return nil, err
+ }
+ allSetting := &entity.AllSetting{}
+ t := reflect.TypeOf(allSetting).Elem()
+ v := reflect.ValueOf(allSetting).Elem()
+ fields := reflect_util.GetFields(t)
+
+ setSetting := func(key, value string) (err error) {
+ defer func() {
+ panicErr := recover()
+ if panicErr != nil {
+ err = errors.New(fmt.Sprint(panicErr))
+ }
+ }()
+
+ var found bool
+ var field reflect.StructField
+ for _, f := range fields {
+ if f.Tag.Get("json") == key {
+ field = f
+ found = true
+ break
+ }
+ }
+
+ if !found {
+ // 有些设置自动生成,不需要返回到前端给用户修改
+ return nil
+ }
+
+ fieldV := v.FieldByName(field.Name)
+ switch t := fieldV.Interface().(type) {
+ case int:
+ n, err := strconv.ParseInt(value, 10, 64)
+ if err != nil {
+ return err
+ }
+ fieldV.SetInt(n)
+ case string:
+ fieldV.SetString(value)
+ case bool:
+ fieldV.SetBool(value == "true")
+ default:
+ return common.NewErrorf("unknown field %v type %v", key, t)
+ }
+ return
+ }
+
+ keyMap := map[string]bool{}
+ for _, setting := range settings {
+ err := setSetting(setting.Key, setting.Value)
+ if err != nil {
+ return nil, err
+ }
+ keyMap[setting.Key] = true
+ }
+
+ for key, value := range defaultValueMap {
+ if keyMap[key] {
+ continue
+ }
+ err := setSetting(key, value)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ return allSetting, nil
+}
+
+func (s *SettingService) ResetSettings() error {
+ db := database.GetDB()
+ return db.Where("1 = 1").Delete(model.Setting{}).Error
+}
+
+func (s *SettingService) getSetting(key string) (*model.Setting, error) {
+ db := database.GetDB()
+ setting := &model.Setting{}
+ err := db.Model(model.Setting{}).Where("key = ?", key).First(setting).Error
+ if err != nil {
+ return nil, err
+ }
+ return setting, nil
+}
+
+func (s *SettingService) saveSetting(key string, value string) error {
+ setting, err := s.getSetting(key)
+ db := database.GetDB()
+ if database.IsNotFound(err) {
+ return db.Create(&model.Setting{
+ Key: key,
+ Value: value,
+ }).Error
+ } else if err != nil {
+ return err
+ }
+ setting.Key = key
+ setting.Value = value
+ return db.Save(setting).Error
+}
+
+func (s *SettingService) getString(key string) (string, error) {
+ setting, err := s.getSetting(key)
+ if database.IsNotFound(err) {
+ value, ok := defaultValueMap[key]
+ if !ok {
+ return "", common.NewErrorf("key <%v> not in defaultValueMap", key)
+ }
+ return value, nil
+ } else if err != nil {
+ return "", err
+ }
+ return setting.Value, nil
+}
+
+func (s *SettingService) setString(key string, value string) error {
+ return s.saveSetting(key, value)
+}
+
+func (s *SettingService) getBool(key string) (bool, error) {
+ str, err := s.getString(key)
+ if err != nil {
+ return false, err
+ }
+ return strconv.ParseBool(str)
+}
+
+func (s *SettingService) setBool(key string, value bool) error {
+ return s.setString(key, strconv.FormatBool(value))
+}
+
+func (s *SettingService) getInt(key string) (int, error) {
+ str, err := s.getString(key)
+ if err != nil {
+ return 0, err
+ }
+ return strconv.Atoi(str)
+}
+
+func (s *SettingService) setInt(key string, value int) error {
+ return s.setString(key, strconv.Itoa(value))
+}
+
+func (s *SettingService) GetXrayConfigTemplate() (string, error) {
+ return s.getString("xrayTemplateConfig")
+}
+
+func (s *SettingService) GetListen() (string, error) {
+ return s.getString("webListen")
+}
+
+func (s *SettingService) GetTgBotToken() (string, error) {
+ return s.getString("tgBotToken")
+}
+
+func (s *SettingService) SetTgBotToken(token string) error {
+ return s.setString("tgBotToken", token)
+}
+
+func (s *SettingService) GetTgBotChatId() (int, error) {
+ return s.getInt("tgBotChatId")
+}
+
+func (s *SettingService) SetTgBotChatId(chatId int) error {
+ return s.setInt("tgBotChatId", chatId)
+}
+
+func (s *SettingService) SetTgbotenabled(value bool) error {
+ return s.setBool("tgBotEnable", value)
+}
+
+func (s *SettingService) GetTgbotenabled() (bool, error) {
+ return s.getBool("tgBotEnable")
+}
+
+func (s *SettingService) SetTgbotRuntime(time string) error {
+ return s.setString("tgRunTime", time)
+}
+
+func (s *SettingService) GetTgbotRuntime() (string, error) {
+ return s.getString("tgRunTime")
+}
+
+func (s *SettingService) GetPort() (int, error) {
+ return s.getInt("webPort")
+}
+
+func (s *SettingService) SetPort(port int) error {
+ return s.setInt("webPort", port)
+}
+
+func (s *SettingService) GetCertFile() (string, error) {
+ return s.getString("webCertFile")
+}
+
+func (s *SettingService) GetKeyFile() (string, error) {
+ return s.getString("webKeyFile")
+}
+
+func (s *SettingService) GetSecret() ([]byte, error) {
+ secret, err := s.getString("secret")
+ if secret == defaultValueMap["secret"] {
+ err := s.saveSetting("secret", secret)
+ if err != nil {
+ logger.Warning("save secret failed:", err)
+ }
+ }
+ return []byte(secret), err
+}
+
+func (s *SettingService) GetBasePath() (string, error) {
+ basePath, err := s.getString("webBasePath")
+ if err != nil {
+ return "", err
+ }
+ if !strings.HasPrefix(basePath, "/") {
+ basePath = "/" + basePath
+ }
+ if !strings.HasSuffix(basePath, "/") {
+ basePath += "/"
+ }
+ return basePath, nil
+}
+
+func (s *SettingService) GetTimeLocation() (*time.Location, error) {
+ l, err := s.getString("timeLocation")
+ if err != nil {
+ return nil, err
+ }
+ location, err := time.LoadLocation(l)
+ if err != nil {
+ defaultLocation := defaultValueMap["timeLocation"]
+ logger.Errorf("location <%v> not exist, using default location: %v", l, defaultLocation)
+ return time.LoadLocation(defaultLocation)
+ }
+ return location, nil
+}
+
+func (s *SettingService) UpdateAllSetting(allSetting *entity.AllSetting) error {
+ if err := allSetting.CheckValid(); err != nil {
+ return err
+ }
+
+ v := reflect.ValueOf(allSetting).Elem()
+ t := reflect.TypeOf(allSetting).Elem()
+ fields := reflect_util.GetFields(t)
+ errs := make([]error, 0)
+ for _, field := range fields {
+ key := field.Tag.Get("json")
+ fieldV := v.FieldByName(field.Name)
+ value := fmt.Sprint(fieldV.Interface())
+ err := s.saveSetting(key, value)
+ if err != nil {
+ errs = append(errs, err)
+ }
+ }
+ return common.Combine(errs...)
+}