Skip to content

Commit 2c32f94

Browse files
committed
ssh-based provisioner: Re-enable support for PowerShell
1 parent a7e2c9f commit 2c32f94

File tree

2 files changed

+69
-6
lines changed

2 files changed

+69
-6
lines changed

internal/communicator/ssh/communicator.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,7 @@ func (c *Communicator) Upload(path string, input io.Reader) error {
436436
return scpUploadFile(targetFile, input, w, stdoutR, size)
437437
}
438438

439-
cmd, err := quoteShell([]string{"scp", "-vt", targetDir}, c.connInfo.TargetPlatform)
439+
cmd, err := quoteScpCommand([]string{"scp", "-vt", targetDir}, c.connInfo.TargetPlatform)
440440
if err != nil {
441441
return err
442442
}
@@ -509,7 +509,7 @@ func (c *Communicator) UploadDir(dst string, src string) error {
509509
return uploadEntries()
510510
}
511511

512-
cmd, err := quoteShell([]string{"scp", "-rvt", dst}, c.connInfo.TargetPlatform)
512+
cmd, err := quoteScpCommand([]string{"scp", "-rvt", dst}, c.connInfo.TargetPlatform)
513513
if err != nil {
514514
return err
515515
}
@@ -886,14 +886,15 @@ func (c *bastionConn) Close() error {
886886
return c.Bastion.Close()
887887
}
888888

889-
func quoteShell(args []string, targetPlatform string) (string, error) {
889+
func quoteScpCommand(args []string, targetPlatform string) (string, error) {
890890
if targetPlatform == TargetPlatformUnix {
891891
return shquot.POSIXShell(args), nil
892892
}
893893
if targetPlatform == TargetPlatformWindows {
894-
return shquot.WindowsArgv(args), nil
894+
cmd, args := shquot.WindowsArgvSplit(args)
895+
return fmt.Sprintf("%s %s", cmd, args), nil
895896
}
896897

897-
return "", fmt.Errorf("Cannot quote shell command, target platform unknown: %s", targetPlatform)
898+
return "", fmt.Errorf("Cannot quote scp command, target platform unknown: %s", targetPlatform)
898899

899900
}

internal/communicator/ssh/communicator_test.go

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"testing"
2424
"time"
2525

26+
"github.com/google/go-cmp/cmp"
2627
"github.com/hashicorp/terraform/internal/communicator/remote"
2728
"github.com/zclconf/go-cty/cty"
2829
"golang.org/x/crypto/ssh"
@@ -660,7 +661,7 @@ func TestAccHugeUploadFile(t *testing.T) {
660661
return scpUploadFile(targetFile, source, w, stdoutR, size)
661662
}
662663

663-
cmd, err := quoteShell([]string{"scp", "-vt", targetDir}, c.connInfo.TargetPlatform)
664+
cmd, err := quoteScpCommand([]string{"scp", "-vt", targetDir}, c.connInfo.TargetPlatform)
664665
if err != nil {
665666
t.Fatal(err)
666667
}
@@ -680,6 +681,67 @@ func TestAccHugeUploadFile(t *testing.T) {
680681
}
681682
}
682683

684+
func TestQuoteScpCommand(t *testing.T) {
685+
testCases := []struct {
686+
inputArgs []string
687+
platform string
688+
expectedCmd string
689+
}{
690+
// valid Unix command
691+
{
692+
[]string{"scp", "-vt", "/var/path"},
693+
TargetPlatformUnix,
694+
"'scp' -vt /var/path",
695+
},
696+
697+
// command injection attempt in Unix
698+
{
699+
[]string{"scp", "-vt", "/var/path;rm"},
700+
TargetPlatformUnix,
701+
"'scp' -vt /var/path\\;rm",
702+
},
703+
{
704+
[]string{"scp", "-vt", "/var/path&&rm"},
705+
TargetPlatformUnix,
706+
"'scp' -vt /var/path\\&\\&rm",
707+
},
708+
{
709+
[]string{"scp", "-vt", "/var/path|rm"},
710+
TargetPlatformUnix,
711+
"'scp' -vt /var/path\\|rm",
712+
},
713+
{
714+
[]string{"scp", "-vt", "/var/path; rm"},
715+
TargetPlatformUnix,
716+
"'scp' -vt '/var/path; rm'",
717+
},
718+
719+
// valid Windows command
720+
{
721+
[]string{"scp", "-vt", "C:\\Windows\\Temp"},
722+
TargetPlatformWindows,
723+
"scp -vt C:\\Windows\\Temp",
724+
},
725+
726+
// command injection attempt in Windows
727+
{
728+
[]string{"scp", "-vt", "C:\\Windows\\Temp\";rmdir"},
729+
TargetPlatformWindows,
730+
"scp -vt \"C:\\Windows\\Temp\\\";rmdir\"",
731+
},
732+
}
733+
734+
for _, tc := range testCases {
735+
cmd, err := quoteScpCommand(tc.inputArgs, tc.platform)
736+
if err != nil {
737+
t.Fatal(err)
738+
}
739+
if diff := cmp.Diff(tc.expectedCmd, cmd); diff != "" {
740+
t.Fatalf("unexpected command: %s", diff)
741+
}
742+
}
743+
}
744+
683745
func TestScriptPath(t *testing.T) {
684746
cases := []struct {
685747
Input string

0 commit comments

Comments
 (0)