Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Edible Scripts Backend #25739

Merged
merged 9 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/24602-editable-scripts
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
- Added API endpoint for updating script contents
25 changes: 25 additions & 0 deletions server/datastore/mysql/scripts.go
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,31 @@ func (ds *Datastore) NewScript(ctx context.Context, script *fleet.Script) (*flee
return ds.getScriptDB(ctx, ds.writer(ctx), uint(id)) //nolint:gosec // dismiss G115
}

func (ds *Datastore) UpdateScriptContents(ctx context.Context, scriptID uint, scriptContents string) (*fleet.Script, error) {
const stmt = `
UPDATE script_contents
INNER JOIN
scripts ON scripts.script_content_id = script_contents.id
SET
contents = ?,
md5_checksum = UNHEX(?)
WHERE
scripts.id = ?
`
md5Checksum := md5ChecksumScriptContent(scriptContents)

_, err := ds.writer(ctx).ExecContext(ctx, stmt, scriptContents, md5Checksum, scriptID)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "updating script contents")
}

if _, err := ds.writer(ctx).ExecContext(ctx, "UPDATE scripts SET updated_at = NOW() WHERE id = ?", scriptID); err != nil {
return nil, ctxerr.Wrap(ctx, err, "updating script updated_at time")
}

return ds.Script(ctx, scriptID)
}

func insertScript(ctx context.Context, tx sqlx.ExtContext, script *fleet.Script, scriptContentsID uint) (sql.Result, error) {
const insertStmt = `
INSERT INTO
Expand Down
40 changes: 40 additions & 0 deletions server/datastore/mysql/scripts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ func TestScripts(t *testing.T) {
{"TestGetAnyScriptContents", testGetAnyScriptContents},
{"TestDeleteScriptsAssignedToPolicy", testDeleteScriptsAssignedToPolicy},
{"TestDeletePendingHostScriptExecutionsForPolicy", testDeletePendingHostScriptExecutionsForPolicy},
{"UpdateScriptContents", testUpdateScriptContents},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
Expand Down Expand Up @@ -1586,3 +1587,42 @@ func testDeletePendingHostScriptExecutionsForPolicy(t *testing.T, ds *Datastore)
)
require.Equal(t, 1, count)
}

func testUpdateScriptContents(t *testing.T, ds *Datastore) {
ctx := context.Background()

originalScript, err := ds.NewScript(ctx, &fleet.Script{
Name: "script1",
ScriptContents: "hello world",
})
require.NoError(t, err)

originalContents, err := ds.GetScriptContents(ctx, originalScript.ScriptContentID)
require.NoError(t, err)
require.Equal(t, "hello world", string(originalContents))

ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error {
_, err := q.ExecContext(ctx, "UPDATE scripts SET updated_at = ? WHERE id = ?", time.Now().Add(-2*time.Minute), originalScript.ID)
if err != nil {
return err
}
return nil
})

// Make sure updated_at was changed correctly, but the script is the same
oldScript, err := ds.Script(ctx, originalScript.ID)
require.Equal(t, originalScript.ScriptContentID, oldScript.ScriptContentID)
require.NoError(t, err)
require.NotEqual(t, originalScript.UpdatedAt, oldScript.UpdatedAt)

// Modify the script
updatedScript, err := ds.UpdateScriptContents(ctx, originalScript.ID, "updated script")
require.NoError(t, err)
require.Equal(t, originalScript.ID, updatedScript.ID)
require.Equal(t, originalScript.ScriptContentID, updatedScript.ScriptContentID)

updatedContents, err := ds.GetScriptContents(ctx, originalScript.ScriptContentID)
require.NoError(t, err)
require.Equal(t, "updated script", string(updatedContents))
require.NotEqual(t, oldScript.UpdatedAt, updatedScript.UpdatedAt)
}
22 changes: 22 additions & 0 deletions server/fleet/activities.go
Original file line number Diff line number Diff line change
Expand Up @@ -1332,6 +1332,28 @@ func (a ActivityTypeAddedScript) Documentation() (activity, details, detailsExam
}`
}

type ActivityTypeUpdatedScript struct {
ScriptName string `json:"script_name"`
TeamID *uint `json:"team_id"`
TeamName *string `json:"team_name"`
}

func (a ActivityTypeUpdatedScript) ActivityName() string {
return "updated_script"
}

func (a ActivityTypeUpdatedScript) Documentation() (activity, details, detailsExample string) {
return `Generated when a script is updated.`,
`This activity contains the following fields:
- "script_name": Name of the script.
- "team_id": The ID of the team that the script applies to, ` + "`null`" + ` if it applies to devices that are not in a team.
- "team_name": The name of the team that the script applies to, ` + "`null`" + ` if it applies to devices that are not in a team.`, `{
"script_name": "set-timezones.sh",
"team_id": 123,
"team_name": "Workstations"
}`
}

