diff --git a/components/ide/jetbrains/launcher/main.go b/components/ide/jetbrains/launcher/main.go index 04deb159c00978..3d5686d2bcb694 100644 --- a/components/ide/jetbrains/launcher/main.go +++ b/components/ide/jetbrains/launcher/main.go @@ -1186,10 +1186,13 @@ func linkRemotePlugin(launchCtx *LaunchContext) error { return safeLink("/ide-desktop-plugins/gitpod-remote", remotePluginDir) } +// safeLink creates a symlink from source to target, removing the old target if it exists func safeLink(source, target string) error { - if _, err := os.Stat(target); err == nil { + if _, err := os.Lstat(target); err == nil { // unlink the old symlink - _ = os.RemoveAll(target) + if err2 := os.RemoveAll(target); err2 != nil { + log.WithError(err).Error("failed to unlink old symlink") + } } return os.Symlink(source, target) } diff --git a/components/ide/jetbrains/launcher/main_test.go b/components/ide/jetbrains/launcher/main_test.go index 0234d406d9b1c1..12c9710898c8ff 100644 --- a/components/ide/jetbrains/launcher/main_test.go +++ b/components/ide/jetbrains/launcher/main_test.go @@ -5,6 +5,9 @@ package main import ( + "fmt" + "os" + "path/filepath" "strings" "testing" @@ -105,3 +108,42 @@ func TestUpdatePlatformProperties(t *testing.T) { } }) } + +func Test_safeLink(t *testing.T) { + type args struct { + changeSource bool + } + t.Log("link folders twice") + tests := []struct { + name string + args args + wantErr bool + }{ + {name: "happy path", args: args{changeSource: true}, wantErr: false}, + {name: "happy path 2", args: args{changeSource: false}, wantErr: false}, + } + for index, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + source := createTempDir(t, fmt.Sprintf("source_%d", index)) + target := createTempDir(t, fmt.Sprintf("target_%d", index)) + safeLink(source, target) + os.RemoveAll(source) + if tt.args.changeSource { + source = createTempDir(t, fmt.Sprintf("source_new_%d", index)) + } else { + source = createTempDir(t, fmt.Sprintf("source_%d", index)) + } + if err := safeLink(source, target); (err != nil) != tt.wantErr { + t.Errorf("safeLink() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func createTempDir(t *testing.T, dir string) string { + path := filepath.Join(t.TempDir(), dir) + if err := os.Mkdir(path, 0o755); err != nil { + t.Fatal(err) + } + return path +}