diff --git a/dotnet/src/Client.cs b/dotnet/src/Client.cs index 99c0eff0..1410f064 100644 --- a/dotnet/src/Client.cs +++ b/dotnet/src/Client.cs @@ -79,6 +79,7 @@ public sealed partial class CopilotClient : IDisposable, IAsyncDisposable private readonly List> _lifecycleHandlers = []; private readonly Dictionary>> _typedLifecycleHandlers = []; private readonly object _lifecycleHandlersLock = new(); + private readonly ConcurrentDictionary _shellProcessMap = new(); private ServerRpc? _rpc; /// @@ -473,6 +474,9 @@ public async Task CreateSessionAsync(SessionConfig config, Cance session.On(config.OnEvent); } _sessions[sessionId] = session; + session.SetShellProcessCallbacks( + (processId, s) => _shellProcessMap[processId] = s, + processId => _shellProcessMap.TryRemove(processId, out _)); try { @@ -587,6 +591,9 @@ public async Task ResumeSessionAsync(string sessionId, ResumeSes session.On(config.OnEvent); } _sessions[sessionId] = session; + session.SetShellProcessCallbacks( + (processId, s) => _shellProcessMap[processId] = s, + processId => _shellProcessMap.TryRemove(processId, out _)); try { @@ -1272,6 +1279,8 @@ private async Task ConnectToServerAsync(Process? cliProcess, string? rpc.AddLocalRpcMethod("permission.request", handler.OnPermissionRequestV2); rpc.AddLocalRpcMethod("userInput.request", handler.OnUserInputRequest); rpc.AddLocalRpcMethod("hooks.invoke", handler.OnHooksInvoke); + rpc.AddLocalRpcMethod("shell.output", handler.OnShellOutput); + rpc.AddLocalRpcMethod("shell.exit", handler.OnShellExit); rpc.AddLocalRpcMethod("systemMessage.transform", handler.OnSystemMessageTransform); rpc.StartListening(); @@ -1508,6 +1517,58 @@ public async Task OnPermissionRequestV2(string sess }); } } + + public void OnShellOutput(string processId, string stream, string data, string? sessionId = null) + { + CopilotSession? session = null; + if (!string.IsNullOrEmpty(sessionId)) + { + session = client.GetSession(sessionId!); + } + + if (session is null) + { + client._shellProcessMap.TryGetValue(processId, out session); + } + + if (session is not null) + { + session.DispatchShellOutput(new ShellOutputNotification + { + SessionId = sessionId, + ProcessId = processId, + Stream = stream, + Data = data, + }); + } + } + + public void OnShellExit(string processId, int exitCode, string? sessionId = null) + { + CopilotSession? session = null; + if (!string.IsNullOrEmpty(sessionId)) + { + session = client.GetSession(sessionId!); + } + + if (session is null) + { + client._shellProcessMap.TryGetValue(processId, out session); + } + + if (session is not null) + { + session.DispatchShellExit(new ShellExitNotification + { + SessionId = sessionId, + ProcessId = processId, + ExitCode = exitCode, + }); + // Clean up the mapping after exit + client._shellProcessMap.TryRemove(processId, out _); + session.UntrackShellProcess(processId); + } + } } private class Connection( diff --git a/dotnet/src/Generated/Rpc.cs b/dotnet/src/Generated/Rpc.cs index fabe4817..5624b89f 100644 --- a/dotnet/src/Generated/Rpc.cs +++ b/dotnet/src/Generated/Rpc.cs @@ -1325,11 +1325,13 @@ public class SessionRpc { private readonly JsonRpc _rpc; private readonly string _sessionId; + private readonly Action? _onShellExec; - internal SessionRpc(JsonRpc rpc, string sessionId) + internal SessionRpc(JsonRpc rpc, string sessionId, Action? onShellExec = null) { _rpc = rpc; _sessionId = sessionId; + _onShellExec = onShellExec; Model = new ModelApi(rpc, sessionId); Mode = new ModeApi(rpc, sessionId); Plan = new PlanApi(rpc, sessionId); @@ -1345,7 +1347,7 @@ internal SessionRpc(JsonRpc rpc, string sessionId) Commands = new CommandsApi(rpc, sessionId); Ui = new UiApi(rpc, sessionId); Permissions = new PermissionsApi(rpc, sessionId); - Shell = new ShellApi(rpc, sessionId); + Shell = new ShellApi(rpc, sessionId, _onShellExec); } /// Model APIs. @@ -1849,18 +1851,22 @@ public class ShellApi { private readonly JsonRpc _rpc; private readonly string _sessionId; + private readonly Action? _onExec; - internal ShellApi(JsonRpc rpc, string sessionId) + internal ShellApi(JsonRpc rpc, string sessionId, Action? onExec = null) { _rpc = rpc; _sessionId = sessionId; + _onExec = onExec; } /// Calls "session.shell.exec". public async Task ExecAsync(string command, string? cwd = null, double? timeout = null, CancellationToken cancellationToken = default) { var request = new SessionShellExecRequest { SessionId = _sessionId, Command = command, Cwd = cwd, Timeout = timeout }; - return await CopilotClient.InvokeRpcAsync(_rpc, "session.shell.exec", [request], cancellationToken); + var result = await CopilotClient.InvokeRpcAsync(_rpc, "session.shell.exec", [request], cancellationToken); + _onExec?.Invoke(result.ProcessId); + return result; } /// Calls "session.shell.kill". diff --git a/dotnet/src/Session.cs b/dotnet/src/Session.cs index 675a3e0c..89d7c9e3 100644 --- a/dotnet/src/Session.cs +++ b/dotnet/src/Session.cs @@ -69,6 +69,12 @@ public sealed partial class CopilotSession : IAsyncDisposable private readonly SemaphoreSlim _transformCallbacksLock = new(1, 1); private SessionRpc? _sessionRpc; private int _isDisposed; + private event Action? ShellOutputHandlers; + private event Action? ShellExitHandlers; + private readonly HashSet _trackedProcessIds = []; + private readonly object _trackedProcessIdsLock = new(); + private Action? _registerShellProcess; + private Action? _unregisterShellProcess; /// /// Channel that serializes event dispatch. enqueues; @@ -87,7 +93,7 @@ public sealed partial class CopilotSession : IAsyncDisposable /// /// Gets the typed RPC client for session-scoped methods. /// - public SessionRpc Rpc => _sessionRpc ??= new SessionRpc(_rpc, SessionId); + public SessionRpc Rpc => _sessionRpc ??= new SessionRpc(_rpc, SessionId, TrackShellProcess); /// /// Gets the path to the session workspace directory when infinite sessions are enabled. @@ -284,6 +290,52 @@ public IDisposable On(SessionEventHandler handler) return new ActionDisposable(() => ImmutableInterlocked.Update(ref _eventHandlers, array => array.Remove(handler))); } + /// + /// Subscribes to shell output notifications for this session. + /// + /// A callback that receives shell output notifications. + /// An that unsubscribes the handler when disposed. + /// + /// Shell output notifications are streamed in chunks when commands started + /// via session.Rpc.Shell.ExecAsync() produce stdout or stderr output. + /// + /// + /// + /// using var sub = session.OnShellOutput(n => + /// { + /// Console.WriteLine($"[{n.ProcessId}:{n.Stream}] {n.Data}"); + /// }); + /// + /// + public IDisposable OnShellOutput(Action handler) + { + ShellOutputHandlers += handler; + return new ActionDisposable(() => ShellOutputHandlers -= handler); + } + + /// + /// Subscribes to shell exit notifications for this session. + /// + /// A callback that receives shell exit notifications. + /// An that unsubscribes the handler when disposed. + /// + /// Shell exit notifications are sent when commands started via + /// session.Rpc.Shell.ExecAsync() complete (after all output has been streamed). + /// + /// + /// + /// using var sub = session.OnShellExit(n => + /// { + /// Console.WriteLine($"Process {n.ProcessId} exited with code {n.ExitCode}"); + /// }); + /// + /// + public IDisposable OnShellExit(Action handler) + { + ShellExitHandlers += handler; + return new ActionDisposable(() => ShellExitHandlers -= handler); + } + /// /// Enqueues an event for serial dispatch to all registered handlers. /// @@ -329,6 +381,57 @@ private async Task ProcessEventsAsync() } } + /// + /// Dispatches a shell output notification to all registered handlers. + /// + internal void DispatchShellOutput(ShellOutputNotification notification) + { + ShellOutputHandlers?.Invoke(notification); + } + + /// + /// Dispatches a shell exit notification to all registered handlers. + /// + internal void DispatchShellExit(ShellExitNotification notification) + { + ShellExitHandlers?.Invoke(notification); + } + + /// + /// Track a shell process ID so notifications are routed to this session. + /// + internal void TrackShellProcess(string processId) + { + lock (_trackedProcessIdsLock) + { + _trackedProcessIds.Add(processId); + } + _registerShellProcess?.Invoke(processId, this); + } + + /// + /// Stop tracking a shell process ID. + /// + internal void UntrackShellProcess(string processId) + { + lock (_trackedProcessIdsLock) + { + _trackedProcessIds.Remove(processId); + } + _unregisterShellProcess?.Invoke(processId); + } + + /// + /// Set the registration callbacks for shell process tracking. + /// + internal void SetShellProcessCallbacks( + Action register, + Action unregister) + { + _registerShellProcess = register; + _unregisterShellProcess = unregister; + } + /// /// Registers custom tool handlers for this session. /// @@ -889,6 +992,18 @@ await InvokeRpcAsync( } _eventHandlers = ImmutableInterlocked.InterlockedExchange(ref _eventHandlers, ImmutableArray.Empty); + ShellOutputHandlers = null; + ShellExitHandlers = null; + + lock (_trackedProcessIdsLock) + { + foreach (var processId in _trackedProcessIds) + { + _unregisterShellProcess?.Invoke(processId); + } + _trackedProcessIds.Clear(); + } + _toolHandlers.Clear(); _permissionHandler = null; diff --git a/dotnet/src/Types.cs b/dotnet/src/Types.cs index d6530f9c..b82dd07e 100644 --- a/dotnet/src/Types.cs +++ b/dotnet/src/Types.cs @@ -2083,6 +2083,66 @@ public class SessionLifecycleEvent public SessionLifecycleEventMetadata? Metadata { get; set; } } +// ============================================================================ +// Shell Notification Types +// ============================================================================ + +/// +/// Notification sent when a shell command produces output. +/// Streamed in chunks (up to 64KB per notification). +/// +public class ShellOutputNotification +{ + /// + /// Process identifier returned by shell.exec. + /// + [JsonPropertyName("processId")] + public string ProcessId { get; set; } = string.Empty; + + /// + /// Identifier of the session that produced this notification, when provided by the runtime. + /// + [JsonPropertyName("sessionId")] + public string? SessionId { get; set; } + + /// + /// Which output stream produced this chunk ("stdout" or "stderr"). + /// + [JsonPropertyName("stream")] + public string Stream { get; set; } = string.Empty; + + /// + /// The output data (UTF-8 string). + /// + [JsonPropertyName("data")] + public string Data { get; set; } = string.Empty; +} + +/// +/// Notification sent when a shell command exits. +/// Sent after all output has been streamed. +/// +public class ShellExitNotification +{ + /// + /// Process identifier returned by shell.exec. + /// + [JsonPropertyName("processId")] + public string ProcessId { get; set; } = string.Empty; + + /// + /// Identifier of the session that produced this notification, when provided by the runtime. + /// + [JsonPropertyName("sessionId")] + public string? SessionId { get; set; } + + /// + /// Process exit code (0 = success). + /// + [JsonPropertyName("exitCode")] + public int ExitCode { get; set; } +} + /// /// Response from session.getForeground /// @@ -2171,6 +2231,8 @@ public class SystemMessageTransformRpcResponse [JsonSerializable(typeof(SessionContext))] [JsonSerializable(typeof(SessionLifecycleEvent))] [JsonSerializable(typeof(SessionLifecycleEventMetadata))] +[JsonSerializable(typeof(ShellExitNotification))] +[JsonSerializable(typeof(ShellOutputNotification))] [JsonSerializable(typeof(SessionListFilter))] [JsonSerializable(typeof(SectionOverride))] [JsonSerializable(typeof(SessionMetadata))] diff --git a/dotnet/test/MultiClientTests.cs b/dotnet/test/MultiClientTests.cs index bdd264a4..d8a0803c 100644 --- a/dotnet/test/MultiClientTests.cs +++ b/dotnet/test/MultiClientTests.cs @@ -169,9 +169,23 @@ public async Task One_Client_Approves_Permission_And_Both_See_The_Result() var client1Events = new ConcurrentBag(); var client2Events = new ConcurrentBag(); + var client1Requested = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var client2Requested = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var client1Completed = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var client2Completed = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - using var sub1 = session1.On(evt => client1Events.Add(evt)); - using var sub2 = session2.On(evt => client2Events.Add(evt)); + using var sub1 = session1.On(evt => + { + client1Events.Add(evt); + if (evt is PermissionRequestedEvent) client1Requested.TrySetResult(true); + if (evt is PermissionCompletedEvent) client1Completed.TrySetResult(true); + }); + using var sub2 = session2.On(evt => + { + client2Events.Add(evt); + if (evt is PermissionRequestedEvent) client2Requested.TrySetResult(true); + if (evt is PermissionCompletedEvent) client2Completed.TrySetResult(true); + }); var response = await session1.SendAndWaitAsync(new MessageOptions { @@ -181,6 +195,10 @@ public async Task One_Client_Approves_Permission_And_Both_See_The_Result() Assert.NotNull(response); Assert.NotEmpty(client1PermissionRequests); + await Task.WhenAll( + client1Requested.Task, client2Requested.Task, + client1Completed.Task, client2Completed.Task).WaitAsync(TimeSpan.FromSeconds(10)); + Assert.Contains(client1Events, e => e is PermissionRequestedEvent); Assert.Contains(client2Events, e => e is PermissionRequestedEvent); Assert.Contains(client1Events, e => e is PermissionCompletedEvent); @@ -214,9 +232,23 @@ public async Task One_Client_Rejects_Permission_And_Both_See_The_Result() var client1Events = new ConcurrentBag(); var client2Events = new ConcurrentBag(); + var client1Requested = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var client2Requested = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var client1Completed = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var client2Completed = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - using var sub1 = session1.On(evt => client1Events.Add(evt)); - using var sub2 = session2.On(evt => client2Events.Add(evt)); + using var sub1 = session1.On(evt => + { + client1Events.Add(evt); + if (evt is PermissionRequestedEvent) client1Requested.TrySetResult(true); + if (evt is PermissionCompletedEvent) client1Completed.TrySetResult(true); + }); + using var sub2 = session2.On(evt => + { + client2Events.Add(evt); + if (evt is PermissionRequestedEvent) client2Requested.TrySetResult(true); + if (evt is PermissionCompletedEvent) client2Completed.TrySetResult(true); + }); // Write a file so the agent has something to edit await File.WriteAllTextAsync(Path.Combine(Ctx.WorkDir, "protected.txt"), "protected content"); @@ -230,6 +262,10 @@ await session1.SendAndWaitAsync(new MessageOptions var content = await File.ReadAllTextAsync(Path.Combine(Ctx.WorkDir, "protected.txt")); Assert.Equal("protected content", content); + await Task.WhenAll( + client1Requested.Task, client2Requested.Task, + client1Completed.Task, client2Completed.Task).WaitAsync(TimeSpan.FromSeconds(10)); + Assert.Contains(client1Events, e => e is PermissionRequestedEvent); Assert.Contains(client2Events, e => e is PermissionRequestedEvent); diff --git a/go/client.go b/go/client.go index 22be47ec..59807a2c 100644 --- a/go/client.go +++ b/go/client.go @@ -91,6 +91,8 @@ type Client struct { lifecycleHandlers []SessionLifecycleHandler typedLifecycleHandlers map[SessionLifecycleEventType][]SessionLifecycleHandler lifecycleHandlersMux sync.Mutex + shellProcessMap map[string]*Session + shellProcessMapMux sync.Mutex startStopMux sync.RWMutex // protects process and state during start/[force]stop processDone chan struct{} processErrorPtr *error @@ -130,6 +132,7 @@ func NewClient(options *ClientOptions) *Client { options: opts, state: StateDisconnected, sessions: make(map[string]*Session), + shellProcessMap: make(map[string]*Session), actualHost: "localhost", isExternalServer: false, useStdio: true, @@ -571,6 +574,7 @@ func (c *Client) CreateSession(ctx context.Context, config *SessionConfig) (*Ses // Create and register the session before issuing the RPC so that // events emitted by the CLI (e.g. session.start) are not dropped. session := newSession(sessionID, c.client, "") + session.setShellProcessCallbacks(c.registerShellProcess, c.unregisterShellProcess) session.registerTools(config.Tools) session.registerPermissionHandler(config.OnPermissionRequest) @@ -692,6 +696,7 @@ func (c *Client) ResumeSessionWithOptions(ctx context.Context, sessionID string, // Create and register the session before issuing the RPC so that // events emitted by the CLI (e.g. session.start) are not dropped. session := newSession(sessionID, c.client, "") + session.setShellProcessCallbacks(c.registerShellProcess, c.unregisterShellProcess) session.registerTools(config.Tools) session.registerPermissionHandler(config.OnPermissionRequest) @@ -1441,6 +1446,8 @@ func (c *Client) setupNotificationHandler() { c.client.SetRequestHandler("permission.request", jsonrpc2.RequestHandlerFor(c.handlePermissionRequestV2)) c.client.SetRequestHandler("userInput.request", jsonrpc2.RequestHandlerFor(c.handleUserInputRequest)) c.client.SetRequestHandler("hooks.invoke", jsonrpc2.RequestHandlerFor(c.handleHooksInvoke)) + c.client.SetRequestHandler("shell.output", jsonrpc2.NotificationHandlerFor(c.handleShellOutput)) + c.client.SetRequestHandler("shell.exit", jsonrpc2.NotificationHandlerFor(c.handleShellExit)) c.client.SetRequestHandler("systemMessage.transform", jsonrpc2.RequestHandlerFor(c.handleSystemMessageTransform)) } @@ -1458,6 +1465,68 @@ func (c *Client) handleSessionEvent(req sessionEventRequest) { } } +func (c *Client) handleShellOutput(notification ShellOutputNotification) { + var sessionID string + if notification.SessionID != nil { + sessionID = *notification.SessionID + } + session, ok := c.getShellNotificationSession(sessionID, notification.ProcessID) + + if ok { + session.dispatchShellOutput(notification) + } +} + +func (c *Client) handleShellExit(notification ShellExitNotification) { + var sessionID string + if notification.SessionID != nil { + sessionID = *notification.SessionID + } + session, ok := c.getShellNotificationSession(sessionID, notification.ProcessID) + + if ok { + session.dispatchShellExit(notification) + if notification.ProcessID != "" { + c.shellProcessMapMux.Lock() + delete(c.shellProcessMap, notification.ProcessID) + c.shellProcessMapMux.Unlock() + session.untrackShellProcess(notification.ProcessID) + } + } +} + +func (c *Client) getShellNotificationSession(sessionID, processID string) (*Session, bool) { + if sessionID != "" { + c.sessionsMux.Lock() + session, ok := c.sessions[sessionID] + c.sessionsMux.Unlock() + if ok { + return session, true + } + } + + if processID != "" { + c.shellProcessMapMux.Lock() + session, ok := c.shellProcessMap[processID] + c.shellProcessMapMux.Unlock() + return session, ok + } + + return nil, false +} + +func (c *Client) registerShellProcess(processID string, session *Session) { + c.shellProcessMapMux.Lock() + c.shellProcessMap[processID] = session + c.shellProcessMapMux.Unlock() +} + +func (c *Client) unregisterShellProcess(processID string) { + c.shellProcessMapMux.Lock() + delete(c.shellProcessMap, processID) + c.shellProcessMapMux.Unlock() +} + // handleUserInputRequest handles a user input request from the CLI server. func (c *Client) handleUserInputRequest(req userInputRequest) (*userInputResponse, *jsonrpc2.Error) { if req.SessionID == "" || req.Question == "" { diff --git a/go/rpc/generated_rpc.go b/go/rpc/generated_rpc.go index b9ba408b..a9fbefc4 100644 --- a/go/rpc/generated_rpc.go +++ b/go/rpc/generated_rpc.go @@ -1383,6 +1383,7 @@ func (a *PermissionsRpcApi) HandlePendingPermissionRequest(ctx context.Context, type ShellRpcApi struct { client *jsonrpc2.Client sessionID string + onExec func(string) } func (a *ShellRpcApi) Exec(ctx context.Context, params *SessionShellExecParams) (*SessionShellExecResult, error) { @@ -1404,6 +1405,9 @@ func (a *ShellRpcApi) Exec(ctx context.Context, params *SessionShellExecParams) if err := json.Unmarshal(raw, &result); err != nil { return nil, err } + if a.onExec != nil { + a.onExec(result.ProcessID) + } return &result, nil } @@ -1473,7 +1477,11 @@ func (a *SessionRpc) Log(ctx context.Context, params *SessionLogParams) (*Sessio return &result, nil } -func NewSessionRpc(client *jsonrpc2.Client, sessionID string) *SessionRpc { +func NewSessionRpc(client *jsonrpc2.Client, sessionID string, onShellExec ...func(string)) *SessionRpc { + var shellExecHandler func(string) + if len(onShellExec) > 0 { + shellExecHandler = onShellExec[0] + } return &SessionRpc{client: client, sessionID: sessionID, Model: &ModelRpcApi{client: client, sessionID: sessionID}, Mode: &ModeRpcApi{client: client, sessionID: sessionID}, @@ -1490,6 +1498,6 @@ func NewSessionRpc(client *jsonrpc2.Client, sessionID string) *SessionRpc { Commands: &CommandsRpcApi{client: client, sessionID: sessionID}, Ui: &UiRpcApi{client: client, sessionID: sessionID}, Permissions: &PermissionsRpcApi{client: client, sessionID: sessionID}, - Shell: &ShellRpcApi{client: client, sessionID: sessionID}, + Shell: &ShellRpcApi{client: client, sessionID: sessionID, onExec: shellExecHandler}, } } diff --git a/go/session.go b/go/session.go index 3a94a818..644b7646 100644 --- a/go/session.go +++ b/go/session.go @@ -17,6 +17,16 @@ type sessionHandler struct { fn SessionEventHandler } +type shellOutputHandlerEntry struct { + id uint64 + fn ShellOutputHandler +} + +type shellExitHandlerEntry struct { + id uint64 + fn ShellExitHandler +} + // Session represents a single conversation session with the Copilot CLI. // // A session maintains conversation state, handles events, and manages tool execution. @@ -72,6 +82,15 @@ type Session struct { eventCh chan SessionEvent closeOnce sync.Once // guards eventCh close so Disconnect is safe to call more than once + shellOutputHandlers []shellOutputHandlerEntry + shellExitHandlers []shellExitHandlerEntry + shellHandlerMux sync.RWMutex + nextShellHandlerID uint64 + trackedProcessIDs map[string]struct{} + trackedProcessMux sync.Mutex + registerShellProc func(processID string, session *Session) + unregisterShellProc func(processID string) + // RPC provides typed session-scoped RPC methods. RPC *rpc.SessionRpc } @@ -86,14 +105,15 @@ func (s *Session) WorkspacePath() string { // newSession creates a new session wrapper with the given session ID and client. func newSession(sessionID string, client *jsonrpc2.Client, workspacePath string) *Session { s := &Session{ - SessionID: sessionID, - workspacePath: workspacePath, - client: client, - handlers: make([]sessionHandler, 0), - toolHandlers: make(map[string]ToolHandler), - eventCh: make(chan SessionEvent, 128), - RPC: rpc.NewSessionRpc(client, sessionID), + SessionID: sessionID, + workspacePath: workspacePath, + client: client, + handlers: make([]sessionHandler, 0), + toolHandlers: make(map[string]ToolHandler), + eventCh: make(chan SessionEvent, 128), + trackedProcessIDs: make(map[string]struct{}), } + s.RPC = rpc.NewSessionRpc(client, sessionID, s.trackShellProcess) go s.processEvents() return s } @@ -269,6 +289,74 @@ func (s *Session) On(handler SessionEventHandler) func() { } } +// OnShellOutput subscribes to shell output notifications for this session. +// +// Shell output notifications are streamed in chunks when commands started +// via session.RPC.Shell.Exec() produce stdout or stderr output. +// +// The returned function can be called to unsubscribe the handler. +// +// Example: +// +// unsubscribe := session.OnShellOutput(func(n copilot.ShellOutputNotification) { +// fmt.Printf("[%s:%s] %s", n.ProcessID, n.Stream, n.Data) +// }) +// defer unsubscribe() +func (s *Session) OnShellOutput(handler ShellOutputHandler) func() { + s.shellHandlerMux.Lock() + defer s.shellHandlerMux.Unlock() + + id := s.nextShellHandlerID + s.nextShellHandlerID++ + s.shellOutputHandlers = append(s.shellOutputHandlers, shellOutputHandlerEntry{id: id, fn: handler}) + + return func() { + s.shellHandlerMux.Lock() + defer s.shellHandlerMux.Unlock() + + for i, h := range s.shellOutputHandlers { + if h.id == id { + s.shellOutputHandlers = append(s.shellOutputHandlers[:i], s.shellOutputHandlers[i+1:]...) + break + } + } + } +} + +// OnShellExit subscribes to shell exit notifications for this session. +// +// Shell exit notifications are sent when commands started via +// session.RPC.Shell.Exec() complete (after all output has been streamed). +// +// The returned function can be called to unsubscribe the handler. +// +// Example: +// +// unsubscribe := session.OnShellExit(func(n copilot.ShellExitNotification) { +// fmt.Printf("Process %s exited with code %d\n", n.ProcessID, n.ExitCode) +// }) +// defer unsubscribe() +func (s *Session) OnShellExit(handler ShellExitHandler) func() { + s.shellHandlerMux.Lock() + defer s.shellHandlerMux.Unlock() + + id := s.nextShellHandlerID + s.nextShellHandlerID++ + s.shellExitHandlers = append(s.shellExitHandlers, shellExitHandlerEntry{id: id, fn: handler}) + + return func() { + s.shellHandlerMux.Lock() + defer s.shellHandlerMux.Unlock() + + for i, h := range s.shellExitHandlers { + if h.id == id { + s.shellExitHandlers = append(s.shellExitHandlers[:i], s.shellExitHandlers[i+1:]...) + break + } + } + } +} + // registerTools registers tool handlers for this session. // // Tools allow the assistant to execute custom functions. When the assistant @@ -544,6 +632,77 @@ func (s *Session) processEvents() { } } +// dispatchShellOutput dispatches a shell output notification to all registered handlers. +func (s *Session) dispatchShellOutput(notification ShellOutputNotification) { + s.shellHandlerMux.RLock() + handlers := make([]ShellOutputHandler, 0, len(s.shellOutputHandlers)) + for _, h := range s.shellOutputHandlers { + handlers = append(handlers, h.fn) + } + s.shellHandlerMux.RUnlock() + + for _, handler := range handlers { + func() { + defer func() { + if r := recover(); r != nil { + fmt.Printf("Error in shell output handler: %v\n", r) + } + }() + handler(notification) + }() + } +} + +// dispatchShellExit dispatches a shell exit notification to all registered handlers. +func (s *Session) dispatchShellExit(notification ShellExitNotification) { + s.shellHandlerMux.RLock() + handlers := make([]ShellExitHandler, 0, len(s.shellExitHandlers)) + for _, h := range s.shellExitHandlers { + handlers = append(handlers, h.fn) + } + s.shellHandlerMux.RUnlock() + + for _, handler := range handlers { + func() { + defer func() { + if r := recover(); r != nil { + fmt.Printf("Error in shell exit handler: %v\n", r) + } + }() + handler(notification) + }() + } +} + +// trackShellProcess starts tracking a shell process ID. +func (s *Session) trackShellProcess(processID string) { + s.trackedProcessMux.Lock() + s.trackedProcessIDs[processID] = struct{}{} + s.trackedProcessMux.Unlock() + if s.registerShellProc != nil { + s.registerShellProc(processID, s) + } +} + +// untrackShellProcess stops tracking a shell process ID. +func (s *Session) untrackShellProcess(processID string) { + s.trackedProcessMux.Lock() + delete(s.trackedProcessIDs, processID) + s.trackedProcessMux.Unlock() + if s.unregisterShellProc != nil { + s.unregisterShellProc(processID) + } +} + +// setShellProcessCallbacks sets the registration callbacks for shell process tracking. +func (s *Session) setShellProcessCallbacks( + register func(processID string, session *Session), + unregister func(processID string), +) { + s.registerShellProc = register + s.unregisterShellProc = unregister +} + // handleBroadcastEvent handles broadcast request events by executing local handlers // and responding via RPC. This implements the protocol v3 broadcast model where tool // calls and permission requests are broadcast as session events to all clients. @@ -748,6 +907,20 @@ func (s *Session) Disconnect() error { s.permissionHandler = nil s.permissionMux.Unlock() + s.shellHandlerMux.Lock() + s.shellOutputHandlers = nil + s.shellExitHandlers = nil + s.shellHandlerMux.Unlock() + + s.trackedProcessMux.Lock() + for processID := range s.trackedProcessIDs { + if s.unregisterShellProc != nil { + s.unregisterShellProc(processID) + } + } + s.trackedProcessIDs = nil + s.trackedProcessMux.Unlock() + return nil } diff --git a/go/types.go b/go/types.go index 502d61c1..f79efb6e 100644 --- a/go/types.go +++ b/go/types.go @@ -740,6 +740,46 @@ type SessionLifecycleEventMetadata struct { // SessionLifecycleHandler is a callback for session lifecycle events type SessionLifecycleHandler func(event SessionLifecycleEvent) +// ShellOutputStream represents the output stream identifier for shell notifications. +type ShellOutputStream string + +const ( + // ShellStreamStdout represents standard output. + ShellStreamStdout ShellOutputStream = "stdout" + // ShellStreamStderr represents standard error. + ShellStreamStderr ShellOutputStream = "stderr" +) + +// ShellOutputNotification is sent when a shell command produces output. +// Streamed in chunks (up to 64KB per notification). +type ShellOutputNotification struct { + // ProcessID is the process identifier returned by shell.exec. + ProcessID string `json:"processId"` + // SessionID is the optional session identifier for direct routing when provided by the runtime. + SessionID *string `json:"sessionId,omitempty"` + // Stream indicates which output stream produced this chunk. + Stream ShellOutputStream `json:"stream"` + // Data is the output data (UTF-8 string). + Data string `json:"data"` +} + +// ShellExitNotification is sent when a shell command exits. +// Sent after all output has been streamed. +type ShellExitNotification struct { + // ProcessID is the process identifier returned by shell.exec. + ProcessID string `json:"processId"` + // SessionID is the optional session identifier for direct routing when provided by the runtime. + SessionID *string `json:"sessionId,omitempty"` + // ExitCode is the process exit code (0 = success). + ExitCode int `json:"exitCode"` +} + +// ShellOutputHandler is a callback for shell output notifications. +type ShellOutputHandler func(notification ShellOutputNotification) + +// ShellExitHandler is a callback for shell exit notifications. +type ShellExitHandler func(notification ShellExitNotification) + // createSessionRequest is the request for session.create type createSessionRequest struct { Model string `json:"model,omitempty"` diff --git a/nodejs/src/client.ts b/nodejs/src/client.ts index 9b8af3dd..0464ad62 100644 --- a/nodejs/src/client.ts +++ b/nodejs/src/client.ts @@ -45,6 +45,8 @@ import type { SessionLifecycleHandler, SessionListFilter, SessionMetadata, + ShellExitNotification, + ShellOutputNotification, SystemMessageCustomizeConfig, TelemetryConfig, Tool, @@ -236,6 +238,7 @@ export class CopilotClient { Set<(event: SessionLifecycleEvent) => void> > = new Map(); private _rpc: ReturnType | null = null; + private shellProcessMap: Map = new Map(); private processExitPromise: Promise | null = null; // Rejects when CLI process exits private negotiatedProtocolVersion: number | null = null; @@ -638,6 +641,10 @@ export class CopilotClient { undefined, this.onGetTraceContext ); + session._setShellProcessCallbacks( + (processId, trackedSession) => this.shellProcessMap.set(processId, trackedSession), + (processId) => this.shellProcessMap.delete(processId) + ); session.registerTools(config.tools); session.registerPermissionHandler(config.onPermissionRequest); if (config.onUserInputRequest) { @@ -753,6 +760,10 @@ export class CopilotClient { undefined, this.onGetTraceContext ); + session._setShellProcessCallbacks( + (processId, trackedSession) => this.shellProcessMap.set(processId, trackedSession), + (processId) => this.shellProcessMap.delete(processId) + ); session.registerTools(config.tools); session.registerPermissionHandler(config.onPermissionRequest); if (config.onUserInputRequest) { @@ -1497,6 +1508,14 @@ export class CopilotClient { this.handleSessionLifecycleNotification(notification); }); + this.connection.onNotification("shell.output", (notification: unknown) => { + this.handleShellOutputNotification(notification); + }); + + this.connection.onNotification("shell.exit", (notification: unknown) => { + this.handleShellExitNotification(notification); + }); + // Protocol v3 servers send tool calls and permission requests as broadcast events // (external_tool.requested / permission.requested) handled in CopilotSession._dispatchEvent. // Protocol v2 servers use the older tool.call / permission.request RPC model instead. @@ -1607,6 +1626,53 @@ export class CopilotClient { } } + private handleShellOutputNotification(notification: unknown): void { + if (typeof notification !== "object" || !notification) { + return; + } + + const session = this.getSessionForShellNotification( + notification as { sessionId?: unknown; processId?: unknown } + ); + if (session) { + session._dispatchShellOutput(notification as ShellOutputNotification); + } + } + + private handleShellExitNotification(notification: unknown): void { + if (typeof notification !== "object" || !notification) { + return; + } + + const typedNotification = notification as { sessionId?: unknown; processId?: unknown }; + const session = this.getSessionForShellNotification(typedNotification); + if (session) { + session._dispatchShellExit(notification as ShellExitNotification); + if (typeof typedNotification.processId === "string") { + this.shellProcessMap.delete(typedNotification.processId); + session._untrackShellProcess(typedNotification.processId); + } + } + } + + private getSessionForShellNotification(notification: { + sessionId?: unknown; + processId?: unknown; + }): CopilotSession | undefined { + if (typeof notification.sessionId === "string") { + const session = this.sessions.get(notification.sessionId); + if (session) { + return session; + } + } + + if (typeof notification.processId === "string") { + return this.shellProcessMap.get(notification.processId); + } + + return undefined; + } + private async handleUserInputRequest(params: { sessionId: string; question: string; diff --git a/nodejs/src/index.ts b/nodejs/src/index.ts index f3788e16..301a1056 100644 --- a/nodejs/src/index.ts +++ b/nodejs/src/index.ts @@ -45,6 +45,11 @@ export type { SessionContext, SessionListFilter, SessionMetadata, + ShellExitHandler, + ShellExitNotification, + ShellOutputHandler, + ShellOutputNotification, + ShellOutputStream, SystemMessageAppendConfig, SystemMessageConfig, SystemMessageCustomizeConfig, diff --git a/nodejs/src/session.ts b/nodejs/src/session.ts index 122f4ece..3d45837d 100644 --- a/nodejs/src/session.ts +++ b/nodejs/src/session.ts @@ -23,6 +23,10 @@ import type { SessionEventPayload, SessionEventType, SessionHooks, + ShellExitHandler, + ShellExitNotification, + ShellOutputHandler, + ShellOutputNotification, Tool, ToolHandler, TraceContextProvider, @@ -71,6 +75,11 @@ export class CopilotSession { private permissionHandler?: PermissionHandler; private userInputHandler?: UserInputHandler; private hooks?: SessionHooks; + private shellOutputHandlers: Set = new Set(); + private shellExitHandlers: Set = new Set(); + private trackedProcessIds: Set = new Set(); + private _registerShellProcess?: (processId: string, session: CopilotSession) => void; + private _unregisterShellProcess?: (processId: string) => void; private transformCallbacks?: Map; private _rpc: ReturnType | null = null; private traceContextProvider?: TraceContextProvider; @@ -98,7 +107,14 @@ export class CopilotSession { */ get rpc(): ReturnType { if (!this._rpc) { - this._rpc = createSessionRpc(this.connection, this.sessionId); + const rpc = createSessionRpc(this.connection, this.sessionId); + const exec = rpc.shell.exec; + rpc.shell.exec = async (params) => { + const result = await exec(params); + this._trackShellProcess(result.processId); + return result; + }; + this._rpc = rpc; } return this._rpc; } @@ -297,6 +313,52 @@ export class CopilotSession { }; } + /** + * Subscribe to shell output notifications for this session. + * + * Shell output notifications are streamed in chunks when commands started + * via `session.rpc.shell.exec()` produce stdout or stderr output. + * + * @param handler - Callback receiving shell output notifications + * @returns A function that, when called, unsubscribes the handler + * + * @example + * ```typescript + * const unsubscribe = session.onShellOutput((notification) => { + * console.log(`[${notification.processId}:${notification.stream}] ${notification.data}`); + * }); + * ``` + */ + onShellOutput(handler: ShellOutputHandler): () => void { + this.shellOutputHandlers.add(handler); + return () => { + this.shellOutputHandlers.delete(handler); + }; + } + + /** + * Subscribe to shell exit notifications for this session. + * + * Shell exit notifications are sent when commands started via + * `session.rpc.shell.exec()` complete (after all output has been streamed). + * + * @param handler - Callback receiving shell exit notifications + * @returns A function that, when called, unsubscribes the handler + * + * @example + * ```typescript + * const unsubscribe = session.onShellExit((notification) => { + * console.log(`Process ${notification.processId} exited with code ${notification.exitCode}`); + * }); + * ``` + */ + onShellExit(handler: ShellExitHandler): () => void { + this.shellExitHandlers.add(handler); + return () => { + this.shellExitHandlers.delete(handler); + }; + } + /** * Dispatches an event to all registered handlers. * Also handles broadcast request events internally (external tool calls, permissions). @@ -330,6 +392,59 @@ export class CopilotSession { } } + /** @internal */ + _dispatchShellOutput(notification: ShellOutputNotification): void { + for (const handler of this.shellOutputHandlers) { + try { + handler(notification); + } catch { + // Ignore handler errors + } + } + } + + /** @internal */ + _dispatchShellExit(notification: ShellExitNotification): void { + for (const handler of this.shellExitHandlers) { + try { + handler(notification); + } catch { + // Ignore handler errors + } + } + } + + /** + * Track a shell process ID so notifications are routed to this session. + * @internal + */ + _trackShellProcess(processId: string): void { + this.trackedProcessIds.add(processId); + this._registerShellProcess?.(processId, this); + } + + /** + * Stop tracking a shell process ID. + * @internal + */ + _untrackShellProcess(processId: string): void { + this.trackedProcessIds.delete(processId); + this._unregisterShellProcess?.(processId); + } + + /** + * Set the registration callbacks for shell process tracking. + * Called by the client when setting up the session. + * @internal + */ + _setShellProcessCallbacks( + register: (processId: string, session: CopilotSession) => void, + unregister: (processId: string) => void + ): void { + this._registerShellProcess = register; + this._unregisterShellProcess = unregister; + } + /** * Handles broadcast request events by executing local handlers and responding via RPC. * Handlers are dispatched as fire-and-forget — rejections propagate as unhandled promise @@ -712,6 +827,13 @@ export class CopilotSession { this.typedEventHandlers.clear(); this.toolHandlers.clear(); this.permissionHandler = undefined; + this.shellOutputHandlers.clear(); + this.shellExitHandlers.clear(); + // Unregister all tracked processes + for (const processId of this.trackedProcessIds) { + this._unregisterShellProcess?.(processId); + } + this.trackedProcessIds.clear(); } /** diff --git a/nodejs/src/types.ts b/nodejs/src/types.ts index 992dbdb9..7a35a3f1 100644 --- a/nodejs/src/types.ts +++ b/nodejs/src/types.ts @@ -1262,6 +1262,49 @@ export type TypedSessionLifecycleHandler = event: SessionLifecycleEvent & { type: K } ) => void; +// ============================================================================ +// Shell Notification Types +// ============================================================================ + +/** + * Output stream identifier for shell notifications + */ +export type ShellOutputStream = "stdout" | "stderr"; + +/** + * Notification sent when a shell command produces output. + * Streamed in chunks (up to 64KB per notification). + */ +export interface ShellOutputNotification { + /** Process identifier returned by shell.exec */ + processId: string; + /** Which output stream produced this chunk */ + stream: ShellOutputStream; + /** The output data (UTF-8 string) */ + data: string; +} + +/** + * Notification sent when a shell command exits. + * Sent after all output has been streamed. + */ +export interface ShellExitNotification { + /** Process identifier returned by shell.exec */ + processId: string; + /** Process exit code (0 = success) */ + exitCode: number; +} + +/** + * Handler for shell output notifications + */ +export type ShellOutputHandler = (notification: ShellOutputNotification) => void; + +/** + * Handler for shell exit notifications + */ +export type ShellExitHandler = (notification: ShellExitNotification) => void; + /** * Information about the foreground session in TUI+server mode */ diff --git a/nodejs/test/session-shell.test.ts b/nodejs/test/session-shell.test.ts new file mode 100644 index 00000000..9d39bc93 --- /dev/null +++ b/nodejs/test/session-shell.test.ts @@ -0,0 +1,149 @@ +import { describe, it, expect } from "vitest"; +import { CopilotSession } from "../src/session.js"; +import type { ShellOutputNotification, ShellExitNotification } from "../src/types.js"; + +// Create a minimal mock session for testing +function createMockSession(): CopilotSession { + const mockConnection = { + sendRequest: async () => ({}), + } as any; + return new CopilotSession("test-session", mockConnection); +} + +describe("CopilotSession shell notifications", () => { + describe("onShellOutput", () => { + it("should register and dispatch shell output notifications", () => { + const session = createMockSession(); + const received: ShellOutputNotification[] = []; + + session.onShellOutput((notification) => { + received.push(notification); + }); + + const notification: ShellOutputNotification = { + processId: "proc-1", + stream: "stdout", + data: "hello world\n", + }; + + session._dispatchShellOutput(notification); + + expect(received).toHaveLength(1); + expect(received[0]).toEqual(notification); + }); + + it("should support multiple handlers", () => { + const session = createMockSession(); + const received1: ShellOutputNotification[] = []; + const received2: ShellOutputNotification[] = []; + + session.onShellOutput((n) => received1.push(n)); + session.onShellOutput((n) => received2.push(n)); + + const notification: ShellOutputNotification = { + processId: "proc-1", + stream: "stderr", + data: "error output", + }; + + session._dispatchShellOutput(notification); + + expect(received1).toHaveLength(1); + expect(received2).toHaveLength(1); + }); + + it("should unsubscribe when the returned function is called", () => { + const session = createMockSession(); + const received: ShellOutputNotification[] = []; + + const unsubscribe = session.onShellOutput((n) => received.push(n)); + + session._dispatchShellOutput({ + processId: "proc-1", + stream: "stdout", + data: "first", + }); + + unsubscribe(); + + session._dispatchShellOutput({ + processId: "proc-1", + stream: "stdout", + data: "second", + }); + + expect(received).toHaveLength(1); + expect(received[0].data).toBe("first"); + }); + + it("should not crash when a handler throws", () => { + const session = createMockSession(); + const received: ShellOutputNotification[] = []; + + session.onShellOutput(() => { + throw new Error("handler error"); + }); + session.onShellOutput((n) => received.push(n)); + + session._dispatchShellOutput({ + processId: "proc-1", + stream: "stdout", + data: "test", + }); + + expect(received).toHaveLength(1); + }); + }); + + describe("onShellExit", () => { + it("should register and dispatch shell exit notifications", () => { + const session = createMockSession(); + const received: ShellExitNotification[] = []; + + session.onShellExit((notification) => { + received.push(notification); + }); + + const notification: ShellExitNotification = { + processId: "proc-1", + exitCode: 0, + }; + + session._dispatchShellExit(notification); + + expect(received).toHaveLength(1); + expect(received[0]).toEqual(notification); + }); + + it("should unsubscribe when the returned function is called", () => { + const session = createMockSession(); + const received: ShellExitNotification[] = []; + + const unsubscribe = session.onShellExit((n) => received.push(n)); + + session._dispatchShellExit({ processId: "proc-1", exitCode: 0 }); + unsubscribe(); + session._dispatchShellExit({ processId: "proc-2", exitCode: 1 }); + + expect(received).toHaveLength(1); + }); + }); + + describe("shell process tracking", () => { + it("should track and untrack process IDs via callbacks", () => { + const session = createMockSession(); + const registered = new Map(); + + session._setShellProcessCallbacks( + (processId, s) => registered.set(processId, s), + (processId) => registered.delete(processId) + ); + + session._trackShellProcess("proc-1"); + expect(registered.has("proc-1")).toBe(true); + + session._untrackShellProcess("proc-1"); + expect(registered.has("proc-1")).toBe(false); + }); + }); +}); diff --git a/python/copilot/__init__.py b/python/copilot/__init__.py index 92764c0e..6e736152 100644 --- a/python/copilot/__init__.py +++ b/python/copilot/__init__.py @@ -4,8 +4,62 @@ JSON-RPC based SDK for programmatic control of GitHub Copilot CLI """ -from .client import CopilotClient, ExternalServerConfig, SubprocessConfig -from .session import CopilotSession +from .client import ( + ConnectionState, + CopilotClient, + ExternalServerConfig, + GetAuthStatusResponse, + GetStatusResponse, + ModelBilling, + ModelCapabilities, + ModelInfo, + ModelPolicy, + PingResponse, + SessionContext, + SessionLifecycleEvent, + SessionLifecycleEventType, + SessionLifecycleHandler, + SessionListFilter, + SessionMetadata, + StopError, + SubprocessConfig, + TelemetryConfig, +) +from .generated.session_events import PermissionRequest, SessionEvent +from .session import ( + SYSTEM_PROMPT_SECTIONS, + Attachment, + AzureProviderOptions, + BlobAttachment, + CopilotSession, + CustomAgentConfig, + DirectoryAttachment, + FileAttachment, + MCPLocalServerConfig, + MCPRemoteServerConfig, + MCPServerConfig, + PermissionHandler, + PermissionRequestResult, + ProviderConfig, + SectionOverride, + SectionOverrideAction, + SectionTransformFn, + SelectionAttachment, + ShellExitHandler, + ShellExitNotification, + ShellOutputHandler, + ShellOutputNotification, + ShellOutputStream, + SystemMessageAppendConfig, + SystemMessageConfig, + SystemMessageCustomizeConfig, + SystemMessageReplaceConfig, + SystemPromptSection, + Tool, + ToolHandler, + ToolInvocation, + ToolResult, +) from .tools import define_tool __version__ = "0.1.0" @@ -13,7 +67,56 @@ __all__ = [ "CopilotClient", "CopilotSession", + "ConnectionState", "ExternalServerConfig", + "GetAuthStatusResponse", + "GetStatusResponse", + "PingResponse", + "StopError", + "FileAttachment", + "DirectoryAttachment", + "SelectionAttachment", + "BlobAttachment", + "Attachment", + "TelemetryConfig", + "AzureProviderOptions", + "ProviderConfig", + "SessionContext", + "SessionListFilter", + "SessionMetadata", + "SessionLifecycleEvent", + "SessionLifecycleEventType", + "SessionLifecycleHandler", + "MCPLocalServerConfig", + "MCPRemoteServerConfig", + "MCPServerConfig", + "CustomAgentConfig", + "ModelBilling", + "ModelCapabilities", + "ModelInfo", + "ModelPolicy", + "PermissionHandler", + "PermissionRequest", + "PermissionRequestResult", + "SectionOverride", + "SectionOverrideAction", + "SectionTransformFn", + "SessionEvent", + "ShellExitHandler", + "ShellExitNotification", + "ShellOutputHandler", + "ShellOutputNotification", + "ShellOutputStream", "SubprocessConfig", + "SystemMessageAppendConfig", + "SystemMessageConfig", + "SystemMessageCustomizeConfig", + "SystemMessageReplaceConfig", + "SystemPromptSection", + "SYSTEM_PROMPT_SECTIONS", + "Tool", + "ToolHandler", + "ToolInvocation", + "ToolResult", "define_tool", ] diff --git a/python/copilot/client.py b/python/copilot/client.py index c3bb0b29..a544ff33 100644 --- a/python/copilot/client.py +++ b/python/copilot/client.py @@ -42,6 +42,8 @@ ReasoningEffort, SectionTransformFn, SessionHooks, + ShellExitNotification, + ShellOutputNotification, SystemMessageConfig, UserInputHandler, _PermissionHandlerFn, @@ -814,6 +816,8 @@ def __init__( self._state: ConnectionState = "disconnected" self._sessions: dict[str, CopilotSession] = {} self._sessions_lock = threading.Lock() + self._shell_process_map: dict[str, CopilotSession] = {} + self._shell_process_map_lock = threading.Lock() self._models_cache: list[ModelInfo] | None = None self._models_cache_lock = asyncio.Lock() self._lifecycle_handlers: list[SessionLifecycleHandler] = [] @@ -1263,6 +1267,10 @@ async def create_session( session._register_transform_callbacks(transform_callbacks) if on_event: session.on(on_event) + session._set_shell_process_callbacks( + register=self._register_shell_process, + unregister=self._unregister_shell_process, + ) with self._sessions_lock: self._sessions[actual_session_id] = session @@ -1467,6 +1475,10 @@ async def resume_session( session._register_transform_callbacks(transform_callbacks) if on_event: session.on(on_event) + session._set_shell_process_callbacks( + register=self._register_shell_process, + unregister=self._unregister_shell_process, + ) with self._sessions_lock: self._sessions[session_id] = session @@ -1843,6 +1855,16 @@ def _dispatch_lifecycle_event(self, event: SessionLifecycleEvent) -> None: except Exception: pass # Ignore handler errors + def _register_shell_process(self, process_id: str, session: CopilotSession) -> None: + """Register a shell process ID mapping to a session.""" + with self._shell_process_map_lock: + self._shell_process_map[process_id] = session + + def _unregister_shell_process(self, process_id: str) -> None: + """Unregister a shell process ID mapping.""" + with self._shell_process_map_lock: + self._shell_process_map.pop(process_id, None) + async def _verify_protocol_version(self) -> None: """Verify that the server's protocol version is within the supported range and store the negotiated version.""" @@ -2100,6 +2122,26 @@ def handle_notification(method: str, params: dict): # Handle session lifecycle events lifecycle_event = SessionLifecycleEvent.from_dict(params) self._dispatch_lifecycle_event(lifecycle_event) + elif method == "shell.output": + process_id = params.get("processId") + if process_id: + with self._shell_process_map_lock: + session = self._shell_process_map.get(process_id) + if session: + notification = ShellOutputNotification.from_dict(params) + session._dispatch_shell_output(notification) + elif method == "shell.exit": + process_id = params.get("processId") + if process_id: + with self._shell_process_map_lock: + session = self._shell_process_map.get(process_id) + if session: + notification = ShellExitNotification.from_dict(params) + session._dispatch_shell_exit(notification) + # Clean up the mapping after exit + with self._shell_process_map_lock: + self._shell_process_map.pop(process_id, None) + session._untrack_shell_process(process_id) self._client.set_notification_handler(handle_notification) # Protocol v3 servers send tool calls / permission requests as broadcast events. @@ -2189,6 +2231,26 @@ def handle_notification(method: str, params: dict): # Handle session lifecycle events lifecycle_event = SessionLifecycleEvent.from_dict(params) self._dispatch_lifecycle_event(lifecycle_event) + elif method == "shell.output": + process_id = params.get("processId") + if process_id: + with self._shell_process_map_lock: + session = self._shell_process_map.get(process_id) + if session: + notification = ShellOutputNotification.from_dict(params) + session._dispatch_shell_output(notification) + elif method == "shell.exit": + process_id = params.get("processId") + if process_id: + with self._shell_process_map_lock: + session = self._shell_process_map.get(process_id) + if session: + notification = ShellExitNotification.from_dict(params) + session._dispatch_shell_exit(notification) + # Clean up the mapping after exit + with self._shell_process_map_lock: + self._shell_process_map.pop(process_id, None) + session._untrack_shell_process(process_id) self._client.set_notification_handler(handle_notification) # Protocol v3 servers send tool calls / permission requests as broadcast events. diff --git a/python/copilot/generated/rpc.py b/python/copilot/generated/rpc.py index 14ae307d..3281ab75 100644 --- a/python/copilot/generated/rpc.py +++ b/python/copilot/generated/rpc.py @@ -2836,14 +2836,18 @@ async def handle_pending_permission_request(self, params: SessionPermissionsHand class ShellApi: - def __init__(self, client: "JsonRpcClient", session_id: str): + def __init__(self, client: "JsonRpcClient", session_id: str, on_exec: Callable[[str], None] | None = None): self._client = client self._session_id = session_id + self._on_exec = on_exec async def exec(self, params: SessionShellExecParams, *, timeout: float | None = None) -> SessionShellExecResult: params_dict = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id - return SessionShellExecResult.from_dict(await self._client.request("session.shell.exec", params_dict, **_timeout_kwargs(timeout))) + result = SessionShellExecResult.from_dict(await self._client.request("session.shell.exec", params_dict, **_timeout_kwargs(timeout))) + if self._on_exec is not None: + self._on_exec(result.process_id) + return result async def kill(self, params: SessionShellKillParams, *, timeout: float | None = None) -> SessionShellKillResult: params_dict = {k: v for k, v in params.to_dict().items() if v is not None} @@ -2853,7 +2857,7 @@ async def kill(self, params: SessionShellKillParams, *, timeout: float | None = class SessionRpc: """Typed session-scoped RPC methods.""" - def __init__(self, client: "JsonRpcClient", session_id: str): + def __init__(self, client: "JsonRpcClient", session_id: str, on_shell_exec: Callable[[str], None] | None = None): self._client = client self._session_id = session_id self.model = ModelApi(client, session_id) @@ -2871,7 +2875,7 @@ def __init__(self, client: "JsonRpcClient", session_id: str): self.commands = CommandsApi(client, session_id) self.ui = UiApi(client, session_id) self.permissions = PermissionsApi(client, session_id) - self.shell = ShellApi(client, session_id) + self.shell = ShellApi(client, session_id, on_shell_exec) async def log(self, params: SessionLogParams, *, timeout: float | None = None) -> SessionLogResult: params_dict = {k: v for k, v in params.to_dict().items() if v is not None} diff --git a/python/copilot/session.py b/python/copilot/session.py index d57105ea..a5e30a98 100644 --- a/python/copilot/session.py +++ b/python/copilot/session.py @@ -246,6 +246,49 @@ class UserInputResponse(TypedDict): UserInputResponse | Awaitable[UserInputResponse], ] +# ============================================================================ +# Shell Notification Types +# ============================================================================ + +ShellOutputStream = Literal["stdout", "stderr"] +"""Output stream identifier for shell notifications.""" + + +@dataclass +class ShellOutputNotification: + """Notification sent when a shell command produces output.""" + + processId: str + stream: ShellOutputStream + data: str + + @staticmethod + def from_dict(data: dict[str, Any]) -> ShellOutputNotification: + return ShellOutputNotification( + processId=str(data.get("processId", "")), + stream=cast(ShellOutputStream, data.get("stream", "stdout")), + data=str(data.get("data", "")), + ) + + +@dataclass +class ShellExitNotification: + """Notification sent when a shell command exits.""" + + processId: str + exitCode: int + + @staticmethod + def from_dict(data: dict[str, Any]) -> ShellExitNotification: + return ShellExitNotification( + processId=str(data.get("processId", "")), + exitCode=int(data.get("exitCode", 1)), + ) + + +ShellOutputHandler = Callable[[ShellOutputNotification], None] +ShellExitHandler = Callable[[ShellExitNotification], None] + # ============================================================================ # Hook Types # ============================================================================ @@ -666,6 +709,14 @@ def __init__(self, session_id: str, client: Any, workspace_path: str | None = No self._user_input_handler_lock = threading.Lock() self._hooks: SessionHooks | None = None self._hooks_lock = threading.Lock() + self._shell_output_handlers: set[ShellOutputHandler] = set() + self._shell_exit_handlers: set[ShellExitHandler] = set() + self._shell_output_handlers_lock = threading.Lock() + self._shell_exit_handlers_lock = threading.Lock() + self._tracked_process_ids: set[str] = set() + self._tracked_process_ids_lock = threading.Lock() + self._register_shell_process: Callable[[str, CopilotSession], None] | None = None + self._unregister_shell_process_fn: Callable[[str], None] | None = None self._transform_callbacks: dict[str, SectionTransformFn] | None = None self._transform_callbacks_lock = threading.Lock() self._rpc: SessionRpc | None = None @@ -674,7 +725,7 @@ def __init__(self, session_id: str, client: Any, workspace_path: str | None = No def rpc(self) -> SessionRpc: """Typed session-scoped RPC methods.""" if self._rpc is None: - self._rpc = SessionRpc(self._client, self.session_id) + self._rpc = SessionRpc(self._client, self.session_id, self._track_shell_process) return self._rpc @property @@ -829,6 +880,106 @@ def unsubscribe(): return unsubscribe + def on_shell_output(self, handler: ShellOutputHandler) -> Callable[[], None]: + """Subscribe to shell output notifications for this session. + + Shell output notifications are streamed in chunks when commands started + via ``session.rpc.shell.exec()`` produce stdout or stderr output. + + Args: + handler: A callback that receives shell output notifications. + + Returns: + A function that, when called, unsubscribes the handler. + + Example: + >>> def handle_output(notification): + ... print(f"[{notification.processId}:{notification.stream}] {notification.data}") + >>> unsubscribe = session.on_shell_output(handle_output) + """ + with self._shell_output_handlers_lock: + self._shell_output_handlers.add(handler) + + def unsubscribe(): + with self._shell_output_handlers_lock: + self._shell_output_handlers.discard(handler) + + return unsubscribe + + def on_shell_exit(self, handler: ShellExitHandler) -> Callable[[], None]: + """Subscribe to shell exit notifications for this session. + + Shell exit notifications are sent when commands started via + ``session.rpc.shell.exec()`` complete (after all output has been streamed). + + Args: + handler: A callback that receives shell exit notifications. + + Returns: + A function that, when called, unsubscribes the handler. + + Example: + >>> def handle_exit(notification): + ... print(f"Process {notification.processId} exited: {notification.exitCode}") + >>> unsubscribe = session.on_shell_exit(handle_exit) + """ + with self._shell_exit_handlers_lock: + self._shell_exit_handlers.add(handler) + + def unsubscribe(): + with self._shell_exit_handlers_lock: + self._shell_exit_handlers.discard(handler) + + return unsubscribe + + def _dispatch_shell_output(self, notification: ShellOutputNotification) -> None: + """Dispatch a shell output notification to all registered handlers.""" + with self._shell_output_handlers_lock: + handlers = list(self._shell_output_handlers) + + for handler in handlers: + try: + handler(notification) + except Exception as e: + print(f"Error in shell output handler: {e}") + + def _dispatch_shell_exit(self, notification: ShellExitNotification) -> None: + """Dispatch a shell exit notification to all registered handlers.""" + with self._shell_exit_handlers_lock: + handlers = list(self._shell_exit_handlers) + + for handler in handlers: + try: + handler(notification) + except Exception as e: + print(f"Error in shell exit handler: {e}") + + def _track_shell_process(self, process_id: str) -> None: + """Track a shell process ID so notifications are routed to this session.""" + with self._tracked_process_ids_lock: + self._tracked_process_ids.add(process_id) + if self._register_shell_process is not None: + self._register_shell_process(process_id, self) + + def _untrack_shell_process(self, process_id: str) -> None: + """Stop tracking a shell process ID.""" + with self._tracked_process_ids_lock: + self._tracked_process_ids.discard(process_id) + if self._unregister_shell_process_fn is not None: + self._unregister_shell_process_fn(process_id) + + def _set_shell_process_callbacks( + self, + register: Callable[[str, CopilotSession], None], + unregister: Callable[[str], None], + ) -> None: + """Set the registration callbacks for shell process tracking. + + Called by the client when setting up the session. + """ + self._register_shell_process = register + self._unregister_shell_process_fn = unregister + def _dispatch_event(self, event: SessionEvent) -> None: """ Dispatch an event to all registered handlers. @@ -1286,6 +1437,15 @@ async def disconnect(self) -> None: self._tool_handlers.clear() with self._permission_handler_lock: self._permission_handler = None + with self._shell_output_handlers_lock: + self._shell_output_handlers.clear() + with self._shell_exit_handlers_lock: + self._shell_exit_handlers.clear() + with self._tracked_process_ids_lock: + for process_id in list(self._tracked_process_ids): + if self._unregister_shell_process_fn is not None: + self._unregister_shell_process_fn(process_id) + self._tracked_process_ids.clear() async def destroy(self) -> None: """ diff --git a/python/pyproject.toml b/python/pyproject.toml index 7c1f8bbf..6b711b4c 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -39,6 +39,7 @@ dev = [ "pytest-asyncio>=0.21.0", "pytest-timeout>=2.0.0", "httpx>=0.24.0", + "opentelemetry-api>=1.0.0", ] telemetry = [ "opentelemetry-api>=1.0.0", diff --git a/scripts/codegen/csharp.ts b/scripts/codegen/csharp.ts index a48ed47b..9be015d6 100644 --- a/scripts/codegen/csharp.ts +++ b/scripts/codegen/csharp.ts @@ -828,9 +828,11 @@ function emitSessionRpcClasses(node: Record, classes: string[]) const groups = Object.entries(node).filter(([, v]) => typeof v === "object" && v !== null && !isRpcMethod(v)); const topLevelMethods = Object.entries(node).filter(([, v]) => isRpcMethod(v)); - const srLines = [`/// Provides typed session-scoped RPC methods.`, `public class SessionRpc`, `{`, ` private readonly JsonRpc _rpc;`, ` private readonly string _sessionId;`, ""]; - srLines.push(` internal SessionRpc(JsonRpc rpc, string sessionId)`, ` {`, ` _rpc = rpc;`, ` _sessionId = sessionId;`); - for (const [groupName] of groups) srLines.push(` ${toPascalCase(groupName)} = new ${toPascalCase(groupName)}Api(rpc, sessionId);`); + const srLines = [`/// Provides typed session-scoped RPC methods.`, `public class SessionRpc`, `{`, ` private readonly JsonRpc _rpc;`, ` private readonly string _sessionId;`, ` private readonly Action? _onShellExec;`, ""]; + srLines.push(` internal SessionRpc(JsonRpc rpc, string sessionId, Action? onShellExec = null)`, ` {`, ` _rpc = rpc;`, ` _sessionId = sessionId;`, ` _onShellExec = onShellExec;`); + for (const [groupName] of groups) srLines.push( + ` ${toPascalCase(groupName)} = new ${toPascalCase(groupName)}Api(rpc, sessionId${groupName === "shell" ? ", _onShellExec" : ""});` + ); srLines.push(` }`); for (const [groupName] of groups) srLines.push("", ` /// ${toPascalCase(groupName)} APIs.`, ` public ${toPascalCase(groupName)}Api ${toPascalCase(groupName)} { get; }`); @@ -896,15 +898,22 @@ function emitSessionMethod(key: string, method: RpcMethod, lines: string[], clas lines.push(`${indent}public async Task<${resultClassName}> ${methodName}Async(${sigParams.join(", ")})`); lines.push(`${indent}{`, `${indent} var request = new ${requestClassName} { ${bodyAssignments.join(", ")} };`); - lines.push(`${indent} return await CopilotClient.InvokeRpcAsync<${resultClassName}>(_rpc, "${method.rpcMethod}", [request], cancellationToken);`, `${indent}}`); + if (method.rpcMethod === "session.shell.exec") { + lines.push(`${indent} var result = await CopilotClient.InvokeRpcAsync<${resultClassName}>(_rpc, "${method.rpcMethod}", [request], cancellationToken);`); + lines.push(`${indent} _onExec?.Invoke(result.ProcessId);`); + lines.push(`${indent} return result;`, `${indent}}`); + } else { + lines.push(`${indent} return await CopilotClient.InvokeRpcAsync<${resultClassName}>(_rpc, "${method.rpcMethod}", [request], cancellationToken);`, `${indent}}`); + } } function emitSessionApiClass(className: string, node: Record, classes: string[]): string { const displayName = className.replace(/Api$/, ""); const groupExperimental = isNodeFullyExperimental(node); const experimentalAttr = groupExperimental ? `[Experimental(Diagnostics.Experimental)]\n` : ""; - const lines = [`/// Provides session-scoped ${displayName} APIs.`, `${experimentalAttr}public class ${className}`, `{`, ` private readonly JsonRpc _rpc;`, ` private readonly string _sessionId;`, ""]; - lines.push(` internal ${className}(JsonRpc rpc, string sessionId)`, ` {`, ` _rpc = rpc;`, ` _sessionId = sessionId;`, ` }`); + const ctorSuffix = className === "ShellApi" ? ", Action? onExec = null" : ""; + const lines = [`/// Provides session-scoped ${displayName} APIs.`, `${experimentalAttr}public class ${className}`, `{`, ` private readonly JsonRpc _rpc;`, ` private readonly string _sessionId;`, ...(className === "ShellApi" ? [` private readonly Action? _onExec;`] : []), ""]; + lines.push(` internal ${className}(JsonRpc rpc, string sessionId${ctorSuffix})`, ` {`, ` _rpc = rpc;`, ` _sessionId = sessionId;`, ...(className === "ShellApi" ? [` _onExec = onExec;`] : []), ` }`); for (const [key, value] of Object.entries(node)) { if (!isRpcMethod(value)) continue; diff --git a/scripts/codegen/go.ts b/scripts/codegen/go.ts index 59abee29..3610d432 100644 --- a/scripts/codegen/go.ts +++ b/scripts/codegen/go.ts @@ -315,6 +315,9 @@ function emitRpcWrapper(lines: string[], node: Record, isSessio if (isSession) { lines.push(`\tclient *jsonrpc2.Client`); lines.push(`\tsessionID string`); + if (groupName === "shell") { + lines.push(`\tonExec func(string)`); + } } else { lines.push(`\tclient *jsonrpc2.Client`); } @@ -355,14 +358,22 @@ function emitRpcWrapper(lines: string[], node: Record, isSessio const padKey = (name: string) => (name + ":").padEnd(maxKeyLen + 1); // +1 for min trailing space // Constructor - const ctorParams = isSession ? "client *jsonrpc2.Client, sessionID string" : "client *jsonrpc2.Client"; + const ctorParams = isSession ? "client *jsonrpc2.Client, sessionID string, onShellExec ...func(string)" : "client *jsonrpc2.Client"; const ctorFields = isSession ? "client: client, sessionID: sessionID," : "client: client,"; lines.push(`func New${wrapperName}(${ctorParams}) *${wrapperName} {`); + if (isSession) { + lines.push(`\tvar shellExecHandler func(string)`); + lines.push(`\tif len(onShellExec) > 0 {`); + lines.push(`\t\tshellExecHandler = onShellExec[0]`); + lines.push(`\t}`); + } lines.push(`\treturn &${wrapperName}{${ctorFields}`); for (const [groupName] of groups) { const prefix = isSession ? "" : "Server"; const apiInit = isSession - ? `&${toPascalCase(groupName)}${apiSuffix}{client: client, sessionID: sessionID}` + ? groupName === "shell" + ? `&${toPascalCase(groupName)}${apiSuffix}{client: client, sessionID: sessionID, onExec: shellExecHandler}` + : `&${toPascalCase(groupName)}${apiSuffix}{client: client, sessionID: sessionID}` : `&${prefix}${toPascalCase(groupName)}${apiSuffix}{client: client}`; lines.push(`\t\t${padKey(toPascalCase(groupName))}${apiInit},`); } @@ -421,6 +432,11 @@ function emitMethod(lines: string[], receiver: string, name: string, method: Rpc lines.push(`\tif err := json.Unmarshal(raw, &result); err != nil {`); lines.push(`\t\treturn nil, err`); lines.push(`\t}`); + if (method.rpcMethod === "session.shell.exec") { + lines.push(`\tif a.onExec != nil {`); + lines.push(`\t\ta.onExec(result.ProcessID)`); + lines.push(`\t}`); + } lines.push(`\treturn &result, nil`); lines.push(`}`); lines.push(``); diff --git a/scripts/codegen/python.ts b/scripts/codegen/python.ts index 0340cf1f..08bae060 100644 --- a/scripts/codegen/python.ts +++ b/scripts/codegen/python.ts @@ -319,9 +319,16 @@ function emitRpcWrapper(lines: string[], node: Record, isSessio lines.push(`# Experimental: this API group is experimental and may change or be removed.`); } lines.push(`class ${apiName}:`); - lines.push(` def __init__(self, client: "JsonRpcClient", session_id: str):`); + if (groupName === "shell") { + lines.push(` def __init__(self, client: "JsonRpcClient", session_id: str, on_exec: Callable[[str], None] | None = None):`); + } else { + lines.push(` def __init__(self, client: "JsonRpcClient", session_id: str):`); + } lines.push(` self._client = client`); lines.push(` self._session_id = session_id`); + if (groupName === "shell") { + lines.push(` self._on_exec = on_exec`); + } } else { if (groupExperimental) { lines.push(`# Experimental: this API group is experimental and may change or be removed.`); @@ -342,11 +349,15 @@ function emitRpcWrapper(lines: string[], node: Record, isSessio if (isSession) { lines.push(`class ${wrapperName}:`); lines.push(` """Typed session-scoped RPC methods."""`); - lines.push(` def __init__(self, client: "JsonRpcClient", session_id: str):`); + lines.push(` def __init__(self, client: "JsonRpcClient", session_id: str, on_shell_exec: Callable[[str], None] | None = None):`); lines.push(` self._client = client`); lines.push(` self._session_id = session_id`); for (const [groupName] of groups) { - lines.push(` self.${toSnakeCase(groupName)} = ${toPascalCase(groupName)}Api(client, session_id)`); + if (groupName === "shell") { + lines.push(` self.${toSnakeCase(groupName)} = ${toPascalCase(groupName)}Api(client, session_id, on_shell_exec)`); + } else { + lines.push(` self.${toSnakeCase(groupName)} = ${toPascalCase(groupName)}Api(client, session_id)`); + } } } else { lines.push(`class ${wrapperName}:`); @@ -392,7 +403,14 @@ function emitMethod(lines: string[], name: string, method: RpcMethod, isSession: if (hasParams) { lines.push(` params_dict = {k: v for k, v in params.to_dict().items() if v is not None}`); lines.push(` params_dict["sessionId"] = self._session_id`); - lines.push(` return ${resultType}.from_dict(await self._client.request("${method.rpcMethod}", params_dict, **_timeout_kwargs(timeout)))`); + if (method.rpcMethod === "session.shell.exec") { + lines.push(` result = ${resultType}.from_dict(await self._client.request("${method.rpcMethod}", params_dict, **_timeout_kwargs(timeout)))`); + lines.push(` if self._on_exec is not None:`); + lines.push(` self._on_exec(result.process_id)`); + lines.push(` return result`); + } else { + lines.push(` return ${resultType}.from_dict(await self._client.request("${method.rpcMethod}", params_dict, **_timeout_kwargs(timeout)))`); + } } else { lines.push(` return ${resultType}.from_dict(await self._client.request("${method.rpcMethod}", {"sessionId": self._session_id}, **_timeout_kwargs(timeout)))`); }