type ActivityTypeDeletedScript struct {
ScriptName string `json:"script_name"`
TeamID *uint `json:"team_id"`
Expand Down
3 changes: 3 additions & 0 deletions server/fleet/datastore.go
Original file line number Diff line number Diff line change
Expand Up @@ -1640,6 +1640,9 @@ type Datastore interface {
// NewScript creates a new saved script.
NewScript(ctx context.Context, script *Script) (*Script, error)

// UpdateScriptContents replaces the script contents of a script
UpdateScriptContents(ctx context.Context, scriptID uint, scriptContents string) (*Script, error)

// Script returns the saved script corresponding to id.
Script(ctx context.Context, id uint) (*Script, error)

Expand Down
3 changes: 3 additions & 0 deletions server/fleet/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -1123,6 +1123,9 @@ type Service interface {
// io.Reader r.
NewScript(ctx context.Context, teamID *uint, name string, r io.Reader) (*Script, error)

// UpdateScript updates a saved script with the contents of io.Reader r
UpdateScript(ctx context.Context, scriptID uint, r io.Reader) (*Script, error)

// DeleteScript deletes an existing (saved) script.
DeleteScript(ctx context.Context, scriptID uint) error

Expand Down
12 changes: 12 additions & 0 deletions server/mock/datastore_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -1047,6 +1047,8 @@ type ListPendingHostScriptExecutionsFunc func(ctx context.Context, hostID uint,

type NewScriptFunc func(ctx context.Context, script *fleet.Script) (*fleet.Script, error)

type UpdateScriptContentsFunc func(ctx context.Context, scriptID uint, scriptContents string) (*fleet.Script, error)

type ScriptFunc func(ctx context.Context, id uint) (*fleet.Script, error)

type GetScriptContentsFunc func(ctx context.Context, id uint) ([]byte, error)
Expand Down Expand Up @@ -2747,6 +2749,9 @@ type DataStore struct {
NewScriptFunc NewScriptFunc
NewScriptFuncInvoked bool

UpdateScriptContentsFunc UpdateScriptContentsFunc
UpdateScriptContentsFuncInvoked bool

ScriptFunc ScriptFunc
ScriptFuncInvoked bool

Expand Down Expand Up @@ -6581,6 +6586,13 @@ func (s *DataStore) NewScript(ctx context.Context, script *fleet.Script) (*fleet
return s.NewScriptFunc(ctx, script)
}

func (s *DataStore) UpdateScriptContents(ctx context.Context, scriptID uint, scriptContents string) (*fleet.Script, error) {
s.mu.Lock()
s.UpdateScriptContentsFuncInvoked = true
s.mu.Unlock()
return s.UpdateScriptContentsFunc(ctx, scriptID, scriptContents)
}

func (s *DataStore) Script(ctx context.Context, id uint) (*fleet.Script, error) {
s.mu.Lock()
s.ScriptFuncInvoked = true
Expand Down
1 change: 1 addition & 0 deletions server/service/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,7 @@ func attachFleetAPIRoutes(r *mux.Router, svc fleet.Service, config config.FleetC
ue.POST("/api/_version_/fleet/scripts", createScriptEndpoint, createScriptRequest{})
ue.GET("/api/_version_/fleet/scripts", listScriptsEndpoint, listScriptsRequest{})
ue.GET("/api/_version_/fleet/scripts/{script_id:[0-9]+}", getScriptEndpoint, getScriptRequest{})
ue.PATCH("/api/_version_/fleet/scripts/{script_id:[0-9]+}", updateScriptEndpoint, updateScriptRequest{})
ue.DELETE("/api/_version_/fleet/scripts/{script_id:[0-9]+}", deleteScriptEndpoint, deleteScriptRequest{})
ue.POST("/api/_version_/fleet/scripts/batch", batchSetScriptsEndpoint, batchSetScriptsRequest{})

Expand Down
27 changes: 27 additions & 0 deletions server/service/integration_enterprise_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7200,6 +7200,33 @@ func (s *integrationEnterpriseTestSuite) TestSavedScripts() {
require.NotEqual(t, tmScriptID, newScriptResp.ScriptID)
s.lastActivityMatches("added_script", fmt.Sprintf(`{"script_name": %q, "team_name": %q, "team_id": %d}`, "script2.sh", tm.Name, tm.ID), 0)

// Update a script
updateScriptRep := updateScriptResponse{}
body, headers = generateNewScriptMultipartRequest(t,
"script1.sh", []byte(`echo "updated script"`), s.token, map[string][]string{"id": {fmt.Sprintf("%d", tmScriptID)}})
res = s.DoRawWithHeaders("PATCH", fmt.Sprintf("/api/latest/fleet/scripts/%d", tmScriptID), body.Bytes(), http.StatusOK, headers)
err = json.NewDecoder(res.Body).Decode(&updateScriptRep)
require.NoError(t, err)
require.NotZero(t, newScriptResp.ScriptID)
require.Equal(t, tmScriptID, updateScriptRep.ScriptID)
s.lastActivityMatches("updated_script", fmt.Sprintf(`{"script_name": %q, "team_name": %q, "team_id": %d}`, "script1.sh", tm.Name, tm.ID), 0)

// Download the updated script
res = s.Do("GET", fmt.Sprintf("/api/latest/fleet/scripts/%d", tmScriptID), nil, http.StatusOK, "alt", "media")
b, err = io.ReadAll(res.Body)
require.NoError(t, err)
require.Equal(t, `echo "updated script"`, string(b))
require.Equal(t, int64(len(`echo "updated script"`)), res.ContentLength)
require.Equal(t, fmt.Sprintf("attachment;filename=\"%s %s\"", time.Now().Format(time.DateOnly), "script1.sh"), res.Header.Get("Content-Disposition"))

// Try updating a non-existant script
updateScriptRep = updateScriptResponse{}
body, headers = generateNewScriptMultipartRequest(t,
"script1.sh", []byte(`echo "updated script"`), s.token, map[string][]string{"id": {fmt.Sprintf("%d", 99999999999)}})
res = s.DoRawWithHeaders("PATCH", fmt.Sprintf("/api/latest/fleet/scripts/%d", tmScriptID), body.Bytes(), http.StatusNotFound, headers)
err = json.NewDecoder(res.Body).Decode(&updateScriptRep)
require.NoError(t, err)

// delete the no-team script
s.Do("DELETE", fmt.Sprintf("/api/latest/fleet/scripts/%d", noTeamScriptID), nil, http.StatusNoContent)
s.lastActivityMatches("deleted_script", fmt.Sprintf(`{"script_name": %q, "team_name": null, "team_id": null}`, "script1.sh"), 0)
Expand Down
108 changes: 108 additions & 0 deletions server/service/scripts.go
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,114 @@ func (svc *Service) GetScript(ctx context.Context, scriptID uint, withContent bo
return script, content, nil
}

////////////////////////////////////////////////////////////////////////////////
// Update Script Contents
////////////////////////////////////////////////////////////////////////////////

type updateScriptRequest struct {
Script *multipart.FileHeader
ScriptID uint
dantecatalfamo marked this conversation as resolved.
Show resolved Hide resolved
}

func (updateScriptRequest) DecodeRequest(ctx context.Context, r *http.Request) (interface{}, error) {
var decoded updateScriptRequest

err := r.ParseMultipartForm(512 * units.MiB) // same in-memory size as for other multipart requests we have
if err != nil {
return nil, &fleet.BadRequestError{
Message: "failed to parse multipart form",
InternalErr: err,
}
}

fhs, ok := r.MultipartForm.File["script"]
if !ok || len(fhs) < 1 {
return nil, &fleet.BadRequestError{Message: "no file headers for script"}
}
decoded.Script = fhs[0]

return &decoded, nil
}

type updateScriptResponse struct {
Err error `json:"error,omitempty"`
ScriptID uint `json:"script_id,omitempty"`
}

func (r updateScriptResponse) error() error { return r.Err }

func updateScriptEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (errorer, error) {
req := request.(*updateScriptRequest)

scriptFile, err := req.Script.Open()
if err != nil {
return &updateScriptResponse{Err: err}, nil
}
defer scriptFile.Close()

script, err := svc.UpdateScript(ctx, req.ScriptID, scriptFile)
if err != nil {
return updateScriptResponse{Err: err}, nil
}
return updateScriptResponse{ScriptID: script.ID}, nil
}

func (svc *Service) UpdateScript(ctx context.Context, scriptID uint, r io.Reader) (*fleet.Script, error) {
script, err := svc.ds.Script(ctx, scriptID)
if err != nil {
svc.authz.SkipAuthorization(ctx)
return nil, ctxerr.Wrap(ctx, err, "finding original script to update")
}

if err := svc.authz.Authorize(ctx, &fleet.Script{TeamID: script.TeamID}, fleet.ActionWrite); err != nil {
return nil, err
}

b, err := io.ReadAll(r)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "read script contents")
}

scriptContents := file.Dos2UnixNewlines(string(b))

if err := svc.ds.ValidateEmbeddedSecrets(ctx, []string{scriptContents}); err != nil {
return nil, fleet.NewInvalidArgumentError("script", err.Error())
}

if err := fleet.ValidateHostScriptContents(scriptContents, true); err != nil {
return nil, fleet.NewInvalidArgumentError("script", err.Error())
}

// Update the script
savedScript, err := svc.ds.UpdateScriptContents(ctx, scriptID, scriptContents)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "updating script contents")
}

var teamName *string
if script.TeamID != nil && *script.TeamID != 0 {
tm, err := svc.EnterpriseOverrides.TeamByIDOrName(ctx, script.TeamID, nil)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "get team name for create script activity")
}
teamName = &tm.Name
}

if err := svc.NewActivity(
ctx,
authz.UserFromContext(ctx),
fleet.ActivityTypeUpdatedScript{
TeamID: script.TeamID,
TeamName: teamName,
ScriptName: script.Name,
},
); err != nil {
return nil, ctxerr.Wrap(ctx, err, "new activity for update script")
}

return savedScript, nil
}

////////////////////////////////////////////////////////////////////////////////
// Get Host Script Details
////////////////////////////////////////////////////////////////////////////////
Expand Down
Loading