Skip to content

Commit

Permalink
Split commands_test.go.
Browse files Browse the repository at this point in the history
  • Loading branch information
airforce270 committed Nov 22, 2023
1 parent a5bf31e commit 19e5e8a
Show file tree
Hide file tree
Showing 32 changed files with 3,609 additions and 3,398 deletions.
2 changes: 1 addition & 1 deletion apiclients/bible/bible_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func TestFetchUser(t *testing.T) {
}

for _, tc := range tests {
server.Resp = tc.useResp
server.Resps = []string{tc.useResp}
t.Run(tc.desc, func(t *testing.T) {
got, err := FetchVerses("Philippians 4:8")
if err != nil {
Expand Down
10 changes: 5 additions & 5 deletions apiclients/ivr/ivr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ func TestFetchUser(t *testing.T) {
}

for _, tc := range tests {
server.Resp = tc.useResp
server.Resps = []string{tc.useResp}
t.Run(tc.desc, func(t *testing.T) {
got, err := FetchUsers("fake-username")
if err != nil {
Expand Down Expand Up @@ -355,7 +355,7 @@ func TestFetchModsAndVIPs(t *testing.T) {
}

for _, tc := range tests {
server.Resp = tc.useResp
server.Resps = []string{tc.useResp}
t.Run(tc.desc, func(t *testing.T) {
got, err := FetchModsAndVIPs("fakeusername")
if err != nil {
Expand Down Expand Up @@ -470,7 +470,7 @@ func TestFetchFounders(t *testing.T) {
}

for _, tc := range tests {
server.Resp = tc.useResp
server.Resps = []string{tc.useResp}
t.Run(tc.desc, func(t *testing.T) {
got, err := FetchFounders("fakeusername")
if err != nil {
Expand Down Expand Up @@ -676,7 +676,7 @@ func TestFetchSubAge(t *testing.T) {
}

for _, tc := range tests {
server.Resp = tc.useResp
server.Resps = []string{tc.useResp}
t.Run(tc.desc, func(t *testing.T) {
got, err := FetchSubAge("fakeuser", "fakechannel")
if err != nil && tc.wantErr == nil {
Expand Down Expand Up @@ -723,7 +723,7 @@ func TestIsVerifiedBot(t *testing.T) {
}

for _, tc := range tests {
server.Resp = tc.useResp
server.Resps = []string{tc.useResp}
t.Run(tc.desc, func(t *testing.T) {
users, err := FetchUsers("fake-username")
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion apiclients/kick/kick_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ func TestFetchChannel(t *testing.T) {
}

for _, tc := range tests {
server.Resp = tc.useResp
server.Resps = []string{tc.useResp}
t.Run(tc.desc, func(t *testing.T) {
got, err := kick.FetchChannel("user1")
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion apiclients/pastebin/pastebin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func TestFetchPaste(t *testing.T) {
t.Parallel()
server := fakeserver.New()
defer server.Close()
server.Resp = tc.useResp
server.Resps = []string{tc.useResp}
got, err := FetchPaste(server.URL())
if err != nil {
t.Fatalf("FetchPaste() unexpected error: %v", err)
Expand Down
2 changes: 1 addition & 1 deletion apiclients/seventv/seventv_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ func TestFetchUserConnectionByTwitchUserId(t *testing.T) {
}

for _, tc := range tests {
server.Resp = tc.useResp
server.Resps = []string{tc.useResp}
t.Run(tc.desc, func(t *testing.T) {
got, err := seventv.FetchUserConnectionByTwitchUserId("user1")
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion apiclients/supinic/supinic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func TestUpdateBotActivity(t *testing.T) {
t.Parallel()
server := fakeserver.New()
defer server.Close()
server.Resp = tc.useResp
server.Resps = []string{tc.useResp}
client := NewClientForTesting(server.URL())

err := client.updateBotActivity()
Expand Down
32 changes: 25 additions & 7 deletions base/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
package base

import (
"crypto/rand"
"errors"
"io"
"strings"
"time"

"github.com/airforce270/airbot/cache"
"github.com/airforce270/airbot/database/models"
"github.com/airforce270/airbot/permission"
"gorm.io/gorm"

exprand "golang.org/x/exp/rand"
)
Expand Down Expand Up @@ -79,8 +81,8 @@ type IncomingMessage struct {
Prefix string
// PermissionLevel is the permission level of the user that sent the message.
PermissionLevel permission.Level
// Platform is the platform the message was sent on.
Platform Platform
// Resources contains resources available to an incoming message.
Resources Resources
}

// MessageTextWithoutPrefix returns the message's text without the prefix.
Expand All @@ -100,7 +102,23 @@ type OutgoingMessage struct {
ReplyToID string
}

var (
RandReader = rand.Reader
RandSource exprand.Source = nil
)
// Resources contains references to app-level resources.
type Resources struct {
// Platform is the current platform.
Platform Platform
// DB is a reference to the database.
DB *gorm.DB
// Cache is a reference to the cache.
Cache cache.Cache
// Rand is a reference to random sources.
Rand RandResources
}

// RandResources contains references to random number resources.
type RandResources struct {
// Reader will be used as the reader for random values.
Reader io.Reader
// Source is a source of random numbers.
// Optional - a default will be used if not provided.
Source exprand.Source
}
29 changes: 2 additions & 27 deletions cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,36 +4,11 @@ package cache
import (
"context"
"errors"
"sync"
"time"

"github.com/airforce270/airbot/base"

"github.com/redis/go-redis/v9"
)

func Instance() Cache {
connMtx.RLock()
defer connMtx.RUnlock()
if conn == nil {
panic("cache.Conn is nil!")
}
return conn
}

func SetInstance(c Cache) {
connMtx.Lock()
conn = c
connMtx.Unlock()
}

var (
// Conn is an instance of the cache.
conn Cache = nil

connMtx sync.RWMutex // protects Conn
)

// A Cache stores and retrieves simple key-value data quickly.
type Cache interface {
// StoreBool stores a bool value with no expiration.
Expand Down Expand Up @@ -67,8 +42,8 @@ const (
)

// GlobalSlowmodeKey returns the global slowmode cache key for a platform.
func GlobalSlowmodeKey(p base.Platform) string {
return "global_slowmode_" + p.Name()
func GlobalSlowmodeKey(platformName string) string {
return "global_slowmode_" + platformName
}

// Redis implements Cache for a real Redis database.
Expand Down
48 changes: 19 additions & 29 deletions commands/admin/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,11 @@ var (
)

func botSlowmode(msg *base.IncomingMessage, args []arg.Arg) ([]*base.Message, error) {
cdb := cache.Instance()

enableArg := args[0]
key := cache.GlobalSlowmodeKey(msg.Platform)
key := cache.GlobalSlowmodeKey(msg.Resources.Platform.Name())

if !enableArg.Present {
enabled, err := cdb.FetchBool(key)
enabled, err := msg.Resources.Cache.FetchBool(key)
if err != nil {
return nil, fmt.Errorf("failed to fetch cache key %s (bool): %w", key, err)
}
Expand All @@ -174,22 +172,22 @@ func botSlowmode(msg *base.IncomingMessage, args []arg.Arg) ([]*base.Message, er
return []*base.Message{
{
Channel: msg.Message.Channel,
Text: fmt.Sprintf("Bot slowmode is currently %s on %s", enabledMsg, msg.Platform.Name()),
Text: fmt.Sprintf("Bot slowmode is currently %s on %s", enabledMsg, msg.Resources.Platform.Name()),
},
}, nil
}

enable := enableArg.BoolValue

if err := cdb.StoreBool(key, enable); err != nil {
if err := msg.Resources.Cache.StoreBool(key, enable); err != nil {
failureMsgStart := "Failed to enable"
if !enable {
failureMsgStart = "Failed to disable"
}
return []*base.Message{
{
Channel: msg.Message.Channel,
Text: fmt.Sprintf("%s bot slowmode on %s", failureMsgStart, msg.Platform.Name()),
Text: fmt.Sprintf("%s bot slowmode on %s", failureMsgStart, msg.Resources.Platform.Name()),
},
}, nil
}
Expand All @@ -201,19 +199,17 @@ func botSlowmode(msg *base.IncomingMessage, args []arg.Arg) ([]*base.Message, er
return []*base.Message{
{
Channel: msg.Message.Channel,
Text: fmt.Sprintf("%s bot slowmode on %s", outMsgStart, msg.Platform.Name()),
Text: fmt.Sprintf("%s bot slowmode on %s", outMsgStart, msg.Resources.Platform.Name()),
},
}, nil
}

const defaultPrefix = "$"

func joinChannel(msg *base.IncomingMessage, targetChannel, prefix string) ([]*base.Message, error) {
db := database.Instance()

var channels []models.JoinedChannel
db.Where(models.JoinedChannel{
Platform: msg.Platform.Name(),
msg.Resources.DB.Where(models.JoinedChannel{
Platform: msg.Resources.Platform.Name(),
Channel: strings.ToLower(targetChannel),
}).Find(&channels)

Expand All @@ -227,16 +223,16 @@ func joinChannel(msg *base.IncomingMessage, targetChannel, prefix string) ([]*ba
}

channelRecord := models.JoinedChannel{
Platform: msg.Platform.Name(),
Platform: msg.Resources.Platform.Name(),
Channel: targetChannel,
Prefix: prefix,
JoinedAt: time.Now(),
}
if err := db.Create(&channelRecord).Error; err != nil {
if err := msg.Resources.DB.Create(&channelRecord).Error; err != nil {
return nil, fmt.Errorf("failed to join channel %s: %w", targetChannel, err)
}

err := msg.Platform.Join(targetChannel, prefix)
err := msg.Resources.Platform.Join(targetChannel, prefix)

if errors.Is(err, twitchplatform.ErrChannelNotFound) {
return []*base.Message{
Expand Down Expand Up @@ -273,10 +269,8 @@ func joinChannel(msg *base.IncomingMessage, targetChannel, prefix string) ([]*ba
const maxUsersPerMessage = 15

func joined(msg *base.IncomingMessage, args []arg.Arg) ([]*base.Message, error) {
db := database.Instance()

var joinedChannels []*models.JoinedChannel
if err := db.Find(&joinedChannels).Error; err != nil {
if err := msg.Resources.DB.Find(&joinedChannels).Error; err != nil {
return nil, fmt.Errorf("failed to find channels: %w", err)
}
var channels []string
Expand Down Expand Up @@ -305,9 +299,7 @@ func joined(msg *base.IncomingMessage, args []arg.Arg) ([]*base.Message, error)
}

func leaveChannel(msg *base.IncomingMessage, targetChannel string) ([]*base.Message, error) {
db := database.Instance()

err := database.LeaveChannel(db, msg.Platform.Name(), targetChannel)
err := database.LeaveChannel(msg.Resources.DB, msg.Resources.Platform.Name(), targetChannel)

if err != nil {
return []*base.Message{
Expand All @@ -320,7 +312,7 @@ func leaveChannel(msg *base.IncomingMessage, targetChannel string) ([]*base.Mess

go func() {
time.Sleep(time.Millisecond * 500)
if err := msg.Platform.Leave(targetChannel); err != nil {
if err := msg.Resources.Platform.Leave(targetChannel); err != nil {
log.Printf("failed to leave channel %s: %v", targetChannel, err)
}
}()
Expand Down Expand Up @@ -356,7 +348,7 @@ func reloadConfig(msg *base.IncomingMessage, args []arg.Arg) ([]*base.Message, e
}

func restartBot(msg *base.IncomingMessage, args []arg.Arg) ([]*base.Message, error) {
go restart.WriteRequester(msg.Platform.Name(), msg.Message.Channel, msg.Message.ID)
go restart.WriteRequester(msg.Resources.Cache, msg.Resources.Platform.Name(), msg.Message.Channel, msg.Message.ID)

const delay = 100 * time.Millisecond
time.AfterFunc(delay, func() { restart.C <- true })
Expand All @@ -376,18 +368,16 @@ func setPrefix(msg *base.IncomingMessage, args []arg.Arg) ([]*base.Message, erro
}
newPrefix := prefixArg.StringValue

db := database.Instance()

var channels []models.JoinedChannel
err := db.Where("platform = ? AND LOWER(channel) = ?", msg.Platform.Name(), strings.ToLower(msg.Message.Channel)).Find(&channels).Error
err := msg.Resources.DB.Where("platform = ? AND LOWER(channel) = ?", msg.Resources.Platform.Name(), strings.ToLower(msg.Message.Channel)).Find(&channels).Error
if err != nil {
return nil, fmt.Errorf("failed to fetch channels matching %s/%s: %w", msg.Platform.Name(), strings.ToLower(msg.Message.Channel), err)
return nil, fmt.Errorf("failed to fetch channels matching %s/%s: %w", msg.Resources.Platform.Name(), strings.ToLower(msg.Message.Channel), err)
}

for _, channel := range channels {
channel.Prefix = newPrefix

result := db.Save(&channel)
result := msg.Resources.DB.Save(&channel)
if err := result.Error; err != nil {
return nil, fmt.Errorf("failed to save new prefix %s for channel %s: %w", newPrefix, channel.Channel, err)
}
Expand All @@ -403,7 +393,7 @@ func setPrefix(msg *base.IncomingMessage, args []arg.Arg) ([]*base.Message, erro
}
}

if err := msg.Platform.SetPrefix(msg.Message.Channel, newPrefix); err != nil {
if err := msg.Resources.Platform.SetPrefix(msg.Message.Channel, newPrefix); err != nil {
log.Printf("Failed to update prefix: %v", err)
return []*base.Message{
{
Expand Down
Loading

0 comments on commit 19e5e8a

Please sign in to comment.