diff --git a/experimental/ssh/cmd/connect.go b/experimental/ssh/cmd/connect.go index 4eca1aee7b..c05430368a 100644 --- a/experimental/ssh/cmd/connect.go +++ b/experimental/ssh/cmd/connect.go @@ -35,6 +35,7 @@ the SSH server and handling the connection proxy. var autoStartCluster bool var userKnownHostsFile string var liteswap string + var skipSettingsCheck bool cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID (for dedicated clusters)") cmd.Flags().DurationVar(&shutdownDelay, "shutdown-delay", defaultShutdownDelay, "Delay before shutting down the server after the last client disconnects") @@ -64,6 +65,9 @@ the SSH server and handling the connection proxy. cmd.Flags().StringVar(&liteswap, "liteswap", "", "Liteswap header value for traffic routing (dev/test only)") cmd.Flags().MarkHidden("liteswap") + cmd.Flags().BoolVar(&skipSettingsCheck, "skip-settings-check", false, "Skip checking and updating IDE settings") + cmd.Flags().MarkHidden("skip-settings-check") + cmd.PreRunE = func(cmd *cobra.Command, args []string) error { // CLI in the proxy mode is executed by the ssh client and can't prompt for input if proxyMode { @@ -113,6 +117,7 @@ the SSH server and handling the connection proxy. ClientPrivateKeyName: clientPrivateKeyName, UserKnownHostsFile: userKnownHostsFile, Liteswap: liteswap, + SkipSettingsCheck: skipSettingsCheck, AdditionalArgs: args, } return client.Run(ctx, wsClient, opts) diff --git a/experimental/ssh/internal/client/client.go b/experimental/ssh/internal/client/client.go index 940f792f0e..d751833db1 100644 --- a/experimental/ssh/internal/client/client.go +++ b/experimental/ssh/internal/client/client.go @@ -20,6 +20,7 @@ import ( "github.com/databricks/cli/experimental/ssh/internal/keys" "github.com/databricks/cli/experimental/ssh/internal/proxy" "github.com/databricks/cli/experimental/ssh/internal/sshconfig" + "github.com/databricks/cli/experimental/ssh/internal/vscode" sshWorkspace "github.com/databricks/cli/experimental/ssh/internal/workspace" "github.com/databricks/cli/internal/build" "github.com/databricks/cli/libs/cmdio" @@ -92,6 +93,8 @@ type ClientOptions struct { UserKnownHostsFile string // Liteswap header value for traffic routing (dev/test only). Liteswap string + // If true, skip checking and updating IDE settings. + SkipSettingsCheck bool } func (o *ClientOptions) IsServerlessMode() bool { @@ -206,6 +209,26 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt cmdio.LogString(ctx, "Using SSH key: "+keyPath) cmdio.LogString(ctx, fmt.Sprintf("Secrets scope: %s, key name: %s", secretScopeName, opts.ClientPublicKeyName)) + // Check and update IDE settings for serverless mode, where we must set up + // desired server ports (or socket connection mode) for the connection to go through + // (as the majority of the localhost ports on the remote side are blocked by iptable rules). + // Plus the platform (always linux), and extensions (python and jupyter), to make the initial experience smoother. + if opts.IDE != "" && opts.IsServerlessMode() && !opts.ProxyMode && !opts.SkipSettingsCheck && cmdio.IsPromptSupported(ctx) { + err = vscode.CheckAndUpdateSettings(ctx, opts.IDE, opts.ConnectionName) + if err != nil { + cmdio.LogString(ctx, fmt.Sprintf("Failed to update IDE settings: %v", err)) + cmdio.LogString(ctx, vscode.GetManualInstructions(opts.IDE, opts.ConnectionName)) + cmdio.LogString(ctx, "Use --skip-settings-check to bypass IDE settings verification.") + shouldProceed, promptErr := cmdio.AskYesOrNo(ctx, "Do you want to proceed with the connection?") + if promptErr != nil { + return fmt.Errorf("failed to prompt user: %w", promptErr) + } + if !shouldProceed { + return errors.New("aborted: IDE settings need to be updated manually, user declined to proceed") + } + } + } + var userName string var serverPort int var clusterID string diff --git a/experimental/ssh/internal/vscode/settings.go b/experimental/ssh/internal/vscode/settings.go new file mode 100644 index 0000000000..d6ddb15eee --- /dev/null +++ b/experimental/ssh/internal/vscode/settings.go @@ -0,0 +1,368 @@ +package vscode + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "runtime" + "strings" + + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/env" + "github.com/databricks/cli/libs/log" + "github.com/tailscale/hujson" +) + +const ( + portRange = "4000-4005" + remotePlatform = "linux" + pythonExtension = "ms-python.python" + jupyterExtension = "ms-toolsai.jupyter" + serverPickPortsKey = "remote.SSH.serverPickPortsFromRange" + remotePlatformKey = "remote.SSH.remotePlatform" + defaultExtensionsKey = "remote.SSH.defaultExtensions" + listenOnSocketKey = "remote.SSH.remoteServerListenOnSocket" + vscodeIDE = "vscode" + cursorIDE = "cursor" + vscodeName = "VS Code" + cursorName = "Cursor" +) + +func getIDEName(ide string) string { + if ide == cursorIDE { + return cursorName + } + return vscodeName +} + +type missingSettings struct { + portRange bool + platform bool + listenOnSocket bool + extensions []string +} + +func (m *missingSettings) isEmpty() bool { + return !m.portRange && !m.platform && !m.listenOnSocket && len(m.extensions) == 0 +} + +// Builds a JSON Pointer (RFC 6901) from path segments to be used in hujson.Value.Find. +// Escapes "~" → "~0" and "/" → "~1" per spec. +func jsonPtr(segments ...string) string { + var b strings.Builder + r := strings.NewReplacer("~", "~0", "/", "~1") + for _, s := range segments { + b.WriteByte('/') + b.WriteString(r.Replace(s)) + } + return b.String() +} + +type patchOp struct { + Op string `json:"op"` + Path string `json:"path"` + Value any `json:"value,omitempty"` +} + +func logSkippingSettings(ctx context.Context, msg string) { + cmdio.LogString(ctx, msg+"\n\nWARNING: the connection might not work as expected\n") +} + +func CheckAndUpdateSettings(ctx context.Context, ide, connectionName string) error { + if !cmdio.IsPromptSupported(ctx) { + logSkippingSettings(ctx, "Skipping IDE settings check: prompts not supported") + return nil + } + + settingsPath, err := getDefaultSettingsPath(ctx, ide) + if err != nil { + return fmt.Errorf("failed to get settings path: %w", err) + } + + settings, err := loadSettings(settingsPath) + if err != nil { + if os.IsNotExist(err) { + return handleMissingFile(ctx, ide, connectionName, settingsPath) + } + return fmt.Errorf("failed to load settings: %w", err) + } + + missing := validateSettings(settings, connectionName) + if missing.isEmpty() { + log.Debugf(ctx, "IDE settings already correct for %s", connectionName) + return nil + } + + shouldUpdate, err := promptUserForUpdate(ctx, ide, connectionName, missing) + if err != nil { + return fmt.Errorf("failed to prompt user: %w", err) + } + if !shouldUpdate { + logSkippingSettings(ctx, "Skipping IDE settings update") + return nil + } + + if err := backupSettings(ctx, settingsPath); err != nil { + log.Warnf(ctx, "Failed to backup settings: %v. Continuing with update.", err) + } + + if err := updateSettings(&settings, connectionName, missing); err != nil { + return fmt.Errorf("failed to update settings: %w", err) + } + + if err := saveSettings(settingsPath, &settings); err != nil { + return fmt.Errorf("failed to save settings: %w", err) + } + + cmdio.LogString(ctx, fmt.Sprintf("Updated %s settings for '%s'", getIDEName(ide), connectionName)) + return nil +} + +func getDefaultSettingsPath(ctx context.Context, ide string) (string, error) { + home, err := env.UserHomeDir(ctx) + if err != nil { + return "", fmt.Errorf("failed to get home directory: %w", err) + } + + appName := "Code" + if ide == cursorIDE { + appName = "Cursor" + } + + var settingsDir string + switch runtime.GOOS { + case "darwin": + settingsDir = filepath.Join(home, "Library", "Application Support", appName, "User") + case "windows": + appData := env.Get(ctx, "APPDATA") + if appData == "" { + appData = filepath.Join(home, "AppData", "Roaming") + } + settingsDir = filepath.Join(appData, appName, "User") + case "linux": + settingsDir = filepath.Join(home, ".config", appName, "User") + default: + return "", fmt.Errorf("unsupported operating system: %s", runtime.GOOS) + } + + return filepath.Join(settingsDir, "settings.json"), nil +} + +func loadSettings(path string) (hujson.Value, error) { + data, err := os.ReadFile(path) + if err != nil { + return hujson.Value{}, err + } + v, err := hujson.Parse(data) + if err != nil { + return hujson.Value{}, fmt.Errorf("failed to parse settings JSON: %w", err) + } + return v, nil +} + +func hasCorrectPortRange(v hujson.Value, connectionName string) bool { + found := v.Find(jsonPtr(serverPickPortsKey, connectionName)) + if found == nil { + return false + } + lit, ok := found.Value.(hujson.Literal) + return ok && lit.String() == portRange +} + +func hasCorrectPlatform(v hujson.Value, connectionName string) bool { + found := v.Find(jsonPtr(remotePlatformKey, connectionName)) + if found == nil { + return false + } + lit, ok := found.Value.(hujson.Literal) + return ok && lit.String() == remotePlatform +} + +func hasCorrectListenOnSocket(v hujson.Value) bool { + found := v.Find(jsonPtr(listenOnSocketKey)) + if found == nil { + return false + } + lit, ok := found.Value.(hujson.Literal) + return ok && lit.Bool() +} + +func getMissingExtensions(v hujson.Value) []string { + required := []string{pythonExtension, jupyterExtension} + found := v.Find(jsonPtr(defaultExtensionsKey)) + if found == nil { + return required + } + arr, ok := found.Value.(*hujson.Array) + if !ok { + return required + } + existingSet := make(map[string]bool, len(arr.Elements)) + for _, el := range arr.Elements { + if lit, ok := el.Value.(hujson.Literal); ok { + existingSet[lit.String()] = true + } + } + var missing []string + for _, ext := range required { + if !existingSet[ext] { + missing = append(missing, ext) + } + } + return missing +} + +func validateSettings(v hujson.Value, connectionName string) *missingSettings { + return &missingSettings{ + portRange: !hasCorrectPortRange(v, connectionName), + platform: !hasCorrectPlatform(v, connectionName), + listenOnSocket: !hasCorrectListenOnSocket(v), + extensions: getMissingExtensions(v), + } +} + +func settingsMessage(connectionName string, missing *missingSettings) string { + var lines []string + if missing.portRange { + lines = append(lines, fmt.Sprintf(" \"%s\": {\"%s\": \"%s\"}", serverPickPortsKey, connectionName, portRange)) + } + if missing.platform { + lines = append(lines, fmt.Sprintf(" \"%s\": {\"%s\": \"%s\"}", remotePlatformKey, connectionName, remotePlatform)) + } + if missing.listenOnSocket { + lines = append(lines, fmt.Sprintf(" \"%s\": true // Global setting that affects all remote ssh connections", listenOnSocketKey)) + } + if len(missing.extensions) > 0 { + quoted := make([]string, len(missing.extensions)) + for i, ext := range missing.extensions { + quoted[i] = fmt.Sprintf("\"%s\"", ext) + } + lines = append(lines, fmt.Sprintf(" \"%s\": [%s] // Global setting that affects all remote ssh connections", defaultExtensionsKey, strings.Join(quoted, ", "))) + } + return strings.Join(lines, "\n") +} + +func promptUserForUpdate(ctx context.Context, ide, connectionName string, missing *missingSettings) (bool, error) { + question := fmt.Sprintf( + "The following settings will be applied to %s for '%s':\n%s\nApply these settings?", + getIDEName(ide), connectionName, settingsMessage(connectionName, missing)) + return cmdio.AskYesOrNo(ctx, question) +} + +func handleMissingFile(ctx context.Context, ide, connectionName, settingsPath string) error { + missing := &missingSettings{ + portRange: true, + platform: true, + listenOnSocket: true, + extensions: []string{pythonExtension, jupyterExtension}, + } + shouldCreate, err := promptUserForUpdate(ctx, ide, connectionName, missing) + if err != nil { + return fmt.Errorf("failed to prompt user: %w", err) + } + if !shouldCreate { + logSkippingSettings(ctx, "Skipping IDE settings creation") + return nil + } + + settingsDir := filepath.Dir(settingsPath) + if err := os.MkdirAll(settingsDir, 0o755); err != nil { + return fmt.Errorf("failed to create settings directory: %w", err) + } + + v, err := hujson.Parse([]byte("{}")) + if err != nil { + return fmt.Errorf("failed to create settings: %w", err) + } + if err := updateSettings(&v, connectionName, missing); err != nil { + return fmt.Errorf("failed to update settings: %w", err) + } + + if err := saveSettings(settingsPath, &v); err != nil { + return fmt.Errorf("failed to save settings: %w", err) + } + + cmdio.LogString(ctx, fmt.Sprintf("Created %s settings at %s", getIDEName(ide), filepath.ToSlash(settingsPath))) + return nil +} + +func backupSettings(ctx context.Context, path string) error { + data, err := os.ReadFile(path) + if err != nil { + return err + } + if len(data) == 0 { + return nil + } + + originalBak := path + ".original.bak" + latestBak := path + ".latest.bak" + + if _, err := os.Stat(originalBak); os.IsNotExist(err) { + cmdio.LogString(ctx, "Backing up settings to "+filepath.ToSlash(originalBak)) + return os.WriteFile(originalBak, data, 0o600) + } + + cmdio.LogString(ctx, "Backing up settings to "+filepath.ToSlash(latestBak)) + return os.WriteFile(latestBak, data, 0o600) +} + +// subKeyOp returns a patch op that sets key/subKey=value, creating the parent object if absent. +func subKeyOp(v *hujson.Value, key, subKey, value string) patchOp { + if v.Find(jsonPtr(key)) == nil { + return patchOp{"add", jsonPtr(key), map[string]string{subKey: value}} + } + return patchOp{"add", jsonPtr(key, subKey), value} +} + +func updateSettings(v *hujson.Value, connectionName string, missing *missingSettings) error { + var ops []patchOp + if missing.portRange { + ops = append(ops, subKeyOp(v, serverPickPortsKey, connectionName, portRange)) + } + if missing.platform { + ops = append(ops, subKeyOp(v, remotePlatformKey, connectionName, remotePlatform)) + } + if missing.listenOnSocket { + ops = append(ops, patchOp{"add", jsonPtr(listenOnSocketKey), true}) + } + if len(missing.extensions) > 0 { + parent := jsonPtr(defaultExtensionsKey) + if v.Find(parent) == nil { + ops = append(ops, patchOp{"add", parent, missing.extensions}) + } else { + for _, ext := range missing.extensions { + ops = append(ops, patchOp{"add", parent + "/-", ext}) + } + } + } + if len(ops) == 0 { + return nil + } + patchData, err := json.Marshal(ops) + if err != nil { + return fmt.Errorf("failed to marshal patch: %w", err) + } + return v.Patch(patchData) +} + +func saveSettings(path string, v *hujson.Value) error { + if err := os.WriteFile(path, v.Pack(), 0o600); err != nil { + return fmt.Errorf("failed to write settings file: %w", err) + } + return nil +} + +func GetManualInstructions(ide, connectionName string) string { + missing := &missingSettings{ + portRange: true, + platform: true, + listenOnSocket: true, + extensions: []string{pythonExtension, jupyterExtension}, + } + return fmt.Sprintf( + "To ensure the remote connection works as expected, manually add these settings to your %s settings.json:\n%s", + getIDEName(ide), settingsMessage(connectionName, missing)) +} diff --git a/experimental/ssh/internal/vscode/settings_test.go b/experimental/ssh/internal/vscode/settings_test.go new file mode 100644 index 0000000000..267254e093 --- /dev/null +++ b/experimental/ssh/internal/vscode/settings_test.go @@ -0,0 +1,562 @@ +package vscode + +import ( + "context" + "encoding/json" + "os" + "path/filepath" + "runtime" + "testing" + + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/env" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tailscale/hujson" +) + +func parseTestValue(t *testing.T, jsonStr string) hujson.Value { + t.Helper() + v, err := hujson.Parse([]byte(jsonStr)) + require.NoError(t, err) + return v +} + +func findString(t *testing.T, v hujson.Value, ptr string) (string, bool) { + t.Helper() + found := v.Find(ptr) + if found == nil { + return "", false + } + var s string + if err := json.Unmarshal(found.Pack(), &s); err != nil { + return "", false + } + return s, true +} + +func findStringSlice(t *testing.T, v hujson.Value, ptr string) []string { + t.Helper() + found := v.Find(ptr) + if found == nil { + return nil + } + var ss []string + if err := json.Unmarshal(found.Pack(), &ss); err != nil { + return nil + } + return ss +} + +func TestGetDefaultSettingsPath_VSCode_Linux(t *testing.T) { + if runtime.GOOS != "linux" { + t.Skip("Skipping Linux-specific test") + } + + ctx := context.Background() + ctx = env.Set(ctx, "HOME", "/home/testuser") + + path, err := getDefaultSettingsPath(ctx, vscodeIDE) + require.NoError(t, err) + assert.Equal(t, "/home/testuser/.config/Code/User/settings.json", path) +} + +func TestGetDefaultSettingsPath_Cursor_Linux(t *testing.T) { + if runtime.GOOS != "linux" { + t.Skip("Skipping Linux-specific test") + } + + ctx := context.Background() + ctx = env.Set(ctx, "HOME", "/home/testuser") + + path, err := getDefaultSettingsPath(ctx, cursorIDE) + require.NoError(t, err) + assert.Equal(t, "/home/testuser/.config/Cursor/User/settings.json", path) +} + +func TestGetDefaultSettingsPath_VSCode_Darwin(t *testing.T) { + if runtime.GOOS != "darwin" { + t.Skip("Skipping Darwin-specific test") + } + + ctx := context.Background() + ctx = env.Set(ctx, "HOME", "/Users/testuser") + + path, err := getDefaultSettingsPath(ctx, vscodeIDE) + require.NoError(t, err) + assert.Equal(t, "/Users/testuser/Library/Application Support/Code/User/settings.json", path) +} + +func TestGetDefaultSettingsPath_Cursor_Darwin(t *testing.T) { + if runtime.GOOS != "darwin" { + t.Skip("Skipping Darwin-specific test") + } + + ctx := context.Background() + ctx = env.Set(ctx, "HOME", "/Users/testuser") + + path, err := getDefaultSettingsPath(ctx, cursorIDE) + require.NoError(t, err) + assert.Equal(t, "/Users/testuser/Library/Application Support/Cursor/User/settings.json", path) +} + +func TestGetDefaultSettingsPath_VSCode_Windows(t *testing.T) { + if runtime.GOOS != "windows" { + t.Skip("Skipping Windows-specific test") + } + + ctx := context.Background() + ctx = env.Set(ctx, "APPDATA", `C:\Users\testuser\AppData\Roaming`) + + path, err := getDefaultSettingsPath(ctx, vscodeIDE) + require.NoError(t, err) + assert.Equal(t, `C:\Users\testuser\AppData\Roaming\Code\User\settings.json`, path) +} + +func TestGetDefaultSettingsPath_Cursor_Windows(t *testing.T) { + if runtime.GOOS != "windows" { + t.Skip("Skipping Windows-specific test") + } + + ctx := context.Background() + ctx = env.Set(ctx, "APPDATA", `C:\Users\testuser\AppData\Roaming`) + + path, err := getDefaultSettingsPath(ctx, cursorIDE) + require.NoError(t, err) + assert.Equal(t, `C:\Users\testuser\AppData\Roaming\Cursor\User\settings.json`, path) +} + +func TestLoadSettings_Valid(t *testing.T) { + tmpDir := t.TempDir() + settingsPath := filepath.Join(tmpDir, "settings.json") + + settingsData := `{ + "editor.fontSize": 14, + "remote.SSH.serverPickPortsFromRange": { + "test-conn": "4000-4005" + } + }` + err := os.WriteFile(settingsPath, []byte(settingsData), 0o600) + require.NoError(t, err) + + settings, err := loadSettings(settingsPath) + require.NoError(t, err) + assert.NotNil(t, settings.Find("/editor.fontSize")) + assert.NotNil(t, settings.Find("/remote.SSH.serverPickPortsFromRange")) +} + +func TestLoadSettings_Invalid(t *testing.T) { + tmpDir := t.TempDir() + settingsPath := filepath.Join(tmpDir, "settings.json") + + err := os.WriteFile(settingsPath, []byte("invalid json {"), 0o600) + require.NoError(t, err) + + _, err = loadSettings(settingsPath) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to parse settings JSON") +} + +func TestLoadSettings_WithComments(t *testing.T) { + tmpDir := t.TempDir() + settingsPath := filepath.Join(tmpDir, "settings.json") + + // JSONC format with comments and trailing commas (typical VS Code settings) + settingsData := `{ + // Editor settings + "editor.fontSize": 14, + /* Connection settings */ + "remote.SSH.serverPickPortsFromRange": { + "test-conn": "4000-4005" // Port range for SSH + }, + "remote.SSH.remotePlatform": { + "test-conn": "linux", // trailing comma + } + }` + err := os.WriteFile(settingsPath, []byte(settingsData), 0o600) + require.NoError(t, err) + + settings, err := loadSettings(settingsPath) + require.NoError(t, err) + assert.NotNil(t, settings.Find("/editor.fontSize")) + assert.NotNil(t, settings.Find("/remote.SSH.serverPickPortsFromRange")) + + val, ok := findString(t, settings, jsonPtr(serverPickPortsKey, "test-conn")) + assert.True(t, ok) + assert.Equal(t, "4000-4005", val) +} + +func TestLoadSettings_NotExists(t *testing.T) { + tmpDir := t.TempDir() + settingsPath := filepath.Join(tmpDir, "nonexistent.json") + + _, err := loadSettings(settingsPath) + assert.Error(t, err) + assert.True(t, os.IsNotExist(err)) +} + +func TestValidateSettings_Complete(t *testing.T) { + v := parseTestValue(t, `{ + "remote.SSH.serverPickPortsFromRange": {"test-conn": "4000-4005"}, + "remote.SSH.remotePlatform": {"test-conn": "linux"}, + "remote.SSH.remoteServerListenOnSocket": true, + "remote.SSH.defaultExtensions": ["ms-python.python", "ms-toolsai.jupyter"] + }`) + + missing := validateSettings(v, "test-conn") + assert.True(t, missing.isEmpty()) +} + +func TestValidateSettings_Missing(t *testing.T) { + v := parseTestValue(t, `{}`) + + missing := validateSettings(v, "test-conn") + assert.False(t, missing.isEmpty()) + assert.True(t, missing.portRange) + assert.True(t, missing.platform) + assert.Equal(t, []string{"ms-python.python", "ms-toolsai.jupyter"}, missing.extensions) +} + +func TestValidateSettings_IncorrectValues(t *testing.T) { + v := parseTestValue(t, `{ + "remote.SSH.serverPickPortsFromRange": {"test-conn": "5000-5005"}, + "remote.SSH.remotePlatform": {"test-conn": "windows"}, + "remote.SSH.defaultExtensions": ["ms-python.python"] + }`) + + missing := validateSettings(v, "test-conn") + assert.False(t, missing.isEmpty()) + assert.True(t, missing.portRange) + assert.True(t, missing.platform) + assert.Equal(t, []string{"ms-toolsai.jupyter"}, missing.extensions) +} + +func TestValidateSettings_DuplicateExtensionsNotReported(t *testing.T) { + v := parseTestValue(t, `{ + "remote.SSH.serverPickPortsFromRange": {"test-conn": "4000-4005"}, + "remote.SSH.remotePlatform": {"test-conn": "linux"}, + "remote.SSH.remoteServerListenOnSocket": true, + "remote.SSH.defaultExtensions": ["ms-python.python", "ms-python.python", "ms-toolsai.jupyter"] + }`) + + missing := validateSettings(v, "test-conn") + assert.True(t, missing.isEmpty()) +} + +func TestValidateSettings_MissingConnection(t *testing.T) { + v := parseTestValue(t, `{ + "remote.SSH.serverPickPortsFromRange": {"other-conn": "4000-4005"}, + "remote.SSH.remotePlatform": {"other-conn": "linux"}, + "remote.SSH.defaultExtensions": ["ms-python.python", "ms-toolsai.jupyter"] + }`) + + // Validating for a different connection should show port and platform as missing + missing := validateSettings(v, "test-conn") + assert.False(t, missing.isEmpty()) + assert.True(t, missing.portRange) + assert.True(t, missing.platform) + assert.Empty(t, missing.extensions) // Extensions are global, so they're present +} + +func TestUpdateSettings_PreserveExistingConnections(t *testing.T) { + v := parseTestValue(t, `{ + "remote.SSH.serverPickPortsFromRange": { + "conn-a": "5000-5005", + "conn-b": "6000-6005" + }, + "remote.SSH.remotePlatform": { + "conn-a": "linux", + "conn-b": "darwin" + }, + "remote.SSH.defaultExtensions": ["other.extension"] + }`) + + missing := &missingSettings{ + portRange: true, + platform: true, + extensions: []string{"ms-python.python", "ms-toolsai.jupyter"}, + } + + err := updateSettings(&v, "conn-c", missing) + require.NoError(t, err) + + // Check that new connection was added + val, ok := findString(t, v, jsonPtr(serverPickPortsKey, "conn-c")) + assert.True(t, ok) + assert.Equal(t, "4000-4005", val) + + val, ok = findString(t, v, jsonPtr(remotePlatformKey, "conn-c")) + assert.True(t, ok) + assert.Equal(t, "linux", val) + + // Check that existing connections were preserved + val, ok = findString(t, v, jsonPtr(serverPickPortsKey, "conn-a")) + assert.True(t, ok) + assert.Equal(t, "5000-5005", val) + + val, ok = findString(t, v, jsonPtr(serverPickPortsKey, "conn-b")) + assert.True(t, ok) + assert.Equal(t, "6000-6005", val) + + val, ok = findString(t, v, jsonPtr(remotePlatformKey, "conn-a")) + assert.True(t, ok) + assert.Equal(t, "linux", val) + + val, ok = findString(t, v, jsonPtr(remotePlatformKey, "conn-b")) + assert.True(t, ok) + assert.Equal(t, "darwin", val) + + // Check that extensions were merged + exts := findStringSlice(t, v, jsonPtr(defaultExtensionsKey)) + assert.Len(t, exts, 3) + assert.Contains(t, exts, "other.extension") + assert.Contains(t, exts, "ms-python.python") + assert.Contains(t, exts, "ms-toolsai.jupyter") +} + +func TestUpdateSettings_NewConnection(t *testing.T) { + v := parseTestValue(t, `{}`) + + missing := &missingSettings{ + portRange: true, + platform: true, + extensions: []string{"ms-python.python", "ms-toolsai.jupyter"}, + } + + err := updateSettings(&v, "new-conn", missing) + require.NoError(t, err) + + val, ok := findString(t, v, jsonPtr(serverPickPortsKey, "new-conn")) + assert.True(t, ok) + assert.Equal(t, "4000-4005", val) + + val, ok = findString(t, v, jsonPtr(remotePlatformKey, "new-conn")) + assert.True(t, ok) + assert.Equal(t, "linux", val) + + exts := findStringSlice(t, v, jsonPtr(defaultExtensionsKey)) + assert.Len(t, exts, 2) + assert.Contains(t, exts, "ms-python.python") + assert.Contains(t, exts, "ms-toolsai.jupyter") +} + +func TestUpdateSettings_GlobalExtensions(t *testing.T) { + // Verify that extensions are global, not per-connection + v := parseTestValue(t, `{ + "remote.SSH.defaultExtensions": ["ms-python.python"] + }`) + + missing := &missingSettings{ + extensions: []string{"ms-toolsai.jupyter"}, + } + + err := updateSettings(&v, "conn-a", missing) + require.NoError(t, err) + + exts := findStringSlice(t, v, jsonPtr(defaultExtensionsKey)) + assert.Len(t, exts, 2) + assert.Contains(t, exts, "ms-python.python") + assert.Contains(t, exts, "ms-toolsai.jupyter") + + // Update for another connection should use the same global array + missing2 := &missingSettings{ + extensions: []string{"another.extension"}, + } + + err = updateSettings(&v, "conn-b", missing2) + require.NoError(t, err) + + exts = findStringSlice(t, v, jsonPtr(defaultExtensionsKey)) + assert.Len(t, exts, 3) + assert.Contains(t, exts, "ms-python.python") + assert.Contains(t, exts, "ms-toolsai.jupyter") + assert.Contains(t, exts, "another.extension") +} + +func TestUpdateSettings_MergeExtensions(t *testing.T) { + v := parseTestValue(t, `{ + "remote.SSH.defaultExtensions": ["existing.extension", "ms-python.python"] + }`) + + missing := &missingSettings{ + extensions: []string{"ms-toolsai.jupyter"}, // ms-python.python already present + } + + err := updateSettings(&v, "test-conn", missing) + require.NoError(t, err) + + exts := findStringSlice(t, v, jsonPtr(defaultExtensionsKey)) + assert.Len(t, exts, 3) + assert.Contains(t, exts, "existing.extension") + assert.Contains(t, exts, "ms-python.python") + assert.Contains(t, exts, "ms-toolsai.jupyter") +} + +func TestUpdateSettings_PartialUpdate(t *testing.T) { + v := parseTestValue(t, `{ + "remote.SSH.serverPickPortsFromRange": {"test-conn": "4000-4005"}, + "remote.SSH.remotePlatform": {"other-conn": "linux"}, + "remote.SSH.defaultExtensions": ["ms-python.python", "ms-toolsai.jupyter"] + }`) + + missing := &missingSettings{ + portRange: false, // Already set + platform: true, // Needs update + extensions: nil, // Already present + } + + err := updateSettings(&v, "test-conn", missing) + require.NoError(t, err) + + // Port range should not be modified + val, ok := findString(t, v, jsonPtr(serverPickPortsKey, "test-conn")) + assert.True(t, ok) + assert.Equal(t, "4000-4005", val) + + // Platform should be added for test-conn + val, ok = findString(t, v, jsonPtr(remotePlatformKey, "test-conn")) + assert.True(t, ok) + assert.Equal(t, "linux", val) + + val, ok = findString(t, v, jsonPtr(remotePlatformKey, "other-conn")) + assert.True(t, ok) + assert.Equal(t, "linux", val) // Preserve other connection + + // Extensions should not be modified + exts := findStringSlice(t, v, jsonPtr(defaultExtensionsKey)) + assert.Len(t, exts, 2) +} + +func TestBackupSettings(t *testing.T) { + tmpDir := t.TempDir() + settingsPath := filepath.Join(tmpDir, "settings.json") + originalBak := settingsPath + ".original.bak" + latestBak := settingsPath + ".latest.bak" + + originalContent := []byte(`{"key": "value"}`) + err := os.WriteFile(settingsPath, originalContent, 0o600) + require.NoError(t, err) + + ctx, _ := cmdio.NewTestContextWithStderr(context.Background()) + + // First backup: should create .original.bak + err = backupSettings(ctx, settingsPath) + require.NoError(t, err) + + content, err := os.ReadFile(originalBak) + require.NoError(t, err) + assert.Equal(t, originalContent, content) + _, err = os.Stat(latestBak) + assert.True(t, os.IsNotExist(err)) + + // Second backup: .original.bak exists, should create .latest.bak + updatedContent := []byte(`{"key": "updated"}`) + err = os.WriteFile(settingsPath, updatedContent, 0o600) + require.NoError(t, err) + + err = backupSettings(ctx, settingsPath) + require.NoError(t, err) + + // .original.bak must remain unchanged + content, err = os.ReadFile(originalBak) + require.NoError(t, err) + assert.Equal(t, originalContent, content) + + // .latest.bak should have the updated content + content, err = os.ReadFile(latestBak) + require.NoError(t, err) + assert.Equal(t, updatedContent, content) +} + +func TestSaveSettings_Formatting(t *testing.T) { + tmpDir := t.TempDir() + settingsPath := filepath.Join(tmpDir, "settings.json") + + v := parseTestValue(t, `{ + "remote.SSH.serverPickPortsFromRange": {"test-conn": "4000-4005"}, + "editor.fontSize": 14 + }`) + + err := saveSettings(settingsPath, &v) + require.NoError(t, err) + + content, err := os.ReadFile(settingsPath) + require.NoError(t, err) + + // Verify it's valid JSON after standardizing + standardized, err := hujson.Standardize(content) + require.NoError(t, err) + var parsed map[string]any + err = json.Unmarshal(standardized, &parsed) + require.NoError(t, err) + + // Verify permissions + info, err := os.Stat(settingsPath) + require.NoError(t, err) + if runtime.GOOS != "windows" { + assert.Equal(t, os.FileMode(0o600), info.Mode().Perm()) + } +} + +func TestSaveSettings_PreservesComments(t *testing.T) { + tmpDir := t.TempDir() + settingsPath := filepath.Join(tmpDir, "settings.json") + + original := `{ + // This is a comment + "editor.fontSize": 14 +}` + err := os.WriteFile(settingsPath, []byte(original), 0o600) + require.NoError(t, err) + + v, err := loadSettings(settingsPath) + require.NoError(t, err) + + // Add a new setting + missing := &missingSettings{listenOnSocket: true} + err = updateSettings(&v, "test-conn", missing) + require.NoError(t, err) + + err = saveSettings(settingsPath, &v) + require.NoError(t, err) + + content, err := os.ReadFile(settingsPath) + require.NoError(t, err) + assert.Contains(t, string(content), "// This is a comment") +} + +func TestMissingSettings_IsEmpty(t *testing.T) { + empty := &missingSettings{} + assert.True(t, empty.isEmpty()) + + notEmpty := &missingSettings{portRange: true} + assert.False(t, notEmpty.isEmpty()) + + notEmpty2 := &missingSettings{extensions: []string{"ext"}} + assert.False(t, notEmpty2.isEmpty()) +} + +func TestGetManualInstructions_VSCode(t *testing.T) { + instructions := GetManualInstructions(vscodeIDE, "test-conn") + + assert.Contains(t, instructions, "VS Code") + assert.Contains(t, instructions, "test-conn") + assert.Contains(t, instructions, "4000-4005") + assert.Contains(t, instructions, "linux") + assert.Contains(t, instructions, "ms-python.python") + assert.Contains(t, instructions, "ms-toolsai.jupyter") + assert.Contains(t, instructions, "remote.SSH.serverPickPortsFromRange") + assert.Contains(t, instructions, "remote.SSH.remotePlatform") + assert.Contains(t, instructions, "remote.SSH.defaultExtensions") +} + +func TestGetManualInstructions_Cursor(t *testing.T) { + instructions := GetManualInstructions("cursor", "my-connection") + + assert.Contains(t, instructions, "Cursor") + assert.Contains(t, instructions, "my-connection") + assert.Contains(t, instructions, "4000-4005") + assert.Contains(t, instructions, "linux") + assert.Contains(t, instructions, "ms-python.python") + assert.Contains(t, instructions, "ms-toolsai.jupyter") +} diff --git a/go.mod b/go.mod index 1afeb28c07..9b57af4f1c 100644 --- a/go.mod +++ b/go.mod @@ -48,6 +48,9 @@ require github.com/google/jsonschema-go v0.4.2 // MIT require gopkg.in/yaml.v3 v3.0.1 // indirect +// Dependencies for experimental SSH commands +require github.com/tailscale/hujson v0.0.0-20250605163823-992244df8c5a // BSD-3-Clause + require ( cloud.google.com/go/auth v0.18.1 // indirect cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect diff --git a/go.sum b/go.sum index 1894726358..3bc43fe4fe 100644 --- a/go.sum +++ b/go.sum @@ -216,6 +216,8 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/tailscale/hujson v0.0.0-20250605163823-992244df8c5a h1:a6TNDN9CgG+cYjaeN8l2mc4kSz2iMiCDQxPEyltUV/I= +github.com/tailscale/hujson v0.0.0-20250605163823-992244df8c5a/go.mod h1:EbW0wDK/qEUYI0A5bqq0C2kF8JTQwWONmGDBbzsxxHo= github.com/xanzy/ssh-agent v0.3.3 h1:+/15pJfg/RsTxqYcX6fHqOXZwwMP+2VyYWJeWM2qQFM= github.com/xanzy/ssh-agent v0.3.3/go.mod h1:6dzNDKs0J9rVPHPhaGCukekBHKqfl+L3KghI1Bc68Uw= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=