diff options
| author | MHSanaei <mc.sanaei@gmail.com> | 2023-02-09 22:18:06 +0300 |
|---|---|---|
| committer | MHSanaei <mc.sanaei@gmail.com> | 2023-02-09 22:18:06 +0300 |
| commit | b73e4173a3c1e69e02ad6b4e3b43e425e57a5be9 (patch) | |
| tree | d95d2f5e903d97082e11eb9f9023c165b1bde388 /web/service | |
3x-ui
Diffstat (limited to 'web/service')
| -rw-r--r-- | web/service/config.json | 75 | ||||
| -rw-r--r-- | web/service/inbound.go | 417 | ||||
| -rw-r--r-- | web/service/panel.go | 26 | ||||
| -rw-r--r-- | web/service/server.go | 302 | ||||
| -rw-r--r-- | web/service/setting.go | 303 | ||||
| -rw-r--r-- | web/service/user.go | 73 | ||||
| -rw-r--r-- | web/service/xray.go | 163 |
7 files changed, 1359 insertions, 0 deletions
diff --git a/web/service/config.json b/web/service/config.json new file mode 100644 index 00000000..5370fcf4 --- /dev/null +++ b/web/service/config.json @@ -0,0 +1,75 @@ +{ + "log": { + "loglevel": "warning", + "access": "./access.log" + }, + + "api": { + "services": [ + "HandlerService", + "LoggerService", + "StatsService" + ], + "tag": "api" + }, + "inbounds": [ + { + "listen": "127.0.0.1", + "port": 62789, + "protocol": "dokodemo-door", + "settings": { + "address": "127.0.0.1" + }, + "tag": "api" + } + ], + "outbounds": [ + { + "protocol": "freedom", + "settings": {} + }, + { + "protocol": "blackhole", + "settings": {}, + "tag": "blocked" + } + ], + "policy": { + "levels": { + "0": { + "statsUserUplink": true, + "statsUserDownlink": true + } + }, + "system": { + "statsInboundDownlink": true, + "statsInboundUplink": true + } + }, + "routing": { + "rules": [ + { + "inboundTag": [ + "api" + ], + "outboundTag": "api", + "type": "field" + }, + { + "ip": [ + "geoip:private" + ], + "outboundTag": "blocked", + "type": "field" + }, + { + "outboundTag": "blocked", + "protocol": [ + "bittorrent" + ], + "type": "field" + } + ] + }, + "stats": {} +} diff --git a/web/service/inbound.go b/web/service/inbound.go new file mode 100644 index 00000000..37888729 --- /dev/null +++ b/web/service/inbound.go @@ -0,0 +1,417 @@ +package service + +import ( + "fmt" + "time" + "x-ui/database" + "encoding/json" + "x-ui/database/model" + "x-ui/util/common" + "x-ui/xray" + "x-ui/logger" + + "gorm.io/gorm" +) + +type InboundService struct { +} + +func (s *InboundService) GetInbounds(userId int) ([]*model.Inbound, error) { + db := database.GetDB() + var inbounds []*model.Inbound + err := db.Model(model.Inbound{}).Preload("ClientStats").Where("user_id = ?", userId).Find(&inbounds).Error + if err != nil && err != gorm.ErrRecordNotFound { + return nil, err + } + return inbounds, nil +} + +func (s *InboundService) GetAllInbounds() ([]*model.Inbound, error) { + db := database.GetDB() + var inbounds []*model.Inbound + err := db.Model(model.Inbound{}).Preload("ClientStats").Find(&inbounds).Error + if err != nil && err != gorm.ErrRecordNotFound { + return nil, err + } + return inbounds, nil +} + +func (s *InboundService) checkPortExist(port int, ignoreId int) (bool, error) { + db := database.GetDB() + db = db.Model(model.Inbound{}).Where("port = ?", port) + if ignoreId > 0 { + db = db.Where("id != ?", ignoreId) + } + var count int64 + err := db.Count(&count).Error + if err != nil { + return false, err + } + return count > 0, nil +} + +func (s *InboundService) getClients(inbound *model.Inbound) ([]model.Client, error) { + settings := map[string][]model.Client{} + json.Unmarshal([]byte(inbound.Settings), &settings) + if settings == nil { + return nil, fmt.Errorf("Setting is null") + } + + clients := settings["clients"] + if clients == nil { + return nil, nil + } + return clients, nil +} + +func (s *InboundService) checkEmailsExist(emails map[string] bool, ignoreId int) (string, error) { + db := database.GetDB() + var inbounds []*model.Inbound + db = db.Model(model.Inbound{}).Where("Protocol in ?", []model.Protocol{model.VMess, model.VLESS}) + if (ignoreId > 0) { + db = db.Where("id != ?", ignoreId) + } + db = db.Find(&inbounds) + if db.Error != nil { + return "", db.Error + } + + for _, inbound := range inbounds { + clients, err := s.getClients(inbound) + if err != nil { + return "", err + } + + for _, client := range clients { + if emails[client.Email] { + return client.Email, nil + } + } + } + return "", nil +} + +func (s *InboundService) checkEmailExistForInbound(inbound *model.Inbound) (string, error) { + clients, err := s.getClients(inbound) + if err != nil { + return "", err + } + emails := make(map[string] bool) + for _, client := range clients { + if (client.Email != "") { + if emails[client.Email] { + return client.Email, nil + } + emails[client.Email] = true; + } + } + return s.checkEmailsExist(emails, inbound.Id) +} + +func (s *InboundService) AddInbound(inbound *model.Inbound) (*model.Inbound,error) { + exist, err := s.checkPortExist(inbound.Port, 0) + if err != nil { + return inbound, err + } + if exist { + return inbound, common.NewError("端口已存在:", inbound.Port) + } + + existEmail, err := s.checkEmailExistForInbound(inbound) + if err != nil { + return inbound, err + } + if existEmail != "" { + return inbound, common.NewError("Duplicate email:", existEmail) + } + + db := database.GetDB() + + err = db.Save(inbound).Error + if err == nil { + s.UpdateClientStat(inbound.Id,inbound.Settings) + } + return inbound, err +} + +func (s *InboundService) AddInbounds(inbounds []*model.Inbound) error { + for _, inbound := range inbounds { + exist, err := s.checkPortExist(inbound.Port, 0) + if err != nil { + return err + } + if exist { + return common.NewError("端口已存在:", inbound.Port) + } + } + + db := database.GetDB() + tx := db.Begin() + var err error + defer func() { + if err == nil { + tx.Commit() + } else { + tx.Rollback() + } + }() + + for _, inbound := range inbounds { + err = tx.Save(inbound).Error + if err != nil { + return err + } + } + + return nil +} + +func (s *InboundService) DelInbound(id int) error { + db := database.GetDB() + return db.Delete(model.Inbound{}, id).Error +} + +func (s *InboundService) GetInbound(id int) (*model.Inbound, error) { + db := database.GetDB() + inbound := &model.Inbound{} + err := db.Model(model.Inbound{}).First(inbound, id).Error + if err != nil { + return nil, err + } + return inbound, nil +} + +func (s *InboundService) UpdateInbound(inbound *model.Inbound) (*model.Inbound, error) { + exist, err := s.checkPortExist(inbound.Port, inbound.Id) + if err != nil { + return inbound, err + } + if exist { + return inbound, common.NewError("端口已存在:", inbound.Port) + } + + existEmail, err := s.checkEmailExistForInbound(inbound) + if err != nil { + return inbound, err + } + if existEmail != "" { + return inbound, common.NewError("Duplicate email:", existEmail) + } + + oldInbound, err := s.GetInbound(inbound.Id) + if err != nil { + return inbound, err + } + oldInbound.Up = inbound.Up + oldInbound.Down = inbound.Down + oldInbound.Total = inbound.Total + oldInbound.Remark = inbound.Remark + oldInbound.Enable = inbound.Enable + oldInbound.ExpiryTime = inbound.ExpiryTime + oldInbound.Listen = inbound.Listen + oldInbound.Port = inbound.Port + oldInbound.Protocol = inbound.Protocol + oldInbound.Settings = inbound.Settings + oldInbound.StreamSettings = inbound.StreamSettings + oldInbound.Sniffing = inbound.Sniffing + oldInbound.Tag = fmt.Sprintf("inbound-%v", inbound.Port) + + s.UpdateClientStat(inbound.Id,inbound.Settings) + db := database.GetDB() + return inbound, db.Save(oldInbound).Error +} + +func (s *InboundService) AddTraffic(traffics []*xray.Traffic) (err error) { + if len(traffics) == 0 { + return nil + } + db := database.GetDB() + db = db.Model(model.Inbound{}) + tx := db.Begin() + defer func() { + if err != nil { + tx.Rollback() + } else { + tx.Commit() + } + }() + for _, traffic := range traffics { + if traffic.IsInbound { + err = tx.Where("tag = ?", traffic.Tag). + UpdateColumn("up", gorm.Expr("up + ?", traffic.Up)). + UpdateColumn("down", gorm.Expr("down + ?", traffic.Down)). + Error + if err != nil { + return + } + } + } + return +} +func (s *InboundService) AddClientTraffic(traffics []*xray.ClientTraffic) (err error) { + if len(traffics) == 0 { + return nil + } + db := database.GetDB() + dbInbound := db.Model(model.Inbound{}) + + db = db.Model(xray.ClientTraffic{}) + tx := db.Begin() + defer func() { + if err != nil { + tx.Rollback() + } else { + tx.Commit() + } + }() + txInbound := dbInbound.Begin() + defer func() { + if err != nil { + txInbound.Rollback() + } else { + txInbound.Commit() + } + }() + + for _, traffic := range traffics { + inbound := &model.Inbound{} + + err := txInbound.Where("settings like ?", "%" + traffic.Email + "%").First(inbound).Error + traffic.InboundId = inbound.Id + if err != nil { + if err == gorm.ErrRecordNotFound { + // delete removed client record + clientErr := s.DelClientStat(tx, traffic.Email) + logger.Warning(err, traffic.Email,clientErr) + + } + continue + } + // get settings clients + settings := map[string][]model.Client{} + json.Unmarshal([]byte(inbound.Settings), &settings) + clients := settings["clients"] + for _, client := range clients { + if traffic.Email == client.Email { + traffic.ExpiryTime = client.ExpiryTime + traffic.Total = client.TotalGB + } + } + if tx.Where("inbound_id = ?", inbound.Id).Where("email = ?", traffic.Email). + UpdateColumn("enable", true). + UpdateColumn("expiry_time", traffic.ExpiryTime). + UpdateColumn("total",traffic.Total). + UpdateColumn("up", gorm.Expr("up + ?", traffic.Up)). + UpdateColumn("down", gorm.Expr("down + ?", traffic.Down)).RowsAffected == 0 { + err = tx.Create(traffic).Error + } + + if err != nil { + logger.Warning("AddClientTraffic update data ", err) + continue + } + + } + return +} + +func (s *InboundService) DisableInvalidInbounds() (int64, error) { + db := database.GetDB() + now := time.Now().Unix() * 1000 + result := db.Model(model.Inbound{}). + Where("((total > 0 and up + down >= total) or (expiry_time > 0 and expiry_time <= ?)) and enable = ?", now, true). + Update("enable", false) + err := result.Error + count := result.RowsAffected + return count, err +} +func (s *InboundService) DisableInvalidClients() (int64, error) { + db := database.GetDB() + now := time.Now().Unix() * 1000 + result := db.Model(xray.ClientTraffic{}). + Where("((total > 0 and up + down >= total) or (expiry_time > 0 and expiry_time <= ?)) and enable = ?", now, true). + Update("enable", false) + err := result.Error + count := result.RowsAffected + return count, err +} +func (s *InboundService) UpdateClientStat(inboundId int, inboundSettings string) (error) { + db := database.GetDB() + + // get settings clients + settings := map[string][]model.Client{} + json.Unmarshal([]byte(inboundSettings), &settings) + clients := settings["clients"] + for _, client := range clients { + result := db.Model(xray.ClientTraffic{}). + Where("inbound_id = ? and email = ?", inboundId, client.Email). + Updates(map[string]interface{}{"enable": true, "total": client.TotalGB, "expiry_time": client.ExpiryTime}) + if result.RowsAffected == 0 { + clientTraffic := xray.ClientTraffic{} + clientTraffic.InboundId = inboundId + clientTraffic.Email = client.Email + clientTraffic.Total = client.TotalGB + clientTraffic.ExpiryTime = client.ExpiryTime + clientTraffic.Enable = true + clientTraffic.Up = 0 + clientTraffic.Down = 0 + db.Create(&clientTraffic) + } + err := result.Error + if err != nil { + return err + } + + } + return nil +} +func (s *InboundService) DelClientStat(tx *gorm.DB, email string) error { + return tx.Where("email = ?", email).Delete(xray.ClientTraffic{}).Error +} + +func (s *InboundService) ResetClientTraffic(clientEmail string) (error) { + db := database.GetDB() + + result := db.Model(xray.ClientTraffic{}). + Where("email = ?", clientEmail). + Update("up", 0). + Update("down", 0) + + err := result.Error + + + if err != nil { + return err + } + return nil +} +func (s *InboundService) GetClientTrafficById(uuid string) (traffic *xray.ClientTraffic, err error) { + db := database.GetDB() + inbound := &model.Inbound{} + traffic = &xray.ClientTraffic{} + + err = db.Model(model.Inbound{}).Where("settings like ?", "%" + uuid + "%").First(inbound).Error + if err != nil { + if err == gorm.ErrRecordNotFound { + logger.Warning(err) + return nil, err + } + } + traffic.InboundId = inbound.Id + + // get settings clients + settings := map[string][]model.Client{} + json.Unmarshal([]byte(inbound.Settings), &settings) + clients := settings["clients"] + for _, client := range clients { + if uuid == client.ID { + traffic.Email = client.Email + } + } + err = db.Model(xray.ClientTraffic{}).Where("email = ?", traffic.Email).First(traffic).Error + if err != nil { + logger.Warning(err) + return nil, err + } + return traffic, err +} diff --git a/web/service/panel.go b/web/service/panel.go new file mode 100644 index 00000000..f90d3e66 --- /dev/null +++ b/web/service/panel.go @@ -0,0 +1,26 @@ +package service + +import ( + "os" + "syscall" + "time" + "x-ui/logger" +) + +type PanelService struct { +} + +func (s *PanelService) RestartPanel(delay time.Duration) error { + p, err := os.FindProcess(syscall.Getpid()) + if err != nil { + return err + } + go func() { + time.Sleep(delay) + err := p.Signal(syscall.SIGHUP) + if err != nil { + logger.Error("send signal SIGHUP failed:", err) + } + }() + return nil +} diff --git a/web/service/server.go b/web/service/server.go new file mode 100644 index 00000000..efd985e6 --- /dev/null +++ b/web/service/server.go @@ -0,0 +1,302 @@ +package service + +import ( + "archive/zip" + "bytes" + "encoding/json" + "fmt" + "io" + "io/fs" + "net/http" + "os" + "runtime" + "time" + "x-ui/logger" + "x-ui/util/sys" + "x-ui/xray" + + "github.com/shirou/gopsutil/cpu" + "github.com/shirou/gopsutil/disk" + "github.com/shirou/gopsutil/host" + "github.com/shirou/gopsutil/load" + "github.com/shirou/gopsutil/mem" + "github.com/shirou/gopsutil/net" +) + +type ProcessState string + +const ( + Running ProcessState = "running" + Stop ProcessState = "stop" + Error ProcessState = "error" +) + +type Status struct { + T time.Time `json:"-"` + Cpu float64 `json:"cpu"` + Mem struct { + Current uint64 `json:"current"` + Total uint64 `json:"total"` + } `json:"mem"` + Swap struct { + Current uint64 `json:"current"` + Total uint64 `json:"total"` + } `json:"swap"` + Disk struct { + Current uint64 `json:"current"` + Total uint64 `json:"total"` + } `json:"disk"` + Xray struct { + State ProcessState `json:"state"` + ErrorMsg string `json:"errorMsg"` + Version string `json:"version"` + } `json:"xray"` + Uptime uint64 `json:"uptime"` + Loads []float64 `json:"loads"` + TcpCount int `json:"tcpCount"` + UdpCount int `json:"udpCount"` + NetIO struct { + Up uint64 `json:"up"` + Down uint64 `json:"down"` + } `json:"netIO"` + NetTraffic struct { + Sent uint64 `json:"sent"` + Recv uint64 `json:"recv"` + } `json:"netTraffic"` +} + +type Release struct { + TagName string `json:"tag_name"` +} + +type ServerService struct { + xrayService XrayService +} + +func (s *ServerService) GetStatus(lastStatus *Status) *Status { + now := time.Now() + status := &Status{ + T: now, + } + + percents, err := cpu.Percent(0, false) + if err != nil { + logger.Warning("get cpu percent failed:", err) + } else { + status.Cpu = percents[0] + } + + upTime, err := host.Uptime() + if err != nil { + logger.Warning("get uptime failed:", err) + } else { + status.Uptime = upTime + } + + memInfo, err := mem.VirtualMemory() + if err != nil { + logger.Warning("get virtual memory failed:", err) + } else { + status.Mem.Current = memInfo.Used + status.Mem.Total = memInfo.Total + } + + swapInfo, err := mem.SwapMemory() + if err != nil { + logger.Warning("get swap memory failed:", err) + } else { + status.Swap.Current = swapInfo.Used + status.Swap.Total = swapInfo.Total + } + + distInfo, err := disk.Usage("/") + if err != nil { + logger.Warning("get dist usage failed:", err) + } else { + status.Disk.Current = distInfo.Used + status.Disk.Total = distInfo.Total + } + + avgState, err := load.Avg() + if err != nil { + logger.Warning("get load avg failed:", err) + } else { + status.Loads = []float64{avgState.Load1, avgState.Load5, avgState.Load15} + } + + ioStats, err := net.IOCounters(false) + if err != nil { + logger.Warning("get io counters failed:", err) + } else if len(ioStats) > 0 { + ioStat := ioStats[0] + status.NetTraffic.Sent = ioStat.BytesSent + status.NetTraffic.Recv = ioStat.BytesRecv + + if lastStatus != nil { + duration := now.Sub(lastStatus.T) + seconds := float64(duration) / float64(time.Second) + up := uint64(float64(status.NetTraffic.Sent-lastStatus.NetTraffic.Sent) / seconds) + down := uint64(float64(status.NetTraffic.Recv-lastStatus.NetTraffic.Recv) / seconds) + status.NetIO.Up = up + status.NetIO.Down = down + } + } else { + logger.Warning("can not find io counters") + } + + status.TcpCount, err = sys.GetTCPCount() + if err != nil { + logger.Warning("get tcp connections failed:", err) + } + + status.UdpCount, err = sys.GetUDPCount() + if err != nil { + logger.Warning("get udp connections failed:", err) + } + + if s.xrayService.IsXrayRunning() { + status.Xray.State = Running + status.Xray.ErrorMsg = "" + } else { + err := s.xrayService.GetXrayErr() + if err != nil { + status.Xray.State = Error + } else { + status.Xray.State = Stop + } + status.Xray.ErrorMsg = s.xrayService.GetXrayResult() + } + status.Xray.Version = s.xrayService.GetXrayVersion() + + return status +} + +func (s *ServerService) GetXrayVersions() ([]string, error) { + url := "https://api.github.com/repos/XTLS/Xray-core/releases" + resp, err := http.Get(url) + if err != nil { + return nil, err + } + + defer resp.Body.Close() + buffer := bytes.NewBuffer(make([]byte, 8192)) + buffer.Reset() + _, err = buffer.ReadFrom(resp.Body) + if err != nil { + return nil, err + } + + releases := make([]Release, 0) + err = json.Unmarshal(buffer.Bytes(), &releases) + if err != nil { + return nil, err + } + versions := make([]string, 0, len(releases)) + for _, release := range releases { + versions = append(versions, release.TagName) + } + return versions, nil +} + +func (s *ServerService) downloadXRay(version string) (string, error) { + osName := runtime.GOOS + arch := runtime.GOARCH + + switch osName { + case "darwin": + osName = "macos" + } + + switch arch { + case "amd64": + arch = "64" + case "arm64": + arch = "arm64-v8a" + } + + fileName := fmt.Sprintf("Xray-%s-%s.zip", osName, arch) + url := fmt.Sprintf("https://github.com/XTLS/Xray-core/releases/download/%s/%s", version, fileName) + resp, err := http.Get(url) + if err != nil { + return "", err + } + defer resp.Body.Close() + + os.Remove(fileName) + file, err := os.Create(fileName) + if err != nil { + return "", err + } + defer file.Close() + + _, err = io.Copy(file, resp.Body) + if err != nil { + return "", err + } + + return fileName, nil +} + +func (s *ServerService) UpdateXray(version string) error { + zipFileName, err := s.downloadXRay(version) + if err != nil { + return err + } + + zipFile, err := os.Open(zipFileName) + if err != nil { + return err + } + defer func() { + zipFile.Close() + os.Remove(zipFileName) + }() + + stat, err := zipFile.Stat() + if err != nil { + return err + } + reader, err := zip.NewReader(zipFile, stat.Size()) + if err != nil { + return err + } + + s.xrayService.StopXray() + defer func() { + err := s.xrayService.RestartXray(true) + if err != nil { + logger.Error("start xray failed:", err) + } + }() + + copyZipFile := func(zipName string, fileName string) error { + zipFile, err := reader.Open(zipName) + if err != nil { + return err + } + os.Remove(fileName) + file, err := os.OpenFile(fileName, os.O_CREATE|os.O_RDWR|os.O_TRUNC, fs.ModePerm) + if err != nil { + return err + } + defer file.Close() + _, err = io.Copy(file, zipFile) + return err + } + + err = copyZipFile("xray", xray.GetBinaryPath()) + if err != nil { + return err + } + err = copyZipFile("geosite.dat", xray.GetGeositePath()) + if err != nil { + return err + } + err = copyZipFile("geoip.dat", xray.GetGeoipPath()) + if err != nil { + return err + } + + return nil + +} diff --git a/web/service/setting.go b/web/service/setting.go new file mode 100644 index 00000000..a0592e93 --- /dev/null +++ b/web/service/setting.go @@ -0,0 +1,303 @@ +package service + +import ( + _ "embed" + "errors" + "fmt" + "reflect" + "strconv" + "strings" + "time" + "x-ui/database" + "x-ui/database/model" + "x-ui/logger" + "x-ui/util/common" + "x-ui/util/random" + "x-ui/util/reflect_util" + "x-ui/web/entity" +) + +//go:embed config.json +var xrayTemplateConfig string + +var defaultValueMap = map[string]string{ + "xrayTemplateConfig": xrayTemplateConfig, + "webListen": "", + "webPort": "54321", + "webCertFile": "", + "webKeyFile": "", + "secret": random.Seq(32), + "webBasePath": "/", + "timeLocation": "Asia/Tehran", + "tgBotEnable": "false", + "tgBotToken": "", + "tgBotChatId": "0", + "tgRunTime": "", +} + +type SettingService struct { +} + +func (s *SettingService) GetAllSetting() (*entity.AllSetting, error) { + db := database.GetDB() + settings := make([]*model.Setting, 0) + err := db.Model(model.Setting{}).Find(&settings).Error + if err != nil { + return nil, err + } + allSetting := &entity.AllSetting{} + t := reflect.TypeOf(allSetting).Elem() + v := reflect.ValueOf(allSetting).Elem() + fields := reflect_util.GetFields(t) + + setSetting := func(key, value string) (err error) { + defer func() { + panicErr := recover() + if panicErr != nil { + err = errors.New(fmt.Sprint(panicErr)) + } + }() + + var found bool + var field reflect.StructField + for _, f := range fields { + if f.Tag.Get("json") == key { + field = f + found = true + break + } + } + + if !found { + // 有些设置自动生成,不需要返回到前端给用户修改 + return nil + } + + fieldV := v.FieldByName(field.Name) + switch t := fieldV.Interface().(type) { + case int: + n, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return err + } + fieldV.SetInt(n) + case string: + fieldV.SetString(value) + case bool: + fieldV.SetBool(value == "true") + default: + return common.NewErrorf("unknown field %v type %v", key, t) + } + return + } + + keyMap := map[string]bool{} + for _, setting := range settings { + err := setSetting(setting.Key, setting.Value) + if err != nil { + return nil, err + } + keyMap[setting.Key] = true + } + + for key, value := range defaultValueMap { + if keyMap[key] { + continue + } + err := setSetting(key, value) + if err != nil { + return nil, err + } + } + + return allSetting, nil +} + +func (s *SettingService) ResetSettings() error { + db := database.GetDB() + return db.Where("1 = 1").Delete(model.Setting{}).Error +} + +func (s *SettingService) getSetting(key string) (*model.Setting, error) { + db := database.GetDB() + setting := &model.Setting{} + err := db.Model(model.Setting{}).Where("key = ?", key).First(setting).Error + if err != nil { + return nil, err + } + return setting, nil +} + +func (s *SettingService) saveSetting(key string, value string) error { + setting, err := s.getSetting(key) + db := database.GetDB() + if database.IsNotFound(err) { + return db.Create(&model.Setting{ + Key: key, + Value: value, + }).Error + } else if err != nil { + return err + } + setting.Key = key + setting.Value = value + return db.Save(setting).Error +} + +func (s *SettingService) getString(key string) (string, error) { + setting, err := s.getSetting(key) + if database.IsNotFound(err) { + value, ok := defaultValueMap[key] + if !ok { + return "", common.NewErrorf("key <%v> not in defaultValueMap", key) + } + return value, nil + } else if err != nil { + return "", err + } + return setting.Value, nil +} + +func (s *SettingService) setString(key string, value string) error { + return s.saveSetting(key, value) +} + +func (s *SettingService) getBool(key string) (bool, error) { + str, err := s.getString(key) + if err != nil { + return false, err + } + return strconv.ParseBool(str) +} + +func (s *SettingService) setBool(key string, value bool) error { + return s.setString(key, strconv.FormatBool(value)) +} + +func (s *SettingService) getInt(key string) (int, error) { + str, err := s.getString(key) + if err != nil { + return 0, err + } + return strconv.Atoi(str) +} + +func (s *SettingService) setInt(key string, value int) error { + return s.setString(key, strconv.Itoa(value)) +} + +func (s *SettingService) GetXrayConfigTemplate() (string, error) { + return s.getString("xrayTemplateConfig") +} + +func (s *SettingService) GetListen() (string, error) { + return s.getString("webListen") +} + +func (s *SettingService) GetTgBotToken() (string, error) { + return s.getString("tgBotToken") +} + +func (s *SettingService) SetTgBotToken(token string) error { + return s.setString("tgBotToken", token) +} + +func (s *SettingService) GetTgBotChatId() (int, error) { + return s.getInt("tgBotChatId") +} + +func (s *SettingService) SetTgBotChatId(chatId int) error { + return s.setInt("tgBotChatId", chatId) +} + +func (s *SettingService) SetTgbotenabled(value bool) error { + return s.setBool("tgBotEnable", value) +} + +func (s *SettingService) GetTgbotenabled() (bool, error) { + return s.getBool("tgBotEnable") +} + +func (s *SettingService) SetTgbotRuntime(time string) error { + return s.setString("tgRunTime", time) +} + +func (s *SettingService) GetTgbotRuntime() (string, error) { + return s.getString("tgRunTime") +} + +func (s *SettingService) GetPort() (int, error) { + return s.getInt("webPort") +} + +func (s *SettingService) SetPort(port int) error { + return s.setInt("webPort", port) +} + +func (s *SettingService) GetCertFile() (string, error) { + return s.getString("webCertFile") +} + +func (s *SettingService) GetKeyFile() (string, error) { + return s.getString("webKeyFile") +} + +func (s *SettingService) GetSecret() ([]byte, error) { + secret, err := s.getString("secret") + if secret == defaultValueMap["secret"] { + err := s.saveSetting("secret", secret) + if err != nil { + logger.Warning("save secret failed:", err) + } + } + return []byte(secret), err +} + +func (s *SettingService) GetBasePath() (string, error) { + basePath, err := s.getString("webBasePath") + if err != nil { + return "", err + } + if !strings.HasPrefix(basePath, "/") { + basePath = "/" + basePath + } + if !strings.HasSuffix(basePath, "/") { + basePath += "/" + } + return basePath, nil +} + +func (s *SettingService) GetTimeLocation() (*time.Location, error) { + l, err := s.getString("timeLocation") + if err != nil { + return nil, err + } + location, err := time.LoadLocation(l) + if err != nil { + defaultLocation := defaultValueMap["timeLocation"] + logger.Errorf("location <%v> not exist, using default location: %v", l, defaultLocation) + return time.LoadLocation(defaultLocation) + } + return location, nil +} + +func (s *SettingService) UpdateAllSetting(allSetting *entity.AllSetting) error { + if err := allSetting.CheckValid(); err != nil { + return err + } + + v := reflect.ValueOf(allSetting).Elem() + t := reflect.TypeOf(allSetting).Elem() + fields := reflect_util.GetFields(t) + errs := make([]error, 0) + for _, field := range fields { + key := field.Tag.Get("json") + fieldV := v.FieldByName(field.Name) + value := fmt.Sprint(fieldV.Interface()) + err := s.saveSetting(key, value) + if err != nil { + errs = append(errs, err) + } + } + return common.Combine(errs...) +} diff --git a/web/service/user.go b/web/service/user.go new file mode 100644 index 00000000..e4e7572d --- /dev/null +++ b/web/service/user.go @@ -0,0 +1,73 @@ +package service + +import ( + "errors" + "x-ui/database" + "x-ui/database/model" + "x-ui/logger" + + "gorm.io/gorm" +) + +type UserService struct { +} + +func (s *UserService) GetFirstUser() (*model.User, error) { + db := database.GetDB() + + user := &model.User{} + err := db.Model(model.User{}). + First(user). + Error + if err != nil { + return nil, err + } + return user, nil +} + +func (s *UserService) CheckUser(username string, password string) *model.User { + db := database.GetDB() + + user := &model.User{} + err := db.Model(model.User{}). + Where("username = ? and password = ?", username, password). + First(user). + Error + if err == gorm.ErrRecordNotFound { + return nil + } else if err != nil { + logger.Warning("check user err:", err) + return nil + } + return user +} + +func (s *UserService) UpdateUser(id int, username string, password string) error { + db := database.GetDB() + return db.Model(model.User{}). + Where("id = ?", id). + Update("username", username). + Update("password", password). + Error +} + +func (s *UserService) UpdateFirstUser(username string, password string) error { + if username == "" { + return errors.New("username can not be empty") + } else if password == "" { + return errors.New("password can not be empty") + } + db := database.GetDB() + user := &model.User{} + err := db.Model(model.User{}).First(user).Error + if database.IsNotFound(err) { + user.Username = username + user.Password = password + return db.Model(model.User{}).Create(user).Error + } else if err != ni
|
