Skip to content

Commit

Permalink
修复连接mysql失败时的空指针
Browse files Browse the repository at this point in the history
  • Loading branch information
Jrohy committed Mar 23, 2020
1 parent a4ac41e commit 201cb67
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 16 deletions.
29 changes: 24 additions & 5 deletions core/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package core
import (
"crypto/sha256"
"database/sql"
"errors"
"fmt"
// mysql sql驱动
_ "github.com/go-sql-driver/mysql"
Expand Down Expand Up @@ -64,6 +65,9 @@ CREATE TABLE IF NOT EXISTS users (
// CreateUser 创建Trojan用户
func (mysql *Mysql) CreateUser(username string, password string) error {
db := mysql.GetDB()
if db == nil {
return errors.New("can't connect mysql")
}
defer db.Close()
encryPass := sha256.Sum224([]byte(password))
if _, err := db.Exec(fmt.Sprintf("INSERT INTO users(username, password, quota) VALUES ('%s', '%x', -1);", username, encryPass)); err != nil {
Expand All @@ -80,8 +84,14 @@ func (mysql *Mysql) CreateUser(username string, password string) error {
// DeleteUser 删除用户
func (mysql *Mysql) DeleteUser(id uint) error {
db := mysql.GetDB()
if db == nil {
return errors.New("can't connect mysql")
}
defer db.Close()
userList := *mysql.GetData(strconv.Itoa(int(id)))
userList := mysql.GetData(strconv.Itoa(int(id)))
if userList == nil {
return errors.New("can't connnect mysql")
}
if userList[0].Username != "admin" {
_ = DelValue(userList[0].Username + "_pass")
}
Expand All @@ -95,6 +105,9 @@ func (mysql *Mysql) DeleteUser(id uint) error {
// SetQuota 限制流量
func (mysql *Mysql) SetQuota(id uint, quota int) error {
db := mysql.GetDB()
if db == nil {
return errors.New("can't connect mysql")
}
defer db.Close()
if _, err := db.Exec(fmt.Sprintf("UPDATE users SET quota=%d WHERE id=%d;", quota, id)); err != nil {
fmt.Println(err)
Expand All @@ -106,6 +119,9 @@ func (mysql *Mysql) SetQuota(id uint, quota int) error {
// CleanData 清空流量统计
func (mysql *Mysql) CleanData(id uint) error {
db := mysql.GetDB()
if db == nil {
return errors.New("can't connect mysql")
}
defer db.Close()
if _, err := db.Exec(fmt.Sprintf("UPDATE users SET download=0, upload=0 WHERE id=%d;", id)); err != nil {
fmt.Println(err)
Expand All @@ -115,10 +131,13 @@ func (mysql *Mysql) CleanData(id uint) error {
}

// GetData 获取用户记录
func (mysql *Mysql) GetData(ids ...string) *[]User {
var dataList []User
func (mysql *Mysql) GetData(ids ...string) []*User {
var dataList []*User
querySQL := "SELECT * FROM users"
db := mysql.GetDB()
if db == nil {
return nil
}
defer db.Close()
if len(ids) > 0 {
querySQL = querySQL + " WHERE id in (" + strings.Join(ids, ",") + ")"
Expand Down Expand Up @@ -146,7 +165,7 @@ func (mysql *Mysql) GetData(ids ...string) *[]User {
if err != nil {
password = ""
}
dataList = append(dataList, User{ID: id, Username: username, Password: password, Download: download, Upload: upload, Quota: quota})
dataList = append(dataList, &User{ID: id, Username: username, Password: password, Download: download, Upload: upload, Quota: quota})
}
return &dataList
return dataList
}
10 changes: 7 additions & 3 deletions trojan/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,20 @@ func GenClientJson() {
domain = ""
}
mysql := core.GetMysql()
userList := *mysql.GetData()
userList := mysql.GetData()
if userList == nil {
fmt.Println("连接mysql失败!")
return
}
if len(userList) == 1 {
user = userList[0]
user = *userList[0]
} else {
UserList()
choice := util.LoopInput("请选择要生成配置文件的用户序号: ", userList, true)
if choice < 0 {
return
}
user = userList[choice-1]
user = *userList[choice-1]
}
password, err := core.GetValue(user.Username + "_pass")
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion trojan/install.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ func InstallMysql() {
mysql.Database = "trojan"
mysql.CreateTable()
core.WriterMysql(&mysql)
if len(*mysql.GetData()) == 0 {
if len(mysql.GetData()) == 0 {
AddUser()
}
fmt.Println()
Expand Down
16 changes: 10 additions & 6 deletions trojan/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func AddUser() {

// DelUser 删除用户
func DelUser() {
userList := *UserList()
userList := UserList()
mysql := core.GetMysql()
choice := util.LoopInput("请选择要删除的用户序号: ", userList, true)
if mysql.DeleteUser(userList[choice-1].ID) == nil {
Expand All @@ -55,7 +55,7 @@ func SetUserQuota() {
limit int
err error
)
userList := *UserList()
userList := UserList()
mysql := core.GetMysql()
choice := util.LoopInput("请选择要限制流量的用户序号: ", userList, true)
if choice == -1 {
Expand All @@ -77,7 +77,7 @@ func SetUserQuota() {

// CleanData 清空用户流量
func CleanData() {
userList := *UserList()
userList := UserList()
mysql := core.GetMysql()
choice := util.LoopInput("请选择要清空流量的用户序号: ", userList, true)
if mysql.CleanData(userList[choice-1].ID) == nil {
Expand All @@ -86,9 +86,13 @@ func CleanData() {
}

// UserList 获取用户列表并打印显示
func UserList(ids ...string) *[]core.User {
func UserList(ids ...string) []*core.User {
mysql := core.GetMysql()
userList := *mysql.GetData(ids...)
userList := mysql.GetData(ids...)
if userList == nil {
fmt.Println("连接mysql失败!")
return nil
}
domain, err := core.GetValue("domain")
if err != nil {
domain = ""
Expand All @@ -107,5 +111,5 @@ func UserList(ids ...string) *[]core.User {
fmt.Println("分享链接: " + util.Green(fmt.Sprintf("trojan://%s@%s:443", k.Password, domain)))
fmt.Println()
}
return &userList
return userList
}
6 changes: 5 additions & 1 deletion web/controller/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@ func UserList() *ResponseBody {
responseBody := ResponseBody{Msg: "success"}
defer TimeCost(time.Now(), &responseBody)
mysql := core.GetMysql()
userList := *mysql.GetData()
userList := mysql.GetData()
if userList == nil {
responseBody.Msg = "连接mysql失败!"
return &responseBody
}
domain, err := core.GetValue("domain")
if err != nil {
domain = ""
Expand Down

0 comments on commit 201cb67

Please sign in to comment.