From 7c724d0b16a739dd617e698606d641e3696d7947 Mon Sep 17 00:00:00 2001 From: Josh Smith Date: Tue, 9 Jul 2024 01:29:23 -0400 Subject: [PATCH] Add db transaction to distributed clans write and fix lint errors (#51) * Fix lint errors * revert * unnest & use transaction * add transaction to missing case --- app/v1/clan.go | 115 ++++++++++++++++++++++++++--------------------- app/v1/friend.go | 5 +-- app/v1/rap.go | 24 ---------- app/v1/token.go | 2 +- 4 files changed, 66 insertions(+), 80 deletions(-) delete mode 100644 app/v1/rap.go diff --git a/app/v1/clan.go b/app/v1/clan.go index 59d4416..b57ebc2 100644 --- a/app/v1/clan.go +++ b/app/v1/clan.go @@ -115,7 +115,7 @@ func ClanLeaderboardGET(md common.MethodData) common.CodeMessager { for i := 1; rows.Next(); i++ { clan := clanLbData{} var pp float64 - rows.Scan(&pp, &clan.ChosenMode.RankedScore, &clan.ChosenMode.TotalScore, &clan.ChosenMode.PlayCount, &clan.ChosenMode.Accuracy, &clan.Name, &clan.ID) + err = rows.Scan(&pp, &clan.ChosenMode.RankedScore, &clan.ChosenMode.TotalScore, &clan.ChosenMode.PlayCount, &clan.ChosenMode.Accuracy, &clan.Name, &clan.ID) if err != nil { md.Err(err) return Err500 @@ -134,8 +134,6 @@ func ClanLeaderboardGET(md common.MethodData) common.CodeMessager { return r } -var dbmode = [...]string{"std", "taiko", "ctb", "mania"} - func ClanStatsGET(md common.MethodData) common.CodeMessager { if md.Query("id") == "" { return ErrMissingField("id") @@ -232,13 +230,6 @@ func ResolveInviteGET(md common.MethodData) common.CodeMessager { return r } -func resolveInvite(c string, md *common.MethodData) (id int, err error) { - row := md.DB.QueryRow("SELECT id FROM clans where invite = ?", c) - err = row.Scan(&id) - - return -} - func ClanJoinPOST(md common.MethodData) common.CodeMessager { if md.ID() == 0 { return common.SimpleResponse(401, "not authorised") @@ -272,7 +263,8 @@ func ClanJoinPOST(md common.MethodData) common.CodeMessager { var hasInvite bool if u.Invite != "" { - u.ID, err = resolveInvite(u.Invite, &md) + row := md.DB.QueryRow("SELECT id FROM clans where invite = ?", u.Invite) + err = row.Scan(&u.ID) if err != nil { if err == sql.ErrNoRows { @@ -285,45 +277,65 @@ func ClanJoinPOST(md common.MethodData) common.CodeMessager { hasInvite = true } - if u.ID > 0 { - c, err := getClan(u.ID, md) - if err != nil { - if err == sql.ErrNoRows { - return common.SimpleResponse(404, "clan not found") - } - md.Err(err) - return Err500 - } + if u.ID <= 0 { + return common.SimpleResponse(400, "invalid id parameter") + } - if c.Status == 0 || (c.Status == 2 && !hasInvite) { - return common.SimpleResponse(403, "closed") + c, err := getClan(u.ID, md) + if err != nil { + if err == sql.ErrNoRows { + return common.SimpleResponse(404, "clan not found") } + md.Err(err) + return Err500 + } + + if c.Status == 0 || (c.Status == 2 && !hasInvite) { + return common.SimpleResponse(403, "closed") + } + + var count int + err = md.DB.QueryRow("SELECT COUNT(id) FROM users WHERE clan_id = ?", c.ID).Scan(&count) + if err != nil { + md.Err(err) + return Err500 + } + + if count >= clanMemberLimit { + return common.SimpleResponse(403, "clan is full") + } + + tx, err := md.DB.Begin() + if err != nil { + md.Err(err) + return Err500 + } - var count int - err = md.DB.QueryRow("SELECT COUNT(id) FROM users WHERE clan_id = ?", c.ID).Scan(&count) + if c.Status == 3 { + _, err = tx.Exec("INSERT INTO clan_requests VALUES (?, ?, DEFAULT) ON DUPLICATE KEY UPDATE time = NOW()", c.ID, md.ID()) if err != nil { + tx.Rollback() md.Err(err) return Err500 } - if count >= clanMemberLimit { - return common.SimpleResponse(403, "clan is full") - } + return common.SimpleResponse(200, "join request sent") + } + _, err = tx.Exec("UPDATE users SET clan_id = ? WHERE id = ?", c.ID, md.ID()) + if err != nil { + tx.Rollback() + md.Err(err) + return Err500 + } - if c.Status == 3 { - _, err = md.DB.Exec("INSERT INTO clan_requests VALUES (?, ?, DEFAULT) ON DUPLICATE KEY UPDATE time = NOW()", c.ID, md.ID()) - return common.SimpleResponse(200, "join request sent") - } - _, err = md.DB.Exec("UPDATE users SET clan_id = ? WHERE id = ?", c.ID, md.ID()) - r.Clan = c - r.Code = 200 + tx.Commit() - md.R.Publish("api:update_user_clan", strconv.Itoa(md.ID())) + r.Clan = c + r.Code = 200 - return r - } else { - return common.SimpleResponse(400, "invalid id parameter") - } + md.R.Publish("api:update_user_clan", strconv.Itoa(md.ID())) + + return r } func ClanLeavePOST(md common.MethodData) common.CodeMessager { @@ -346,28 +358,39 @@ func ClanLeavePOST(md common.MethodData) common.CodeMessager { return Err500 } + tx, err := md.DB.Begin() + if err != nil { + md.Err(err) + return Err500 + } + disbanded := false if clan.Owner == md.ID() { - _, err = md.DB.Exec("UPDATE users SET clan_id = 0 WHERE clan_id = ?", clan.ID) + _, err = tx.Exec("UPDATE users SET clan_id = 0 WHERE clan_id = ?", clan.ID) if err != nil { + tx.Rollback() md.Err(err) return Err500 } err := disbandClan(clan.ID, md) if err != nil { + tx.Rollback() md.Err(err) return Err500 } disbanded = true } else { - _, err := md.DB.Exec("UPDATE users SET clan_id = 0 WHERE id = ?", md.ID()) + _, err := tx.Exec("UPDATE users SET clan_id = 0 WHERE id = ?", md.ID()) if err != nil { + tx.Rollback() md.Err(err) return Err500 } } + tx.Commit() + md.R.Publish("api:update_user_clan", strconv.Itoa(md.ID())) message := "success" @@ -597,13 +620,3 @@ func getClan(id int, md common.MethodData) (Clan, error) { return c, err } - -func getUserData(id int, md common.MethodData) (userData, error) { - u := userData{} - if id == 0 { - return u, nil - } - err := md.DB.QueryRow("SELECT id, username, register_datetime, privileges, latest_activity, username_aka, country FROM users WHERE id = ?", id).Scan(&u.ID, &u.Username, &u.RegisteredOn, &u.Privileges, &u.LatestActivity, &u.UsernameAKA, &u.Country) - - return u, err -} diff --git a/app/v1/friend.go b/app/v1/friend.go index e1f7082..f445977 100644 --- a/app/v1/friend.go +++ b/app/v1/friend.go @@ -90,14 +90,11 @@ AND privileges & 1 } func friendPuts(md common.MethodData, row *sql.Rows) (user friendData) { - var err error - - err = row.Scan(&user.ID, &user.Username, &user.RegisteredOn, &user.Privileges, &user.LatestActivity, &user.UsernameAKA, &user.Country) + err := row.Scan(&user.ID, &user.Username, &user.RegisteredOn, &user.Privileges, &user.LatestActivity, &user.UsernameAKA, &user.Country) if err != nil { md.Err(err) return } - return } diff --git a/app/v1/rap.go b/app/v1/rap.go deleted file mode 100644 index ca33fd6..0000000 --- a/app/v1/rap.go +++ /dev/null @@ -1,24 +0,0 @@ -package v1 - -import ( - "time" - - "github.com/osuAkatsuki/akatsuki-api/common" -) - -func rapLog(md common.MethodData, message string) { - ua := string(md.Ctx.UserAgent()) - if len(ua) > 20 { - ua = ua[:20] + "…" - } - through := "API" - if ua != "" { - through += " (" + ua + ")" - } - - _, err := md.DB.Exec("INSERT INTO rap_logs(userid, text, datetime, through) VALUES (?, ?, ?, ?)", - md.User.UserID, message, time.Now().Unix(), through) - if err != nil { - md.Err(err) - } -} diff --git a/app/v1/token.go b/app/v1/token.go index 3b501d2..a07d527 100644 --- a/app/v1/token.go +++ b/app/v1/token.go @@ -85,7 +85,7 @@ func (o *oauthClient) Scan(src interface{}) error { case []byte: s = x default: - return errors.New("Can't scan non-string") + return errors.New("can't scan non-string") } var vals [3]string