Skip to content

Commit

Permalink
Add ability to restart bot.
Browse files Browse the repository at this point in the history
  • Loading branch information
airforce270 committed Nov 21, 2023
1 parent 72e83b5 commit 70bee85
Show file tree
Hide file tree
Showing 12 changed files with 388 additions and 86 deletions.
13 changes: 10 additions & 3 deletions apiclients/supinic/supinic.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package supinic

import (
"context"
"encoding/json"
"fmt"
"io"
Expand Down Expand Up @@ -31,10 +32,16 @@ const pingInterval = 15 * time.Minute
// StartPinging starts a background task to ping the Supinic API regularly
// to make sure the API knows the bot is still online.
// This function blocks and should be run within a goroutine.
func (c *Client) StartPinging() {
func (c *Client) StartPinging(ctx context.Context) {
pingTimer := time.NewTicker(pingInterval)
for {
go c.pingAPI()
time.Sleep(pingInterval)
select {
case <-ctx.Done():
log.Print("Stopping pinging Supinic API, context cancelled")
return
case <-pingTimer.C:
go c.pingAPI()
}
}
}

Expand Down
12 changes: 9 additions & 3 deletions cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ type Cache interface {
const (
// Cache key for the last sent Twitch message.
KeyLastSentTwitchMessage = "twitch_last_sent_message"
// Cache key for the platform that bot restart was requested from.
KeyRestartRequestedOnPlatform = "restart_requested_on_platform"
// Cache key for the channel that bot restart was requested from.
KeyRestartRequestedInChannel = "restart_requested_from_channel"
// Cache key for the ID of the message that requested the bot restart.
KeyRestartRequestedByMessageID = "restart_requested_by_message"
)

// GlobalSlowmodeKey returns the global slowmode cache key for a platform.
Expand All @@ -78,10 +84,10 @@ func (c *Redis) StoreExpiringBool(key string, value bool, expiration time.Durati
}
func (c *Redis) FetchBool(key string) (bool, error) {
resp, err := c.r.Get(context.Background(), key).Bool()
if errors.Is(err, redis.Nil) {
return false, nil
}
if err != nil {
if errors.Is(err, redis.Nil) {
return false, nil
}
return false, err
}
return resp, nil
Expand Down
23 changes: 23 additions & 0 deletions commands/admin/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/airforce270/airbot/permission"
twitchplatform "github.com/airforce270/airbot/platforms/twitch"
"github.com/airforce270/airbot/utils"
"github.com/airforce270/airbot/utils/restart"
)

// Commands contains this package's commands.
Expand All @@ -30,6 +31,7 @@ var Commands = [...]basecommand.Command{
leaveCommand,
leaveOtherCommand,
reloadConfigCommand,
restartCommand,
setPrefixCommand,
}

Expand Down Expand Up @@ -138,6 +140,13 @@ var (
Handler: reloadConfig,
}

restartCommand = basecommand.Command{
Name: "restart",
Desc: "Restarts the bot. Does not restart the database, etc.",
Permission: permission.Admin,
Handler: restartBot,
}

setPrefixCommand = basecommand.Command{
Name: "setprefix",
Desc: "Sets the bot's prefix in the channel.",
Expand Down Expand Up @@ -346,6 +355,20 @@ func reloadConfig(msg *base.IncomingMessage, args []arg.Arg) ([]*base.Message, e
}, nil
}

func restartBot(msg *base.IncomingMessage, args []arg.Arg) ([]*base.Message, error) {
go restart.WriteRequester(msg.Platform.Name(), msg.Message.Channel, msg.Message.ID)

const delay = 100 * time.Millisecond
time.AfterFunc(delay, func() { restart.C <- true })

return []*base.Message{
{
Channel: msg.Message.Channel,
Text: "Restarting Airbot.",
},
}, nil
}

func setPrefix(msg *base.IncomingMessage, args []arg.Arg) ([]*base.Message, error) {
prefixArg := args[0]
if !prefixArg.Present {
Expand Down
39 changes: 37 additions & 2 deletions commands/commands_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,41 @@ func TestCommands(t *testing.T) {
},
want: nil,
},
{
input: base.IncomingMessage{
Message: base.Message{
Text: "$restart",
UserID: "user1",
User: "user1",
Channel: "user2",
Time: time.Date(2020, 5, 15, 10, 7, 0, 0, time.UTC),
},
Prefix: "$",
PermissionLevel: permission.Owner,
Platform: twitch.NewForTesting("forsen", databasetest.NewFakeDBConn()),
},
want: []*base.Message{
{
Text: "Restarting Airbot.",
Channel: "user2",
},
},
},
{
input: base.IncomingMessage{
Message: base.Message{
Text: "$restart",
UserID: "user1",
User: "user1",
Channel: "user2",
Time: time.Date(2020, 5, 15, 10, 7, 0, 0, time.UTC),
},
Prefix: "$",
PermissionLevel: permission.Normal,
Platform: twitch.NewForTesting("forsen", databasetest.NewFakeDBConn()),
},
want: nil,
},
{
input: base.IncomingMessage{
Message: base.Message{
Expand Down Expand Up @@ -2983,7 +3018,7 @@ func setFakes(url string, db *gorm.DB) {
kick.BaseURL = url
pastebin.FetchPasteURLOverride = url
seventv.BaseURL = url
twitch.Conn = twitch.NewForTesting(url, db)
twitch.SetInstance(twitch.NewForTesting(url, db))
}

func resetFakes() {
Expand All @@ -2995,7 +3030,7 @@ func resetFakes() {
kick.BaseURL = savedKickURL
pastebin.FetchPasteURLOverride = ""
seventv.BaseURL = saved7TVURL
twitch.Conn = twitch.NewForTesting(helix.DefaultAPIBaseURL, nil)
twitch.SetInstance(twitch.NewForTesting(helix.DefaultAPIBaseURL, nil))
}

func joinOtherUser1() error {
Expand Down
6 changes: 4 additions & 2 deletions database/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
package database

import (
"context"
"fmt"
"strings"
"sync"
Expand Down Expand Up @@ -36,10 +37,10 @@ var (
)

// Connect creates a connection to the database.
func Connect(dbname, user, password string) (*gorm.DB, error) {
func Connect(ctx context.Context, dbName, user, password string) (*gorm.DB, error) {
settings := map[string]string{
"host": "database",
"dbname": dbname,
"dbname": dbName,
"user": user,
"password": password,
"port": "5432",
Expand All @@ -51,6 +52,7 @@ func Connect(dbname, user, password string) (*gorm.DB, error) {
if err != nil {
return nil, fmt.Errorf("failed to open DB connection: %w", err)
}
gormDB.WithContext(ctx)

db, err := gormDB.DB()
if err != nil {
Expand Down
6 changes: 6 additions & 0 deletions docs/commands.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ If it's wrapped in `[square brackets]`, it's an **optional** parameter.
- > Usage: `$reloadconfig`
- > Minimum permission level: `Admin`
### $restart

- Restarts the bot. Does not restart the database, etc.
- > Usage: `$restart`
- > Minimum permission level: `Admin`
### $setprefix

- Sets the bot's prefix in the channel.
Expand Down
13 changes: 10 additions & 3 deletions gamba/gamba.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
package gamba

import (
"context"
"errors"
"fmt"
"log"
Expand All @@ -23,10 +24,16 @@ var (

// StartGrantingPoints starts a loop to grant points to all chatters on an interval.
// This function blocks and should be run within a goroutine.
func StartGrantingPoints(ps map[string]base.Platform, db *gorm.DB) {
func StartGrantingPoints(ctx context.Context, ps map[string]base.Platform, db *gorm.DB) {
timer := time.NewTicker(grantInterval)
for {
go grantPoints(ps, db)
time.Sleep(grantInterval)
select {
case <-ctx.Done():
log.Print("Stopping point granting, context cancelled")
return
case <-timer.C:
go grantPoints(ps, db)
}
}
}

Expand Down
Loading

0 comments on commit 70bee85

Please sign in to comment.