Skip to content

Commit

Permalink
Prevent bad query strings from LLM (#229)
Browse files Browse the repository at this point in the history
* Prevent bad query strings from LLM

* Use http constants
  • Loading branch information
crspeller authored Aug 8, 2024
1 parent f450d96 commit fa19a71
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 19 deletions.
4 changes: 2 additions & 2 deletions server/ai/anthropic/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func (c *Client) MessageCompletionNoStream(completionRequest MessageRequest) (st
return "", fmt.Errorf("could not marshal completion request: %w", err)
}

req, err := http.NewRequest("POST", MessageEndpoint, bytes.NewReader(reqBodyBytes))
req, err := http.NewRequest(http.MethodPost, MessageEndpoint, bytes.NewReader(reqBodyBytes))
if err != nil {
return "", fmt.Errorf("could not create request: %w", err)
}
Expand Down Expand Up @@ -130,7 +130,7 @@ func (c *Client) MessageCompletion(completionRequest MessageRequest) (*ai.TextSt
return nil, err
}

req, err := http.NewRequest("POST", MessageEndpoint, bytes.NewReader(reqBodyBytes))
req, err := http.NewRequest(http.MethodPost, MessageEndpoint, bytes.NewReader(reqBodyBytes))
if err != nil {
return nil, err
}
Expand Down
10 changes: 5 additions & 5 deletions server/ai/asksage/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func (c *Client) Login(params GetTokenParams) error {
AccessToken string `json:"access_token"`
}
}
err := c.doAuth("POST", "/get-token", &params, &response)
err := c.doAuth(http.MethodPost, "/get-token", &params, &response)
if err != nil {
return err
}
Expand All @@ -90,7 +90,7 @@ func (c *Client) Login(params GetTokenParams) error {

func (c *Client) Query(params QueryParams) (*CompletionResponse, error) {
response := &CompletionResponse{}
if err := c.doServer("POST", "/query", &params, response); err != nil {
if err := c.doServer(http.MethodPost, "/query", &params, response); err != nil {
return nil, err
}

Expand All @@ -99,7 +99,7 @@ func (c *Client) Query(params QueryParams) (*CompletionResponse, error) {

func (c *Client) FollowUpQuestions(params FollowUpParams) (*CompletionResponse, error) {
response := &CompletionResponse{}
if err := c.doServer("POST", "/follow-up-questions", &params, response); err != nil {
if err := c.doServer(http.MethodPost, "/follow-up-questions", &params, response); err != nil {
return nil, err
}
return response, nil
Expand All @@ -109,7 +109,7 @@ func (c *Client) GetPersonas() ([]Persona, error) {
var response struct {
Response []Persona `json:"response"`
}
if err := c.doServer("POST", "/get-personas", nil, &response); err != nil {
if err := c.doServer(http.MethodPost, "/get-personas", nil, &response); err != nil {
return nil, err
}
return response.Response, nil
Expand All @@ -119,7 +119,7 @@ func (c *Client) GetDatasets() ([]Dataset, error) {
var response struct {
Response []Dataset `json:"dataset"`
}
if err := c.doServer("POST", "/get-datasets", nil, &response); err != nil {
if err := c.doServer(http.MethodPost, "/get-datasets", nil, &response); err != nil {
return nil, err
}
return response.Response, nil
Expand Down
22 changes: 11 additions & 11 deletions server/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func TestPostRouter(t *testing.T) {
envSetup func(e *TestEnvironment)
}{
"test no permission to channel": {
request: httptest.NewRequest("POST", url, nil),
request: httptest.NewRequest(http.MethodPost, url, nil),
expectedStatus: http.StatusForbidden,
config: Config{
EnableUseRestrictions: false,
Expand All @@ -49,7 +49,7 @@ func TestPostRouter(t *testing.T) {
},
},
"test user not allowed": {
request: httptest.NewRequest("POST", url, nil),
request: httptest.NewRequest(http.MethodPost, url, nil),
expectedStatus: http.StatusForbidden,
config: Config{
EnableUseRestrictions: true,
Expand All @@ -66,7 +66,7 @@ func TestPostRouter(t *testing.T) {
},
},
"not allowed team": {
request: httptest.NewRequest("POST", url, nil),
request: httptest.NewRequest(http.MethodPost, url, nil),
expectedStatus: http.StatusForbidden,
config: Config{
EnableUseRestrictions: true,
Expand All @@ -82,7 +82,7 @@ func TestPostRouter(t *testing.T) {
},
},
"not on private channels": {
request: httptest.NewRequest("POST", url, nil),
request: httptest.NewRequest(http.MethodPost, url, nil),
expectedStatus: http.StatusForbidden,
config: Config{
EnableUseRestrictions: true,
Expand All @@ -98,7 +98,7 @@ func TestPostRouter(t *testing.T) {
},
},
"not on dms": {
request: httptest.NewRequest("POST", url, nil),
request: httptest.NewRequest(http.MethodPost, url, nil),
expectedStatus: http.StatusForbidden,
config: Config{
EnableUseRestrictions: true,
Expand Down Expand Up @@ -152,7 +152,7 @@ func TestAdminRouter(t *testing.T) {
envSetup func(e *TestEnvironment)
}{
"only admins": {
request: httptest.NewRequest("GET", url, nil),
request: httptest.NewRequest(http.MethodGet, url, nil),
expectedStatus: http.StatusForbidden,
config: Config{
EnableUseRestrictions: false,
Expand Down Expand Up @@ -198,7 +198,7 @@ func TestChannelRouter(t *testing.T) {
envSetup func(e *TestEnvironment)
}{
"test no permission to channel": {
request: httptest.NewRequest("POST", url, nil),
request: httptest.NewRequest(http.MethodPost, url, nil),
expectedStatus: http.StatusForbidden,
config: Config{
EnableUseRestrictions: false,
Expand All @@ -213,7 +213,7 @@ func TestChannelRouter(t *testing.T) {
},
},
"test user not allowed": {
request: httptest.NewRequest("POST", url, nil),
request: httptest.NewRequest(http.MethodPost, url, nil),
expectedStatus: http.StatusForbidden,
config: Config{
EnableUseRestrictions: true,
Expand All @@ -230,7 +230,7 @@ func TestChannelRouter(t *testing.T) {
},
},
"not allowed team": {
request: httptest.NewRequest("POST", url, nil),
request: httptest.NewRequest(http.MethodPost, url, nil),
expectedStatus: http.StatusForbidden,
config: Config{
EnableUseRestrictions: true,
Expand All @@ -246,7 +246,7 @@ func TestChannelRouter(t *testing.T) {
},
},
"not on private channels": {
request: httptest.NewRequest("POST", url, nil),
request: httptest.NewRequest(http.MethodPost, url, nil),
expectedStatus: http.StatusForbidden,
config: Config{
EnableUseRestrictions: true,
Expand All @@ -262,7 +262,7 @@ func TestChannelRouter(t *testing.T) {
},
},
"not on dms": {
request: httptest.NewRequest("POST", url, nil),
request: httptest.NewRequest(http.MethodPost, url, nil),
expectedStatus: http.StatusForbidden,
config: Config{
EnableUseRestrictions: true,
Expand Down
10 changes: 9 additions & 1 deletion server/built_in_tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"regexp"
"strings"
"time"
Expand Down Expand Up @@ -165,7 +166,14 @@ func (p *Plugin) toolGetGithubIssue(context ai.ConversationContext, argsGetter a
return "invalid parameters to function", errors.New("invalid issue number")
}

req, err := http.NewRequest("GET", fmt.Sprintf("/github/api/v1/issue?owner=%s&repo=%s&number=%d", args.RepoOwner, args.RepoName, args.Number), nil)
req, err := http.NewRequest(http.MethodGet,
fmt.Sprintf("/github/api/v1/issue?owner=%s&repo=%s&number=%d",
url.QueryEscape(args.RepoOwner),
url.QueryEscape(args.RepoName),
args.Number,
),
nil,
)
if err != nil {
return "internal failure", fmt.Errorf("failed to create request: %w", err)
}
Expand Down

0 comments on commit fa19a71

Please sign in to comment.