diff options
Diffstat (limited to 'web/service/inbound.go')
| -rw-r--r-- | web/service/inbound.go | 167 |
1 files changed, 106 insertions, 61 deletions
diff --git a/web/service/inbound.go b/web/service/inbound.go index b7eb6789..c3f92e5a 100644 --- a/web/service/inbound.go +++ b/web/service/inbound.go @@ -64,28 +64,45 @@ func (s *InboundService) getClients(inbound *model.Inbound) ([]model.Client, err return clients, nil } -func (s *InboundService) checkEmailsExist(emails map[string]bool, ignoreId int) (string, error) { +func (s *InboundService) getAllEmails() ([]string, error) { db := database.GetDB() - var inbounds []*model.Inbound - db = db.Model(model.Inbound{}).Where("Protocol in ?", []model.Protocol{model.VMess, model.VLESS, model.Trojan}) - if ignoreId > 0 { - db = db.Where("id != ?", ignoreId) - } - db = db.Find(&inbounds) - if db.Error != nil { - return "", db.Error + var emails []string + err := db.Raw(` + SELECT JSON_EXTRACT(client.value, '$.email') + FROM inbounds, + JSON_EACH(JSON_EXTRACT(inbounds.settings, '$.clients')) AS client + `).Scan(&emails).Error + + if err != nil { + return nil, err } + return emails, nil +} - for _, inbound := range inbounds { - clients, err := s.getClients(inbound) - if err != nil { - return "", err +func (s *InboundService) contains(slice []string, str string) bool { + for _, s := range slice { + if s == str { + return true } + } + return false +} - for _, client := range clients { - if emails[client.Email] { +func (s *InboundService) checkEmailsExistForClients(clients []model.Client) (string, error) { + allEmails, err := s.getAllEmails() + if err != nil { + return "", err + } + var emails []string + for _, client := range clients { + if client.Email != "" { + if s.contains(emails, client.Email) { + return client.Email, nil + } + if s.contains(allEmails, client.Email) { return client.Email, nil } + emails = append(emails, client.Email) } } return "", nil @@ -96,16 +113,23 @@ func (s *InboundService) checkEmailExistForInbound(inbound *model.Inbound) (stri if err != nil { return "", err } - emails := make(map[string]bool) + allEmails, err := s.getAllEmails() + if err != nil { + return "", err + } + var emails []string for _, client := range clients { if client.Email != "" { - if emails[client.Email] { + if s.contains(emails, client.Email) { + return client.Email, nil + } + if s.contains(allEmails, client.Email) { return client.Email, nil } - emails[client.Email] = true + emails = append(emails, client.Email) } } - return s.checkEmailsExist(emails, inbound.Id) + return "", nil } func (s *InboundService) AddInbound(inbound *model.Inbound) (*model.Inbound, error) { @@ -215,14 +239,6 @@ func (s *InboundService) UpdateInbound(inbound *model.Inbound) (*model.Inbound, return inbound, common.NewError("Port already exists:", inbound.Port) } - existEmail, err := s.checkEmailExistForInbound(inbound) - if err != nil { - return inbound, err - } - if existEmail != "" { - return inbound, common.NewError("Duplicate email:", existEmail) - } - oldInbound, err := s.GetInbound(inbound.Id) if err != nil { return inbound, err @@ -245,8 +261,12 @@ func (s *InboundService) UpdateInbound(inbound *model.Inbound) (*model.Inbound, return inbound, db.Save(oldInbound).Error } -func (s *InboundService) AddInboundClient(inbound *model.Inbound) error { - existEmail, err := s.checkEmailExistForInbound(inbound) +func (s *InboundService) AddInboundClient(data *model.Inbound) error { + clients, err := s.getClients(data) + if err != nil { + return err + } + existEmail, err := s.checkEmailsExistForClients(clients) if err != nil { return err } @@ -255,29 +275,35 @@ func (s *InboundService) AddInboundClient(inbound *model.Inbound) error { return common.NewError("Duplicate email:", existEmail) } - clients, err := s.getClients(inbound) + oldInbound, err := s.GetInbound(data.Id) if err != nil { return err } - oldInbound, err := s.GetInbound(inbound.Id) + var settings map[string]interface{} + err = json.Unmarshal([]byte(oldInbound.Settings), &settings) if err != nil { return err } - oldClients, err := s.getClients(oldInbound) + oldClients := settings["clients"].([]interface{}) + var newClients []interface{} + for _, client := range clients { + newClients = append(newClients, client) + } + + settings["clients"] = append(oldClients, newClients...) + + newSettings, err := json.MarshalIndent(settings, "", " ") if err != nil { return err } - oldInbound.Settings = inbound.Settings + oldInbound.Settings = string(newSettings) - if len(clients[len(clients)-1].Email) > 0 { - s.AddClientStat(inbound.Id, &clients[len(clients)-1]) - } - for i := len(oldClients); i < len(clients); i++ { - if len(clients[i].Email) > 0 { - s.AddClientStat(inbound.Id, &clients[i]) + for _, client := range clients { + if len(client.Email) > 0 { + s.AddClientStat(data.Id, &client) } } db := database.GetDB() @@ -309,37 +335,56 @@ func (s *InboundService) DelInboundClient(inbound *model.Inbound, email string) return db.Save(oldInbound).Error } -func (s *InboundService) UpdateInboundClient(inbound *model.Inbound, index int) error { - existEmail, err := s.checkEmailExistForInbound(inbound) +func (s *InboundService) UpdateInboundClient(data *model.Inbound, index int) error { + clients, err := s.getClients(data) if err != nil { return err } - if existEmail != "" { - return common.NewError("Duplicate email:", existEmail) - } - clients, err := s.getClients(inbound) + oldInbound, err := s.GetInbound(data.Id) if err != nil { return err } - oldInbound, err := s.GetInbound(inbound.Id) + oldClients, err := s.getClients(oldInbound) if err != nil { return err } - oldClients, err := s.getClients(oldInbound) + if len(clients[0].Email) > 0 && clients[0].Email != oldClients[index].Email { + existEmail, err := s.checkEmailsExistForClients(clients) + if err != nil { + return err + } + if existEmail != "" { + return common.NewError("Duplicate email:", existEmail) + } + } + + var settings map[string]interface{} + err = json.Unmarshal([]byte(oldInbound.Settings), &settings) if err != nil { return err } - oldInbound.Settings = inbound.Settings + settingsClients := settings["clients"].([]interface{}) + var newClients []interface{} + newClients = append(newClients, clients[0]) + settingsClients[index] = newClients[0] + + settings["clients"] = settingsClients + + newSettings, err := json.MarshalIndent(settings, "", " ") + if err != nil { + return err + } + oldInbound.Settings = string(newSettings) db := database.GetDB() - if len(clients[index].Email) > 0 { + if len(clients[0].Email) > 0 { if len(oldClients[index].Email) > 0 { - err = s.UpdateClientStat(oldClients[index].Email, &clients[index]) + err = s.UpdateClientStat(oldClients[index].Email, &clients[0]) if err != nil { return err } @@ -348,7 +393,7 @@ func (s *InboundService) UpdateInboundClient(inbound *model.Inbound, index int) return err } } else { - s.AddClientStat(inbound.Id, &clients[index]) + s.AddClientStat(data.Id, &clients[0]) } } else { err = s.DelClientStat(db, oldClients[index].Email) @@ -507,6 +552,16 @@ func (s *InboundService) DisableInvalidInbounds() (int64, error) { count := result.RowsAffected return count, err } +func (s *InboundService) DisableInvalidClients() (int64, error) { + db := database.GetDB() + now := time.Now().Unix() * 1000 + result := db.Model(xray.ClientTraffic{}). + Where("((total > 0 and up + down >= total) or (expiry_time > 0 and expiry_time <= ?)) and enable = ?", now, true). + Update("enable", false) + err := result.Error + count := result.RowsAffected + return count, err +} func (s *InboundService) RemoveOrphanedTraffics() { db := database.GetDB() db.Exec(` @@ -518,16 +573,6 @@ func (s *InboundService) RemoveOrphanedTraffics() { ) `) } -func (s *InboundService) DisableInvalidClients() (int64, error) { - db := database.GetDB() - now := time.Now().Unix() * 1000 - result := db.Model(xray.ClientTraffic{}). - Where("((total > 0 and up + down >= total) or (expiry_time > 0 and expiry_time <= ?)) and enable = ?", now, true). - Update("enable", false) - err := result.Error - count := result.RowsAffected - return count, err -} func (s *InboundService) AddClientStat(inboundId int, client *model.Client) error { db := database.GetDB() |
