diff --git a/.chief/prds/uplink/prd.json b/.chief/prds/uplink/prd.json new file mode 100644 index 0000000..81f0ac3 --- /dev/null +++ b/.chief/prds/uplink/prd.json @@ -0,0 +1,496 @@ +{ + "project": "Uplink — Chief Web App Support", + "description": "Modifications to the Chief Go CLI to support the chiefloop.com web application. Adds device OAuth authentication, a headless `serve` command with WebSocket connectivity, workspace management, Claude session management over WebSocket, and a self-update mechanism. The existing TUI experience is preserved unchanged by refactoring the loop engine into a shared core that both TUI and serve consume.", + "userStories": [ + { + "id": "US-001", + "title": "Migrate CLI to Cobra Framework", + "description": "As a developer, I want Chief to use the Cobra CLI framework so that adding new commands (login, logout, serve, update) is structured and maintainable.", + "acceptanceCriteria": [ + "Replace manual flag parsing in `main.go` with Cobra command tree", + "All existing commands (`new`, `edit`, `status`, `list`, default TUI launch) work identically", + "All existing flags (`--max-iterations`, `--no-sound`, `--no-retry`, `--verbose`, `--merge`, `--force`, `--help`, `--version`) are preserved", + "Cobra is added to `go.mod`", + "Command structure supports future subcommands (`login`, `logout`, `serve`, `update`)", + "`chief --help` output is clean and organized with command grouping" + ], + "priority": 1, + "passes": true, + "inProgress": false + }, + { + "id": "US-002", + "title": "Implement Credential Storage", + "description": "As a developer, I want a credential storage module so that login, logout, and serve can read/write auth tokens consistently.", + "acceptanceCriteria": [ + "New `internal/auth` package with `Credentials` struct: `access_token`, `refresh_token`, `expires_at`, `device_name`, `user`", + "Credentials stored at `~/.chief/credentials.yaml` (YAML format, consistent with config)", + "File written with `0600` permissions", + "`LoadCredentials()` returns credentials or a clear \"not logged in\" error", + "`SaveCredentials()` writes atomically (write to temp file, rename)", + "`DeleteCredentials()` removes the file", + "`IsExpired()` and `IsNearExpiry(duration)` methods on credentials", + "Unit tests for load/save/delete cycle and permission enforcement" + ], + "priority": 2, + "passes": true, + "inProgress": false + }, + { + "id": "US-003", + "title": "Implement `chief login`", + "description": "As a user, I want to authenticate Chief with chiefloop.com using device OAuth so that I never need to copy-paste tokens.", + "acceptanceCriteria": [ + "`chief login` calls `POST https://chiefloop.com/oauth/device/code` to get `user_code` and `device_code`", + "Prints URL (`chiefloop.com/device`) and user code in clear, prominent formatting", + "Attempts to open browser automatically (`xdg-open` on Linux, `open` on macOS)", + "Polls `POST https://chiefloop.com/oauth/device/token` every 5 seconds with `device_code`", + "On approval, receives access token (1h expiry) and refresh token (90d expiry)", + "Stores both tokens via the credential storage module (US-002)", + "Device name auto-detected from hostname, overridable with `--name` flag", + "Device name sent during authorization request", + "Success message: `Logged in as \u003cuser\u003e (\u003cdevice-name\u003e)`", + "Handles timeout gracefully (user never approves): exits with clear message after 5 minutes", + "Handles network errors with retry and clear error messages", + "If already logged in, warns and asks for confirmation before overwriting" + ], + "priority": 3, + "passes": true, + "inProgress": false + }, + { + "id": "US-004", + "title": "Implement `chief logout` and Automatic Token Refresh", + "description": "As a user, I want to log out (revoking my device) and have tokens refresh automatically so that sessions stay alive without manual intervention.", + "acceptanceCriteria": [ + "`chief logout` calls revocation endpoint server-side to deauthorize the device", + "Deletes `~/.chief/credentials.yaml`", + "Prints: `Logged out. Device \"\u003cdevice-name\u003e\" has been deauthorized.`", + "Handles \"not logged in\" gracefully (no credentials file)", + "Handles revocation API failure gracefully (still deletes local credentials, warns about server-side)", + "`RefreshToken()` function calls `POST https://chiefloop.com/oauth/token` with refresh token", + "Auto-refresh triggers when access token is within 5 minutes of expiry", + "On refresh failure (revoked/expired refresh token), returns clear error: `Session expired. Run 'chief login' again.`", + "Refresh is thread-safe (mutex-protected for concurrent use by serve)" + ], + "priority": 4, + "passes": true, + "inProgress": false + }, + { + "id": "US-005", + "title": "Extract Shared Engine from TUI", + "description": "As a developer, I want the loop orchestration logic extracted into a shared engine so that both the TUI and the serve command can drive Ralph loops and Claude sessions without code duplication.", + "acceptanceCriteria": [ + "New `internal/engine` package with `Engine` interface/struct", + "Engine wraps `loop.Manager` functionality: start/pause/resume/stop runs, event streaming", + "Engine provides a channel-based event API (same events the TUI currently consumes)", + "TUI refactored to consume events from the engine instead of directly from `loop.Manager`", + "TUI behavior is identical before and after refactor (no user-visible changes)", + "Engine supports multiple consumers (TUI gets one, WebSocket handler gets another)", + "Engine manages Claude process lifecycle (spawn, track, kill)", + "Engine exposes project state queries (list PRDs, get run status, etc.)", + "All existing TUI tests still pass" + ], + "priority": 5, + "passes": true, + "inProgress": false + }, + { + "id": "US-006", + "title": "Implement WebSocket Client with Reconnection", + "description": "As a developer, I want a robust WebSocket client that connects to chiefloop.com and automatically reconnects on failure.", + "acceptanceCriteria": [ + "New `internal/ws` package with `Client` struct", + "Connects to `wss://chiefloop.com/ws/server` (default URL)", + "WebSocket URL configurable via `--ws-url` flag, `CHIEF_WS_URL` env var, or `ws_url` field in `~/.chief/config.yaml` (flag \u003e env \u003e config \u003e default) for local development", + "Reconnection with exponential backoff + jitter: 1s, 2s, 4s, 8s ... max 60s", + "Backoff resets on successful connection", + "Each reconnection attempt is logged", + "On reconnection, signals the serve command to re-send full state snapshot", + "`Send(message)` and `Receive() \u003c-chan Message` API", + "Graceful close with WebSocket close frame", + "Context-based cancellation for clean shutdown", + "Ping/pong keepalive support (responds to `ping` with `pong`)" + ], + "priority": 6, + "passes": true, + "inProgress": false + }, + { + "id": "US-007", + "title": "Implement Protocol Handshake", + "description": "As a developer, I want the WebSocket protocol handshake to authenticate the connection and verify version compatibility.", + "acceptanceCriteria": [ + "On WebSocket open, chief sends `hello` message containing: `protocol_version` (1), `chief_version`, `device_name`, `os`, `arch`, and `access_token`", + "Authentication is done via the `access_token` field in the `hello` message (NOT as a URL query parameter)", + "If access token is expired or near-expiry, chief refreshes it before connecting", + "Chief waits for `welcome` or `incompatible` response", + "On `welcome`: connection is established, chief proceeds normally", + "On `incompatible`: chief logs the `message` field, does NOT retry, exits serve", + "On auth failure response: chief logs \"Device deauthorized. Run 'chief login' to re-authenticate.\" and exits", + "Handshake has a 10-second timeout (if no response, reconnect)", + "All subsequent messages include `type`, `id` (UUID), and `timestamp` (ISO8601) fields" + ], + "priority": 7, + "passes": true, + "inProgress": false + }, + { + "id": "US-008", + "title": "Implement Message Types and Serialization", + "description": "As a developer, I want strongly-typed message definitions and JSON serialization for all protocol messages so that the WebSocket communication is reliable and type-safe.", + "acceptanceCriteria": [ + "Go structs for every message type in the protocol catalog (see full list below)", + "Server → Web App: `hello`, `state_snapshot`, `project_list`, `project_state`, `prd_content`, `claude_output`, `run_progress`, `run_complete`, `run_paused`, `diff`, `clone_progress`, `clone_complete`, `error`, `quota_exhausted`, `log_lines`, `session_timeout_warning`, `session_expired`, `settings`, `update_available`", + "Web App → Server: `welcome`, `incompatible`, `list_projects`, `get_project`, `get_prd`, `new_prd`, `prd_message`, `close_prd_session`, `start_run`, `pause_run`, `resume_run`, `stop_run`, `clone_repo`, `create_project`, `get_diff`, `get_logs`, `get_settings`, `update_settings`, `trigger_update`, `ping`", + "Bidirectional: `pong`", + "Standardized error codes as constants: `AUTH_FAILED`, `PROJECT_NOT_FOUND`, `PRD_NOT_FOUND`, `RUN_ALREADY_ACTIVE`, `RUN_NOT_ACTIVE`, `SESSION_NOT_FOUND`, `CLONE_FAILED`, `QUOTA_EXHAUSTED`, `FILESYSTEM_ERROR`, `CLAUDE_ERROR`, `UPDATE_FAILED`, `INCOMPATIBLE_VERSION`", + "Message dispatcher that routes incoming messages by `type` to registered handlers", + "Unknown message types are logged and ignored (forward compatibility)", + "Unit tests for serialization/deserialization round-trips" + ], + "priority": 8, + "passes": true, + "inProgress": false + }, + { + "id": "US-009", + "title": "Implement Basic `chief serve` Command", + "description": "As a user, I want to run `chief serve --workspace ~/projects` to start a headless daemon that connects to chiefloop.com and accepts commands from the web app.", + "acceptanceCriteria": [ + "`chief serve` command registered via Cobra", + "Required `--workspace` flag (or from `~/.chief/config.yaml` `workspace` field)", + "Optional `--name` flag to override device name for this session", + "Optional `--log-file` flag for background use (default: stdout)", + "Validates workspace directory exists", + "Checks for credentials, exits with `Not logged in. Run 'chief login' first.` if missing", + "Refreshes token if near-expiry, then connects WebSocket", + "Completes protocol handshake", + "Runs as long-lived process (no TUI, no interactive prompts)", + "Logs connection status, reconnections, and errors to stdout/log file", + "Handles SIGTERM/SIGINT: kills all child Claude processes, closes WebSocket cleanly, exits" + ], + "priority": 9, + "passes": true, + "inProgress": false + }, + { + "id": "US-010", + "title": "Implement Workspace Scanner", + "description": "As a user, I want chief to automatically discover git repositories in my workspace directory so that all my projects appear in the web app without manual configuration.", + "acceptanceCriteria": [ + "On startup, scans workspace directory one level deep for directories containing `.git/`", + "Every git repository is a project, regardless of `.chief/` presence", + "For each project, tracks: name (directory name), path, `has_chief` (bool), git branch, last commit (hash, message, author, timestamp), PRD list (id, name, story count, completion status)", + "Re-scans every 60 seconds to detect new/removed projects", + "Sends `project_list` update over WebSocket when projects change", + "Ignores non-directory entries and directories without `.git/`", + "Handles permission errors gracefully (logs warning, skips directory)" + ], + "priority": 10, + "passes": true, + "inProgress": false + }, + { + "id": "US-011", + "title": "Implement Selective File Watching", + "description": "As a developer, I want chief to watch filesystem changes only for active projects so that we minimize resource usage while still pushing real-time updates.", + "acceptanceCriteria": [ + "Use `fsnotify` to watch the workspace root directory (for new/removed project directories)", + "Deep watchers (`.chief/` directory, `.git/HEAD`) are only set up for \"active\" projects", + "A project becomes \"active\" when: a run is started, a Claude session is opened, or `get_project` is requested", + "A project becomes \"inactive\" after 10 minutes of no activity (watchers are removed)", + "When `.chief/prds/` changes are detected, push `project_state` update over WebSocket", + "When `.git/HEAD` changes, update branch info and push `project_state`", + "Watcher setup/teardown is logged at debug level" + ], + "priority": 11, + "passes": true, + "inProgress": false + }, + { + "id": "US-012", + "title": "Implement State Snapshot and Project Handlers", + "description": "As a web app developer, I want chief to send a full state snapshot on connect/reconnect and respond to project queries so that the web app always has an accurate cache.", + "acceptanceCriteria": [ + "On successful handshake (and on every reconnect), chief sends `state_snapshot` with: all projects, all active runs, all active Claude sessions", + "`list_projects` handler returns `project_list` with all discovered projects", + "`get_project` handler returns `project_state` for a single project: PRDs, run history, settings", + "`get_prd` handler returns `prd_content` with markdown content and state from `prd.json`", + "`PROJECT_NOT_FOUND` and `PRD_NOT_FOUND` errors returned for invalid references", + "Activates file watching for the requested project (US-011)" + ], + "priority": 12, + "passes": true, + "inProgress": false + }, + { + "id": "US-013", + "title": "Implement Interactive Claude PRD Sessions over WebSocket", + "description": "As a web app user, I want to create and refine PRDs through a chat interface that streams Claude responses in real-time.", + "acceptanceCriteria": [ + "`new_prd` handler spawns `claude -p` in streaming mode with the project's `init_prompt.txt` context and the `initial_message` from the web app", + "Process tracked by `session_id` (UUID generated by the web app)", + "stdout/stderr read in a goroutine and forwarded as `claude_output` messages with `session_id`, `data` (text chunk), and `done` (bool)", + "`prd_message` handler writes user messages to the Claude process stdin", + "`close_prd_session` handler: if `save` is true, waits for Claude to finish writing the PRD file; if false, kills immediately", + "When Claude finishes writing `prd.md`, chief auto-converts to `prd.json` (existing conversion logic)", + "`SESSION_NOT_FOUND` error if session_id doesn't match an active session", + "Claude process runs in the project's working directory" + ], + "priority": 13, + "passes": true, + "inProgress": false + }, + { + "id": "US-014", + "title": "Implement Session Timeout with Warnings", + "description": "As a user, I want Claude PRD sessions to timeout after 30 minutes of inactivity with advance warnings so that I'm not surprised by session loss and resources aren't wasted.", + "acceptanceCriteria": [ + "Each Claude session has a 30-minute inactivity timer that resets on every `prd_message`", + "At 20 minutes of inactivity (10 min remaining): send `session_timeout_warning` with `minutes_remaining: 10`", + "At 25 minutes (5 min remaining): send warning with `minutes_remaining: 5`", + "At 29 minutes (1 min remaining): send warning with `minutes_remaining: 1`", + "At 30 minutes: save whatever PRD state exists to disk, kill Claude process, send `session_expired` with saved state", + "Multiple concurrent sessions each have independent timers", + "Timer is goroutine-safe (no races between message handling and timeout)" + ], + "priority": 14, + "passes": true, + "inProgress": false + }, + { + "id": "US-015", + "title": "Implement Run Control via WebSocket", + "description": "As a web app user, I want to start, pause, resume, and stop Ralph loop runs from the web interface.", + "acceptanceCriteria": [ + "`start_run` handler: starts a Ralph loop for the specified project/PRD via the shared engine (US-005)", + "`pause_run` handler: pauses the loop (finishes current Claude invocation, then stops)", + "`resume_run` handler: resumes a paused loop from where it left off", + "`stop_run` handler: kills the Claude process immediately, marks current story as interrupted", + "Error responses: `RUN_ALREADY_ACTIVE` if starting a run that's running, `RUN_NOT_ACTIVE` if pausing/stopping a run that isn't running", + "Multiple runs across different projects can execute concurrently", + "Run state reflected in `prd.json` (existing `inProgress` field for interrupted stories)" + ], + "priority": 15, + "passes": true, + "inProgress": false + }, + { + "id": "US-016", + "title": "Implement Quota Detection and Auto-Pause", + "description": "As a user, I want Ralph loops to automatically pause when Claude quota is exhausted so that iterations aren't wasted on failures.", + "acceptanceCriteria": [ + "Detect Claude CLI quota/rate-limit errors by non-zero exit code and stderr content matching known patterns (e.g., \"rate limit\", \"quota\", \"429\")", + "On quota detection: pause the current loop immediately", + "Send `run_paused` with `reason: \"quota_exhausted\"`", + "Send `quota_exhausted` listing all affected runs and sessions", + "Do NOT auto-retry or auto-resume; user must explicitly `resume_run`", + "Current story marked as interrupted; on resume, story restarts from scratch (fresh context per iteration)", + "Quota errors are distinguishable from Claude crashes (which trigger existing retry logic)" + ], + "priority": 16, + "passes": true + }, + { + "id": "US-017", + "title": "Implement Run Progress Streaming", + "description": "As a web app user, I want to see real-time progress updates for running Ralph loops so that the UI stays current.", + "acceptanceCriteria": [ + "`run_progress` messages sent on every meaningful state change: story started, iteration started, test pass/fail, story complete", + "Each message includes: `project`, `prd_id`, `story_id`, `status`, `iteration`, `attempt`", + "`run_complete` sent when all stories pass or max iterations reached, with summary: stories completed, duration, pass/fail counts", + "`claude_output` messages streamed during active runs (same content as TUI log view)", + "Progress messages are the primary mechanism for the web app to update its UI" + ], + "priority": 17, + "passes": true + }, + { + "id": "US-018", + "title": "Implement Project Settings via WebSocket", + "description": "As a web app user, I want to view and edit project settings so that I can configure Chief behavior without SSH access.", + "acceptanceCriteria": [ + "`get_settings` handler returns current settings from `.chief/config.yaml`", + "`settings` response includes all editable fields: `max_iterations` (int, default 5), `auto_commit` (bool, default true), `commit_prefix` (string), `claude_model` (string), `test_command` (string)", + "`update_settings` handler: merges provided fields into existing config, writes to `.chief/config.yaml`", + "Response echoes back full updated settings on success, or error on failure", + "Settings not editable via web: `CLAUDE.md`, `init_prompt.txt`, any file outside `.chief/`", + "Invalid setting values return `FILESYSTEM_ERROR` with descriptive message", + "New config fields added to existing `internal/config` package" + ], + "priority": 18, + "passes": true, + "inProgress": false + }, + { + "id": "US-019", + "title": "Implement Per-Story Logging", + "description": "As a web app user, I want to view logs for each story so that I can debug failures and understand what Claude did.", + "acceptanceCriteria": [ + "During Ralph loop runs, Claude output for each story is written to `.chief/prds/\u003cid\u003e/logs/\u003cstory-id\u003e.log`", + "Log files contain raw Claude output including all iterations for that story", + "`get_logs` handler returns historical log lines for a completed story", + "`get_logs` supports optional `story_id` (specific story) and `lines` (count) parameters", + "If `story_id` is omitted, returns the most recent log activity for the PRD", + "Real-time log streaming uses existing `claude_output` messages during active runs", + "Starting a new run on the same PRD overwrites previous log files (V1 simplicity)", + "`log_lines` response includes `project`, `prd_id`, `story_id`, `lines[]`, and `level`" + ], + "priority": 19, + "passes": true, + "inProgress": false + }, + { + "id": "US-020", + "title": "Implement Per-Story Diff Generation and Retrieval", + "description": "As a web app user, I want to see the git diff for each completed story so that I can review what changed.", + "acceptanceCriteria": [ + "`get_diff` handler returns the diff for a specific story", + "Diff is per-story: shows the changes introduced by that story's commit(s) compared to the state before the story started", + "Response includes: `project`, `prd_id`, `story_id`, `files[]` (list of changed file paths), `diff_text` (unified diff format)", + "If auto-commit is enabled, diff is derived from the story's commit(s)", + "If story has not completed or has no changes, returns appropriate error", + "`diff` messages also sent proactively when a story completes during an active run" + ], + "priority": 20, + "passes": true, + "inProgress": false + }, + { + "id": "US-021", + "title": "Implement Git Clone and Project Creation via WebSocket", + "description": "As a web app user, I want to clone repositories and create new projects from the web interface so that I can set up projects without SSH.", + "acceptanceCriteria": [ + "`clone_repo` handler: runs `git clone \u003curl\u003e` in the workspace directory", + "Optional `directory_name` field overrides the default clone directory name", + "Streams `clone_progress` messages with `progress_text` and `percent` (if parseable from git output)", + "Sends `clone_complete` on finish with `success` bool and `error` message if failed", + "Uses the host's existing SSH keys/git credential helpers for authentication", + "`CLONE_FAILED` error for git failures (includes git error message)", + "`create_project` handler: creates a new empty directory in workspace, optionally runs `git init`", + "New projects appear in next workspace scan or immediately via fsnotify" + ], + "priority": 21, + "passes": true, + "inProgress": false + }, + { + "id": "US-022", + "title": "Implement Version Check and `chief update` Command", + "description": "As a user, I want Chief to check for updates and allow self-updating so that I always have the latest version.", + "acceptanceCriteria": [ + "On startup (all commands), check GitHub Releases API: `GET https://api.github.com/repos/MiniCodeMonkey/chief/releases/latest`", + "Compare version tag against embedded build version", + "For interactive CLI: print one-liner if update available: `Chief v0.5.1 available (you have v0.5.0). Run 'chief update' to upgrade.` (non-blocking)", + "For `chief serve`: check every 24 hours, send `update_available` over WebSocket", + "`chief update` command: download binary for current `GOOS`/`GOARCH`, verify SHA256 checksum, atomic rename to replace current binary", + "Before update: check write permissions on binary path; if insufficient, print: `Permission denied. Run 'sudo chief update' to upgrade.`", + "Success message: `Updated to v0.5.1. Restart 'chief serve' to apply.`", + "Does NOT auto-restart" + ], + "priority": 22, + "passes": true, + "inProgress": false + }, + { + "id": "US-023", + "title": "Implement Remote Update via WebSocket", + "description": "As a web app user, I want to trigger Chief updates from the dashboard so that I can update remote servers without SSH.", + "acceptanceCriteria": [ + "`trigger_update` handler: downloads and installs new binary (same as `chief update`)", + "Sends confirmation message over WebSocket before exiting", + "Exits with code 0 so systemd `Restart=always` picks up the new binary", + "If write permissions insufficient, returns `UPDATE_FAILED` error with message about permissions", + "If already on latest version, returns informational message (not an error)" + ], + "priority": 23, + "passes": true, + "inProgress": false + }, + { + "id": "US-024", + "title": "Create Systemd Service and Cloud-Init Script", + "description": "As a user deploying to a VPS, I want a cloud-init script that sets up Chief as a systemd service so that I can get running with minimal manual steps.", + "acceptanceCriteria": [ + "Systemd unit file includes `ConditionPathExists=/home/chief/.chief/credentials.yaml` so service doesn't start/restart-loop before auth", + "Unit file: `Type=simple`, `Restart=always`, `RestartSec=5`, `After=network-online.target`", + "Cloud-init script installs: chief binary, Claude Code CLI (`npm install -g @anthropic-ai/claude-code`), creates `chief` user, creates `/home/chief/projects` workspace, writes and enables systemd unit", + "Service is enabled but NOT started (requires auth first)", + "Cloud-init script is idempotent (safe to run multiple times)", + "Post-deploy instructions guide user through: SSH in, `chief login`, authenticate Claude Code, `sudo systemctl start chief`" + ], + "priority": 24, + "passes": true, + "inProgress": false + }, + { + "id": "US-025", + "title": "Implement One-Time Setup Token for Automated Chief Auth", + "description": "As a user, I want the web app to pass a one-time setup token during VPS provisioning so that Chief can authenticate without an interactive `chief login` step.", + "acceptanceCriteria": [ + "`chief login --setup-token \u003ctoken\u003e` mode: exchanges a one-time token for credentials via `POST https://chiefloop.com/oauth/device/exchange`", + "Token is generated by the web app during the deploy flow and passed to cloud-init as a parameter", + "Cloud-init writes the token to `/tmp/chief-setup-token` (or passes it as an environment variable)", + "After cloud-init runs, a one-shot systemd service runs `chief login --setup-token \u003ctoken\u003e` as the chief user", + "On success: credentials are written, token file is deleted, main chief service can start", + "Token is single-use and expires after 10 minutes", + "If token is invalid or expired, falls back to manual `chief login` with clear instructions", + "User still needs to authenticate Claude Code separately (unavoidable)" + ], + "priority": 25, + "passes": true, + "inProgress": false + }, + { + "id": "US-026", + "title": "Implement WebSocket Rate Limiting", + "description": "As a developer, I want basic rate limiting on incoming WebSocket messages so that a buggy or compromised web app can't overwhelm chief.", + "acceptanceCriteria": [ + "Token bucket rate limiter on incoming messages: 30 messages/second burst, 10 messages/second sustained", + "Expensive operations (`clone_repo`, `start_run`, `new_prd`) have a stricter per-type limit: 2/minute", + "When rate limited, chief responds with an `error` message: code `RATE_LIMITED`, includes retry-after hint", + "Rate limiter state resets on reconnection", + "Keepalive `ping` messages are exempt from rate limiting", + "Limits are hardcoded (not configurable) for V1" + ], + "priority": 26, + "passes": true, + "inProgress": false + }, + { + "id": "US-027", + "title": "Implement Graceful Shutdown for `chief serve`", + "description": "As a user, I want `chief serve` to shut down cleanly on SIGTERM/SIGINT so that systemd restarts are seamless and no data is corrupted.", + "acceptanceCriteria": [ + "SIGTERM and SIGINT handlers registered", + "On signal: kill all child Claude processes (both PRD sessions and Ralph loops)", + "Mark any in-progress stories as interrupted (`inProgress: true`) in `prd.json`", + "Close WebSocket connection with proper close frame", + "Flush and close log file if `--log-file` is used", + "Exit within 10 seconds (force-kill any hung processes after 5 seconds)", + "Log shutdown sequence: \"Shutting down...\", \"Killed N processes\", \"Goodbye.\"" + ], + "priority": 27, + "passes": true, + "inProgress": false + }, + { + "id": "US-028", + "title": "Implement WebSocket URL Configuration for Local Development", + "description": "As a developer working on the chiefloop.com integration, I want to point chief at a local WebSocket server so that I can develop and test without hitting production.", + "acceptanceCriteria": [ + "`chief serve --ws-url ws://localhost:8080/ws/server` flag overrides the default URL", + "`CHIEF_WS_URL` environment variable also overrides (flag takes precedence)", + "`ws_url` field in `~/.chief/config.yaml` as a third option (lowest precedence)", + "Supports both `ws://` (unencrypted) and `wss://` (TLS) schemes", + "Default remains `wss://chiefloop.com/ws/server` when no override is set", + "When using `ws://`, chief does NOT require or attempt TLS (for local dev)", + "Log the target URL on startup so it's obvious which server chief is connecting to" + ], + "priority": 28, + "passes": true + } + ] +} diff --git a/.chief/prds/uplink/progress.md b/.chief/prds/uplink/progress.md new file mode 100644 index 0000000..5d338e9 --- /dev/null +++ b/.chief/prds/uplink/progress.md @@ -0,0 +1,536 @@ +## Codebase Patterns +- CLI entry point is `cmd/chief/main.go`, command implementations are in `internal/cmd/` +- Each command handler accepts a typed Options struct (e.g., `cmd.NewOptions`, `cmd.EditOptions`) +- TUI is built on Bubble Tea (`github.com/charmbracelet/bubbletea`) +- Configuration uses YAML via `gopkg.in/yaml.v3`, stored in `.chief/config.yaml` +- PRD files live in `.chief/prds//prd.json` +- Build requires `pkg-config` and `libasound2-dev` for the audio notification library (`oto`) +- Go version 1.24.0 is required (per go.mod) +- Commit style follows conventional commits: `feat:`, `fix:`, `chore:` prefixes +- Use `cobra.ArbitraryArgs` on root command to accept positional PRD name/path +- Cobra is now the CLI framework (`github.com/spf13/cobra`) +- Auth credentials stored at `~/.chief/credentials.yaml` via `internal/auth` package +- Use `t.Setenv("HOME", dir)` to redirect home directory in tests (auto-cleaned up) +- Use `httptest.NewServer` for mocking HTTP endpoints; pass base URL via Options struct +- New commands: add `newXxxCmd()` in `main.go`, `RunXxx(XxxOptions)` in `internal/cmd/xxx.go` +- Token refresh is mutex-protected in `internal/auth` — use `auth.RefreshToken(baseURL)` for thread-safe refresh +- `auth.RevokeDevice(accessToken, baseURL)` handles server-side device revocation +- Logout gracefully handles revocation failure (warns, still deletes local creds) +- Shared engine: `internal/engine` package wraps `loop.Manager` with fan-out event subscription via `Subscribe()` +- TUI uses `a.eng` (engine) not `a.manager` — all loop operations go through engine +- `engine.New(maxIter)` creates engine; `engine.Subscribe()` returns `(<-chan ManagerEvent, unsubFunc)` +- Tests using App struct directly must use `eng` field (not `manager`) and create `engine.New()` instances +- WebSocket client: `internal/ws` package with `ws.New(url, opts...)`, `Connect(ctx)`, `Send(msg)`, `Receive()`, `Close()` +- Use `gorilla/websocket` library for WebSocket connections +- WebSocket test pattern: use `httptest.NewServer` + `websocket.Upgrader` for mock servers; `wsURL()` helper to convert HTTP to WS URL +- `WithOnReconnect(fn)` option allows serve command to re-send state snapshot on reconnect +- Protocol handshake: `client.Handshake(accessToken, version, deviceName)` after `Connect()` — sends `hello`, waits for `welcome`/`incompatible`/`auth_failed` +- UUID generation: `ws.newUUID()` uses `crypto/rand` (no external dependency); `ws.NewMessage(type)` creates envelope with UUID + ISO8601 timestamp +- Handshake errors: `*ErrIncompatible` (version mismatch, do NOT retry), `ErrAuthFailed` (deauthorized), `ErrHandshakeTimeout` (10s timeout) +- Message types defined as `Type*` constants in `internal/ws/messages.go` (e.g., `TypeStartRun`, `TypeGetProject`) +- Error codes defined as `ErrCode*` constants in `internal/ws/messages.go` (e.g., `ErrCodeProjectNotFound`) +- Use `ws.NewDispatcher()` to create a message router; `Register(type, handler)` to add handlers; `Dispatch(msg)` to route +- Use pointer fields (`*int`, `*bool`, `*string`) for optional/partial update messages to distinguish "not set" from zero values +- Serve command uses `ServeOptions.Ctx` (context) for testability — tests cancel ctx to stop serve loop +- For handshake error tests (incompatible/auth_failed), call `srv.CloseClientConnections()` to prevent ws reconnection loops +- Cancel context before `client.Close()` to avoid race where readLoop reconnects during Close() +- Workspace scanner: `internal/workspace` package; `workspace.New(dir, wsClient)` creates scanner; `scanner.Run(ctx)` runs periodic scan loop; `scanner.Projects()` returns current list +- `ws.Client.Send()` accepts `interface{}` — pass typed message structs directly, no need to marshal separately +- Scanner now supports `SetClient()` for deferred client setup and `FindProject(name)` for single-project lookup +- State snapshot is sent after handshake AND on reconnect; use `sendStateSnapshot(client, scanner)` helper +- `sendError(client, code, message, requestID)` helper sends typed error messages over WebSocket +- `serveTestHelper(t, workspacePath, serverFn)` encapsulates WS test boilerplate (hello/welcome/state_snapshot/cancel) +- After handshake, server always receives a `state_snapshot` message — existing tests must account for this +- `sessionManager` in `internal/cmd/session.go` manages Claude PRD sessions: `newPRD()`, `sendMessage()`, `closeSession()`, `killAll()`, `activeSessions()` +- Mock `claude` binary in tests: write shell script to temp dir, prepend to PATH via `t.Setenv("PATH", dir+":"+origPath)` +- `projectFinder` interface allows testing handlers without real `workspace.Scanner` +- Session timeout: `sessionManager` has configurable `timeout`, `warningThresholds`, and `checkInterval` fields — set to small values in tests for speed +- `expireSession()` closes stdin, waits 2s grace period, then force-kills; sends `session_expired` after process exits +- `runManager` in `internal/cmd/runs.go` manages Ralph loop runs: `startRun()`, `pauseRun()`, `resumeRun()`, `stopRun()`, `activeRuns()`, `stopAll()` +- Run control uses `engine.Engine` — resume works by calling `engine.Start()` again (creates fresh Loop picking up next unfinished story) +- `runKey(project, prdID)` creates engine registration key as `"project/prdID"` +- Quota errors detected via `loop.IsQuotaError(text)` — checks stderr/exit error for patterns like "rate limit", "quota", "429" +- Quota errors bypass retry logic and set `LoopStatePaused` (not `LoopStateError`) for resumability +- `runManager.startEventMonitor(ctx)` subscribes to engine events for cross-cutting concerns like quota detection, progress streaming, and output forwarding +- `sendStateSnapshot`, `handleMessage`, and `serveShutdown` now accept `*runManager` parameter +- `runInfo` tracks `startTime` and `storyID` for progress messages — `storyID` is updated on `EventStoryStarted` +- `handleEvent()` routes engine events to `sendRunProgress()`, `sendRunComplete()`, `sendClaudeOutput()` based on event type +- Settings handlers in `internal/cmd/settings.go`: `handleGetSettings`, `handleUpdateSettings` — use `config.Load()`/`config.Save()` with project path +- Config uses `*bool` for optional booleans (e.g., `AutoCommit`) and `Effective*()` methods to provide defaults +- Per-story logs: `storyLogger` in `internal/cmd/logs.go` writes to `.chief/prds//logs/.log`; `handleGetLogs` handler retrieves them +- `runManager.loggers` map tracks per-run story loggers; `writeStoryLog()` writes during event handling; loggers cleaned up on `cleanup()`/`stopAll()` +- Per-story diffs: `handleGetDiff` in `internal/cmd/diffs.go` uses `git log --grep ` to find commits, then `git show` for diff/files; proactive diffs sent on `EventStoryCompleted` +- Clone/create: `handleCloneRepo`/`handleCreateProject` in `internal/cmd/clone.go`; clone runs async in goroutine; `scanner.WorkspacePath()` exposes workspace dir +- Version check/update: `internal/update` package for GitHub Releases API check + binary download; `internal/cmd/update.go` for `RunUpdate()` command +- `update.CheckForUpdate(version, opts)` returns `CheckResult` with `UpdateAvailable` bool; `update.PerformUpdate(version, opts)` does full download+replace +- Startup version check runs non-blocking in `PersistentPreRun`; serve mode checks every 24h and sends `update_available` WS message +- `ServeOptions.ReleasesURL` allows testing the periodic version checker with a mock server +- `handleMessage` returns `bool` — `true` signals the serve loop to exit cleanly (for remote update); all other handlers return `false` +- Setup token login: `chief login --setup-token ` exchanges via `POST /oauth/device/exchange`; cloud-init uses `CHIEF_SETUP_TOKEN` env var and a one-shot systemd service +- Rate limiting: `ws.NewRateLimiter()` checks in main event loop before `handleMessage()`; `Reset()` on reconnect; ping is exempt; expensive ops (clone_repo, start_run, new_prd) have per-type 2/minute limits +- Graceful shutdown: `serveShutdown()` accepts engine, runs `markInterruptedStories()` → `stopAll()`/`killAll()`/`eng.Shutdown()` in goroutine with 5s force-kill timeout; engine shutdown moved from `defer` to inside `serveShutdown` to avoid blocking +- `runManager.markInterruptedStories()` loads PRD, sets `inProgress: true` for active story that hasn't passed, saves back — must be called before `stopAll()` +- User-level config: `config.LoadUserConfig()` reads `~/.chief/config.yaml` (separate from project-level `.chief/config.yaml`); currently has `ws_url` field +- WS URL precedence: `--ws-url` flag > `CHIEF_WS_URL` env var > `ws_url` in `~/.chief/config.yaml` > `ws.DefaultURL` + +--- + +## 2026-02-15 - US-001 +- **What was implemented:** Migrated CLI from manual flag parsing to Cobra framework +- **Files changed:** + - `cmd/chief/main.go` - Replaced switch-based dispatch and manual flag parsing with Cobra command tree + - `go.mod` / `go.sum` - Added `github.com/spf13/cobra` dependency +- **Learnings for future iterations:** + - Cobra's `SetUsageTemplate` propagates to child commands; use `SetHelpFunc` with a parent check to customize only root help + - Set `SilenceErrors: true` and `SilenceUsage: true` on root for clean error output, but remember to print the error in `main()` yourself + - `cobra.ArbitraryArgs` allows positional args on root while still having subcommands + - The `wiggum` Easter egg command is marked `Hidden: true` so it doesn't appear in help + - Future subcommands (login, logout, serve, update) can be added via `rootCmd.AddCommand()` +--- + +## 2026-02-15 - US-002 +- **What was implemented:** Credential storage module in `internal/auth` package +- **Files changed:** + - `internal/auth/auth.go` - `Credentials` struct, `LoadCredentials()`, `SaveCredentials()`, `DeleteCredentials()`, `IsExpired()`, `IsNearExpiry()` methods + - `internal/auth/auth_test.go` - 11 unit tests covering full save/load/delete cycle, permissions, atomic writes, expiry logic +- **Learnings for future iterations:** + - Credentials use `~/.chief/credentials.yaml` (user home, NOT project dir) — different from project config which is relative to `baseDir` + - `os.UserHomeDir()` is used for home directory; override with `t.Setenv("HOME", dir)` in tests + - `t.Setenv()` automatically restores on cleanup — no need for manual defer + - Atomic write pattern: `os.CreateTemp` in same dir → write → `os.Rename` — ensures no partial writes + - File permissions must be `0600` for credentials (not `0644` like config) + - `LoadCredentials()` returns `ErrNotLoggedIn` (not a default) when file is missing — this differs from config's pattern of returning defaults + - `gopkg.in/yaml.v3` handles `time.Time` natively — no custom marshaling needed +--- + +## 2026-02-15 - US-003 +- **What was implemented:** Device OAuth login command (`chief login`) +- **Files changed:** + - `internal/cmd/login.go` - `RunLogin()` with device OAuth flow: request device code, display URL/code, open browser, poll for token, save credentials + - `internal/cmd/login_test.go` - 6 tests using `httptest` mock server: success flow, device code error, authorization denied, default device name, already-logged-in decline, browser open safety + - `cmd/chief/main.go` - Added `login` subcommand with `--name` flag, updated help text +- **Learnings for future iterations:** + - Use `httptest.NewServer` for mocking HTTP endpoints in tests — no need for interface abstraction + - The `LoginOptions.BaseURL` field allows tests to point at the mock server (defaults to `https://chiefloop.com`) + - `os.Pipe()` + `os.Stdin` override is the pattern for testing interactive stdin prompts + - `sync/atomic.Int32` for tracking poll count across goroutines in tests + - Poll-based tests are slow (5s per poll interval) — keep poll count low in tests + - `openBrowser()` is best-effort and uses `xdg-open` on Linux, `open` on macOS + - Login command follows existing pattern: `LoginOptions` struct → `RunLogin()` function +--- + +## 2026-02-15 - US-004 +- **What was implemented:** `chief logout` command and automatic token refresh with thread safety +- **Files changed:** + - `internal/auth/auth.go` - Added `RefreshToken()` (mutex-protected), `RevokeDevice()`, `ErrSessionExpired`, `refreshResponse` struct + - `internal/auth/auth_test.go` - Added 7 tests: refresh success, not-near-expiry skip, session expired, not-logged-in, thread safety (concurrent goroutines), revoke success, revoke server error + - `internal/cmd/logout.go` - New `RunLogout(LogoutOptions)` with revocation endpoint call, graceful handling of revocation failure, credential deletion + - `internal/cmd/logout_test.go` - 3 tests: success with revocation, not-logged-in, revocation fails but still deletes credentials + - `cmd/chief/main.go` - Added `logout` subcommand, updated help text + - `.chief/prds/uplink/prd.json` - Updated US-004 status +- **Learnings for future iterations:** + - `RefreshToken()` uses a global `sync.Mutex` to prevent concurrent refresh attempts — after acquiring the lock, it re-checks `IsNearExpiry()` in case another goroutine already refreshed + - Logout follows the pattern: try server-side revocation, warn on failure, always delete local credentials + - `RevokeDevice()` and `RefreshToken()` accept a `baseURL` parameter for testability (same pattern as login) + - The `defaultBaseURL` constant was moved to `internal/auth` since both auth and cmd packages need it + - Thread-safety test verifies only 1 actual HTTP call is made when 5 goroutines call `RefreshToken()` concurrently +--- + +## 2026-02-15 - US-005 +- **What was implemented:** Extracted shared engine from TUI into `internal/engine` package +- **Files changed:** + - `internal/engine/engine.go` - New `Engine` struct wrapping `loop.Manager` with fan-out event subscription (`Subscribe()`) for multiple consumers + - `internal/engine/engine_test.go` - 25 tests covering: creation, register/unregister, subscribe/unsubscribe, fan-out events, concurrent access, shutdown, worktree info, config, PRD loading + - `internal/tui/app.go` - Replaced `manager *loop.Manager` with `eng *engine.Engine`; added `eventCh` and `unsubFn` fields; TUI now subscribes to engine events via `Subscribe()`; added `NewAppWithEngine()` for sharing engine with serve command + - `internal/tui/dashboard.go` - Updated `a.manager` → `a.eng` references + - `internal/tui/layout_test.go` - Updated tests to create `engine.Engine` instead of `loop.Manager` directly; added `newTestEngine()` and `newTestEngineWithWorktree()` helpers +- **Learnings for future iterations:** + - The `loop.Manager` already had the right abstraction (channels, start/stop/pause, events). The engine adds fan-out subscription on top. + - Fan-out uses non-blocking sends (`select { case ch <- event: default: }`) to avoid slow consumers blocking the pipeline + - `Subscribe()` returns a cleanup function — must be called to avoid resource leaks + - `engine.Shutdown()` stops the forwarding goroutine; `engine.StopAll()` only stops loops + - TabBar and PRDPicker still use `loop.Manager` directly (via `eng.Manager()`) since they only need read-only state queries + - When tests create `App` struct literals, use `eng` field (not `manager`) and pass `engine.New()` instances +--- + +## 2026-02-15 - US-006 +- **What was implemented:** WebSocket client with automatic reconnection in `internal/ws` package +- **Files changed:** + - `internal/ws/client.go` - `Client` struct with `Connect(ctx)`, `Send(msg)`, `Receive()`, `Close()` API; exponential backoff + jitter reconnection (1s→60s max); ping/pong handler; `WithOnReconnect` callback option; context-based cancellation + - `internal/ws/client_test.go` - 10 tests: connect/send/receive, graceful close, send-when-disconnected, reconnect-on-server-close, context cancellation, ping/pong, backoff calculation, default URL, channel buffer, multiple messages + - `go.mod` / `go.sum` - Added `github.com/gorilla/websocket` dependency + - `.chief/prds/uplink/prd.json` - Updated US-006 status +- **Learnings for future iterations:** + - `gorilla/websocket` is the standard Go WebSocket library — `DefaultDialer.Dial()` for connecting, `ReadMessage()`/`WriteMessage()` for I/O + - Ping/pong: set `SetPingHandler` to auto-respond with pong via `WriteControl(PongMessage, ...)`; also set `SetPongHandler` (even empty) to prevent default pong from interfering + - Reconnection loop lives in `readLoop` — on read error, close old conn, dial new one, set up handlers again, call `onRecon` callback + - Backoff with jitter formula: `base * 2^(attempt-1) * rand(0.5, 1.5)`, capped at max + - Test pattern for WebSocket: `httptest.NewServer` with `websocket.Upgrader` in handler, convert URL with `strings.TrimPrefix(s.URL, "http")` → `"ws" + ...` + - `atomic.Int32` useful for tracking connection counts in reconnection tests + - Message struct uses `json.RawMessage` for `Raw` field to preserve the full original message for downstream consumers +--- + +## 2026-02-15 - US-007 +- **What was implemented:** Protocol handshake for WebSocket connections with authentication and version compatibility verification +- **Files changed:** + - `internal/ws/handshake.go` - `Handshake()` method on `Client`, hello/welcome/incompatible/auth_failed message types, `newUUID()` v4 generator, `NewMessage()` envelope helper, `ErrIncompatible`/`ErrAuthFailed`/`ErrHandshakeTimeout` error types + - `internal/ws/handshake_test.go` - 8 tests: success, incompatible version, auth failure, timeout, correct hello contents verification, connection closed during handshake, UUID format, NewMessage helper + - `.chief/prds/uplink/prd.json` - Updated US-007 status +- **Learnings for future iterations:** + - UUID v4 can be generated with `crypto/rand` + bit manipulation (set version=4, variant=RFC4122) — no need for external `google/uuid` dependency + - Handshake uses `client.Receive()` channel with `time.NewTimer` for timeout — clean select-based pattern + - `*ErrIncompatible` is a struct error (not sentinel) so it can carry the server's message; `ErrAuthFailed` and `ErrHandshakeTimeout` are sentinels + - Handshake must be called after `Connect()` — it sends the hello message and blocks until response or timeout + - When connection closes during handshake, the readLoop reconnects but handshake still times out (correct behavior — caller should retry) + - `runtime.GOOS` and `runtime.GOARCH` provide OS/architecture info for the hello message +--- + +## 2026-02-15 - US-008 +- **What was implemented:** Strongly-typed message definitions, error codes, and message dispatcher for the WebSocket protocol +- **Files changed:** + - `internal/ws/messages.go` - All message type structs for the protocol catalog (server→web app, web app→server, bidirectional), type constants, error code constants + - `internal/ws/dispatcher.go` - `Dispatcher` struct with `Register()`, `Unregister()`, `Dispatch()` for routing messages by type to handlers + - `internal/ws/messages_test.go` - 28 serialization/deserialization round-trip tests covering all message types, optional field omission, partial updates + - `internal/ws/dispatcher_test.go` - 7 tests: register/dispatch, unknown type ignored, unregister, handler replacement, raw JSON passthrough, concurrent access, multiple handlers +- **Learnings for future iterations:** + - Use `*int`, `*bool`, `*string` pointer fields for partial/optional updates (e.g., `UpdateSettingsMessage`) — allows distinguishing "not provided" from zero values + - `json.RawMessage` on the `Message.Raw` field preserves the full original JSON, so dispatcher handlers can unmarshal into the specific message type + - Error codes are string constants (`ErrCodeProjectNotFound`) not enums — easier to serialize and forward-compatible + - Type constants (`TypeStartRun`, `TypeGetProject`, etc.) centralize string literals and prevent typos + - Dispatcher uses `sync.RWMutex` for safe concurrent access — handlers can be registered/unregistered while dispatching + - Unknown message types are logged and ignored (forward compatibility per spec) + - `interface{}` is used for `PRDContentMessage.State` since the PRD state is a flexible JSON object +--- + +## 2026-02-15 - US-009 +- **What was implemented:** Basic `chief serve` command — headless daemon that connects to chiefloop.com via WebSocket +- **Files changed:** + - `internal/cmd/serve.go` - `RunServe(ServeOptions)` with workspace validation, credential check, token refresh, WebSocket connection, protocol handshake, signal handling, ping/pong handling, clean shutdown + - `internal/cmd/serve_test.go` - 9 tests: workspace validation (nonexistent, file-not-dir), not-logged-in, successful connect+handshake, device name override, incompatible version, auth failure, log file output, ping/pong, token refresh + - `cmd/chief/main.go` - Added `serve` subcommand with `--workspace`, `--name`, `--log-file` flags; updated help text + - `.chief/prds/uplink/prd.json` - Updated US-009 status +- **Learnings for future iterations:** + - `ServeOptions.Ctx` (context.Context) is the key testing mechanism — tests create a context, cancel it from the mock server handler after handshake completes, which causes RunServe to exit cleanly + - For error-path tests (incompatible/auth_failed), the ws.Client reconnection behavior creates issues — use `srv.CloseClientConnections()` in the server handler to prevent reconnection loops during Close() + - Must cancel context BEFORE calling `client.Close()` to avoid race where readLoop reconnects and Close() gets a stale conn reference + - The serve command passes `version` from build-time `Version` var in main.go through `ServeOptions.Version` to the handshake + - Device name defaults to credential's device name, overridable via `--name` flag + - Log output defaults to stdout; `--log-file` redirects `log.SetOutput()` to a file +--- + +## 2026-02-15 - US-010 +- **What was implemented:** Workspace scanner that discovers git repositories in the workspace directory +- **Files changed:** + - `internal/workspace/scanner.go` - New `Scanner` struct with `Scan()`, `ScanAndUpdate()`, `Run()`, `Projects()` methods. Scans one level deep for `.git/` dirs, gathers branch/commit/PRD info, sends `project_list` over WebSocket on changes, re-scans every 60s + - `internal/workspace/scanner_test.go` - 12 tests: discover repos, detect .chief, discover PRDs, multiple projects, empty workspace, permission errors, detect add/remove, periodic scanning with WebSocket, context cancellation, projectsEqual, branch detection + - `internal/cmd/serve.go` - Integrated scanner: starts `workspace.New(opts.Workspace, client).Run(ctx)` in a goroutine after handshake + - `.chief/prds/uplink/prd.json` - Updated US-010 status +- **Learnings for future iterations:** + - `ws.Client.Send()` accepts `interface{}` and marshals it internally — pass the message struct directly, don't double-marshal + - `os.Stat()` doesn't require read permission on a file, only traverse on parent — to test permission errors, remove perms on the parent directory + - `sendProjectList()` must guard against nil client (scanner can be used standalone in tests) + - `projectsEqual()` compares by building a name→project map — handles different ordering between scans + - Git info gathered via `git rev-parse --abbrev-ref HEAD` (branch) and `git log -1 --format=%H%n%s%n%an%n%aI` (commit hash, message, author, ISO timestamp) + - PRD completion status formatted as `"passed/total"` (e.g., `"2/3"`) + - Scanner uses `time.NewTicker` for periodic scans; tests set `scanner.interval` to small values for speed +- File watcher: `workspace.NewWatcher(dir, scanner, client)` creates watcher; `watcher.Activate(name)` enables deep watching; `watcher.Run(ctx)` runs event loop +- `fsnotify` does NOT recurse into subdirectories — must explicitly `Add()` each subdirectory (e.g., each PRD dir inside `.chief/prds/`) +- Watcher `Activate()` is called by serve command when `get_project` message is received (or run started, session opened) +- `watcher.inactiveTimeout` can be set to small values in tests for fast inactivity cleanup testing +--- + +## 2026-02-15 - US-011 +- **What was implemented:** Selective file watching using `fsnotify` for workspace root and active project deep watchers +- **Files changed:** + - `internal/workspace/watcher.go` - New `Watcher` struct with `Activate()`, `Run()`, `Close()`, inactivity cleanup, deep watcher setup/teardown for `.chief/`, `.chief/prds/` (+ subdirs), `.git/`. Sends `project_state` updates on file changes. + - `internal/workspace/watcher_test.go` - 9 tests: workspace root changes, activate project, unknown project, activity refresh, inactivity cleanup, PRD change sends project_state, git HEAD change sends project_state, context cancellation, no deep watchers for inactive projects + - `internal/cmd/serve.go` - Integrated watcher: creates `NewWatcher()` after scanner, passes to `handleMessage()` and `serveShutdown()`, activates project on `get_project` message + - `.chief/prds/uplink/prd.json` - Updated US-011 status +- **Learnings for future iterations:** + - `fsnotify` watches individual directories, not recursive trees — must add each subdir explicitly (e.g., `.chief/prds/feature/`) + - `activeProject` struct tracks `watching` bool to prevent duplicate watcher setup + - Inactivity cleanup runs on a 1-minute ticker; tests override `inactiveTimeout` to milliseconds + - `handleMessage()` now accepts `*workspace.Watcher` to activate projects on `get_project` + - `serveShutdown()` now accepts `*workspace.Watcher` to close it during shutdown + - For git HEAD changes, `fsnotify` sees `HEAD.lock` operations — matching `strings.Contains(subPath, "HEAD")` catches both direct HEAD writes and lock-based updates +--- + +## 2026-02-15 - US-012 +- **What was implemented:** State snapshot on connect/reconnect, `list_projects`, `get_project`, and `get_prd` handlers with proper error responses +- **Files changed:** + - `internal/cmd/serve.go` - Added `sendStateSnapshot()`, `sendError()`, `handleListProjects()`, `handleGetProject()`, `handleGetPRD()` functions; refactored `RunServe` to create scanner before WebSocket connect for immediate scan availability; `handleMessage` now routes `list_projects`, `get_project`, `get_prd` messages; state snapshot sent after handshake and on reconnect via `WithOnReconnect` callback + - `internal/cmd/serve_test.go` - Added 7 new tests: `StateSnapshotOnConnect`, `ListProjects`, `GetProject`, `GetProjectNotFound`, `GetPRD`, `GetPRDNotFound`, `GetPRDProjectNotFound`; added `createGitRepo()` helper and `serveTestHelper()` for cleaner test setup; updated PingPong test to handle state_snapshot message + - `internal/workspace/scanner.go` - Added `SetClient()` to set client after creation and `FindProject()` for single-project lookup by name +- **Learnings for future iterations:** + - Scanner is now created before WebSocket client (with nil client), then `SetClient()` is called after client creation — this allows initial scan to populate project list before handshake + - `sendStateSnapshot()` is reused for both initial handshake and reconnect (via `WithOnReconnect` callback) + - `sendError()` utility includes `request_id` field to help clients correlate errors to their requests + - `get_prd` reads both `prd.md` (markdown content, optional) and `prd.json` (state) — content can be empty if prd.md doesn't exist yet + - Test helper `serveTestHelper()` encapsulates the boilerplate: reads hello, sends welcome, reads state_snapshot, then calls custom server function + - Existing PingPong test needed updating because state_snapshot is now sent before pong — the server must read it first + - `FindProject()` uses read lock and linear scan — fine for typical workspace sizes (dozens of projects) +--- + +## 2026-02-15 - US-013 +- **What was implemented:** Interactive Claude PRD sessions over WebSocket — spawn, stream output, send messages, and close sessions +- **Files changed:** + - `internal/cmd/session.go` - New `sessionManager` struct managing Claude PRD sessions: `newPRD()` spawns `claude` with init_prompt, `sendMessage()` writes to stdin, `closeSession()` with save/kill options, `killAll()` for shutdown, `activeSessions()` for state snapshots, auto-conversion of prd.md→prd.json on session end + - `internal/cmd/session_test.go` - 10 tests: new_prd (real and mock claude), project not found, prd_message session not found, close_prd_session session not found, mock claude lifecycle (spawn→message→close), save close, active sessions tracking, send message with echo verification, close errors, duplicate session prevention + - `internal/cmd/serve.go` - Integrated `sessionManager`: created after WS client, passed to `handleMessage()`, `serveShutdown()`, `sendStateSnapshot()`; added routing for `new_prd`, `prd_message`, `close_prd_session` message types; state snapshot now includes active sessions + - `.chief/prds/uplink/prd.json` - Updated US-013 status +- **Learnings for future iterations:** + - Claude interactive sessions use positional arg (not `-p` flag): `exec.Command("claude", prompt)` — the `-p` flag is for non-interactive/print mode + - Use shell script mocks in tests: write `#!/bin/sh` script to temp dir, prepend to `PATH` with `t.Setenv("PATH", dir+":"+origPath)` — this allows testing process lifecycle without real claude binary + - `projectFinder` interface extracted for testability — `*workspace.Scanner` satisfies it implicitly + - `sessionManager` uses `done` channel per session for synchronization: `close(sess.done)` signals process exit, `<-sess.done` blocks in `closeSession()` and `killAll()` + - Auto-conversion after session: scans all PRD dirs in project for `prd.NeedsConversion()` — Claude may create new PRD dirs during session + - `serveTestHelper` reads state_snapshot automatically after handshake — all test server functions receive conn after this step +--- + +## 2026-02-15 - US-014 +- **What was implemented:** Session timeout with warnings — sessions automatically expire after 30 minutes of inactivity with warnings at 20, 25, and 29 minutes +- **Files changed:** + - `internal/cmd/session.go` - Added `lastActive`/`activeMu` fields to `claudeSession`, `resetActivity()`/`inactiveDuration()` methods, `runTimeoutChecker()` goroutine with configurable check interval, `sendTimeoutWarning()`, `expireSession()` (saves state, kills process, sends `session_expired`), `sendMessage()` now resets inactivity timer, `killAll()` stops timeout checker + - `internal/cmd/session_test.go` - Added 5 tests: `TimeoutExpiration` (session expires and sends `session_expired`), `TimeoutWarnings` (warnings sent at correct thresholds), `TimeoutResetOnMessage` (prd_message resets timer), `IndependentTimers` (concurrent sessions have separate timers), `TimeoutCheckerGoroutineSafe` (concurrent session creation/messaging while timeout checker runs) +- **Learnings for future iterations:** + - Timeout checker uses a `stopTimeout` channel (closed by `killAll()`) to cleanly stop the goroutine + - Warning thresholds and check intervals are configurable on `sessionManager` for fast testing — production uses 30s check interval with 20/25/29 minute thresholds + - `expireSession()` gives Claude a 2-second grace period to finish writing after closing stdin before force-killing + - `sentWarnings` map in the checker prevents duplicate warnings — cleaned up when sessions are removed + - Direct `activeMu.Lock()` and `lastActive` manipulation in tests allows simulating time passage without real waits + - The process wait goroutine (from `newPRD`) handles cleanup (sends `claude_output done=true`, auto-converts, removes from sessions map); `expireSession` waits for `<-sess.done` so both paths are coordinated +--- + +## 2026-02-15 - US-015 +- **What was implemented:** Run control handlers for start_run, pause_run, resume_run, stop_run via WebSocket +- **Files changed:** + - `internal/cmd/runs.go` - New `runManager` struct wrapping `engine.Engine` with `startRun()`, `pauseRun()`, `resumeRun()`, `stopRun()`, `activeRuns()`, `stopAll()` methods; handler functions `handleStartRun`, `handlePauseRun`, `handleResumeRun`, `handleStopRun`; `activator` interface for testability + - `internal/cmd/runs_test.go` - 11 tests: start_run routing, project not found, pause/resume/stop not active errors, run manager unit tests (start+already active, pause/resume, stop, active runs, multiple concurrent projects, state string conversion) + - `internal/cmd/serve.go` - Integrated engine and run manager: creates `engine.New(5)` and `newRunManager()`, passes to `handleMessage()`, `sendStateSnapshot()`, `serveShutdown()`; routes `start_run`/`pause_run`/`resume_run`/`stop_run` messages; state snapshot now includes active runs; shutdown stops all runs +- **Learnings for future iterations:** + - Resume is implemented by calling `engine.Start()` again after a paused loop — `Manager.Start()` creates a fresh `Loop` that picks up from the next unfinished story in prd.json + - `runKey(project, prdID)` creates engine registration key as `"project/prdID"` to support multiple PRDs per project + - Error strings like `"RUN_ALREADY_ACTIVE"` and `"RUN_NOT_ACTIVE"` are used as sentinel error messages matched in handlers + - `sendStateSnapshot`, `handleMessage`, and `serveShutdown` now accept `*runManager` parameter — existing tests continued to work because `serveTestHelper` uses these functions indirectly through `RunServe` + - `activator` interface allows testing `handleStartRun` without a real `workspace.Watcher` +--- + +## 2026-02-15 - US-016 +- **What was implemented:** Quota detection and auto-pause for Ralph loop runs +- **Files changed:** + - `internal/loop/parser.go` - Added `IsQuotaError()` function with quota/rate-limit pattern matching, `ErrQuotaExhausted` sentinel error, `EventQuotaExhausted` event type + - `internal/loop/loop.go` - Capture stderr into buffer during `runIteration()`, check for quota patterns on non-zero exit; skip retries for quota errors; emit `EventQuotaExhausted` instead of `EventError` + - `internal/loop/manager.go` - Set `LoopStatePaused` (not `LoopStateError`) when quota exhaustion is detected, making the run resumable + - `internal/cmd/runs.go` - Added `startEventMonitor()` goroutine that subscribes to engine events and detects `EventQuotaExhausted`; `handleQuotaExhausted()` sends `run_paused` with `reason: "quota_exhausted"` and `quota_exhausted` message over WebSocket + - `internal/cmd/serve.go` - Wired up `runs.startEventMonitor(ctx)` after run manager creation + - `internal/loop/parser_test.go` - Added `TestIsQuotaError` with 16 test cases for pattern matching + - `internal/cmd/runs_test.go` - Added 5 tests: `HandleQuotaExhausted`, `HandleQuotaExhaustedUnknownRun`, `EventMonitorQuotaDetection`, `QuotaExhaustedWebSocket` (integration test with mock claude), `IsQuotaErrorIntegration` +- **Learnings for future iterations:** + - Quota errors are detected by checking stderr content and exit code error text against known patterns ("rate limit", "quota", "429", "too many requests", "resource_exhausted", "overloaded") + - Quota errors bypass retry logic entirely — `runIterationWithRetry` returns immediately with `ErrQuotaExhausted` instead of retrying + - The manager sets `LoopStatePaused` for quota errors (not `LoopStateError`) so the run can be resumed by the user + - `runManager.startEventMonitor()` subscribes to engine events and runs in a goroutine — it watches for `EventQuotaExhausted` across all runs + - When sending WS messages from `handleQuotaExhausted`, must guard against nil client (run manager can be used without a client in tests) + - `logAndCaptureStream` captures stderr into a `bytes.Buffer` while still logging it — used instead of `logStream` for the stderr pipe + - Mock claude scripts for quota tests: `echo "rate limit exceeded" >&2; exit 1` simulates quota exhaustion +--- + +## 2026-02-15 - US-017 +- **What was implemented:** Run progress streaming — `run_progress`, `run_complete`, and `claude_output` messages sent over WebSocket during active Ralph loop runs +- **Files changed:** + - `internal/cmd/runs.go` - Extended `startEventMonitor` to handle all engine event types; added `handleEvent()` router, `sendRunProgress()`, `sendRunComplete()`, `sendClaudeOutput()` methods; added `startTime` and `storyID` tracking to `runInfo` + - `internal/cmd/runs_test.go` - Added 5 tests: `HandleEventRunProgress` (all event types with nil client), `HandleEventUnknownRun`, `HandleEventStoryTracking`, `SendRunComplete`, `RunProgressStreaming` (integration test with mock claude) + - `.chief/prds/uplink/prd.json` - Updated US-017 status +- **Learnings for future iterations:** + - `startEventMonitor` was extended (not replaced) — the event loop now handles all event types via `handleEvent()`, not just quota exhaustion + - `runInfo` tracks `storyID` (updated on `EventStoryStarted`) so that `sendRunProgress` and `sendClaudeOutput` can include it even for events that don't carry a story ID + - `sendRunComplete` loads the PRD from disk to calculate pass/fail counts — same pattern as other PRD readers in the codebase + - All send methods guard against nil client, so the run manager can be used in tests without a WebSocket connection + - Mock claude scripts that output stream-json format are useful for integration testing: `echo '{"type":"system","subtype":"init"}'` triggers `EventIterationStart` + - `time.Since(info.startTime).Round(time.Second).String()` gives human-readable durations like "5m0s" for the `run_complete` message +--- + +## 2026-02-15 - US-018 +- **What was implemented:** Project settings via WebSocket — `get_settings` and `update_settings` handlers +- **Files changed:** + - `internal/config/config.go` - Added `MaxIterations`, `AutoCommit` (`*bool`), `CommitPrefix`, `ClaudeModel`, `TestCommand` fields to `Config` struct; added `EffectiveMaxIterations()` and `EffectiveAutoCommit()` helper methods; added `DefaultMaxIterations` constant + - `internal/config/config_test.go` - Added `TestSaveAndLoadSettingsFields` and `TestEffectiveDefaults` tests + - `internal/cmd/settings.go` - New file with `handleGetSettings()` and `handleUpdateSettings()` handlers; loads/saves config via `config.Load()`/`config.Save()`; validates `max_iterations >= 1`; partial update support via pointer fields + - `internal/cmd/settings_test.go` - 7 integration tests: defaults, project not found (get/update), existing config, full update, partial update preserving existing, invalid max_iterations + - `internal/cmd/serve.go` - Added routing for `get_settings` and `update_settings` messages in `handleMessage()` + - `.chief/prds/uplink/prd.json` - Updated US-018 status +- **Learnings for future iterations:** + - Config fields use `*bool` for `AutoCommit` so `false` is distinguishable from "not set" — `EffectiveAutoCommit()` returns `true` when nil + - `config.Load()` returns `Default()` when config file doesn't exist — no error, just zero values with `Effective*()` providing defaults + - Settings handlers reuse `projectFinder` interface (same pattern as sessions and runs) + - `handleUpdateSettings` does load→merge→save pattern: loads existing config, applies only non-nil fields from request, saves back + - YAML tags use `omitempty` to avoid writing zero values to config file +--- + +## 2026-02-15 - US-019 +- **What was implemented:** Per-story logging during Ralph loop runs and `get_logs` handler +- **Files changed:** + - `internal/cmd/logs.go` - New `storyLogger` struct for writing per-story log files, `handleGetLogs` handler for retrieving logs via WebSocket, `readLogFile`/`readMostRecentLog`/`sendLogLines` helper functions + - `internal/cmd/logs_test.go` - 15 tests: story logger write/read, empty story ID, overwrite on new run, line limit, nonexistent files, most recent log, run manager integration, serve integration tests for get_logs (with story ID, without story ID, project not found, PRD not found, line limit), end-to-end logging integration + - `internal/cmd/runs.go` - Added `loggers` map to `runManager`, `writeStoryLog()` method, logger creation in `startRun()`, logger cleanup in `cleanup()`/`stopAll()`, story log writing in `handleEvent()` for AssistantText/ToolStart/ToolResult/Error events + - `internal/cmd/serve.go` - Added `get_logs` message routing in `handleMessage()` + - `.chief/prds/uplink/prd.json` - Updated US-019 status +- **Learnings for future iterations:** + - Per-story logs are stored at `.chief/prds//logs/.log` — separate from the main `claude.log` which logs all raw output + - `newStoryLogger()` removes the entire `logs/` directory on creation (V1 simplicity: starting a new run overwrites previous logs) + - `storyLogger` lazily opens files on first write for each story ID — avoids creating empty log files + - `readMostRecentLog()` uses file modification time to find the most recently active story's log + - `readLogFile()` returns empty slice (not error) for nonexistent files — graceful handling of missing logs + - The `runManager.loggers` map is keyed by the same `runKey(project, prdID)` as the `runs` map + - Story log writing happens in `handleEvent()` alongside WebSocket message sending — they are parallel operations + - `handleGetLogs` follows the same `projectFinder` + error handling pattern as settings/sessions handlers +--- + +## 2026-02-15 - US-020 +- **What was implemented:** Per-story diff generation and retrieval via `get_diff` WebSocket handler, plus proactive diff sending on story completion +- **Files changed:** + - `internal/cmd/diffs.go` - New file with `handleGetDiff` handler, `getStoryDiff()`, `findStoryCommit()`, `getCommitDiff()`, `getCommitFiles()`, `sendDiffMessage()` functions + - `internal/cmd/diffs_test.go` - 11 tests: getStoryDiff success/no-commit/multiple-files, findStoryCommit most-recent, sendDiffMessage nil-safety, runManager sendStoryDiff, serve integration tests (get_diff success, project not found, PRD not found, no commit) + - `internal/cmd/runs.go` - Added `sendStoryDiff()` method to `runManager` for proactive diff on `EventStoryCompleted`; added call in `handleEvent()` + - `internal/cmd/serve.go` - Added `get_diff` message routing in `handleMessage()` +- **Learnings for future iterations:** + - Story commits follow `feat: - ` pattern — use `git log --grep <storyID> -1` to find the most recent matching commit + - `git show --format= --patch <hash>` gives the diff without commit metadata; `--name-only` gives the file list + - `sendStoryDiff` derives project path from `prdPath` by walking 4 levels up: `prd.json → <id> → prds → .chief → project` + - `getStoryDiff` is shared between the `handleGetDiff` handler (on-demand) and `sendStoryDiff` (proactive on story completion) + - `createGitRepoWithStoryCommit()` test helper creates a git repo with a commit matching the story pattern for diff testing +--- + +## 2026-02-15 - US-021 +- **What was implemented:** Git clone and project creation via WebSocket — `clone_repo` and `create_project` handlers +- **Files changed:** + - `internal/cmd/clone.go` - New file with `handleCloneRepo` (async git clone with progress streaming), `handleCreateProject` (directory creation with optional git init), `inferDirName()`, `runClone()`, `scanGitProgress()`, `sendCloneProgress()`, `sendCloneComplete()` functions + - `internal/cmd/clone_test.go` - 15 tests: inferDirName, clone success, custom directory name, directory already exists, invalid URL, create project success, create with git init, already exists, empty name, scanGitProgress splitter, percent pattern parsing, nil client safety + - `internal/cmd/serve.go` - Added `clone_repo` and `create_project` message routing in `handleMessage()` + - `internal/workspace/scanner.go` - Added `WorkspacePath()` method to expose workspace directory path + - `.chief/prds/uplink/prd.json` - Updated US-021 status +- **Learnings for future iterations:** + - Git clone writes progress to stderr, not stdout — use `StderrPipe()` to capture progress + - Git clone uses `\r` for in-place progress updates — custom `scanGitProgress` splitter handles both `\r` and `\n` + - Clone runs in a goroutine to avoid blocking the message loop — sends `clone_progress` and `clone_complete` messages asynchronously + - `inferDirName()` handles both HTTPS URLs and SSH-style URLs (git@github.com:user/repo.git) + - After clone/create, `scanner.ScanAndUpdate()` is called to make the new project immediately discoverable + - `create_project` with `git_init: true` sends `project_state` (project is discoverable); without git_init sends `project_list` + - `WorkspacePath()` was added to `Scanner` to expose the workspace path for clone/create operations +--- + +## 2026-02-15 - US-022 +- **What was implemented:** Version check against GitHub Releases API and `chief update` self-update command +- **Files changed:** + - `internal/update/update.go` - New package with `CheckForUpdate()`, `PerformUpdate()`, `CompareVersions()`, version normalization, asset finding, download/checksum verification, atomic binary replacement + - `internal/update/update_test.go` - 19 tests: version check (update available, already latest, dev version, API error, bad JSON), version normalization, version comparison, asset finding (match, no match, no checksum), write permission check, download to temp, checksum verification (success, mismatch), perform update (already latest, full flow), version with v-prefix + - `internal/cmd/update.go` - `RunUpdate(UpdateOptions)` command, `CheckVersionOnStartup()` (non-blocking goroutine for interactive CLI), `CheckVersionForServe()` for serve mode + - `internal/cmd/update_test.go` - 6 tests: already latest, API error, serve version check (update available, no update, API failure, dev version) + - `internal/cmd/serve.go` - Added `runVersionChecker()` goroutine (checks every 24h), `checkAndNotify()` helper that sends `update_available` over WebSocket, added `ReleasesURL` to `ServeOptions` for testing + - `cmd/chief/main.go` - Added `update` subcommand, `PersistentPreRun` with non-blocking startup version check (skipped for update/serve/version commands), updated help text +- **Learnings for future iterations:** + - `PerformUpdate()` accepts `currentVersion` as parameter (not discovered from binary) — version is set via ldflags at build time and passed through + - Asset naming convention: `chief-<GOOS>-<GOARCH>` for binary, `.sha256` suffix for checksum + - Checksum file format: `"hash filename"` — use `strings.Fields()` to parse + - `os.Executable()` + `filepath.EvalSymlinks()` to get the real binary path for replacement + - Write permission check: try `os.CreateTemp` in the target directory, immediately clean up + - `PersistentPreRun` on root Cobra command runs before all subcommands — use command name to skip specific commands + - Startup version check runs in a goroutine (non-blocking) — print message asynchronously; may appear after other output + - Serve version checker: immediate check on startup + `time.NewTicker(24 * time.Hour)` for periodic checks + - `update.Options.ReleasesURL` field allows tests to point at mock server (same pattern as `auth.BaseURL`) +- `handleMessage` returns `bool` — `true` signals the serve loop to exit cleanly (for remote update); all other handlers return `false` +- Setup token login: `chief login --setup-token <token>` exchanges via `POST /oauth/device/exchange`; cloud-init uses `CHIEF_SETUP_TOKEN` env var and a one-shot systemd service +- Rate limiting: `ws.NewRateLimiter()` checks in main event loop before `handleMessage()`; `Reset()` on reconnect; ping is exempt; expensive ops (clone_repo, start_run, new_prd) have per-type 2/minute limits +--- + +## 2026-02-15 - US-023 +- **What was implemented:** Remote update via WebSocket — `trigger_update` handler that downloads and installs latest binary, sends confirmation, and exits cleanly for systemd restart +- **Files changed:** + - `internal/cmd/remote_update.go` - New `handleTriggerUpdate()` function: checks for update, downloads/installs if available, sends `update_available` confirmation or `UPDATE_FAILED` error, returns bool indicating whether serve should exit + - `internal/cmd/remote_update_test.go` - 4 tests: already latest (unit), API error (unit), serve integration already latest, serve integration API error + - `internal/cmd/serve.go` - Added `trigger_update` routing in `handleMessage()`, changed `handleMessage` to return bool for exit signaling, main event loop handles exit cleanly after successful update + - `.chief/prds/uplink/prd.json` - Updated US-023 status +- **Learnings for future iterations:** + - `handleMessage` now returns a `bool` — returning `true` signals the serve loop to exit cleanly (used for remote update) + - `handleTriggerUpdate` returns `true` only on successful update; errors and "already latest" return `false` + - Avoided `os.Exit(0)` in handler — instead, the serve loop performs clean shutdown and returns `nil` error, allowing systemd `Restart=always` to pick up the new binary + - Integration tests that need `ReleasesURL` cannot use `serveTestHelper` (it doesn't expose that field) — write them manually with the same hello/welcome/state_snapshot pattern + - Permission errors from `update.PerformUpdate` contain "Permission denied" text — matched via `strings.Contains` to send a descriptive `UPDATE_FAILED` error +--- + +## 2026-02-15 - US-024 +- **What was implemented:** Systemd service unit file and cloud-init setup script for VPS deployment +- **Files changed:** + - `deploy/chief.service` - Systemd unit file with `ConditionPathExists` for credentials, `Type=simple`, `Restart=always`, `RestartSec=5`, `After=network-online.target`, security hardening directives + - `deploy/cloud-init.sh` - Idempotent cloud-init script that creates `chief` user, installs Chief binary (via existing `install.sh`), installs Claude Code CLI (via npm), creates workspace directory, writes and enables systemd service (but does NOT start it), prints post-deploy instructions +- **Learnings for future iterations:** + - Deployment files live in `deploy/` directory at project root + - Systemd `ConditionPathExists` prevents service from start/restart-looping before auth — service won't start until credentials file exists + - Cloud-init script reuses the existing `install.sh` for binary installation (via `CHIEF_INSTALL_DIR` env var) + - Service is `enabled` but not `started` — user must first run `chief login` and authenticate Claude Code before starting + - Script handles multiple distros for Node.js installation (apt/dnf/yum) + - `ProtectSystem=strict` + `ReadWritePaths=/home/chief` limits write access to only the chief home directory +--- + +## 2026-02-15 - US-025 +- **What was implemented:** One-time setup token for automated Chief authentication during VPS provisioning +- **Files changed:** + - `internal/cmd/login.go` - Added `SetupToken` field to `LoginOptions`, `exchangeSetupToken()` function that calls `POST /oauth/device/exchange` with the token and device name, returns credentials on success or falls back to manual login instructions on failure + - `internal/cmd/login_test.go` - Added 5 tests: setup token success, invalid token, expired token, server error, default device name with setup token + - `cmd/chief/main.go` - Added `--setup-token` flag to the `login` subcommand + - `deploy/cloud-init.sh` - Added `handle_setup_token()` function: writes token to `/tmp/chief-setup-token`, creates one-shot `chief-setup.service` that runs `chief login --setup-token` and starts the main service on success; updated usage docs and post-deploy instructions for token mode +- **Learnings for future iterations:** + - Setup token flow is much simpler than device OAuth — single HTTP POST, no polling, no browser interaction + - `exchangeSetupToken()` is called early in `RunLogin()` (before the "already logged in" check) since it's non-interactive + - The one-shot systemd service (`chief-setup.service`) chains `chief login` → `rm token file` → `systemctl start chief` in a single ExecStart + - `ExecStartPost` cleans up the token file even if the login fails, ensuring the single-use token doesn't persist + - Cloud-init passes the token via `CHIEF_SETUP_TOKEN` environment variable — safer than command-line args which appear in process listings +--- + +## 2026-02-16 - US-026 +- **What was implemented:** WebSocket rate limiting with token bucket algorithm and per-type expensive operation limiting +- **Files changed:** + - `internal/ws/ratelimit.go` - New `RateLimiter` struct with global token bucket (30 burst, 10/sec sustained), per-type `expensiveTracker` for expensive ops (2/minute for `clone_repo`, `start_run`, `new_prd`), ping exemption, `Reset()` for reconnection, `FormatRetryAfter()` helper + - `internal/ws/ratelimit_test.go` - 19 tests: burst allowance, burst exhaustion, token refill, ping exemption, expensive ops limiting, independent expensive trackers, window expiry, reset, expensive-consumes-global, concurrent access, FormatRetryAfter, IsExpensiveType, IsExemptType, tokenBucket retryAfter, expensiveTracker retryAfter + - `internal/cmd/serve.go` - Created `rateLimiter` before WebSocket client, `Reset()` on reconnect, rate limit check in main event loop before `handleMessage()`, sends `RATE_LIMITED` error with retry-after hint + - `internal/cmd/serve_test.go` - 3 integration tests: global rate limit exhaustion, ping exemption during rate limiting, expensive operation limiting +- **Learnings for future iterations:** + - Token bucket is a good fit for global rate limiting: allows bursts while enforcing sustained rate + - Expensive operations need separate per-type tracking with a sliding window (not token bucket) since the limit is per-minute, not per-second + - Rate limit check is done in the main event loop (before `handleMessage`) rather than inside `handleMessage` — cleaner separation of concerns + - `rateLimiter.Reset()` is called in the `WithOnReconnect` callback alongside `sendStateSnapshot` — rate limiter state resets on reconnection + - Pre-existing race conditions exist in `runManager` tests (unrelated to rate limiting) — these fail with `-race` flag but pass without it +--- + +## 2026-02-16 - US-027 +- **What was implemented:** Graceful shutdown for `chief serve` with process killing, story interruption tracking, timeout enforcement, and proper log flushing +- **Files changed:** + - `internal/cmd/serve.go` - Enhanced `serveShutdown()` with: process counting, in-progress story marking, 5-second force-kill timeout, engine shutdown integration, proper log file sync/flush; removed `defer eng.Shutdown()` (engine now shut down inside `serveShutdown`); all callers updated to pass engine + - `internal/cmd/runs.go` - Added `markInterruptedStories()` (loads PRD, sets `inProgress: true` for active story, saves back), `activeRunCount()` method + - `internal/cmd/session.go` - Added `sessionCount()` method for process counting during shutdown + - `internal/cmd/serve_test.go` - Added 4 tests: `ShutdownLogsSequence`, `ShutdownMarksInterruptedStories` (integration test with mock claude), `ShutdownLogFileFlush`, `SessionManager_SessionCount` + - `internal/cmd/runs_test.go` - Added 3 tests: `MarkInterruptedStories`, `MarkInterruptedStoriesNoStoryID`, `ActiveRunCount` +- **Learnings for future iterations:** + - `defer eng.Shutdown()` in `RunServe` blocks on `StopAll().wg.Wait()` — if processes are still alive after the 5-second force-kill timeout, this causes a hang. Solution: move engine shutdown into `serveShutdown` goroutine alongside `stopAll()`/`killAll()` + - `markInterruptedStories()` must be called BEFORE `stopAll()` because `stopAll()` clears the runs map and closes loggers + - Force-kill timeout must encompass both runs (`StopAll` which sends `Kill()` to processes) and sessions (`killAll` which kills process and waits for `<-done`) + - The 5-second timeout allows the goroutine to continue running in the background while we proceed with watcher/WebSocket cleanup — acceptable because the goroutine will eventually finish or be cleaned up by process exit + - Mock claude scripts using `sleep 300` are good for testing force-kill behavior — the process won't exit on its own, so the 5-second timeout is exercised + - Log file flush uses `Sync()` before `Close()` (via deferred function) to ensure all buffered output is written to disk +--- + +## 2026-02-16 - US-028 +- **What was implemented:** WebSocket URL configuration for local development — `--ws-url` flag, `CHIEF_WS_URL` env var, and `ws_url` field in `~/.chief/config.yaml` with proper precedence +- **Files changed:** + - `internal/config/config.go` - Added `UserConfig` struct with `WSURL` field and `LoadUserConfig()` function for reading user-level config from `~/.chief/config.yaml` + - `internal/config/config_test.go` - Added 3 tests: `LoadUserConfig_NonExistent`, `LoadUserConfig_WithWSURL`, `LoadUserConfig_EmptyWSURL` + - `internal/cmd/serve.go` - Updated WS URL resolution to check flag > `CHIEF_WS_URL` env var > user config `ws_url` > default; added `config` import + - `internal/cmd/serve_test.go` - Added 5 integration tests: `WSURLFromEnvVar`, `WSURLFromUserConfig`, `WSURLPrecedence_FlagOverridesEnv`, `WSURLPrecedence_EnvOverridesConfig`, `WSURLLoggedOnStartup` + - `cmd/chief/main.go` - Added `--ws-url` flag to `serve` subcommand +- **Learnings for future iterations:** + - User-level config (`~/.chief/config.yaml`) is separate from project-level config (`.chief/config.yaml`) — uses `os.UserHomeDir()` like auth credentials + - `LoadUserConfig()` returns empty struct (not error) when file doesn't exist — follows same pattern as project config + - `gorilla/websocket` natively supports both `ws://` and `wss://` — no special handling needed for TLS vs plain + - The WS URL was already logged on startup (line 95 in serve.go), so that AC was already met + - Precedence chain (flag > env > config > default) is cleanly implemented with sequential `if wsURL == ""` checks +--- diff --git a/.gitignore b/.gitignore index 6e94bfb..0f4ce00 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,5 @@ # Binaries -/chief +/bin/ *.exe # Log files @@ -11,3 +11,6 @@ node_modules/ # VitePress docs/.vitepress/cache/ docs/.vitepress/dist/ + +# Contract fixtures (synced from chief-uplink) +contract/fixtures/ diff --git a/Makefile b/Makefile index ed056a0..1323438 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,7 @@ # Chief - Autonomous PRD Agent # https://github.com/minicodemonkey/chief -BINARY_NAME := chief +BINARY_NAME := bin/chief VERSION := $(shell git describe --tags --always --dirty 2>/dev/null || echo "dev") BUILD_DIR := ./build MAIN_PKG := ./cmd/chief @@ -9,12 +9,13 @@ MAIN_PKG := ./cmd/chief # Go build flags LDFLAGS := -ldflags "-X main.Version=$(VERSION)" -.PHONY: all build install test lint clean release snapshot help +.PHONY: all build install test lint clean release snapshot help sync-fixtures test-contract all: build ## build: Build the binary build: + @mkdir -p bin go build $(LDFLAGS) -o $(BINARY_NAME) $(MAIN_PKG) ## install: Install to $GOPATH/bin @@ -64,6 +65,33 @@ release: run: build ./$(BINARY_NAME) +## Contract fixtures — chief-uplink is the source of truth. +## Override FIXTURES_REPO for local dev: make sync-fixtures FIXTURES_REPO=../chief-uplink/contract/fixtures +FIXTURES_REPO ?= https://raw.githubusercontent.com/MiniCodeMonkey/chief-uplink/main/contract/fixtures +FIXTURES_DIR := contract/fixtures + +## sync-fixtures: Download contract fixtures from chief-uplink +sync-fixtures: + @mkdir -p $(FIXTURES_DIR)/cli-to-server $(FIXTURES_DIR)/server-to-cli + @for f in cli-to-server/connect_request.json cli-to-server/state_snapshot.json \ + cli-to-server/messages_batch.json cli-to-server/prds_response.json \ + cli-to-server/settings_response.json cli-to-server/diffs_response.json \ + server-to-cli/welcome_response.json server-to-cli/command_create_project.json \ + server-to-cli/command_list_projects.json server-to-cli/command_start_run.json \ + server-to-cli/command_get_prds.json server-to-cli/command_get_settings.json \ + server-to-cli/command_get_diffs.json server-to-cli/command_new_prd.json \ + server-to-cli/command_refine_prd.json server-to-cli/command_prd_message.json; do \ + if echo "$(FIXTURES_REPO)" | grep -q "^http"; then \ + curl -sf "$(FIXTURES_REPO)/$$f" -o "$(FIXTURES_DIR)/$$f" || echo "WARN: failed to fetch $$f"; \ + else \ + cp "$(FIXTURES_REPO)/$$f" "$(FIXTURES_DIR)/$$f" || echo "WARN: failed to copy $$f"; \ + fi; \ + done + +## test-contract: Run contract tests (syncs fixtures first) +test-contract: sync-fixtures + go test ./internal/contract/ -v + ## help: Show this help help: @echo "Usage: make [target]" diff --git a/WEBSOCKET_REFACTOR.md b/WEBSOCKET_REFACTOR.md new file mode 100644 index 0000000..e0d78a0 --- /dev/null +++ b/WEBSOCKET_REFACTOR.md @@ -0,0 +1,511 @@ +# WebSocket Refactor: Managed Reverb Compatibility + +## Problem + +Laravel Cloud's managed Reverb is a separate service that only runs the standard Pusher protocol routes (`/app/{appKey}`, `/up`, etc.). Our custom `/ws/server` endpoint (registered via `ChiefReverbFactory`) never gets loaded because it lives in our application code, not in the managed Reverb cluster. This means the chief CLI can't connect. + +## Solution + +Keep managed Reverb for browser broadcasting (it already works). Adapt the chief CLI to: + +1. **Send data to the server via HTTP POST** (replaces WebSocket sends) +2. **Receive commands via Reverb's Pusher protocol** (subscribes to a private channel as a standard Pusher client) + +No infrastructure changes required — managed Reverb stays as-is. + +## Architecture: Before vs After + +### Before (custom WebSocket) + +``` +Chief CLI ──WebSocket /ws/server──→ ChiefServerController ──broadcast──→ Browser + ↑ │ + └──── WebSocket (sendToDevice) ←── CommandRelayController ←── HTTP POST─┘ +``` + +- Single persistent WebSocket for all bidirectional communication +- Custom hello/welcome handshake for auth +- In-memory connection tracking (ServerConnectionManager) +- ChiefReverbFactory adds custom route to Reverb server + +### After (HTTP + Pusher channel) + +``` +Chief CLI ──HTTP POST /api/device/messages──→ MessageIngestionController ──broadcast──→ Browser + ↑ │ + └──── Reverb private channel (Pusher protocol) ←── broadcast ←── CommandRelayController ←─┘ +``` + +- CLI sends data via HTTP POST (batched) +- CLI receives commands by subscribing to `private-chief-server.{deviceId}` on managed Reverb +- Auth via OAuth access token (existing) on both HTTP and channel subscription +- No custom Reverb routes, no in-memory connection tracking + +--- + +## Detailed Changes + +### Phase 1: New HTTP ingestion endpoint (chief-uplink) + +Create a new HTTP API endpoint that accepts messages from chief CLIs — the replacement for the WebSocket receive path. + +#### 1.1 New controller: `MessageIngestionController` + +**File:** `app/Http/Controllers/Api/MessageIngestionController.php` + +Accepts batched messages from the CLI via HTTP POST. Replaces `ChiefServerController::handleMessage()`. + +``` +POST /api/device/messages +Authorization: Bearer {access_token} +Content-Type: application/json + +{ + "messages": [ + {"type": "state_snapshot", "id": "...", "timestamp": "...", ...}, + {"type": "claude_output", "id": "...", "timestamp": "...", ...}, + ... + ] +} +``` + +Responsibilities: +- Validate OAuth access token (reuse `DeviceOAuthController::validateAccessToken()`) +- Check device not revoked +- For each message: + - If `project_state` → update `CachedProjectState` (same as current `handleProjectState()`) + - If bufferable → buffer via `WebSocketMessageBuffer` + - Broadcast to browser via `ChiefMessageReceived` event +- Return acknowledgment with any pending server-side state (e.g., new session_id) + +Response: +```json +{ + "accepted": 5, + "session_id": "uuid" +} +``` + +#### 1.2 New controller: `DevicePresenceController` + +**File:** `app/Http/Controllers/Api/DevicePresenceController.php` + +Handles explicit connect/disconnect lifecycle — replaces the WebSocket open/close events. + +``` +POST /api/device/connect +Authorization: Bearer {access_token} +Content-Type: application/json + +{ + "chief_version": "0.5.0", + "device_name": "sierra", + "os": "darwin", + "arch": "arm64", + "protocol_version": 1 +} +``` + +Responsibilities: +- Validate access token, check device not revoked +- Update device metadata (chief_version, os, arch, device_name, is_online, last_connected_at) +- Generate session_id for message buffering +- Mark device reconnected in buffer +- Dispatch `DeviceConnected` event +- Return welcome response (same fields as current WebSocket welcome) + +```json +{ + "type": "welcome", + "protocol_version": 1, + "device_id": 42, + "session_id": "uuid", + "reverb": { + "key": "app-key", + "host": "ws-xxx-reverb.laravel.cloud", + "port": 443, + "scheme": "https" + } +} +``` + +The `reverb` block tells the CLI where to connect as a Pusher client. + +``` +POST /api/device/disconnect +Authorization: Bearer {access_token} +``` + +Responsibilities: +- Mark device offline, dispatch `DeviceDisconnected` +- Start buffer grace period + +#### 1.3 New middleware: `AuthenticateDevice` + +**File:** `app/Http/Middleware/AuthenticateDevice.php` + +Extracts and validates the OAuth access token from the `Authorization: Bearer` header. Sets `$request->attributes->set('device_id', ...)` and `$request->attributes->set('user_id', ...)` for downstream controllers. Reuses `DeviceOAuthController::validateAccessToken()`. + +Apply to all `/api/device/*` routes. + +#### 1.4 Channel auth for CLI devices + +**File:** `routes/channels.php` + +Add authorization for the new CLI channel: + +```php +Broadcast::channel('chief-server.{deviceId}', function ($user, $deviceId) { + // CLI authenticates channel subscription using its OAuth token. + // The token is passed as the auth token in the Pusher subscription. + // We need a custom auth endpoint for this — see 1.5. + return $user->deviceAuthorizations() + ->where('id', $deviceId) + ->whereNull('revoked_at') + ->exists(); +}); +``` + +#### 1.5 Custom Pusher auth endpoint for CLI + +The chief CLI authenticates via OAuth access tokens, not Laravel sessions. We need a custom broadcasting auth endpoint that accepts Bearer tokens. + +**File:** `app/Http/Controllers/Api/DeviceBroadcastAuthController.php` + +``` +POST /api/device/broadcasting/auth +Authorization: Bearer {access_token} +Content-Type: application/json + +{"socket_id": "...", "channel_name": "private-chief-server.42"} +``` + +This endpoint: +- Validates the access token +- Checks the device owns the requested channel +- Returns the Pusher auth signature (same format as Laravel's standard broadcast auth) + +#### 1.6 Refactor `CommandRelayController` + +**File:** `app/Http/Controllers/Api/CommandRelayController.php` + +Currently calls `$this->connectionManager->sendToDevice()` which sends via in-memory WebSocket. Change to broadcast the command via Reverb to the CLI's channel: + +```php +// Before: +$sent = $this->connectionManager->sendToDevice($deviceId, $message); + +// After: +broadcast(new ChiefCommandDispatched($deviceId, $userId, $message)); +``` + +New event: `ChiefCommandDispatched` broadcasts on `private-chief-server.{deviceId}` with event name `chief.command`. + +The `isDeviceOnline` check changes from in-memory lookup to checking `DeviceAuthorization.is_online` in the database. + +#### 1.7 Refactor `ServerConnectionManager` + +The in-memory connection tracking (`$connections`, `$deviceToConnection`, `$connectionObjects`) is no longer needed. The class simplifies to a stateless service that: + +- Delegates to `WebSocketMessageBuffer` for buffering +- Checks device online status via database +- No longer stores Connection objects or session IDs in memory (session IDs stored in Redis or on the DeviceAuthorization model) + +Alternatively, this class can be removed entirely and its responsibilities distributed to the new controllers. + +#### 1.8 New event: `ChiefCommandDispatched` + +**File:** `app/Events/ChiefCommandDispatched.php` + +```php +class ChiefCommandDispatched implements ShouldBroadcast +{ + public function __construct( + public readonly int $deviceId, + public readonly int $userId, + public readonly array $command, + ) {} + + public function broadcastOn(): array + { + return [new Channel("private-chief-server.{$this->deviceId}")]; + } + + public function broadcastAs(): string + { + return 'chief.command'; + } + + public function broadcastWith(): array + { + return $this->command; + } +} +``` + +#### 1.9 New route: `DeviceHeartbeatController` + +The CLI needs to periodically confirm it's still alive, since we no longer have a persistent WebSocket connection to detect disconnects. + +``` +POST /api/device/heartbeat +Authorization: Bearer {access_token} +``` + +Called every 30-60 seconds by the CLI. Updates `last_heartbeat_at` on the device. A scheduled job marks devices as offline if no heartbeat received within 2 minutes. + +--- + +### Phase 2: CLI changes (chief — Go) + +#### 2.1 New HTTP client: `internal/uplink/client.go` + +Replaces the WebSocket client for sending data to the server. Handles: + +- `POST /api/device/connect` — on startup +- `POST /api/device/messages` — batched message sending +- `POST /api/device/heartbeat` — periodic keepalive +- `POST /api/device/disconnect` — on shutdown +- OAuth token refresh (existing logic, but applied to HTTP headers) +- Retry with exponential backoff on failure + +#### 2.2 Message batcher: `internal/uplink/batcher.go` + +Batches outgoing messages to reduce HTTP request volume. Key design: + +- Collects messages in a buffer +- Flushes when: buffer reaches N messages (e.g., 20), OR time threshold elapsed (e.g., 200ms), OR a priority message arrives (e.g., `run_complete`) +- Each flush sends a single `POST /api/device/messages` with the batch +- Priority messages (user-visible state changes) flush immediately +- Streaming messages (`claude_output`) batch on the 200ms timer + +Categories: +- **Immediate flush:** `run_complete`, `run_paused`, `error`, `clone_complete`, `session_expired`, `quota_exhausted` +- **Batched (200ms):** `claude_output`, `prd_output`, `run_progress`, `clone_progress` +- **Low priority (1s):** `state_snapshot`, `project_state`, `project_list`, `settings`, `log_lines` + +#### 2.3 Pusher client: `internal/uplink/pusher.go` + +Subscribes to `private-chief-server.{deviceId}` on managed Reverb to receive commands from the browser. Uses a Go Pusher client library (e.g., `pusher/pusher-websocket-go` or a lightweight implementation). + +Responsibilities: +- Connect to Reverb as a Pusher client using the app key and host from the connect response +- Authenticate the private channel via `POST /api/device/broadcasting/auth` +- Listen for `chief.command` events +- Route received commands to the existing message dispatcher + +The existing `Dispatcher` pattern (type-based routing) stays the same — commands arrive through a different transport but are dispatched identically. + +#### 2.4 Refactor `internal/cmd/serve.go` + +Replace `ws.Client` usage with the new uplink client: + +```go +// Before: +client = ws.New(wsURL, ws.WithOnReconnect(func() { ... })) +client.Connect(ctx) +client.Handshake(creds.AccessToken, version, deviceName) +// ... main loop reads from client.Receive() + +// After: +uplink := uplink.New(baseURL, creds.AccessToken, uplink.WithOnReconnect(func() { ... })) +uplink.Connect(ctx) // POST /api/device/connect + subscribe to Pusher channel +// ... main loop reads from uplink.Receive() (same interface, different transport) +``` + +The `Send()` method now enqueues into the batcher instead of writing to WebSocket. The `Receive()` channel is fed by the Pusher client instead of WebSocket reads. + +#### 2.5 Handshake changes + +The current WebSocket handshake (hello → welcome) becomes an HTTP request: + +```go +// Before: +client.Handshake(accessToken, version, deviceName) + +// After: +welcome, err := httpClient.Connect(ConnectRequest{ + ChiefVersion: version, + DeviceName: deviceName, + OS: runtime.GOOS, + Arch: runtime.GOARCH, +}) +// welcome contains device_id, session_id, reverb config +``` + +#### 2.6 Reconnection logic + +Two reconnection paths: + +1. **HTTP failures:** Retry with exponential backoff (same as current WebSocket retry). If the server is unreachable, buffer messages locally and flush on reconnect. +2. **Pusher disconnection:** The Pusher client library handles reconnection automatically. On reconnect, re-authenticate the channel. + +On any reconnection, re-send state snapshot via HTTP POST (same as current `onRecon` callback). + +#### 2.7 Heartbeat + +Add a periodic heartbeat goroutine that calls `POST /api/device/heartbeat` every 30 seconds. If the heartbeat fails, trigger reconnection logic. + +#### 2.8 Graceful shutdown + +On SIGTERM/SIGINT: +1. Stop accepting new commands from Pusher +2. Flush message batcher +3. Call `POST /api/device/disconnect` +4. Disconnect Pusher client +5. Kill Claude sessions and runs (existing logic) + +--- + +### Phase 3: Remove old WebSocket infrastructure (chief-uplink) + +After the new system is working: + +#### 3.1 Delete files +- `app/WebSocket/ChiefReverbFactory.php` +- `app/WebSocket/ChiefServerController.php` +- `app/Console/Commands/StartReverbServer.php` + +#### 3.2 Simplify `WebSocketServiceProvider` +- Remove the `StartServer::class` → `StartReverbServer::class` container binding +- Remove `ServerConnectionManager` singleton if fully replaced +- Keep `PrdSessionManager` singleton + +#### 3.3 Remove Reverb dependency from `ServerConnectionManager` +- Remove `use Laravel\Reverb\Servers\Reverb\Connection;` +- Remove `$connectionObjects` array and all methods that reference it +- Or delete the class entirely if all responsibilities moved to new controllers + +#### 3.4 Clean up tests +- Update `tests/Feature/WebSocket/MessageRelayTest.php` → test HTTP endpoints +- Add tests for `MessageIngestionController`, `DevicePresenceController`, `DeviceBroadcastAuthController` + +--- + +### Phase 4: Frontend changes (chief-uplink — minimal) + +The browser-side code requires almost no changes. + +#### 4.1 `CommandRelayController` response + +The `isDeviceOnline` check changes from in-memory to database, but the HTTP response contract stays the same. Frontend `useCommandRelay.ts` is unchanged. + +#### 4.2 `useChiefMessages.ts` — unchanged + +Still subscribes to `private-device.{deviceId}` and listens for `chief.message` events. The messages arrive via the same Reverb channel — only the server-side path that triggers the broadcast changes (HTTP controller instead of WebSocket controller). + +#### 4.3 `useEcho.ts` — unchanged + +No changes to Echo setup or connection management. + +#### 4.4 `echo.ts` — unchanged + +Still uses `broadcaster: 'reverb'` with managed Reverb config. + +--- + +## Message Flow Comparison + +### CLI → Browser (e.g., `claude_output` streaming) + +**Before:** +1. CLI sends `claude_output` JSON over WebSocket +2. `ChiefServerController::handleMessage()` receives it +3. Buffers message via `WebSocketMessageBuffer` +4. Dispatches `ChiefMessageReceived` broadcast event +5. Reverb delivers to browser on `private-device.{deviceId}` + +**After:** +1. CLI enqueues `claude_output` into message batcher +2. Batcher flushes batch via `POST /api/device/messages` +3. `MessageIngestionController` receives batch +4. For each message: buffer + dispatch `ChiefMessageReceived` +5. Reverb delivers to browser on `private-device.{deviceId}` (same as before) + +### Browser → CLI (e.g., `start_run` command) + +**Before:** +1. Browser calls `POST /ws/command/{deviceId}` with `{type: "start_run", payload: {...}}` +2. `CommandRelayController::send()` validates request +3. Calls `ServerConnectionManager::sendToDevice()` → writes to in-memory WebSocket connection +4. CLI receives `start_run` in `readLoop()` → dispatches to handler + +**After:** +1. Browser calls `POST /ws/command/{deviceId}` with `{type: "start_run", payload: {...}}` (same) +2. `CommandRelayController::send()` validates request (same) +3. Dispatches `ChiefCommandDispatched` broadcast event on `private-chief-server.{deviceId}` +4. Reverb delivers to CLI's Pusher subscription → dispatches to handler + +--- + +## Device Lifecycle + +### Connect + +1. CLI calls `POST /api/device/connect` with metadata + access token +2. Server validates token, updates device record, generates session_id +3. Server dispatches `DeviceConnected` event to browser +4. Server returns welcome response with Reverb config +5. CLI connects to Reverb as Pusher client, subscribes to `private-chief-server.{deviceId}` +6. CLI sends initial `state_snapshot` via `POST /api/device/messages` + +### Steady state + +- CLI sends messages in batches via `POST /api/device/messages` (every 200ms or on priority) +- CLI sends heartbeat via `POST /api/device/heartbeat` (every 30s) +- CLI receives commands via Pusher channel subscription +- Browser sends commands via `POST /ws/command/{deviceId}` (unchanged) +- Browser receives messages via `private-device.{deviceId}` channel (unchanged) + +### Disconnect + +**Graceful (CLI shutdown):** +1. CLI calls `POST /api/device/disconnect` +2. Server marks device offline, starts buffer grace period +3. Server dispatches `DeviceDisconnected` event + +**Ungraceful (network failure, crash):** +1. Heartbeat stops arriving +2. Scheduled job detects stale heartbeat (>2 min) +3. Marks device offline, starts buffer grace period +4. Dispatches `DeviceDisconnected` event + +--- + +## Risks and Mitigations + +### Latency from HTTP batching +**Risk:** 200ms batch window adds latency to `claude_output` streaming. +**Mitigation:** 200ms is imperceptible for terminal-like output. Priority messages flush immediately. Can tune batch window down to 100ms if needed. + +### HTTP overhead vs WebSocket +**Risk:** More HTTP requests than a single persistent connection. +**Mitigation:** Batching reduces request count significantly. A typical streaming session generates ~5 HTTP requests/second (vs thousands of individual WebSocket frames). HTTP/2 connection reuse minimizes TCP overhead. + +### Pusher message size limits +**Risk:** Managed Reverb (Pusher protocol) may have message size limits for channel events. +**Mitigation:** Commands from browser → CLI are small (typically <1KB). The large data flow (CLI → server) goes via HTTP, not Pusher channels. Reverb's default max message size is 10KB, and commands never approach this. + +### Heartbeat-based disconnect detection +**Risk:** Up to 2 minutes to detect a crashed CLI (vs instant WebSocket close detection). +**Mitigation:** Acceptable for the use case — the browser already debounces disconnect events by 2 seconds. The "offline" indicator updates within 2 minutes, which is fine. Can reduce heartbeat interval to 15s and detection to 45s if needed. + +### Authentication on Pusher channel +**Risk:** The CLI uses OAuth tokens, not Laravel sessions, for auth. Pusher channel auth requires a custom endpoint. +**Mitigation:** `DeviceBroadcastAuthController` provides a standard Pusher auth response using the CLI's OAuth token. The Pusher client library supports custom auth endpoints. + +--- + +## Implementation Order + +1. **Phase 1.1–1.3:** New HTTP endpoints + middleware (can deploy independently, no breaking changes) +2. **Phase 1.4–1.5:** Channel auth for CLI (deploy with Phase 1) +3. **Phase 1.8–1.9:** New event + heartbeat (deploy with Phase 1) +4. **Phase 2.1–2.3:** New Go uplink client, batcher, Pusher client +5. **Phase 2.4–2.8:** Refactor serve command to use new client +6. **Phase 1.6–1.7:** Refactor CommandRelayController to broadcast instead of direct WebSocket send +7. **Test end-to-end with both old and new CLI versions** +8. **Phase 3:** Remove old WebSocket infrastructure after confirming new system works +9. **Phase 4:** Any minor frontend adjustments + +Total estimate: ~15 new/modified files across both repos. diff --git a/cmd/chief/main.go b/cmd/chief/main.go index fca1f4a..7dd9782 100644 --- a/cmd/chief/main.go +++ b/cmd/chief/main.go @@ -5,7 +5,6 @@ import ( "log" "os" "path/filepath" - "strconv" "strings" tea "github.com/charmbracelet/bubbletea" @@ -15,6 +14,7 @@ import ( "github.com/minicodemonkey/chief/internal/notify" "github.com/minicodemonkey/chief/internal/prd" "github.com/minicodemonkey/chief/internal/tui" + "github.com/spf13/cobra" ) // Version is set at build time via ldflags @@ -32,47 +32,269 @@ type TUIOptions struct { } func main() { - // Handle subcommands first - if len(os.Args) > 1 { - switch os.Args[1] { - case "new": - runNew() - return - case "edit": - runEdit() - return - case "status": - runStatus() - return - case "list": - runList() - return - case "help": - printHelp() - return - case "--help", "-h": - printHelp() - return - case "--version", "-v": - fmt.Printf("chief version %s\n", Version) - return - case "wiggum": - printWiggum() + rootCmd := buildRootCmd() + if err := rootCmd.Execute(); err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } +} + +func buildRootCmd() *cobra.Command { + opts := &TUIOptions{} + + rootCmd := &cobra.Command{ + Use: "chief [name|path/to/prd.json]", + Short: "Chief - Autonomous PRD Agent", + Long: "Chief breaks down PRDs into user stories and uses Claude Code to implement them autonomously.", + // Accept arbitrary args so positional PRD name/path works + Args: cobra.ArbitraryArgs, + Version: Version, + // Silence Cobra's default error/usage printing so we control output + SilenceErrors: true, + SilenceUsage: true, + PersistentPreRun: func(c *cobra.Command, args []string) { + // Non-blocking version check on startup for all interactive commands + // Skip for update command itself and serve (which has its own check) + name := c.Name() + if name != "update" && name != "serve" && name != "version" { + cmd.CheckVersionOnStartup(Version) + } + }, + RunE: func(c *cobra.Command, args []string) error { + // Resolve positional argument as PRD name or path + if len(args) > 0 { + arg := args[0] + if strings.HasSuffix(arg, ".json") || strings.HasSuffix(arg, "/") { + opts.PRDPath = arg + } else { + opts.PRDPath = fmt.Sprintf(".chief/prds/%s/prd.json", arg) + } + } + runTUIWithOptions(opts) + return nil + }, + } + + // Set custom version template to match previous output format + rootCmd.SetVersionTemplate("chief version {{.Version}}\n") + + // Root flags (TUI mode) + rootCmd.Flags().IntVarP(&opts.MaxIterations, "max-iterations", "n", 0, "Set maximum iterations (default: dynamic)") + rootCmd.Flags().BoolVar(&opts.NoSound, "no-sound", false, "Disable completion sound notifications") + rootCmd.Flags().BoolVar(&opts.NoRetry, "no-retry", false, "Disable auto-retry on Claude crashes") + rootCmd.Flags().BoolVar(&opts.Verbose, "verbose", false, "Show raw Claude output in log") + rootCmd.Flags().BoolVar(&opts.Merge, "merge", false, "Auto-merge progress on conversion conflicts") + rootCmd.Flags().BoolVar(&opts.Force, "force", false, "Auto-overwrite on conversion conflicts") + + // Subcommands + rootCmd.AddCommand(newNewCmd()) + rootCmd.AddCommand(newEditCmd()) + rootCmd.AddCommand(newStatusCmd()) + rootCmd.AddCommand(newListCmd()) + rootCmd.AddCommand(newLoginCmd()) + rootCmd.AddCommand(newLogoutCmd()) + rootCmd.AddCommand(newServeCmd()) + rootCmd.AddCommand(newUpdateCmd()) + rootCmd.AddCommand(newWiggumCmd()) + + // Custom help for root command only (subcommands use default Cobra help) + defaultHelp := rootCmd.HelpFunc() + rootCmd.SetHelpFunc(func(c *cobra.Command, args []string) { + if c != rootCmd { + defaultHelp(c, args) return } + fmt.Print(`Chief - Autonomous PRD Agent + +Usage: + chief [options] [<name>|<path/to/prd.json>] + chief <command> [arguments] + +Commands: + new [name] [context] Create a new PRD interactively + edit [name] [options] Edit an existing PRD interactively + status [name] Show progress for a PRD (default: main) + list List all PRDs with progress + login Authenticate with chiefloop.com + logout Log out and deauthorize this device + serve Start headless daemon for web app + update Update Chief to the latest version + +Options: + --max-iterations N, -n N Set maximum iterations (default: dynamic) + --no-sound Disable completion sound notifications + --no-retry Disable auto-retry on Claude crashes + --verbose Show raw Claude output in log + --merge Auto-merge progress on conversion conflicts + --force Auto-overwrite on conversion conflicts + -h, --help Show this help message + -v, --version Show version number + +Examples: + chief Launch TUI with default PRD (.chief/prds/main/) + chief auth Launch TUI with named PRD (.chief/prds/auth/) + chief ./my-prd.json Launch TUI with specific PRD file + chief -n 20 Launch with 20 max iterations + chief --max-iterations=5 auth + Launch auth PRD with 5 max iterations + chief --no-sound Launch TUI without audio notifications + chief --verbose Launch with raw Claude output visible + chief new Create PRD in .chief/prds/main/ + chief new auth Create PRD in .chief/prds/auth/ + chief new auth "JWT authentication for REST API" + Create PRD with context hint + chief edit Edit PRD in .chief/prds/main/ + chief edit auth Edit PRD in .chief/prds/auth/ + chief edit auth --merge Edit and auto-merge progress + chief status Show progress for default PRD + chief status auth Show progress for auth PRD + chief list List all PRDs with progress + chief --version Show version number +`) + }) + + return rootCmd +} + +func newNewCmd() *cobra.Command { + return &cobra.Command{ + Use: "new [name] [context...]", + Short: "Create a new PRD interactively", + Args: cobra.ArbitraryArgs, + RunE: func(c *cobra.Command, args []string) error { + opts := cmd.NewOptions{} + if len(args) > 0 { + opts.Name = args[0] + } + if len(args) > 1 { + opts.Context = strings.Join(args[1:], " ") + } + return cmd.RunNew(opts) + }, } +} + +func newEditCmd() *cobra.Command { + editOpts := &cmd.EditOptions{} + + editCmd := &cobra.Command{ + Use: "edit [name]", + Short: "Edit an existing PRD interactively", + Args: cobra.MaximumNArgs(1), + RunE: func(c *cobra.Command, args []string) error { + if len(args) > 0 { + editOpts.Name = args[0] + } + return cmd.RunEdit(*editOpts) + }, + } + + editCmd.Flags().BoolVar(&editOpts.Merge, "merge", false, "Auto-merge progress on conversion conflicts") + editCmd.Flags().BoolVar(&editOpts.Force, "force", false, "Auto-overwrite on conversion conflicts") + + return editCmd +} + +func newStatusCmd() *cobra.Command { + return &cobra.Command{ + Use: "status [name]", + Short: "Show progress for a PRD (default: main)", + Args: cobra.MaximumNArgs(1), + RunE: func(c *cobra.Command, args []string) error { + opts := cmd.StatusOptions{} + if len(args) > 0 { + opts.Name = args[0] + } + return cmd.RunStatus(opts) + }, + } +} + +func newListCmd() *cobra.Command { + return &cobra.Command{ + Use: "list", + Short: "List all PRDs with progress", + Args: cobra.NoArgs, + RunE: func(c *cobra.Command, args []string) error { + return cmd.RunList(cmd.ListOptions{}) + }, + } +} + +func newLoginCmd() *cobra.Command { + loginOpts := &cmd.LoginOptions{} + + loginCmd := &cobra.Command{ + Use: "login", + Short: "Authenticate with chiefloop.com", + Args: cobra.NoArgs, + RunE: func(c *cobra.Command, args []string) error { + return cmd.RunLogin(*loginOpts) + }, + } + + loginCmd.Flags().StringVar(&loginOpts.DeviceName, "name", "", "Override device name (default: hostname)") + loginCmd.Flags().StringVar(&loginOpts.SetupToken, "setup-token", "", "One-time setup token for automated auth") - // Parse flags for TUI mode - opts := parseTUIFlags() + return loginCmd +} - // Handle special flags that were parsed - if opts == nil { - // Already handled (--help or --version) - return +func newLogoutCmd() *cobra.Command { + return &cobra.Command{ + Use: "logout", + Short: "Log out and deauthorize this device", + Args: cobra.NoArgs, + RunE: func(c *cobra.Command, args []string) error { + return cmd.RunLogout(cmd.LogoutOptions{}) + }, } +} - // Run the TUI - runTUIWithOptions(opts) +func newServeCmd() *cobra.Command { + serveOpts := &cmd.ServeOptions{} + + serveCmd := &cobra.Command{ + Use: "serve", + Short: "Start headless daemon for web app", + Long: "Starts a headless daemon that connects to chiefloop.com via WebSocket and accepts commands from the web app.", + Args: cobra.NoArgs, + RunE: func(c *cobra.Command, args []string) error { + serveOpts.Version = Version + return cmd.RunServe(*serveOpts) + }, + } + + serveCmd.Flags().StringVar(&serveOpts.Workspace, "workspace", ".", "Path to workspace directory (default: current directory)") + serveCmd.Flags().StringVar(&serveOpts.DeviceName, "name", "", "Override device name for this session") + serveCmd.Flags().StringVar(&serveOpts.LogFile, "log-file", "", "Path to log file (default: stdout)") + serveCmd.Flags().StringVar(&serveOpts.ServerURL, "server-url", "", "Override server URL (default: https://chiefloop.com)") + + return serveCmd +} + +func newUpdateCmd() *cobra.Command { + return &cobra.Command{ + Use: "update", + Short: "Update Chief to the latest version", + Args: cobra.NoArgs, + RunE: func(c *cobra.Command, args []string) error { + return cmd.RunUpdate(cmd.UpdateOptions{ + Version: Version, + }) + }, + } +} + +func newWiggumCmd() *cobra.Command { + return &cobra.Command{ + Use: "wiggum", + Short: "Bake 'em away, toys!", + Hidden: true, + Args: cobra.NoArgs, + Run: func(c *cobra.Command, args []string) { + printWiggum() + }, + } } // findAvailablePRD looks for any available PRD in .chief/prds/ @@ -115,164 +337,6 @@ func listAvailablePRDs() []string { return names } -// parseTUIFlags parses command-line flags for TUI mode -func parseTUIFlags() *TUIOptions { - opts := &TUIOptions{ - PRDPath: "", // Will be resolved later - MaxIterations: 0, // 0 signals dynamic calculation (remaining stories + 5) - NoSound: false, - Verbose: false, - Merge: false, - Force: false, - NoRetry: false, - } - - for i := 1; i < len(os.Args); i++ { - arg := os.Args[i] - - switch { - case arg == "--help" || arg == "-h": - printHelp() - return nil - case arg == "--version" || arg == "-v": - fmt.Printf("chief version %s\n", Version) - return nil - case arg == "--no-sound": - opts.NoSound = true - case arg == "--verbose": - opts.Verbose = true - case arg == "--merge": - opts.Merge = true - case arg == "--force": - opts.Force = true - case arg == "--no-retry": - opts.NoRetry = true - case arg == "--max-iterations" || arg == "-n": - // Next argument should be the number - if i+1 < len(os.Args) { - i++ - n, err := strconv.Atoi(os.Args[i]) - if err != nil { - fmt.Fprintf(os.Stderr, "Error: invalid value for %s: %s\n", arg, os.Args[i]) - os.Exit(1) - } - if n < 1 { - fmt.Fprintf(os.Stderr, "Error: --max-iterations must be at least 1\n") - os.Exit(1) - } - opts.MaxIterations = n - } else { - fmt.Fprintf(os.Stderr, "Error: %s requires a value\n", arg) - os.Exit(1) - } - case strings.HasPrefix(arg, "--max-iterations="): - val := strings.TrimPrefix(arg, "--max-iterations=") - n, err := strconv.Atoi(val) - if err != nil { - fmt.Fprintf(os.Stderr, "Error: invalid value for --max-iterations: %s\n", val) - os.Exit(1) - } - if n < 1 { - fmt.Fprintf(os.Stderr, "Error: --max-iterations must be at least 1\n") - os.Exit(1) - } - opts.MaxIterations = n - case strings.HasPrefix(arg, "-n="): - val := strings.TrimPrefix(arg, "-n=") - n, err := strconv.Atoi(val) - if err != nil { - fmt.Fprintf(os.Stderr, "Error: invalid value for -n: %s\n", val) - os.Exit(1) - } - if n < 1 { - fmt.Fprintf(os.Stderr, "Error: -n must be at least 1\n") - os.Exit(1) - } - opts.MaxIterations = n - case strings.HasPrefix(arg, "-"): - // Unknown flag - fmt.Fprintf(os.Stderr, "Error: unknown flag: %s\n", arg) - fmt.Fprintf(os.Stderr, "Run 'chief --help' for usage.\n") - os.Exit(1) - default: - // Positional argument: PRD name or path - if strings.HasSuffix(arg, ".json") || strings.HasSuffix(arg, "/") { - opts.PRDPath = arg - } else { - // Treat as PRD name - opts.PRDPath = fmt.Sprintf(".chief/prds/%s/prd.json", arg) - } - } - } - - return opts -} - -func runNew() { - opts := cmd.NewOptions{} - - // Parse arguments: chief new [name] [context...] - if len(os.Args) > 2 { - opts.Name = os.Args[2] - } - if len(os.Args) > 3 { - opts.Context = strings.Join(os.Args[3:], " ") - } - - if err := cmd.RunNew(opts); err != nil { - fmt.Fprintf(os.Stderr, "Error: %v\n", err) - os.Exit(1) - } -} - -func runEdit() { - opts := cmd.EditOptions{} - - // Parse arguments: chief edit [name] [--merge] [--force] - for i := 2; i < len(os.Args); i++ { - arg := os.Args[i] - switch arg { - case "--merge": - opts.Merge = true - case "--force": - opts.Force = true - default: - // If not a flag, treat as PRD name (first non-flag arg) - if opts.Name == "" && !strings.HasPrefix(arg, "-") { - opts.Name = arg - } - } - } - - if err := cmd.RunEdit(opts); err != nil { - fmt.Fprintf(os.Stderr, "Error: %v\n", err) - os.Exit(1) - } -} - -func runStatus() { - opts := cmd.StatusOptions{} - - // Parse arguments: chief status [name] - if len(os.Args) > 2 && !strings.HasPrefix(os.Args[2], "-") { - opts.Name = os.Args[2] - } - - if err := cmd.RunStatus(opts); err != nil { - fmt.Fprintf(os.Stderr, "Error: %v\n", err) - os.Exit(1) - } -} - -func runList() { - opts := cmd.ListOptions{} - - if err := cmd.RunList(opts); err != nil { - fmt.Fprintf(os.Stderr, "Error: %v\n", err) - os.Exit(1) - } -} - func runTUIWithOptions(opts *TUIOptions) { prdPath := opts.PRDPath @@ -437,60 +501,6 @@ func runTUIWithOptions(opts *TUIOptions) { } } -func printHelp() { - fmt.Println(`Chief - Autonomous PRD Agent - -Usage: - chief [options] [<name>|<path/to/prd.json>] - chief <command> [arguments] - -Commands: - new [name] [context] Create a new PRD interactively - edit [name] [options] Edit an existing PRD interactively - status [name] Show progress for a PRD (default: main) - list List all PRDs with progress - help Show this help message - -Global Options: - --max-iterations N, -n N Set maximum iterations (default: dynamic) - --no-sound Disable completion sound notifications - --no-retry Disable auto-retry on Claude crashes - --verbose Show raw Claude output in log - --merge Auto-merge progress on conversion conflicts - --force Auto-overwrite on conversion conflicts - --help, -h Show this help message - --version, -v Show version number - -Edit Options: - --merge Auto-merge progress on conversion conflicts - --force Auto-overwrite on conversion conflicts - -Positional Arguments: - <name> PRD name (loads .chief/prds/<name>/prd.json) - <path/to/prd.json> Direct path to a prd.json file - -Examples: - chief Launch TUI with default PRD (.chief/prds/main/) - chief auth Launch TUI with named PRD (.chief/prds/auth/) - chief ./my-prd.json Launch TUI with specific PRD file - chief -n 20 Launch with 20 max iterations - chief --max-iterations=5 auth - Launch auth PRD with 5 max iterations - chief --no-sound Launch TUI without audio notifications - chief --verbose Launch with raw Claude output visible - chief new Create PRD in .chief/prds/main/ - chief new auth Create PRD in .chief/prds/auth/ - chief new auth "JWT authentication for REST API" - Create PRD with context hint - chief edit Edit PRD in .chief/prds/main/ - chief edit auth Edit PRD in .chief/prds/auth/ - chief edit auth --merge Edit and auto-merge progress - chief status Show progress for default PRD - chief status auth Show progress for auth PRD - chief list List all PRDs with progress - chief --version Show version number`) -} - func printWiggum() { // ANSI color codes blue := "\033[34m" diff --git a/deploy/chief.service b/deploy/chief.service new file mode 100644 index 0000000..80af505 --- /dev/null +++ b/deploy/chief.service @@ -0,0 +1,25 @@ +[Unit] +Description=Chief - Autonomous PRD Agent +Documentation=https://github.com/MiniCodeMonkey/chief +After=network-online.target +Wants=network-online.target +ConditionPathExists=/home/chief/.chief/credentials.yaml + +[Service] +Type=simple +User=chief +Group=chief +WorkingDirectory=/home/chief +ExecStart=/usr/local/bin/chief serve --workspace /home/chief/projects --log-file /home/chief/.chief/serve.log +Restart=always +RestartSec=5 +Environment=HOME=/home/chief + +# Security hardening +NoNewPrivileges=true +ProtectSystem=strict +ProtectHome=false +ReadWritePaths=/home/chief + +[Install] +WantedBy=multi-user.target diff --git a/deploy/cloud-init.sh b/deploy/cloud-init.sh new file mode 100755 index 0000000..abc0780 --- /dev/null +++ b/deploy/cloud-init.sh @@ -0,0 +1,263 @@ +#!/bin/bash +# Chief Cloud-Init Setup Script +# https://github.com/MiniCodeMonkey/chief +# +# This script sets up a VPS to run Chief as a systemd service. +# It is designed to be run via cloud-init during VPS provisioning. +# +# Usage (cloud-init user-data): +# #!/bin/bash +# curl -fsSL https://raw.githubusercontent.com/MiniCodeMonkey/chief/main/deploy/cloud-init.sh | bash +# +# With setup token (automated auth): +# #!/bin/bash +# curl -fsSL https://raw.githubusercontent.com/MiniCodeMonkey/chief/main/deploy/cloud-init.sh | CHIEF_SETUP_TOKEN=<token> bash +# +# What this script does: +# 1. Creates a 'chief' user +# 2. Installs the Chief binary +# 3. Installs Claude Code CLI (via npm) +# 4. Creates the workspace directory +# 5. Writes and enables the systemd unit file +# +# After this script runs, you must: +# 1. SSH into the server +# 2. Run: sudo -u chief chief login (skipped if CHIEF_SETUP_TOKEN is set) +# 3. Authenticate Claude Code: sudo -u chief claude +# 4. Start the service: sudo systemctl start chief +# +# This script is idempotent (safe to run multiple times). + +set -euo pipefail + +GITHUB_REPO="MiniCodeMonkey/chief" +CHIEF_USER="chief" +CHIEF_HOME="/home/${CHIEF_USER}" +WORKSPACE_DIR="${CHIEF_HOME}/projects" +BINARY_PATH="/usr/local/bin/chief" +SERVICE_FILE="/etc/systemd/system/chief.service" + +info() { + echo "==> $1" +} + +warn() { + echo "WARNING: $1" +} + +error() { + echo "ERROR: $1" >&2 + exit 1 +} + +# Create chief user if it doesn't exist +create_user() { + if id "${CHIEF_USER}" &>/dev/null; then + info "User '${CHIEF_USER}' already exists" + else + info "Creating user '${CHIEF_USER}'..." + useradd --create-home --shell /bin/bash "${CHIEF_USER}" + fi +} + +# Install Chief binary +install_chief() { + info "Installing Chief binary..." + curl -fsSL "https://raw.githubusercontent.com/${GITHUB_REPO}/main/install.sh" | CHIEF_INSTALL_DIR=/usr/local/bin sh +} + +# Install Node.js and Claude Code CLI +install_claude_code() { + if command -v claude &>/dev/null; then + info "Claude Code CLI already installed" + return 0 + fi + + # Install Node.js if not present + if ! command -v node &>/dev/null; then + info "Installing Node.js..." + if command -v apt-get &>/dev/null; then + curl -fsSL https://deb.nodesource.com/setup_lts.x | bash - + apt-get install -y nodejs + elif command -v dnf &>/dev/null; then + curl -fsSL https://rpm.nodesource.com/setup_lts.x | bash - + dnf install -y nodejs + elif command -v yum &>/dev/null; then + curl -fsSL https://rpm.nodesource.com/setup_lts.x | bash - + yum install -y nodejs + else + warn "Could not install Node.js automatically. Please install it manually." + return 1 + fi + fi + + info "Installing Claude Code CLI..." + npm install -g @anthropic-ai/claude-code +} + +# Create workspace directory +create_workspace() { + if [ -d "${WORKSPACE_DIR}" ]; then + info "Workspace directory already exists: ${WORKSPACE_DIR}" + else + info "Creating workspace directory: ${WORKSPACE_DIR}" + mkdir -p "${WORKSPACE_DIR}" + fi + chown -R "${CHIEF_USER}:${CHIEF_USER}" "${WORKSPACE_DIR}" +} + +# Create .chief config directory +create_config_dir() { + local config_dir="${CHIEF_HOME}/.chief" + if [ -d "${config_dir}" ]; then + info "Config directory already exists: ${config_dir}" + else + info "Creating config directory: ${config_dir}" + mkdir -p "${config_dir}" + fi + chown -R "${CHIEF_USER}:${CHIEF_USER}" "${config_dir}" +} + +# Install and enable systemd service +install_service() { + info "Installing systemd service..." + + cat > "${SERVICE_FILE}" <<'UNIT' +[Unit] +Description=Chief - Autonomous PRD Agent +Documentation=https://github.com/MiniCodeMonkey/chief +After=network-online.target +Wants=network-online.target +ConditionPathExists=/home/chief/.chief/credentials.yaml + +[Service] +Type=simple +User=chief +Group=chief +WorkingDirectory=/home/chief +ExecStart=/usr/local/bin/chief serve --workspace /home/chief/projects --log-file /home/chief/.chief/serve.log +Restart=always +RestartSec=5 +Environment=HOME=/home/chief + +# Security hardening +NoNewPrivileges=true +ProtectSystem=strict +ProtectHome=false +ReadWritePaths=/home/chief + +[Install] +WantedBy=multi-user.target +UNIT + + systemctl daemon-reload + systemctl enable chief.service + info "Service enabled (but NOT started — authentication required first)" +} + +# Handle setup token if provided +handle_setup_token() { + if [ -z "${CHIEF_SETUP_TOKEN:-}" ]; then + return 0 + fi + + info "Setup token provided, configuring automated authentication..." + + # Write the setup token to a temporary file readable only by the chief user + local token_file="/tmp/chief-setup-token" + echo "${CHIEF_SETUP_TOKEN}" > "${token_file}" + chown "${CHIEF_USER}:${CHIEF_USER}" "${token_file}" + chmod 600 "${token_file}" + + # Create a one-shot systemd service that exchanges the token + cat > /etc/systemd/system/chief-setup.service <<SETUP_UNIT +[Unit] +Description=Chief Setup Token Exchange +After=network-online.target +Wants=network-online.target + +[Service] +Type=oneshot +User=${CHIEF_USER} +Group=${CHIEF_USER} +Environment=HOME=${CHIEF_HOME} +ExecStart=/bin/bash -c '/usr/local/bin/chief login --setup-token "\$(cat /tmp/chief-setup-token)" && rm -f /tmp/chief-setup-token && systemctl start chief' +ExecStartPost=/bin/bash -c 'rm -f /tmp/chief-setup-token' +RemainAfterExit=no + +[Install] +WantedBy=multi-user.target +SETUP_UNIT + + systemctl daemon-reload + systemctl enable chief-setup.service + systemctl start chief-setup.service || { + warn "Setup token exchange failed. Please authenticate manually: sudo -u ${CHIEF_USER} chief login" + rm -f "${token_file}" + } +} + +# Print post-deploy instructions +print_instructions() { + echo "" + echo "============================================" + echo " Chief setup complete!" + echo "============================================" + echo "" + if [ -n "${CHIEF_SETUP_TOKEN:-}" ]; then + echo "Chief authentication was configured automatically." + echo "" + echo "Next steps:" + echo "" + echo " 1. SSH into this server" + echo "" + echo " 2. Authenticate Claude Code:" + echo " sudo -u ${CHIEF_USER} claude" + echo "" + echo " 3. Check service status:" + echo " sudo systemctl status chief" + echo " sudo journalctl -u chief -f" + else + echo "Next steps:" + echo "" + echo " 1. SSH into this server" + echo "" + echo " 2. Authenticate Chief with chiefloop.com:" + echo " sudo -u ${CHIEF_USER} chief login" + echo "" + echo " 3. Authenticate Claude Code:" + echo " sudo -u ${CHIEF_USER} claude" + echo "" + echo " 4. Start the Chief service:" + echo " sudo systemctl start chief" + echo "" + echo " 5. Check service status:" + echo " sudo systemctl status chief" + echo " sudo journalctl -u chief -f" + fi + echo "" + echo "============================================" +} + +# Main +main() { + info "Starting Chief setup..." + + # Check if running as root + if [ "$(id -u)" -ne 0 ]; then + error "This script must be run as root (or via cloud-init)" + fi + + create_user + install_chief + install_claude_code + create_workspace + create_config_dir + install_service + handle_setup_token + print_instructions + + info "Chief setup complete!" +} + +main "$@" diff --git a/embed/edit_prompt.txt b/embed/edit_prompt.txt index 46ddf53..08bd16f 100644 --- a/embed/edit_prompt.txt +++ b/embed/edit_prompt.txt @@ -6,57 +6,133 @@ You are helping edit an existing Product Requirements Document (PRD). Edit the existing `prd.md` file at `{{PRD_DIR}}/prd.md` based on the user's requested changes. -## Current PRD +**Important:** Your ONLY job is to edit the `prd.md` file. Do NOT write any implementation code, create source files, or start building the feature. You are a PRD editor, not an implementer. -The current PRD is located at `{{PRD_DIR}}/prd.md`. Read it first to understand the existing structure and content. +--- -## PRD Format +## Step 1: Read the Existing PRD -The PRD should follow this structure: +Read the current PRD at `{{PRD_DIR}}/prd.md` first to understand the existing structure, stories, and progress. -```markdown -# [Project Name] +Note: Some stories may have already been completed. Changes will be merged with progress when the PRD is converted. + +--- + +## Step 2: Clarifying Questions + +Ask clarifying questions about the desired changes using lettered options. This lets users respond quickly with "1A, 2C" instead of writing paragraphs. + +Focus on understanding: + +- **What exactly should change?** Add stories, remove stories, modify scope, update criteria? +- **Why the change?** New requirements, scope reduction, pivot, bug discovered? +- **Impact on existing stories?** Do priorities or dependencies need to shift? + +### Format Questions Like This: + +``` +1. What kind of change are you making? + A. Adding new user stories + B. Modifying existing stories or criteria + C. Removing stories / reducing scope + D. Restructuring priorities or dependencies + +2. Should this affect the success metrics? + A. Yes, update them to reflect the changes + B. No, keep current metrics + C. Not sure, let's discuss +``` + +Remember to indent the lettered options under each question. + +If the requested changes are straightforward and unambiguous, you may skip questions and proceed directly to editing. + +--- + +## Step 3: Edit the PRD + +Make the requested changes while preserving the overall structure. The PRD should maintain these sections: -## Overview -[Brief description of the project/feature] +### 1. Introduction/Overview +Brief description of the feature and the problem it solves. -## User Stories +### 2. Goals +Specific, measurable objectives (bullet list). -### US-001: [Story Title] +### 3. User Stories +Each story needs: +- **Title:** Short descriptive name +- **Description:** "As a [user], I want [feature] so that [benefit]" +- **Acceptance Criteria:** Verifiable checklist of what "done" means + +**Format:** +```markdown +### US-001: [Title] **Priority:** 1 -**Description:** As a [user type], I want [goal] so that [benefit]. +**Description:** As a [user], I want [feature] so that [benefit]. **Acceptance Criteria:** -- [ ] Criterion 1 -- [ ] Criterion 2 -- [ ] Criterion 3 - -### US-002: [Story Title] -**Priority:** 2 -... +- [ ] Specific verifiable criterion +- [ ] Another criterion ``` -## Guidelines +**Editing Guidelines:** +- **Preserve story IDs** - Keep existing US-XXX IDs when modifying stories. +- **Add new stories** with the next available ID number. +- **Update priorities** if story order needs to change. +- Each story should be small enough to implement in one focused coding session. +- Acceptance criteria must be verifiable, not vague. "Works correctly" is bad. "Button shows confirmation dialog before deleting" is good. -1. **Preserve story IDs** - Keep existing US-XXX IDs when modifying stories -2. **Add new stories** with the next available ID number -3. **Keep stories atomic** - each should be completable in one coding session -4. **Update priorities** if story order needs to change -5. **Maintain clear acceptance criteria** - specific, testable requirements -6. **Consider dependencies** - ensure story order makes sense +### 4. Functional Requirements +Numbered list (FR-1, FR-2, etc.). Be explicit and unambiguous. -## Important +### 5. Non-Goals (Out of Scope) +What this feature will NOT include. Update if scope changes. -- **Your ONLY job is to edit the `prd.md` file.** Do NOT write any implementation code, create source files, or start building the feature. You are a PRD editor, not an implementer. -- Once the edits are complete, tell the user to type `/exit` to finish. Chief will automatically convert the updated PRD. +### 6. Design Considerations (Optional) +UI/UX requirements, mockup links, existing components to reuse. -## Instructions +### 7. Technical Considerations (Optional) +Constraints, dependencies, integration points, performance requirements. -1. Read the existing PRD file at `{{PRD_DIR}}/prd.md` first -2. Ask clarifying questions about the desired changes -3. Make the requested edits while preserving the overall structure -4. Save the updated PRD file +### 8. Success Metrics +Specific, measurable indicators of success. Always review these when making changes — if the scope changes, the metrics likely should too. -Note: Some stories may have already been completed. Changes will be merged with progress when the PRD is converted. +### 9. Open Questions +Update with any new questions raised by the changes. + +--- + +## Writing Quality + +The PRD reader may be a junior developer or AI agent. Therefore: + +- Be explicit and unambiguous +- Avoid jargon or explain it when used +- Provide enough detail to understand purpose and core logic +- Number requirements for easy reference +- Use concrete examples where helpful + +--- + +## Checklist + +Before saving the edited PRD, verify: + +- [ ] Asked clarifying questions if changes were ambiguous +- [ ] Existing story IDs are preserved (not renumbered) +- [ ] New stories use the next available ID +- [ ] User stories are small, atomic, and specific +- [ ] Acceptance criteria are verifiable (not vague) +- [ ] Functional requirements are numbered and unambiguous +- [ ] Non-goals updated if scope changed +- [ ] Success metrics reviewed and updated if needed +- [ ] The file is ready for conversion to prd.json + +--- + +## Final Step + +Once the edits are complete, tell the user to type `/exit` to finish. Chief will automatically convert the updated PRD. Start by reading the existing PRD and understanding what changes the user wants to make. diff --git a/embed/init_prompt.txt b/embed/init_prompt.txt index f33d241..cab267d 100644 --- a/embed/init_prompt.txt +++ b/embed/init_prompt.txt @@ -6,55 +6,236 @@ You are helping create a Product Requirements Document (PRD) for a software feat Create a `prd.md` file at `{{PRD_DIR}}/prd.md` with a structured PRD based on the user's description. -## PRD Format +**Important:** Do NOT start implementing. Your ONLY job is to create the `prd.md` file. Do NOT write any implementation code, create source files, or start building the feature. You are a PRD writer, not an implementer. -The PRD should follow this structure: +## Context + +{{CONTEXT}} + +--- + +## Step 1: Clarifying Questions + +Before writing the PRD, ask 3-5 essential clarifying questions where the user's initial prompt is ambiguous. Focus on: + +- **Problem/Goal:** What problem does this solve? Why does it matter? +- **Core Functionality:** What are the key actions or behaviors? +- **Scope/Boundaries:** What should it NOT do? +- **Success Criteria:** How do we know it's done and working? + +### Format Questions With Lettered Options + +This lets users respond quickly with "1A, 2C, 3B" instead of writing paragraphs. + +``` +1. What is the primary goal of this feature? + A. Improve user onboarding experience + B. Increase user retention + C. Reduce support burden + D. Other: [please specify] + +2. Who is the target user? + A. New users only + B. Existing users only + C. All users + D. Admin users only + +3. What is the scope? + A. Minimal viable version + B. Full-featured implementation + C. Just the backend/API + D. Just the UI +``` + +Remember to indent the lettered options under each question. + +--- + +## Step 2: Write the PRD + +After incorporating the user's answers, generate the PRD with these sections: + +### 1. Introduction/Overview +Brief description of the feature and the problem it solves. + +### 2. Goals +Specific, measurable objectives (bullet list). + +### 3. User Stories +Each story needs: +- **Title:** Short descriptive name +- **Description:** "As a [user], I want [feature] so that [benefit]" +- **Acceptance Criteria:** Verifiable checklist of what "done" means + +Each story should be small enough to implement in one focused coding session. + +**Format:** +```markdown +### US-001: [Title] +**Priority:** 1 +**Description:** As a [user], I want [feature] so that [benefit]. + +**Acceptance Criteria:** +- [ ] Specific verifiable criterion +- [ ] Another criterion +``` + +**Guidelines:** +- Lower priority numbers = higher priority. Build foundations first. +- Order stories so earlier ones enable later ones (consider dependencies). +- Acceptance criteria must be verifiable, not vague. "Works correctly" is bad. "Button shows confirmation dialog before deleting" is good. +- Include quality stories for tests and documentation as needed. + +### 4. Functional Requirements +Numbered list of specific functionalities: +- "FR-1: The system must allow users to..." +- "FR-2: When a user clicks X, the system must..." + +Be explicit and unambiguous. + +### 5. Non-Goals (Out of Scope) +What this feature will NOT include. Critical for managing scope and preventing creep. + +### 6. Design Considerations (Optional) +- UI/UX requirements +- Links to mockups if available +- Relevant existing components to reuse + +### 7. Technical Considerations (Optional) +- Known constraints or dependencies +- Integration points with existing systems +- Performance requirements + +### 8. Success Metrics +How will success be measured? Be specific: +- "Reduce time to complete X by 50%" +- "Increase conversion rate by 10%" +- "Users can accomplish Y in under 3 clicks" + +### 9. Open Questions +Remaining questions or areas needing further clarification. + +--- + +## Writing Quality + +The PRD reader may be a junior developer or AI agent. Therefore: + +- Be explicit and unambiguous +- Avoid jargon or explain it when used +- Provide enough detail to understand purpose and core logic +- Number requirements for easy reference +- Use concrete examples where helpful + +--- + +## Example PRD ```markdown -# [Project Name] +# PRD: Task Priority System -## Overview -[Brief description of the project/feature] +## Introduction + +Add priority levels to tasks so users can focus on what matters most. Tasks can be marked as high, medium, or low priority, with visual indicators and filtering to help users manage their workload effectively. + +## Goals + +- Allow assigning priority (high/medium/low) to any task +- Provide clear visual differentiation between priority levels +- Enable filtering and sorting by priority +- Default new tasks to medium priority ## User Stories -### US-001: [Story Title] +### US-001: Add priority field to database **Priority:** 1 -**Description:** As a [user type], I want [goal] so that [benefit]. +**Description:** As a developer, I need to store task priority so it persists across sessions. **Acceptance Criteria:** -- [ ] Criterion 1 -- [ ] Criterion 2 -- [ ] Criterion 3 +- [ ] Add priority column to tasks table: 'high' | 'medium' | 'low' (default 'medium') +- [ ] Generate and run migration successfully +- [ ] Typecheck passes -### US-002: [Story Title] +### US-002: Display priority indicator on task cards **Priority:** 2 -... -``` +**Description:** As a user, I want to see task priority at a glance so I know what needs attention first. -## Guidelines +**Acceptance Criteria:** +- [ ] Each task card shows colored priority badge (red=high, yellow=medium, gray=low) +- [ ] Priority visible without hovering or clicking +- [ ] Typecheck passes -1. **Break down the feature** into small, implementable user stories -2. **Prioritize stories** - lower numbers = higher priority, build foundations first -3. **Keep stories atomic** - each should be completable in one coding session -4. **Include clear acceptance criteria** - specific, testable requirements -5. **Consider dependencies** - order stories so earlier ones enable later ones -6. **Include quality stories** - add stories for tests, documentation as needed +### US-003: Add priority selector to task edit +**Priority:** 3 +**Description:** As a user, I want to change a task's priority when editing it. -## Context +**Acceptance Criteria:** +- [ ] Priority dropdown in task edit modal +- [ ] Shows current priority as selected +- [ ] Saves immediately on selection change +- [ ] Typecheck passes -{{CONTEXT}} +### US-004: Filter tasks by priority +**Priority:** 4 +**Description:** As a user, I want to filter the task list to see only high-priority items when I'm focused. + +**Acceptance Criteria:** +- [ ] Filter dropdown with options: All | High | Medium | Low +- [ ] Filter persists in URL params +- [ ] Empty state message when no tasks match filter +- [ ] Typecheck passes + +## Functional Requirements + +- FR-1: Add `priority` field to tasks table ('high' | 'medium' | 'low', default 'medium') +- FR-2: Display colored priority badge on each task card +- FR-3: Include priority selector in task edit modal +- FR-4: Add priority filter dropdown to task list header +- FR-5: Sort by priority within each status column (high > medium > low) + +## Non-Goals + +- No priority-based notifications or reminders +- No automatic priority assignment based on due date +- No priority inheritance for subtasks + +## Technical Considerations + +- Reuse existing badge component with color variants +- Filter state managed via URL search params +- Priority stored in database, not computed + +## Success Metrics + +- Users can change priority in under 2 clicks +- High-priority tasks immediately visible at top of lists +- No regression in task list performance + +## Open Questions + +- Should priority affect task ordering within a column? +- Should we add keyboard shortcuts for priority changes? +``` + +--- + +## Checklist + +Before saving the PRD, verify: -## Important +- [ ] Asked clarifying questions with lettered options +- [ ] Incorporated user's answers into the PRD +- [ ] User stories are small, atomic, and specific +- [ ] Acceptance criteria are verifiable (not vague) +- [ ] Functional requirements are numbered and unambiguous +- [ ] Non-goals section defines clear boundaries +- [ ] Success metrics are specific and measurable +- [ ] The file is ready for conversion to prd.json -- **Your ONLY job is to create the `prd.md` file.** Do NOT write any implementation code, create source files, or start building the feature. You are a PRD writer, not an implementer. -- Once the PRD file is written, tell the user to type `/exit` to finish. Chief will automatically convert it to the format needed for implementation. +--- -## Instructions +## Final Step -1. Ask clarifying questions if the feature is unclear -2. Propose a high-level breakdown before writing the full PRD -3. Write the complete `prd.md` file -4. The file should be ready for conversion to prd.json +Once the PRD file is written, tell the user to type `/exit` to finish. Chief will automatically convert it to the format needed for implementation. -Start by understanding what the user wants to build, then create the PRD. +Start by understanding what the user wants to build. Ask your clarifying questions first, then create the PRD. diff --git a/go.mod b/go.mod index 99005a4..60a2df9 100644 --- a/go.mod +++ b/go.mod @@ -3,14 +3,17 @@ module github.com/minicodemonkey/chief go 1.24.0 require ( + github.com/alecthomas/chroma/v2 v2.23.1 github.com/charmbracelet/bubbletea v1.3.10 github.com/charmbracelet/lipgloss v1.1.0 github.com/fsnotify/fsnotify v1.9.0 + github.com/gorilla/websocket v1.5.3 github.com/hajimehoshi/oto/v2 v2.4.2 + github.com/spf13/cobra v1.10.2 + gopkg.in/yaml.v3 v3.0.1 ) require ( - github.com/alecthomas/chroma/v2 v2.23.1 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect github.com/charmbracelet/x/ansi v0.10.1 // indirect @@ -19,6 +22,7 @@ require ( github.com/dlclark/regexp2 v1.11.5 // indirect github.com/ebitengine/purego v0.4.1 // indirect github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect + github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/lucasb-eyer/go-colorful v1.2.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-localereader v0.0.1 // indirect @@ -27,8 +31,8 @@ require ( github.com/muesli/cancelreader v0.2.2 // indirect github.com/muesli/termenv v0.16.0 // indirect github.com/rivo/uniseg v0.4.7 // indirect + github.com/spf13/pflag v1.0.9 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect golang.org/x/sys v0.36.0 // indirect golang.org/x/text v0.3.8 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index fef6007..ab8d6fe 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,9 @@ +github.com/alecthomas/assert/v2 v2.11.0 h1:2Q9r3ki8+JYXvGsDyBXwH3LcJ+WK5D0gc5E8vS6K3D0= +github.com/alecthomas/assert/v2 v2.11.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k= github.com/alecthomas/chroma/v2 v2.23.1 h1:nv2AVZdTyClGbVQkIzlDm/rnhk1E9bU9nXwmZ/Vk/iY= github.com/alecthomas/chroma/v2 v2.23.1/go.mod h1:NqVhfBR0lte5Ouh3DcthuUCTUpDC9cxBOfyMbMQPs3o= +github.com/alecthomas/repr v0.5.2 h1:SU73FTI9D1P5UNtvseffFSGmdNci/O6RsqzeXJtP0Qs= +github.com/alecthomas/repr v0.5.2/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4= github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw= @@ -14,6 +18,7 @@ github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd h1:vy0G github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs= github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ= github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg= +github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ= github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/ebitengine/purego v0.4.1 h1:atcZEBdukuoClmy7TI89amtqAsJUzDQyY/JU7HaK+io= @@ -22,8 +27,14 @@ github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6 github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM= github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/hajimehoshi/oto/v2 v2.4.2 h1:uPZq5xEnOv8nIy4eMoDkakLb99YxoNv5XHL7Mm6zHwU= github.com/hajimehoshi/oto/v2 v2.4.2/go.mod h1:tINhdh4kCNJ8N19zqp0Lk/wMFv5WQJYkqnnEZ5W5WtE= +github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= +github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= @@ -41,8 +52,14 @@ github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3 github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU= +github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4= +github.com/spf13/pflag v1.0.9 h1:9exaQaMOCwffKiiiYk6/BndUBv+iRViNW+4lEMi0PvY= +github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= +go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= golang.org/x/exp v0.0.0-20220909182711-5c715a9e8561 h1:MDc5xs78ZrZr3HMQugiXOAkSZtfTpbJLDr/lwfgO53E= golang.org/x/exp v0.0.0-20220909182711-5c715a9e8561/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE= golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -51,6 +68,7 @@ golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k= golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/text v0.3.8 h1:nAL+RVCQ9uMn3vJZbV+MRnydTJFPf8qqY42YiA6MrqY= golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/auth/auth.go b/internal/auth/auth.go new file mode 100644 index 0000000..b699a8f --- /dev/null +++ b/internal/auth/auth.go @@ -0,0 +1,255 @@ +package auth + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "net/http" + "os" + "path/filepath" + "sync" + "time" + + "gopkg.in/yaml.v3" +) + +const credentialsFile = "credentials.yaml" + +const defaultBaseURL = "https://chiefloop.com" + +// ErrNotLoggedIn is returned when no credentials file exists. +var ErrNotLoggedIn = errors.New("not logged in — run 'chief login' first") + +// ErrSessionExpired is returned when the refresh token is revoked or expired. +var ErrSessionExpired = errors.New("session expired — run 'chief login' again") + +// refreshMu protects concurrent token refresh operations. +var refreshMu sync.Mutex + +// Credentials holds authentication token data for chiefloop.com. +type Credentials struct { + AccessToken string `yaml:"access_token"` + RefreshToken string `yaml:"refresh_token"` + ExpiresAt time.Time `yaml:"expires_at"` + DeviceName string `yaml:"device_name"` + User string `yaml:"user"` + WSURL string `yaml:"ws_url,omitempty"` +} + +// IsExpired returns true if the access token has expired. +func (c *Credentials) IsExpired() bool { + return time.Now().After(c.ExpiresAt) +} + +// IsNearExpiry returns true if the access token will expire within the given duration. +func (c *Credentials) IsNearExpiry(d time.Duration) bool { + return time.Now().Add(d).After(c.ExpiresAt) +} + +// credentialsDir returns the path to the ~/.chief directory. +func credentialsDir() (string, error) { + home, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("determining home directory: %w", err) + } + return filepath.Join(home, ".chief"), nil +} + +// credentialsPath returns the full path to ~/.chief/credentials.yaml. +func credentialsPath() (string, error) { + dir, err := credentialsDir() + if err != nil { + return "", err + } + return filepath.Join(dir, credentialsFile), nil +} + +// LoadCredentials reads credentials from ~/.chief/credentials.yaml. +// Returns ErrNotLoggedIn when the file does not exist. +func LoadCredentials() (*Credentials, error) { + path, err := credentialsPath() + if err != nil { + return nil, err + } + + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return nil, ErrNotLoggedIn + } + return nil, fmt.Errorf("reading credentials: %w", err) + } + + var creds Credentials + if err := yaml.Unmarshal(data, &creds); err != nil { + return nil, fmt.Errorf("parsing credentials: %w", err) + } + + return &creds, nil +} + +// SaveCredentials writes credentials to ~/.chief/credentials.yaml atomically. +// It writes to a temporary file first, then renames it into place. +// The file is created with 0600 permissions (owner read/write only). +func SaveCredentials(creds *Credentials) error { + path, err := credentialsPath() + if err != nil { + return err + } + + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0o755); err != nil { + return fmt.Errorf("creating credentials directory: %w", err) + } + + data, err := yaml.Marshal(creds) + if err != nil { + return fmt.Errorf("marshaling credentials: %w", err) + } + + // Write to temp file in the same directory for atomic rename. + tmp, err := os.CreateTemp(dir, "credentials-*.yaml") + if err != nil { + return fmt.Errorf("creating temp file: %w", err) + } + tmpPath := tmp.Name() + + if err := os.Chmod(tmpPath, 0o600); err != nil { + tmp.Close() + os.Remove(tmpPath) + return fmt.Errorf("setting temp file permissions: %w", err) + } + + if _, err := tmp.Write(data); err != nil { + tmp.Close() + os.Remove(tmpPath) + return fmt.Errorf("writing temp file: %w", err) + } + + if err := tmp.Close(); err != nil { + os.Remove(tmpPath) + return fmt.Errorf("closing temp file: %w", err) + } + + if err := os.Rename(tmpPath, path); err != nil { + os.Remove(tmpPath) + return fmt.Errorf("renaming temp file: %w", err) + } + + return nil +} + +// DeleteCredentials removes the credentials file. +// Returns nil if the file does not exist. +func DeleteCredentials() error { + path, err := credentialsPath() + if err != nil { + return err + } + + if err := os.Remove(path); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("removing credentials: %w", err) + } + + return nil +} + +// refreshResponse is the response from the token refresh endpoint. +type refreshResponse struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int `json:"expires_in"` + WSURL string `json:"ws_url"` + Error string `json:"error"` +} + +// RefreshToken refreshes the access token using the refresh token. +// It is thread-safe (mutex-protected for concurrent use by serve). +// baseURL can be empty to use the default (https://chiefloop.com). +func RefreshToken(baseURL string) (*Credentials, error) { + refreshMu.Lock() + defer refreshMu.Unlock() + + creds, err := LoadCredentials() + if err != nil { + return nil, err + } + + // If token was already refreshed by another goroutine, return it. + if !creds.IsNearExpiry(5 * time.Minute) { + return creds, nil + } + + if baseURL == "" { + baseURL = defaultBaseURL + } + + reqBody, _ := json.Marshal(map[string]string{ + "grant_type": "refresh_token", + "refresh_token": creds.RefreshToken, + }) + + client := &http.Client{Timeout: 10 * time.Second} + req, err := http.NewRequest(http.MethodPost, baseURL+"/api/oauth/token", bytes.NewReader(reqBody)) + if err != nil { + return nil, fmt.Errorf("creating refresh request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("refreshing token: %w", err) + } + defer resp.Body.Close() + + var tokenResp refreshResponse + if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { + return nil, fmt.Errorf("parsing refresh response: %w", err) + } + + if tokenResp.Error != "" || resp.StatusCode != http.StatusOK { + return nil, ErrSessionExpired + } + + creds.AccessToken = tokenResp.AccessToken + if tokenResp.RefreshToken != "" { + creds.RefreshToken = tokenResp.RefreshToken + } + if tokenResp.WSURL != "" { + creds.WSURL = tokenResp.WSURL + } + creds.ExpiresAt = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) + + if err := SaveCredentials(creds); err != nil { + return nil, fmt.Errorf("saving refreshed credentials: %w", err) + } + + return creds, nil +} + +// RevokeDevice calls the revocation endpoint to deauthorize the device server-side. +// baseURL can be empty to use the default (https://chiefloop.com). +func RevokeDevice(accessToken, baseURL string) error { + if baseURL == "" { + baseURL = defaultBaseURL + } + + reqBody, _ := json.Marshal(map[string]string{ + "access_token": accessToken, + }) + + client := &http.Client{Timeout: 10 * time.Second} + resp, err := client.Post(baseURL+"/api/oauth/revoke", "application/json", bytes.NewReader(reqBody)) + if err != nil { + return fmt.Errorf("revoking device: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("revocation failed: server returned %s", resp.Status) + } + + return nil +} diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go new file mode 100644 index 0000000..1569430 --- /dev/null +++ b/internal/auth/auth_test.go @@ -0,0 +1,610 @@ +package auth + +import ( + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "sync" + "testing" + "time" +) + +// setTestHome overrides HOME so credentials are read/written inside t.TempDir(). +// It returns a cleanup function that restores the original HOME. +func setTestHome(t *testing.T, dir string) { + t.Helper() + orig := os.Getenv("HOME") + t.Setenv("HOME", dir) + t.Cleanup(func() { + os.Setenv("HOME", orig) + }) +} + +func TestLoadCredentials_NotLoggedIn(t *testing.T) { + setTestHome(t, t.TempDir()) + + _, err := LoadCredentials() + if !errors.Is(err, ErrNotLoggedIn) { + t.Fatalf("expected ErrNotLoggedIn, got %v", err) + } +} + +func TestSaveAndLoadCredentials(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + + expires := time.Date(2026, 6, 15, 12, 0, 0, 0, time.UTC) + creds := &Credentials{ + AccessToken: "access-abc", + RefreshToken: "refresh-xyz", + ExpiresAt: expires, + DeviceName: "my-laptop", + User: "user@example.com", + } + + if err := SaveCredentials(creds); err != nil { + t.Fatalf("SaveCredentials failed: %v", err) + } + + loaded, err := LoadCredentials() + if err != nil { + t.Fatalf("LoadCredentials failed: %v", err) + } + + if loaded.AccessToken != "access-abc" { + t.Errorf("expected access_token %q, got %q", "access-abc", loaded.AccessToken) + } + if loaded.RefreshToken != "refresh-xyz" { + t.Errorf("expected refresh_token %q, got %q", "refresh-xyz", loaded.RefreshToken) + } + if !loaded.ExpiresAt.Equal(expires) { + t.Errorf("expected expires_at %v, got %v", expires, loaded.ExpiresAt) + } + if loaded.DeviceName != "my-laptop" { + t.Errorf("expected device_name %q, got %q", "my-laptop", loaded.DeviceName) + } + if loaded.User != "user@example.com" { + t.Errorf("expected user %q, got %q", "user@example.com", loaded.User) + } +} + +func TestSaveCredentials_FilePermissions(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + + creds := &Credentials{ + AccessToken: "token", + } + + if err := SaveCredentials(creds); err != nil { + t.Fatalf("SaveCredentials failed: %v", err) + } + + path := filepath.Join(home, ".chief", "credentials.yaml") + info, err := os.Stat(path) + if err != nil { + t.Fatalf("Stat failed: %v", err) + } + + perm := info.Mode().Perm() + if perm != 0o600 { + t.Errorf("expected permissions 0600, got %04o", perm) + } +} + +func TestSaveCredentials_Atomic(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + + // Save initial credentials. + initial := &Credentials{ + AccessToken: "first", + User: "user1", + } + if err := SaveCredentials(initial); err != nil { + t.Fatalf("SaveCredentials (initial) failed: %v", err) + } + + // Save updated credentials (should atomically replace). + updated := &Credentials{ + AccessToken: "second", + User: "user2", + } + if err := SaveCredentials(updated); err != nil { + t.Fatalf("SaveCredentials (updated) failed: %v", err) + } + + loaded, err := LoadCredentials() + if err != nil { + t.Fatalf("LoadCredentials failed: %v", err) + } + if loaded.AccessToken != "second" { + t.Errorf("expected access_token %q, got %q", "second", loaded.AccessToken) + } + if loaded.User != "user2" { + t.Errorf("expected user %q, got %q", "user2", loaded.User) + } +} + +func TestDeleteCredentials(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + + creds := &Credentials{AccessToken: "to-delete"} + if err := SaveCredentials(creds); err != nil { + t.Fatalf("SaveCredentials failed: %v", err) + } + + if err := DeleteCredentials(); err != nil { + t.Fatalf("DeleteCredentials failed: %v", err) + } + + _, err := LoadCredentials() + if !errors.Is(err, ErrNotLoggedIn) { + t.Fatalf("expected ErrNotLoggedIn after delete, got %v", err) + } +} + +func TestDeleteCredentials_NonExistent(t *testing.T) { + setTestHome(t, t.TempDir()) + + // Deleting when file doesn't exist should not error. + if err := DeleteCredentials(); err != nil { + t.Fatalf("DeleteCredentials on non-existent file failed: %v", err) + } +} + +func TestSaveLoadDeleteCycle(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + + // 1. Not logged in initially. + _, err := LoadCredentials() + if !errors.Is(err, ErrNotLoggedIn) { + t.Fatalf("expected ErrNotLoggedIn initially, got %v", err) + } + + // 2. Save credentials. + creds := &Credentials{ + AccessToken: "cycle-token", + RefreshToken: "cycle-refresh", + ExpiresAt: time.Now().Add(time.Hour), + DeviceName: "test-device", + User: "cycle-user", + } + if err := SaveCredentials(creds); err != nil { + t.Fatalf("SaveCredentials failed: %v", err) + } + + // 3. Load and verify. + loaded, err := LoadCredentials() + if err != nil { + t.Fatalf("LoadCredentials failed: %v", err) + } + if loaded.AccessToken != "cycle-token" { + t.Errorf("expected access_token %q, got %q", "cycle-token", loaded.AccessToken) + } + + // 4. Delete. + if err := DeleteCredentials(); err != nil { + t.Fatalf("DeleteCredentials failed: %v", err) + } + + // 5. Not logged in again. + _, err = LoadCredentials() + if !errors.Is(err, ErrNotLoggedIn) { + t.Fatalf("expected ErrNotLoggedIn after delete, got %v", err) + } +} + +func TestIsExpired(t *testing.T) { + // Expired token. + expired := &Credentials{ + ExpiresAt: time.Now().Add(-time.Hour), + } + if !expired.IsExpired() { + t.Error("expected token to be expired") + } + + // Valid token. + valid := &Credentials{ + ExpiresAt: time.Now().Add(time.Hour), + } + if valid.IsExpired() { + t.Error("expected token to not be expired") + } +} + +func TestIsNearExpiry(t *testing.T) { + // Token expires in 3 minutes — should be near expiry within 5 minutes. + creds := &Credentials{ + ExpiresAt: time.Now().Add(3 * time.Minute), + } + + if !creds.IsNearExpiry(5 * time.Minute) { + t.Error("expected token to be near expiry within 5 minutes") + } + + if creds.IsNearExpiry(1 * time.Minute) { + t.Error("expected token to NOT be near expiry within 1 minute") + } +} + +func TestIsNearExpiry_AlreadyExpired(t *testing.T) { + creds := &Credentials{ + ExpiresAt: time.Now().Add(-time.Hour), + } + + if !creds.IsNearExpiry(5 * time.Minute) { + t.Error("expected already-expired token to be near expiry") + } +} + +func TestSaveCredentials_CreatesDirectory(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + + chiefDir := filepath.Join(home, ".chief") + if _, err := os.Stat(chiefDir); !os.IsNotExist(err) { + t.Fatal("expected .chief directory to not exist initially") + } + + creds := &Credentials{AccessToken: "create-dir"} + if err := SaveCredentials(creds); err != nil { + t.Fatalf("SaveCredentials failed: %v", err) + } + + info, err := os.Stat(chiefDir) + if err != nil { + t.Fatalf("expected .chief directory to exist after save, got: %v", err) + } + if !info.IsDir() { + t.Error("expected .chief to be a directory") + } +} + +func TestRefreshToken_Success(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + + // Save credentials that are near expiry (within 5 minutes) + creds := &Credentials{ + AccessToken: "old-access-token", + RefreshToken: "test-refresh-token", + ExpiresAt: time.Now().Add(2 * time.Minute), // near expiry + DeviceName: "test-device", + User: "user@example.com", + } + if err := SaveCredentials(creds); err != nil { + t.Fatalf("SaveCredentials failed: %v", err) + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/oauth/token" { + if accept := r.Header.Get("Accept"); accept != "application/json" { + t.Errorf("expected Accept header %q, got %q", "application/json", accept) + } + var body map[string]string + json.NewDecoder(r.Body).Decode(&body) + if body["grant_type"] != "refresh_token" { + t.Errorf("expected grant_type %q, got %q", "refresh_token", body["grant_type"]) + } + if body["refresh_token"] != "test-refresh-token" { + t.Errorf("expected refresh_token %q, got %q", "test-refresh-token", body["refresh_token"]) + } + json.NewEncoder(w).Encode(refreshResponse{ + AccessToken: "new-access-token", + RefreshToken: "new-refresh-token", + ExpiresIn: 3600, + }) + return + } + http.NotFound(w, r) + })) + defer server.Close() + + refreshed, err := RefreshToken(server.URL) + if err != nil { + t.Fatalf("RefreshToken failed: %v", err) + } + + if refreshed.AccessToken != "new-access-token" { + t.Errorf("expected access_token %q, got %q", "new-access-token", refreshed.AccessToken) + } + if refreshed.RefreshToken != "new-refresh-token" { + t.Errorf("expected refresh_token %q, got %q", "new-refresh-token", refreshed.RefreshToken) + } + + // Verify credentials were persisted + loaded, err := LoadCredentials() + if err != nil { + t.Fatalf("LoadCredentials failed: %v", err) + } + if loaded.AccessToken != "new-access-token" { + t.Errorf("expected persisted access_token %q, got %q", "new-access-token", loaded.AccessToken) + } +} + +func TestRefreshToken_NotNearExpiry(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + + // Save credentials that are NOT near expiry + creds := &Credentials{ + AccessToken: "valid-token", + RefreshToken: "refresh-token", + ExpiresAt: time.Now().Add(30 * time.Minute), + DeviceName: "test-device", + User: "user@example.com", + } + if err := SaveCredentials(creds); err != nil { + t.Fatalf("SaveCredentials failed: %v", err) + } + + // Server should not be called since token is still valid + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Error("server should not be called when token is not near expiry") + })) + defer server.Close() + + refreshed, err := RefreshToken(server.URL) + if err != nil { + t.Fatalf("RefreshToken failed: %v", err) + } + + if refreshed.AccessToken != "valid-token" { + t.Errorf("expected access_token %q, got %q", "valid-token", refreshed.AccessToken) + } +} + +func TestRefreshToken_SessionExpired(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + + creds := &Credentials{ + AccessToken: "old-token", + RefreshToken: "revoked-refresh-token", + ExpiresAt: time.Now().Add(2 * time.Minute), + DeviceName: "test-device", + User: "user@example.com", + } + if err := SaveCredentials(creds); err != nil { + t.Fatalf("SaveCredentials failed: %v", err) + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/oauth/token" { + json.NewEncoder(w).Encode(refreshResponse{ + Error: "invalid_grant", + }) + return + } + http.NotFound(w, r) + })) + defer server.Close() + + _, err := RefreshToken(server.URL) + if !errors.Is(err, ErrSessionExpired) { + t.Fatalf("expected ErrSessionExpired, got %v", err) + } +} + +func TestRefreshToken_NotLoggedIn(t *testing.T) { + setTestHome(t, t.TempDir()) + + _, err := RefreshToken("") + if !errors.Is(err, ErrNotLoggedIn) { + t.Fatalf("expected ErrNotLoggedIn, got %v", err) + } +} + +func TestRefreshToken_ThreadSafe(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + + creds := &Credentials{ + AccessToken: "old-token", + RefreshToken: "refresh-token", + ExpiresAt: time.Now().Add(2 * time.Minute), + DeviceName: "test-device", + User: "user@example.com", + } + if err := SaveCredentials(creds); err != nil { + t.Fatalf("SaveCredentials failed: %v", err) + } + + var callCount int + var mu sync.Mutex + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/oauth/token" { + mu.Lock() + callCount++ + mu.Unlock() + json.NewEncoder(w).Encode(refreshResponse{ + AccessToken: "new-token", + RefreshToken: "new-refresh", + ExpiresIn: 3600, + }) + return + } + http.NotFound(w, r) + })) + defer server.Close() + + // Run multiple concurrent refreshes + var wg sync.WaitGroup + for i := 0; i < 5; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _, err := RefreshToken(server.URL) + if err != nil { + t.Errorf("RefreshToken failed: %v", err) + } + }() + } + wg.Wait() + + // Only one actual refresh should have hit the server + // (the first one refreshes, subsequent ones see it's no longer near expiry) + mu.Lock() + count := callCount + mu.Unlock() + if count != 1 { + t.Errorf("expected exactly 1 server call, got %d", count) + } +} + +func TestSaveAndLoadCredentials_WithWSURL(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + + creds := &Credentials{ + AccessToken: "access-abc", + RefreshToken: "refresh-xyz", + ExpiresAt: time.Date(2026, 6, 15, 12, 0, 0, 0, time.UTC), + DeviceName: "my-laptop", + User: "user@example.com", + WSURL: "wss://ws-abc123-reverb.laravel.cloud/ws/server", + } + + if err := SaveCredentials(creds); err != nil { + t.Fatalf("SaveCredentials failed: %v", err) + } + + loaded, err := LoadCredentials() + if err != nil { + t.Fatalf("LoadCredentials failed: %v", err) + } + + if loaded.WSURL != "wss://ws-abc123-reverb.laravel.cloud/ws/server" { + t.Errorf("expected ws_url %q, got %q", "wss://ws-abc123-reverb.laravel.cloud/ws/server", loaded.WSURL) + } +} + +func TestRefreshToken_PreservesWSURL(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + + creds := &Credentials{ + AccessToken: "old-access-token", + RefreshToken: "test-refresh-token", + ExpiresAt: time.Now().Add(2 * time.Minute), + DeviceName: "test-device", + User: "user@example.com", + WSURL: "wss://old-host/ws/server", + } + if err := SaveCredentials(creds); err != nil { + t.Fatalf("SaveCredentials failed: %v", err) + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/oauth/token" { + json.NewEncoder(w).Encode(refreshResponse{ + AccessToken: "new-access-token", + RefreshToken: "new-refresh-token", + ExpiresIn: 3600, + WSURL: "wss://new-host/ws/server", + }) + return + } + http.NotFound(w, r) + })) + defer server.Close() + + refreshed, err := RefreshToken(server.URL) + if err != nil { + t.Fatalf("RefreshToken failed: %v", err) + } + + if refreshed.WSURL != "wss://new-host/ws/server" { + t.Errorf("expected ws_url %q, got %q", "wss://new-host/ws/server", refreshed.WSURL) + } + + // Verify persisted + loaded, err := LoadCredentials() + if err != nil { + t.Fatalf("LoadCredentials failed: %v", err) + } + if loaded.WSURL != "wss://new-host/ws/server" { + t.Errorf("expected persisted ws_url %q, got %q", "wss://new-host/ws/server", loaded.WSURL) + } +} + +func TestRefreshToken_WSURLNotReturned_PreservesExisting(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + + creds := &Credentials{ + AccessToken: "old-access-token", + RefreshToken: "test-refresh-token", + ExpiresAt: time.Now().Add(2 * time.Minute), + DeviceName: "test-device", + User: "user@example.com", + WSURL: "wss://existing-host/ws/server", + } + if err := SaveCredentials(creds); err != nil { + t.Fatalf("SaveCredentials failed: %v", err) + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/oauth/token" { + json.NewEncoder(w).Encode(refreshResponse{ + AccessToken: "new-access-token", + RefreshToken: "new-refresh-token", + ExpiresIn: 3600, + // WSURL intentionally omitted + }) + return + } + http.NotFound(w, r) + })) + defer server.Close() + + refreshed, err := RefreshToken(server.URL) + if err != nil { + t.Fatalf("RefreshToken failed: %v", err) + } + + if refreshed.WSURL != "wss://existing-host/ws/server" { + t.Errorf("expected existing ws_url to be preserved %q, got %q", "wss://existing-host/ws/server", refreshed.WSURL) + } +} + +func TestRevokeDevice_Success(t *testing.T) { + var receivedToken string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/oauth/revoke" { + var body map[string]string + json.NewDecoder(r.Body).Decode(&body) + receivedToken = body["access_token"] + w.WriteHeader(http.StatusOK) + return + } + http.NotFound(w, r) + })) + defer server.Close() + + err := RevokeDevice("my-token", server.URL) + if err != nil { + t.Fatalf("RevokeDevice failed: %v", err) + } + if receivedToken != "my-token" { + t.Errorf("expected token %q, got %q", "my-token", receivedToken) + } +} + +func TestRevokeDevice_ServerError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + err := RevokeDevice("my-token", server.URL) + if err == nil { + t.Fatal("expected error for server error response") + } +} diff --git a/internal/cmd/clone.go b/internal/cmd/clone.go new file mode 100644 index 0000000..333ad39 --- /dev/null +++ b/internal/cmd/clone.go @@ -0,0 +1,251 @@ +package cmd + +import ( + "bufio" + "encoding/json" + "fmt" + "log" + "os" + "os/exec" + "path/filepath" + "regexp" + "strconv" + "strings" + + "github.com/minicodemonkey/chief/internal/workspace" + "github.com/minicodemonkey/chief/internal/ws" +) + +// handleCloneRepo handles a clone_repo request. +func handleCloneRepo(sender messageSender, scanner *workspace.Scanner, msg ws.Message) { + var req ws.CloneRepoMessage + if err := json.Unmarshal(msg.Raw, &req); err != nil { + log.Printf("Error parsing clone_repo message: %v", err) + return + } + + if req.URL == "" { + sendError(sender, ws.ErrCodeCloneFailed, "URL is required", msg.ID) + return + } + + workspaceDir := scanner.WorkspacePath() + + // Determine target directory name + dirName := req.DirectoryName + if dirName == "" { + dirName = inferDirName(req.URL) + } + + targetDir := filepath.Join(workspaceDir, dirName) + + // Check if target already exists + if _, err := os.Stat(targetDir); err == nil { + sendError(sender, ws.ErrCodeCloneFailed, + fmt.Sprintf("Directory %q already exists in workspace", dirName), msg.ID) + return + } + + // Run clone in a goroutine so we don't block the message loop + go runClone(sender, scanner, req.URL, dirName, workspaceDir) +} + +// inferDirName extracts a directory name from a git URL. +func inferDirName(url string) string { + // Remove trailing .git + url = strings.TrimSuffix(url, ".git") + // Remove trailing slash + url = strings.TrimRight(url, "/") + // Get the last path component + parts := strings.Split(url, "/") + if len(parts) > 0 { + name := parts[len(parts)-1] + // Also handle ssh-style urls like git@github.com:user/repo + if idx := strings.LastIndex(name, ":"); idx >= 0 { + name = name[idx+1:] + } + if name != "" { + return name + } + } + return "cloned-repo" +} + +// percentPattern matches git clone progress percentages. +var percentPattern = regexp.MustCompile(`(\d+)%`) + +// runClone executes the git clone and streams progress messages. +func runClone(sender messageSender, scanner *workspace.Scanner, url, dirName, workspaceDir string) { + cmd := exec.Command("git", "clone", "--progress", url, dirName) + cmd.Dir = workspaceDir + + // Git clone writes progress to stderr + stderr, err := cmd.StderrPipe() + if err != nil { + sendCloneComplete(sender, url, "", false, fmt.Sprintf("Failed to set up clone: %v", err)) + return + } + + if err := cmd.Start(); err != nil { + sendCloneComplete(sender, url, "", false, fmt.Sprintf("Failed to start clone: %v", err)) + return + } + + // Stream progress from stderr + stderrScanner := bufio.NewScanner(stderr) + stderrScanner.Split(scanGitProgress) + for stderrScanner.Scan() { + line := strings.TrimSpace(stderrScanner.Text()) + if line == "" { + continue + } + + percent := 0 + if matches := percentPattern.FindStringSubmatch(line); len(matches) > 1 { + percent, _ = strconv.Atoi(matches[1]) + } + + sendCloneProgress(sender, url, line, percent) + } + + if err := cmd.Wait(); err != nil { + sendCloneComplete(sender, url, "", false, fmt.Sprintf("Clone failed: %v", err)) + return + } + + // Trigger a rescan so the new project appears immediately + scanner.ScanAndUpdate() + + sendCloneComplete(sender, url, dirName, true, "") +} + +// scanGitProgress is a bufio.SplitFunc that splits on \r or \n, +// since git clone uses \r for progress updates. +func scanGitProgress(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + // Find the first \r or \n + for i, b := range data { + if b == '\r' || b == '\n' { + return i + 1, data[:i], nil + } + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil +} + +// sendCloneProgress sends a clone_progress message. +func sendCloneProgress(sender messageSender, url, progressText string, percent int) { + if sender == nil { + return + } + envelope := ws.NewMessage(ws.TypeCloneProgress) + msg := ws.CloneProgressMessage{ + Type: envelope.Type, + ID: envelope.ID, + Timestamp: envelope.Timestamp, + URL: url, + ProgressText: progressText, + Percent: percent, + } + if err := sender.Send(msg); err != nil { + log.Printf("Error sending clone_progress: %v", err) + } +} + +// sendCloneComplete sends a clone_complete message. +func sendCloneComplete(sender messageSender, url, project string, success bool, errMsg string) { + if sender == nil { + return + } + envelope := ws.NewMessage(ws.TypeCloneComplete) + msg := ws.CloneCompleteMessage{ + Type: envelope.Type, + ID: envelope.ID, + Timestamp: envelope.Timestamp, + URL: url, + Success: success, + Error: errMsg, + Project: project, + } + if err := sender.Send(msg); err != nil { + log.Printf("Error sending clone_complete: %v", err) + } +} + +// handleCreateProject handles a create_project request. +func handleCreateProject(sender messageSender, scanner *workspace.Scanner, msg ws.Message) { + var req ws.CreateProjectMessage + if err := json.Unmarshal(msg.Raw, &req); err != nil { + log.Printf("Error parsing create_project message: %v", err) + return + } + + if req.Name == "" { + sendError(sender, ws.ErrCodeFilesystemError, "Project name is required", msg.ID) + return + } + + workspaceDir := scanner.WorkspacePath() + projectDir := filepath.Join(workspaceDir, req.Name) + + // Check if directory already exists + if _, err := os.Stat(projectDir); err == nil { + sendError(sender, ws.ErrCodeFilesystemError, + fmt.Sprintf("Directory %q already exists", req.Name), msg.ID) + return + } + + // Create the directory + if err := os.MkdirAll(projectDir, 0o755); err != nil { + sendError(sender, ws.ErrCodeFilesystemError, + fmt.Sprintf("Failed to create directory: %v", err), msg.ID) + return + } + + // Optionally run git init + if req.GitInit { + cmd := exec.Command("git", "init", projectDir) + if out, err := cmd.CombinedOutput(); err != nil { + sendError(sender, ws.ErrCodeFilesystemError, + fmt.Sprintf("git init failed: %v\n%s", err, strings.TrimSpace(string(out))), msg.ID) + return + } + } + + // Trigger rescan so new project appears immediately + scanner.ScanAndUpdate() + + // Send updated project_state if git init was done (it's a discoverable project) + if req.GitInit { + project, found := scanner.FindProject(req.Name) + if found { + envelope := ws.NewMessage(ws.TypeProjectState) + psMsg := ws.ProjectStateMessage{ + Type: envelope.Type, + ID: envelope.ID, + Timestamp: envelope.Timestamp, + Project: project, + } + if err := sender.Send(psMsg); err != nil { + log.Printf("Error sending project_state: %v", err) + } + return + } + } + + // Send a simple project_list update for non-git projects + envelope := ws.NewMessage(ws.TypeProjectList) + plMsg := ws.ProjectListMessage{ + Type: envelope.Type, + ID: envelope.ID, + Timestamp: envelope.Timestamp, + Projects: scanner.Projects(), + } + if err := sender.Send(plMsg); err != nil { + log.Printf("Error sending project_list: %v", err) + } +} diff --git a/internal/cmd/clone_test.go b/internal/cmd/clone_test.go new file mode 100644 index 0000000..3a6e260 --- /dev/null +++ b/internal/cmd/clone_test.go @@ -0,0 +1,531 @@ +package cmd + +import ( + "encoding/json" + "os" + "os/exec" + "path/filepath" + "strconv" + "strings" + "testing" + "time" + + "github.com/minicodemonkey/chief/internal/ws" +) + +func TestInferDirName(t *testing.T) { + tests := []struct { + url string + expected string + }{ + {"https://github.com/user/repo.git", "repo"}, + {"https://github.com/user/repo", "repo"}, + {"git@github.com:user/repo.git", "repo"}, + {"https://github.com/user/repo/", "repo"}, + {"https://github.com/user/my-project.git", "my-project"}, + {"git@github.com:org/my-lib.git", "my-lib"}, + } + + for _, tt := range tests { + t.Run(tt.url, func(t *testing.T) { + got := inferDirName(tt.url) + if got != tt.expected { + t.Errorf("inferDirName(%q) = %q, want %q", tt.url, got, tt.expected) + } + }) + } +} + +func TestHandleCloneRepo_Success(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + // Create a bare git repo to clone from + bareRepo := filepath.Join(home, "bare-repo.git") + cmd := exec.Command("git", "init", "--bare", bareRepo) + cmd.Env = append(os.Environ(), "GIT_CONFIG_GLOBAL=/dev/null", "GIT_CONFIG_SYSTEM=/dev/null") + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("git init --bare failed: %v\n%s", err, out) + } + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + // Send clone_repo request + cloneReq := map[string]interface{}{ + "type": "clone_repo", + "id": "req-clone-1", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "url": bareRepo, + } + if err := ms.sendCommand(cloneReq); err != nil { + t.Fatalf("failed to send clone command: %v", err) + } + + // Wait for clone_complete message + raw, err := ms.waitForMessageType("clone_complete", 5*time.Second) + if err != nil { + t.Fatalf("failed to receive clone_complete: %v", err) + } + + var cloneComplete map[string]interface{} + if err := json.Unmarshal(raw, &cloneComplete); err != nil { + t.Fatalf("failed to unmarshal clone_complete: %v", err) + } + + if cloneComplete["success"] != true { + t.Errorf("expected success=true, got %v (error: %v)", cloneComplete["success"], cloneComplete["error"]) + } + if cloneComplete["project"] != "bare-repo" { + t.Errorf("expected project 'bare-repo', got %v", cloneComplete["project"]) + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + // Verify the directory was created + clonedDir := filepath.Join(workspaceDir, "bare-repo") + if _, err := os.Stat(filepath.Join(clonedDir, ".git")); os.IsNotExist(err) { + t.Error("cloned repository directory does not have .git") + } +} + +func TestHandleCloneRepo_CustomDirectoryName(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + // Create a bare git repo to clone from + bareRepo := filepath.Join(home, "source.git") + cmd := exec.Command("git", "init", "--bare", bareRepo) + cmd.Env = append(os.Environ(), "GIT_CONFIG_GLOBAL=/dev/null", "GIT_CONFIG_SYSTEM=/dev/null") + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("git init --bare failed: %v\n%s", err, out) + } + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + cloneReq := map[string]interface{}{ + "type": "clone_repo", + "id": "req-clone-2", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "url": bareRepo, + "directory_name": "my-custom-name", + } + if err := ms.sendCommand(cloneReq); err != nil { + t.Fatalf("failed to send clone command: %v", err) + } + + // Wait for clone_complete message + raw, err := ms.waitForMessageType("clone_complete", 5*time.Second) + if err != nil { + t.Fatalf("failed to receive clone_complete: %v", err) + } + + var cloneComplete map[string]interface{} + if err := json.Unmarshal(raw, &cloneComplete); err != nil { + t.Fatalf("failed to unmarshal clone_complete: %v", err) + } + + if cloneComplete["success"] != true { + t.Errorf("expected success=true, got %v", cloneComplete["success"]) + } + if cloneComplete["project"] != "my-custom-name" { + t.Errorf("expected project 'my-custom-name', got %v", cloneComplete["project"]) + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + // Verify directory exists under custom name + if _, err := os.Stat(filepath.Join(workspaceDir, "my-custom-name", ".git")); os.IsNotExist(err) { + t.Error("cloned repo not found at custom directory name") + } +} + +func TestHandleCloneRepo_DirectoryExists(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + // Create the target directory ahead of time + if err := os.MkdirAll(filepath.Join(workspaceDir, "existing-repo"), 0o755); err != nil { + t.Fatal(err) + } + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + cloneReq := map[string]interface{}{ + "type": "clone_repo", + "id": "req-clone-3", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "url": "https://github.com/user/existing-repo.git", + } + if err := ms.sendCommand(cloneReq); err != nil { + t.Fatalf("failed to send clone command: %v", err) + } + + // Wait for error message + raw, err := ms.waitForMessageType("error", 2*time.Second) + if err != nil { + t.Fatalf("failed to receive error message: %v", err) + } + + var errorReceived map[string]interface{} + if err := json.Unmarshal(raw, &errorReceived); err != nil { + t.Fatalf("failed to unmarshal error: %v", err) + } + + if errorReceived["code"] != "CLONE_FAILED" { + t.Errorf("expected code CLONE_FAILED, got %v", errorReceived["code"]) + } + if !strings.Contains(errorReceived["message"].(string), "already exists") { + t.Errorf("expected 'already exists' in message, got %v", errorReceived["message"]) + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } +} + +func TestHandleCloneRepo_InvalidURL(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + cloneReq := map[string]interface{}{ + "type": "clone_repo", + "id": "req-clone-4", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "url": "/nonexistent/invalid-repo", + } + if err := ms.sendCommand(cloneReq); err != nil { + t.Fatalf("failed to send clone command: %v", err) + } + + // Wait for clone_complete message + raw, err := ms.waitForMessageType("clone_complete", 5*time.Second) + if err != nil { + t.Fatalf("failed to receive clone_complete: %v", err) + } + + var cloneComplete map[string]interface{} + if err := json.Unmarshal(raw, &cloneComplete); err != nil { + t.Fatalf("failed to unmarshal clone_complete: %v", err) + } + + if cloneComplete["success"] != false { + t.Errorf("expected success=false, got %v", cloneComplete["success"]) + } + errMsg, ok := cloneComplete["error"].(string) + if !ok || errMsg == "" { + t.Error("expected non-empty error message for failed clone") + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } +} + +func TestHandleCreateProject_Success(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + createReq := map[string]interface{}{ + "type": "create_project", + "id": "req-create-1", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "name": "new-project", + "git_init": false, + } + if err := ms.sendCommand(createReq); err != nil { + t.Fatalf("failed to send create command: %v", err) + } + + // Wait for project_list message (without git_init, project won't show up in scanner) + raw, err := ms.waitForMessageType("project_list", 2*time.Second) + if err != nil { + t.Fatalf("failed to receive project_list: %v", err) + } + + var response map[string]interface{} + if err := json.Unmarshal(raw, &response); err != nil { + t.Fatalf("failed to unmarshal project_list: %v", err) + } + + if response["type"] != "project_list" { + t.Errorf("expected type 'project_list', got %v", response["type"]) + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + // Verify directory was created + projectDir := filepath.Join(workspaceDir, "new-project") + info, err := os.Stat(projectDir) + if err != nil { + t.Fatalf("project directory not created: %v", err) + } + if !info.IsDir() { + t.Error("expected project path to be a directory") + } +} + +func TestHandleCreateProject_WithGitInit(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + createReq := map[string]interface{}{ + "type": "create_project", + "id": "req-create-2", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "name": "git-project", + "git_init": true, + } + if err := ms.sendCommand(createReq); err != nil { + t.Fatalf("failed to send create command: %v", err) + } + + // Wait for project_state message (with git_init, scanner finds the project) + raw, err := ms.waitForMessageType("project_state", 2*time.Second) + if err != nil { + t.Fatalf("failed to receive project_state: %v", err) + } + + var response map[string]interface{} + if err := json.Unmarshal(raw, &response); err != nil { + t.Fatalf("failed to unmarshal project_state: %v", err) + } + + if response["type"] != "project_state" { + t.Errorf("expected type 'project_state', got %v", response["type"]) + } + project, ok := response["project"].(map[string]interface{}) + if !ok { + t.Fatal("expected project object in response") + } + if project["name"] != "git-project" { + t.Errorf("expected project name 'git-project', got %v", project["name"]) + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + // Verify git repo was initialized + projectDir := filepath.Join(workspaceDir, "git-project") + gitDir := filepath.Join(projectDir, ".git") + if _, err := os.Stat(gitDir); os.IsNotExist(err) { + t.Error("expected .git directory to be created") + } +} + +func TestHandleCreateProject_AlreadyExists(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + // Create directory ahead of time + if err := os.MkdirAll(filepath.Join(workspaceDir, "existing"), 0o755); err != nil { + t.Fatal(err) + } + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + createReq := map[string]interface{}{ + "type": "create_project", + "id": "req-create-3", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "name": "existing", + "git_init": false, + } + if err := ms.sendCommand(createReq); err != nil { + t.Fatalf("failed to send create command: %v", err) + } + + // Wait for error message + raw, err := ms.waitForMessageType("error", 2*time.Second) + if err != nil { + t.Fatalf("failed to receive error message: %v", err) + } + + var errorReceived map[string]interface{} + if err := json.Unmarshal(raw, &errorReceived); err != nil { + t.Fatalf("failed to unmarshal error: %v", err) + } + + if errorReceived["type"] != "error" { + t.Errorf("expected type 'error', got %v", errorReceived["type"]) + } + if errorReceived["code"] != "FILESYSTEM_ERROR" { + t.Errorf("expected code FILESYSTEM_ERROR, got %v", errorReceived["code"]) + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } +} + +func TestHandleCreateProject_EmptyName(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + createReq := map[string]interface{}{ + "type": "create_project", + "id": "req-create-4", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "name": "", + "git_init": false, + } + if err := ms.sendCommand(createReq); err != nil { + t.Fatalf("failed to send create command: %v", err) + } + + // Wait for error message + raw, err := ms.waitForMessageType("error", 2*time.Second) + if err != nil { + t.Fatalf("failed to receive error message: %v", err) + } + + var errorReceived map[string]interface{} + if err := json.Unmarshal(raw, &errorReceived); err != nil { + t.Fatalf("failed to unmarshal error: %v", err) + } + + if errorReceived["code"] != "FILESYSTEM_ERROR" { + t.Errorf("expected code FILESYSTEM_ERROR, got %v", errorReceived["code"]) + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } +} + +func TestScanGitProgress(t *testing.T) { + input := "Cloning into 'repo'...\rReceiving objects: 50%\rReceiving objects: 100%\nDone.\n" + var tokens []string + data := []byte(input) + for len(data) > 0 { + advance, token, err := scanGitProgress(data, false) + if err != nil { + t.Fatal(err) + } + if advance == 0 { + // Process remaining at EOF + _, token, _ = scanGitProgress(data, true) + if token != nil { + tokens = append(tokens, string(token)) + } + break + } + if token != nil { + tokens = append(tokens, string(token)) + } + data = data[advance:] + } + + expected := []string{"Cloning into 'repo'...", "Receiving objects: 50%", "Receiving objects: 100%", "Done."} + if len(tokens) != len(expected) { + t.Fatalf("expected %d tokens, got %d: %v", len(expected), len(tokens), tokens) + } + for i, tok := range tokens { + if tok != expected[i] { + t.Errorf("token[%d] = %q, want %q", i, tok, expected[i]) + } + } +} + +// Unit tests for clone/create functions with mock projectFinder + +type mockScanner struct { + workspacePath string + projects []ws.ProjectSummary +} + +func (m *mockScanner) FindProject(name string) (ws.ProjectSummary, bool) { + for _, p := range m.projects { + if p.Name == name { + return p, true + } + } + return ws.ProjectSummary{}, false +} + +func TestCloneProgressParsing(t *testing.T) { + tests := []struct { + input string + percent int + }{ + {"Receiving objects: 50% (1/2)", 50}, + {"Resolving deltas: 100% (10/10)", 100}, + {"Cloning into 'repo'...", 0}, + {"Receiving objects: 3% (1/33)", 3}, + } + + for _, tt := range tests { + matches := percentPattern.FindStringSubmatch(tt.input) + got := 0 + if len(matches) > 1 { + got, _ = strconv.Atoi(matches[1]) + } + if got != tt.percent { + t.Errorf("percent for %q: got %d, want %d", tt.input, got, tt.percent) + } + } +} + +func TestSendCloneComplete_NilClient(t *testing.T) { + // Should not panic + sendCloneComplete(nil, "https://example.com/repo.git", "repo", true, "") +} + +func TestSendCloneProgress_NilClient(t *testing.T) { + // Should not panic + sendCloneProgress(nil, "https://example.com/repo.git", "progress", 50) +} diff --git a/internal/cmd/diffs.go b/internal/cmd/diffs.go new file mode 100644 index 0000000..e9a6286 --- /dev/null +++ b/internal/cmd/diffs.go @@ -0,0 +1,232 @@ +package cmd + +import ( + "encoding/json" + "fmt" + "log" + "os" + "os/exec" + "path/filepath" + "strings" + + "github.com/minicodemonkey/chief/internal/ws" +) + +// handleGetDiffs handles a get_diffs request from the browser. +// Unlike get_diff, this does not require prd_id and returns parsed per-file diffs. +func handleGetDiffs(sender messageSender, finder projectFinder, msg ws.Message) { + var req ws.GetDiffsMessage + if err := json.Unmarshal(msg.Raw, &req); err != nil { + log.Printf("Error parsing get_diffs message: %v", err) + return + } + + project, found := finder.FindProject(req.Project) + if !found { + sendError(sender, ws.ErrCodeProjectNotFound, + fmt.Sprintf("Project %q not found", req.Project), msg.ID) + return + } + + diffText, _, err := getStoryDiff(project.Path, req.StoryID) + if err != nil { + sendError(sender, ws.ErrCodeFilesystemError, + fmt.Sprintf("Failed to get diff for story %q: %v", req.StoryID, err), msg.ID) + return + } + + files := parseDiffFiles(diffText) + + resp := ws.DiffsResponseMessage{ + Type: ws.TypeDiffsResponse, + Payload: ws.DiffsResponsePayload{ + Project: req.Project, + StoryID: req.StoryID, + Files: files, + }, + } + if err := sender.Send(resp); err != nil { + log.Printf("Error sending diffs_response: %v", err) + } +} + +// parseDiffFiles splits a unified diff into per-file details. +func parseDiffFiles(diffText string) []ws.DiffFileDetail { + if diffText == "" { + return []ws.DiffFileDetail{} + } + + // Split on "diff --git" boundaries + chunks := strings.Split(diffText, "diff --git ") + var files []ws.DiffFileDetail + + for _, chunk := range chunks { + chunk = strings.TrimSpace(chunk) + if chunk == "" { + continue + } + + // Extract filename from first line: "a/path b/path" + firstLine := chunk + if idx := strings.IndexByte(chunk, '\n'); idx != -1 { + firstLine = chunk[:idx] + } + + filename := "" + if parts := strings.SplitN(firstLine, " b/", 2); len(parts) == 2 { + filename = parts[1] + } + + // Count additions and deletions + additions, deletions := 0, 0 + for _, line := range strings.Split(chunk, "\n") { + if strings.HasPrefix(line, "+") && !strings.HasPrefix(line, "+++") { + additions++ + } else if strings.HasPrefix(line, "-") && !strings.HasPrefix(line, "---") { + deletions++ + } + } + + files = append(files, ws.DiffFileDetail{ + Filename: filename, + Additions: additions, + Deletions: deletions, + Patch: "diff --git " + chunk, + }) + } + + if files == nil { + files = []ws.DiffFileDetail{} + } + return files +} + +// handleGetDiff handles a get_diff request. +func handleGetDiff(sender messageSender, finder projectFinder, msg ws.Message) { + var req ws.GetDiffMessage + if err := json.Unmarshal(msg.Raw, &req); err != nil { + log.Printf("Error parsing get_diff message: %v", err) + return + } + + project, found := finder.FindProject(req.Project) + if !found { + sendError(sender, ws.ErrCodeProjectNotFound, + fmt.Sprintf("Project %q not found", req.Project), msg.ID) + return + } + + prdDir := filepath.Join(project.Path, ".chief", "prds", req.PRDID) + if _, err := os.Stat(prdDir); os.IsNotExist(err) { + sendError(sender, ws.ErrCodePRDNotFound, + fmt.Sprintf("PRD %q not found in project %q", req.PRDID, req.Project), msg.ID) + return + } + + diffText, files, err := getStoryDiff(project.Path, req.StoryID) + if err != nil { + sendError(sender, ws.ErrCodeFilesystemError, + fmt.Sprintf("Failed to get diff for story %q: %v", req.StoryID, err), msg.ID) + return + } + + sendDiffMessage(sender, req.Project, req.PRDID, req.StoryID, files, diffText) +} + +// getStoryDiff returns the diff and list of changed files for a story's commit(s). +// It finds commits matching the story ID pattern "feat: <story-id> -" in the commit message. +func getStoryDiff(repoDir, storyID string) (string, []string, error) { + // Find commit hash(es) for this story by searching commit messages + commitHash, err := findStoryCommit(repoDir, storyID) + if err != nil { + return "", nil, err + } + if commitHash == "" { + return "", nil, fmt.Errorf("no commit found for story %s", storyID) + } + + // Get the unified diff for the commit + diffText, err := getCommitDiff(repoDir, commitHash) + if err != nil { + return "", nil, fmt.Errorf("getting diff: %w", err) + } + + // Get the list of changed files + files, err := getCommitFiles(repoDir, commitHash) + if err != nil { + return "", nil, fmt.Errorf("getting changed files: %w", err) + } + + return diffText, files, nil +} + +// findStoryCommit finds the most recent commit hash matching a story ID. +// It searches for commits with messages matching "feat: <storyID> -" or +// containing the story ID. +func findStoryCommit(repoDir, storyID string) (string, error) { + // Search for commits with messages containing the story ID + cmd := exec.Command("git", "log", "--format=%H", "--grep", storyID, "-1") + cmd.Dir = repoDir + output, err := cmd.Output() + if err != nil { + return "", fmt.Errorf("searching git log: %w", err) + } + + hash := strings.TrimSpace(string(output)) + return hash, nil +} + +// getCommitDiff returns the unified diff for a specific commit. +func getCommitDiff(repoDir, commitHash string) (string, error) { + cmd := exec.Command("git", "show", "--format=", "--patch", commitHash) + cmd.Dir = repoDir + output, err := cmd.Output() + if err != nil { + return "", err + } + return string(output), nil +} + +// getCommitFiles returns the list of files changed in a specific commit. +func getCommitFiles(repoDir, commitHash string) ([]string, error) { + cmd := exec.Command("git", "show", "--format=", "--name-only", commitHash) + cmd.Dir = repoDir + output, err := cmd.Output() + if err != nil { + return nil, err + } + + raw := strings.TrimSpace(string(output)) + if raw == "" { + return []string{}, nil + } + + files := strings.Split(raw, "\n") + return files, nil +} + +// sendDiffMessage sends a diff message. +func sendDiffMessage(sender messageSender, project, prdID, storyID string, files []string, diffText string) { + if sender == nil { + return + } + + if files == nil { + files = []string{} + } + + envelope := ws.NewMessage(ws.TypeDiff) + msg := ws.DiffMessage{ + Type: envelope.Type, + ID: envelope.ID, + Timestamp: envelope.Timestamp, + Project: project, + PRDID: prdID, + StoryID: storyID, + Files: files, + DiffText: diffText, + } + if err := sender.Send(msg); err != nil { + log.Printf("Error sending diff: %v", err) + } +} diff --git a/internal/cmd/diffs_test.go b/internal/cmd/diffs_test.go new file mode 100644 index 0000000..4f5a1a4 --- /dev/null +++ b/internal/cmd/diffs_test.go @@ -0,0 +1,636 @@ +package cmd + +import ( + "encoding/json" + "os" + "os/exec" + "path/filepath" + "strings" + "sync" + "testing" + "time" + + "github.com/minicodemonkey/chief/internal/engine" +) + +// gitCmd runs a git command in the given directory with test-safe env. +func gitCmd(t *testing.T, dir string, args ...string) string { + t.Helper() + cmd := exec.Command("git", args...) + cmd.Dir = dir + cmd.Env = append(os.Environ(), "GIT_CONFIG_GLOBAL=/dev/null", "GIT_CONFIG_SYSTEM=/dev/null") + out, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("git %s failed: %v\n%s", strings.Join(args, " "), err, out) + } + return strings.TrimSpace(string(out)) +} + +// createGitRepoWithStoryCommit creates a git repo with an initial commit +// and a story commit matching the "feat: <storyID> - <title>" pattern. +func createGitRepoWithStoryCommit(t *testing.T, dir, storyID, title string) { + t.Helper() + createGitRepo(t, dir) + + // Create a file and commit it with the story commit message + filePath := filepath.Join(dir, "feature.go") + if err := os.WriteFile(filePath, []byte("package main\n\nfunc feature() {}\n"), 0o644); err != nil { + t.Fatal(err) + } + gitCmd(t, dir, "add", "feature.go") + gitCmd(t, dir, "commit", "-m", "feat: "+storyID+" - "+title) +} + +func TestGetStoryDiff_Success(t *testing.T) { + dir := t.TempDir() + createGitRepoWithStoryCommit(t, dir, "US-001", "Add feature") + + diffText, files, err := getStoryDiff(dir, "US-001") + if err != nil { + t.Fatalf("getStoryDiff failed: %v", err) + } + + if diffText == "" { + t.Error("expected non-empty diff text") + } + if !strings.Contains(diffText, "feature.go") { + t.Errorf("expected diff to contain 'feature.go', got: %s", diffText) + } + + if len(files) != 1 { + t.Errorf("expected 1 changed file, got %d: %v", len(files), files) + } + if len(files) > 0 && files[0] != "feature.go" { + t.Errorf("expected file 'feature.go', got %q", files[0]) + } +} + +func TestGetStoryDiff_NoCommitFound(t *testing.T) { + dir := t.TempDir() + createGitRepo(t, dir) + + _, _, err := getStoryDiff(dir, "US-999") + if err == nil { + t.Fatal("expected error for missing story commit") + } + if !strings.Contains(err.Error(), "no commit found") { + t.Errorf("expected 'no commit found' error, got: %v", err) + } +} + +func TestGetStoryDiff_MultipleFiles(t *testing.T) { + dir := t.TempDir() + createGitRepo(t, dir) + + // Create multiple files and commit + for _, name := range []string{"a.go", "b.go", "c.go"} { + if err := os.WriteFile(filepath.Join(dir, name), []byte("package main\n"), 0o644); err != nil { + t.Fatal(err) + } + } + gitCmd(t, dir, "add", ".") + gitCmd(t, dir, "commit", "-m", "feat: US-002 - Add multiple files") + + diffText, files, err := getStoryDiff(dir, "US-002") + if err != nil { + t.Fatalf("getStoryDiff failed: %v", err) + } + + if len(files) != 3 { + t.Errorf("expected 3 changed files, got %d: %v", len(files), files) + } + + if !strings.Contains(diffText, "a.go") || !strings.Contains(diffText, "b.go") || !strings.Contains(diffText, "c.go") { + t.Errorf("expected diff to contain all files, got: %s", diffText) + } +} + +func TestFindStoryCommit_FindsMostRecent(t *testing.T) { + dir := t.TempDir() + createGitRepo(t, dir) + + // Create first commit for the story + if err := os.WriteFile(filepath.Join(dir, "v1.go"), []byte("package v1\n"), 0o644); err != nil { + t.Fatal(err) + } + gitCmd(t, dir, "add", ".") + gitCmd(t, dir, "commit", "-m", "feat: US-003 - Initial attempt") + + // Create second commit for the same story (more recent) + if err := os.WriteFile(filepath.Join(dir, "v2.go"), []byte("package v2\n"), 0o644); err != nil { + t.Fatal(err) + } + gitCmd(t, dir, "add", ".") + gitCmd(t, dir, "commit", "-m", "feat: US-003 - Fixed version") + + hash, err := findStoryCommit(dir, "US-003") + if err != nil { + t.Fatalf("findStoryCommit failed: %v", err) + } + if hash == "" { + t.Fatal("expected non-empty commit hash") + } + + // The most recent commit should be the "Fixed version" one + // Verify by checking the commit message + cmd := exec.Command("git", "log", "--format=%s", "-1", hash) + cmd.Dir = dir + out, err := cmd.Output() + if err != nil { + t.Fatalf("git log failed: %v", err) + } + msg := strings.TrimSpace(string(out)) + if msg != "feat: US-003 - Fixed version" { + t.Errorf("expected most recent commit, got: %q", msg) + } +} + +func TestSendDiffMessage(t *testing.T) { + // sendDiffMessage with nil client should not panic + sendDiffMessage(nil, "project", "prd", "US-001", []string{"a.go"}, "diff text") +} + +func TestSendDiffMessage_NilFiles(t *testing.T) { + // sendDiffMessage with nil files should not panic + sendDiffMessage(nil, "project", "prd", "US-001", nil, "diff text") +} + +func TestRunManager_SendStoryDiff(t *testing.T) { + eng := engine.New(5) + defer eng.Shutdown() + + rm := newRunManager(eng, nil) // nil client — just verify no panic + + // Create a temp project with a git repo and story commit + projectDir := t.TempDir() + createGitRepoWithStoryCommit(t, projectDir, "US-001", "Test Story") + + prdDir := filepath.Join(projectDir, ".chief", "prds", "feature") + if err := os.MkdirAll(prdDir, 0o755); err != nil { + t.Fatal(err) + } + prdPath := filepath.Join(prdDir, "prd.json") + if err := os.WriteFile(prdPath, []byte(`{}`), 0o644); err != nil { + t.Fatal(err) + } + + rm.mu.Lock() + rm.runs["myproject/feature"] = &runInfo{ + project: "myproject", + prdID: "feature", + prdPath: prdPath, + startTime: time.Now(), + storyID: "US-001", + } + rm.mu.Unlock() + + // Call sendStoryDiff with nil client — should not panic, just log + info := rm.runs["myproject/feature"] + rm.sendStoryDiff(info, engine.ManagerEvent{}.Event) +} + +func TestRunServe_GetDiff(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + // Create a git repo with a story commit + projectDir := filepath.Join(workspaceDir, "myproject") + createGitRepoWithStoryCommit(t, projectDir, "US-001", "Add feature") + + prdDir := filepath.Join(projectDir, ".chief", "prds", "feature") + if err := os.MkdirAll(prdDir, 0o755); err != nil { + t.Fatal(err) + } + + prdState := `{"project": "My Feature", "userStories": [{"id": "US-001", "title": "Add feature", "passes": true}]}` + if err := os.WriteFile(filepath.Join(prdDir, "prd.json"), []byte(prdState), 0o644); err != nil { + t.Fatal(err) + } + + var response map[string]interface{} + var mu sync.Mutex + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + getDiffReq := map[string]interface{}{ + "type": "get_diff", + "id": "req-1", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "project": "myproject", + "prd_id": "feature", + "story_id": "US-001", + } + if err := ms.sendCommand(getDiffReq); err != nil { + t.Errorf("sendCommand failed: %v", err) + return + } + + raw, err := ms.waitForMessageType("diff", 5*time.Second) + if err == nil { + mu.Lock() + json.Unmarshal(raw, &response) + mu.Unlock() + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + mu.Lock() + defer mu.Unlock() + + if response == nil { + t.Fatal("response was not received") + } + if response["type"] != "diff" { + t.Errorf("expected type 'diff', got %v", response["type"]) + } + if response["project"] != "myproject" { + t.Errorf("expected project 'myproject', got %v", response["project"]) + } + if response["prd_id"] != "feature" { + t.Errorf("expected prd_id 'feature', got %v", response["prd_id"]) + } + if response["story_id"] != "US-001" { + t.Errorf("expected story_id 'US-001', got %v", response["story_id"]) + } + + // Verify files array + files, ok := response["files"].([]interface{}) + if !ok { + t.Fatal("expected files to be an array") + } + if len(files) != 1 { + t.Errorf("expected 1 file, got %d", len(files)) + } + if len(files) > 0 && files[0] != "feature.go" { + t.Errorf("expected file 'feature.go', got %v", files[0]) + } + + // Verify diff_text is non-empty and contains the file + diffText, ok := response["diff_text"].(string) + if !ok || diffText == "" { + t.Error("expected non-empty diff_text") + } + if !strings.Contains(diffText, "feature.go") { + t.Errorf("expected diff_text to contain 'feature.go'") + } +} + +func TestRunServe_GetDiffProjectNotFound(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + var response map[string]interface{} + var mu sync.Mutex + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + getDiffReq := map[string]interface{}{ + "type": "get_diff", + "id": "req-2", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "project": "nonexistent", + "prd_id": "feature", + "story_id": "US-001", + } + if err := ms.sendCommand(getDiffReq); err != nil { + t.Errorf("sendCommand failed: %v", err) + return + } + + raw, err := ms.waitForMessageType("error", 5*time.Second) + if err == nil { + mu.Lock() + json.Unmarshal(raw, &response) + mu.Unlock() + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + mu.Lock() + defer mu.Unlock() + + if response == nil { + t.Fatal("response was not received") + } + if response["type"] != "error" { + t.Errorf("expected type 'error', got %v", response["type"]) + } + if response["code"] != "PROJECT_NOT_FOUND" { + t.Errorf("expected code 'PROJECT_NOT_FOUND', got %v", response["code"]) + } +} + +func TestRunServe_GetDiffPRDNotFound(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + projectDir := filepath.Join(workspaceDir, "myproject") + createGitRepo(t, projectDir) + + var response map[string]interface{} + var mu sync.Mutex + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + getDiffReq := map[string]interface{}{ + "type": "get_diff", + "id": "req-3", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "project": "myproject", + "prd_id": "nonexistent", + "story_id": "US-001", + } + if err := ms.sendCommand(getDiffReq); err != nil { + t.Errorf("sendCommand failed: %v", err) + return + } + + raw, err := ms.waitForMessageType("error", 5*time.Second) + if err == nil { + mu.Lock() + json.Unmarshal(raw, &response) + mu.Unlock() + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + mu.Lock() + defer mu.Unlock() + + if response == nil { + t.Fatal("response was not received") + } + if response["type"] != "error" { + t.Errorf("expected type 'error', got %v", response["type"]) + } + if response["code"] != "PRD_NOT_FOUND" { + t.Errorf("expected code 'PRD_NOT_FOUND', got %v", response["code"]) + } +} + +func TestRunServe_GetDiffNoCommit(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + projectDir := filepath.Join(workspaceDir, "myproject") + createGitRepo(t, projectDir) + + prdDir := filepath.Join(projectDir, ".chief", "prds", "feature") + if err := os.MkdirAll(prdDir, 0o755); err != nil { + t.Fatal(err) + } + + prdState := `{"project": "My Feature", "userStories": [{"id": "US-001", "title": "Test", "passes": false}]}` + if err := os.WriteFile(filepath.Join(prdDir, "prd.json"), []byte(prdState), 0o644); err != nil { + t.Fatal(err) + } + + var response map[string]interface{} + var mu sync.Mutex + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + getDiffReq := map[string]interface{}{ + "type": "get_diff", + "id": "req-4", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "project": "myproject", + "prd_id": "feature", + "story_id": "US-001", + } + if err := ms.sendCommand(getDiffReq); err != nil { + t.Errorf("sendCommand failed: %v", err) + return + } + + raw, err := ms.waitForMessageType("error", 5*time.Second) + if err == nil { + mu.Lock() + json.Unmarshal(raw, &response) + mu.Unlock() + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + mu.Lock() + defer mu.Unlock() + + if response == nil { + t.Fatal("response was not received") + } + if response["type"] != "error" { + t.Errorf("expected type 'error', got %v", response["type"]) + } + if response["code"] != "FILESYSTEM_ERROR" { + t.Errorf("expected code 'FILESYSTEM_ERROR', got %v", response["code"]) + } +} + +func TestRunServe_GetDiffs(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + // Create a git repo with a story commit + projectDir := filepath.Join(workspaceDir, "myproject") + createGitRepoWithStoryCommit(t, projectDir, "US-001", "Add feature") + + var response map[string]interface{} + var mu sync.Mutex + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + // get_diffs does not require prd_id (unlike get_diff) + req := map[string]interface{}{ + "type": "get_diffs", + "id": "req-1", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "project": "myproject", + "story_id": "US-001", + } + if err := ms.sendCommand(req); err != nil { + t.Errorf("sendCommand failed: %v", err) + return + } + + raw, err := ms.waitForMessageType("diffs_response", 5*time.Second) + if err == nil { + mu.Lock() + json.Unmarshal(raw, &response) + mu.Unlock() + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + mu.Lock() + defer mu.Unlock() + + if response == nil { + t.Fatal("diffs_response was not received") + } + if response["type"] != "diffs_response" { + t.Errorf("expected type 'diffs_response', got %v", response["type"]) + } + + payload, ok := response["payload"].(map[string]interface{}) + if !ok { + t.Fatal("expected payload to be an object") + } + if payload["project"] != "myproject" { + t.Errorf("expected project 'myproject', got %v", payload["project"]) + } + if payload["story_id"] != "US-001" { + t.Errorf("expected story_id 'US-001', got %v", payload["story_id"]) + } + + files, ok := payload["files"].([]interface{}) + if !ok { + t.Fatal("expected files to be an array") + } + if len(files) != 1 { + t.Fatalf("expected 1 file, got %d", len(files)) + } + + file := files[0].(map[string]interface{}) + if file["filename"] != "feature.go" { + t.Errorf("expected filename 'feature.go', got %v", file["filename"]) + } + if int(file["additions"].(float64)) == 0 { + t.Error("expected additions > 0") + } + if _, ok := file["patch"].(string); !ok || file["patch"] == "" { + t.Error("expected non-empty patch string") + } +} + +func TestRunServe_GetDiffs_ProjectNotFound(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + var response map[string]interface{} + var mu sync.Mutex + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + req := map[string]interface{}{ + "type": "get_diffs", + "id": "req-1", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "project": "nonexistent", + "story_id": "US-001", + } + if err := ms.sendCommand(req); err != nil { + t.Errorf("sendCommand failed: %v", err) + return + } + + raw, err := ms.waitForMessageType("error", 5*time.Second) + if err == nil { + mu.Lock() + json.Unmarshal(raw, &response) + mu.Unlock() + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + mu.Lock() + defer mu.Unlock() + + if response == nil { + t.Fatal("error message was not received") + } + if response["code"] != "PROJECT_NOT_FOUND" { + t.Errorf("expected code 'PROJECT_NOT_FOUND', got %v", response["code"]) + } +} + +func TestParseDiffFiles(t *testing.T) { + diffText := `diff --git a/main.go b/main.go +index abc..def 100644 +--- a/main.go ++++ b/main.go +@@ -1,3 +1,5 @@ + package main ++import "fmt" ++func hello() { fmt.Println("hi") } + func main() {} +diff --git a/util.go b/util.go +new file mode 100644 +--- /dev/null ++++ b/util.go +@@ -0,0 +1,3 @@ ++package main ++func helper() {} ++func other() {} +` + + files := parseDiffFiles(diffText) + if len(files) != 2 { + t.Fatalf("expected 2 files, got %d", len(files)) + } + + if files[0].Filename != "main.go" { + t.Errorf("files[0].filename = %q, want %q", files[0].Filename, "main.go") + } + if files[0].Additions != 2 { + t.Errorf("files[0].additions = %d, want 2", files[0].Additions) + } + if files[0].Deletions != 0 { + t.Errorf("files[0].deletions = %d, want 0", files[0].Deletions) + } + + if files[1].Filename != "util.go" { + t.Errorf("files[1].filename = %q, want %q", files[1].Filename, "util.go") + } + if files[1].Additions != 3 { + t.Errorf("files[1].additions = %d, want 3", files[1].Additions) + } +} + +func TestParseDiffFiles_Empty(t *testing.T) { + files := parseDiffFiles("") + if len(files) != 0 { + t.Errorf("expected 0 files for empty diff, got %d", len(files)) + } +} diff --git a/internal/cmd/e2e_test.go b/internal/cmd/e2e_test.go new file mode 100644 index 0000000..6a85e32 --- /dev/null +++ b/internal/cmd/e2e_test.go @@ -0,0 +1,400 @@ +package cmd + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "sync" + "testing" + "time" +) + +// End-to-end integration tests verifying the complete CLI ↔ server message flow +// via the uplink HTTP+Pusher transport. These tests complement the unit-level +// uplink tests (internal/uplink/*_test.go) and the existing serve_test.go tests. + +func TestE2E_DeviceLifecycle_ConnectMessagesHeartbeatDisconnect(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + createGitRepo(t, filepath.Join(workspaceDir, "myproject")) + + ctx, cancel := context.WithCancel(context.Background()) + ms := newMockUplinkServer(t) + + go func() { + // Wait for full connection (Pusher subscribe). + if err := ms.waitForPusherSubscribe(10 * time.Second); err != nil { + t.Logf("waitForPusherSubscribe: %v", err) + cancel() + return + } + + // Wait for state_snapshot (CLI sends on connect). + if _, err := ms.waitForMessageType("state_snapshot", 5*time.Second); err != nil { + t.Logf("waitForMessageType(state_snapshot): %v", err) + cancel() + return + } + + // Send a command via Pusher → CLI should respond. + listReq := map[string]string{ + "type": "list_projects", + "id": "req-1", + "timestamp": time.Now().UTC().Format(time.RFC3339), + } + if err := ms.sendCommand(listReq); err != nil { + t.Logf("sendCommand: %v", err) + cancel() + return + } + + // Wait for project_list response via HTTP messages endpoint. + if _, err := ms.waitForMessageType("project_list", 5*time.Second); err != nil { + t.Logf("waitForMessageType(project_list): %v", err) + } + + // Cancel to trigger graceful shutdown (disconnect). + cancel() + }() + + err := RunServe(ServeOptions{ + Workspace: workspaceDir, + ServerURL: ms.httpSrv.URL, + Version: "1.0.0", + Ctx: ctx, + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + // Verify the full lifecycle happened: + // 1. Connect was called + if got := ms.connectCalls.Load(); got < 1 { + t.Errorf("connect calls = %d, want >= 1", got) + } + + // 2. Messages were sent (state_snapshot + project_list at minimum) + if got := ms.messagesCalls.Load(); got < 1 { + t.Errorf("messages calls = %d, want >= 1", got) + } + + // 3. State snapshot was received + if _, err := ms.waitForMessageType("state_snapshot", time.Second); err != nil { + t.Error("state_snapshot not received") + } + + // 4. Project list response was received + raw, err := ms.waitForMessageType("project_list", time.Second) + if err != nil { + t.Error("project_list not received") + } else { + var resp map[string]interface{} + json.Unmarshal(raw, &resp) + projects := resp["projects"].([]interface{}) + if len(projects) != 1 { + t.Errorf("expected 1 project, got %d", len(projects)) + } + } + + // 5. Disconnect was called during shutdown + if got := ms.disconnectCalls.Load(); got != 1 { + t.Errorf("disconnect calls = %d, want 1", got) + } +} + +func TestE2E_BidirectionalMessageFlow(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + createGitRepo(t, filepath.Join(workspaceDir, "alpha")) + createGitRepo(t, filepath.Join(workspaceDir, "beta")) + + var responses []map[string]interface{} + var mu sync.Mutex + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + // Send multiple commands rapidly — verify CLI processes and responds to each. + + // Command 1: list_projects + ms.sendCommand(map[string]string{ + "type": "list_projects", + "id": "req-1", + "timestamp": time.Now().UTC().Format(time.RFC3339), + }) + + // Command 2: get_project + ms.sendCommand(map[string]string{ + "type": "get_project", + "id": "req-2", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "project": "alpha", + }) + + // Command 3: ping + ms.sendCommand(map[string]string{ + "type": "ping", + "id": "req-3", + "timestamp": time.Now().UTC().Format(time.RFC3339), + }) + + // Wait for all three response types. + types := []string{"project_list", "project_state", "pong"} + for _, typ := range types { + raw, err := ms.waitForMessageType(typ, 5*time.Second) + if err == nil { + var resp map[string]interface{} + json.Unmarshal(raw, &resp) + mu.Lock() + responses = append(responses, resp) + mu.Unlock() + } else { + t.Logf("waitForMessageType(%s): %v", typ, err) + } + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + mu.Lock() + defer mu.Unlock() + + if len(responses) != 3 { + t.Fatalf("expected 3 responses, got %d", len(responses)) + } + + // Verify we received all expected types. + typeSet := make(map[string]bool) + for _, r := range responses { + typeSet[r["type"].(string)] = true + } + for _, expected := range []string{"project_list", "project_state", "pong"} { + if !typeSet[expected] { + t.Errorf("missing response type %q in %v", expected, typeSet) + } + } +} + +func TestE2E_HeartbeatSentDuringSession(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspace := filepath.Join(home, "projects") + if err := os.MkdirAll(workspace, 0o755); err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithCancel(context.Background()) + ms := newMockUplinkServer(t) + + go func() { + if err := ms.waitForPusherSubscribe(10 * time.Second); err != nil { + t.Logf("waitForPusherSubscribe: %v", err) + cancel() + return + } + + // Wait for state_snapshot. + if _, err := ms.waitForMessageType("state_snapshot", 5*time.Second); err != nil { + t.Logf("waitForMessageType(state_snapshot): %v", err) + cancel() + return + } + + // Wait for at least one heartbeat call (up to 40s since default interval is 30s). + deadline := time.After(40 * time.Second) + for { + if ms.heartbeatCalls.Load() > 0 { + break + } + select { + case <-deadline: + t.Logf("timeout waiting for heartbeat") + cancel() + return + case <-time.After(100 * time.Millisecond): + } + } + + cancel() + }() + + err := RunServe(ServeOptions{ + Workspace: workspace, + ServerURL: ms.httpSrv.URL, + Version: "1.0.0", + Ctx: ctx, + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + if got := ms.heartbeatCalls.Load(); got < 1 { + t.Errorf("heartbeat calls = %d, want >= 1", got) + } +} + +func TestE2E_MultipleCommandsBatchedIntoHTTPPosts(t *testing.T) { + // Verifies that multiple CLI responses are batched into HTTP POST calls + // (not one per message) via the batcher. + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + createGitRepo(t, filepath.Join(workspaceDir, "myproject")) + + var finalPongCount int + var finalTotalMsgs int + var finalHTTPCalls int + var mu sync.Mutex + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + // Send several rapid commands that produce responses. + for i := 0; i < 5; i++ { + ms.sendCommand(map[string]string{ + "type": "ping", + "id": fmt.Sprintf("ping-%d", i), + "timestamp": time.Now().UTC().Format(time.RFC3339), + }) + } + + // Wait for all pongs. + deadline := time.After(5 * time.Second) + pongCount := 0 + for pongCount < 5 { + msgs := ms.getMessages() + pongCount = 0 + for _, raw := range msgs { + var msg map[string]interface{} + json.Unmarshal(raw, &msg) + if msg["type"] == "pong" { + pongCount++ + } + } + if pongCount >= 5 { + break + } + select { + case <-deadline: + t.Logf("timeout: got %d pongs", pongCount) + mu.Lock() + finalPongCount = pongCount + finalTotalMsgs = len(ms.getMessages()) + finalHTTPCalls = int(ms.messagesCalls.Load()) + mu.Unlock() + return + case <-time.After(50 * time.Millisecond): + } + } + + mu.Lock() + finalPongCount = pongCount + finalTotalMsgs = len(ms.getMessages()) + finalHTTPCalls = int(ms.messagesCalls.Load()) + mu.Unlock() + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + mu.Lock() + defer mu.Unlock() + + if finalPongCount != 5 { + t.Errorf("expected 5 pong messages, got %d", finalPongCount) + } + + t.Logf("total messages: %d, HTTP POST calls: %d", finalTotalMsgs, finalHTTPCalls) + if finalHTTPCalls > finalTotalMsgs { + t.Errorf("HTTP POST calls (%d) > total messages (%d), batching may not be working", finalHTTPCalls, finalTotalMsgs) + } +} + +func TestE2E_GracefulShutdownFlushesMessagesAndDisconnects(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspace := filepath.Join(home, "projects") + if err := os.MkdirAll(workspace, 0o755); err != nil { + t.Fatal(err) + } + + createGitRepo(t, filepath.Join(workspace, "myproject")) + + ctx, cancel := context.WithCancel(context.Background()) + ms := newMockUplinkServer(t) + + go func() { + if err := ms.waitForPusherSubscribe(10 * time.Second); err != nil { + t.Logf("waitForPusherSubscribe: %v", err) + cancel() + return + } + + // Wait for state_snapshot. + if _, err := ms.waitForMessageType("state_snapshot", 5*time.Second); err != nil { + t.Logf("waitForMessageType(state_snapshot): %v", err) + cancel() + return + } + + // Send a ping command. + ms.sendCommand(map[string]string{ + "type": "ping", + "id": "ping-before-shutdown", + "timestamp": time.Now().UTC().Format(time.RFC3339), + }) + + // Wait for the pong to actually be delivered to the server. + // This confirms the full send path (enqueue → batcher flush → HTTP POST) + // completed before we trigger shutdown. + if _, err := ms.waitForMessageType("pong", 5*time.Second); err != nil { + t.Logf("waitForMessageType(pong): %v", err) + } + + // Trigger shutdown after messages have been flushed. + cancel() + }() + + err := RunServe(ServeOptions{ + Workspace: workspace, + ServerURL: ms.httpSrv.URL, + Version: "1.0.0", + Ctx: ctx, + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + // Verify pong was received by the server (already confirmed in goroutine). + if _, err := ms.waitForMessageType("pong", time.Second); err != nil { + t.Error("pong not received by server") + } + + // Verify disconnect was called during shutdown. + if got := ms.disconnectCalls.Load(); got != 1 { + t.Errorf("disconnect calls = %d, want 1", got) + } +} diff --git a/internal/cmd/login.go b/internal/cmd/login.go new file mode 100644 index 0000000..249eb4a --- /dev/null +++ b/internal/cmd/login.go @@ -0,0 +1,257 @@ +package cmd + +import ( + "bufio" + "bytes" + "encoding/json" + "errors" + "fmt" + "net/http" + "os" + "os/exec" + "runtime" + "strings" + "time" + + "github.com/minicodemonkey/chief/internal/auth" +) + +const ( + defaultBaseURL = "https://chiefloop.com" + pollInterval = 5 * time.Second + loginTimeout = 5 * time.Minute +) + +// LoginOptions contains configuration for the login command. +type LoginOptions struct { + DeviceName string // Override device name (default: hostname) + BaseURL string // Override base URL (for testing) + SetupToken string // One-time setup token for automated auth +} + +// deviceCodeResponse is the response from the device code endpoint. +type deviceCodeResponse struct { + DeviceCode string `json:"device_code"` + UserCode string `json:"user_code"` +} + +// tokenResponse is the response from the token polling endpoint. +type tokenResponse struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int `json:"expires_in"` + User string `json:"user"` + WSURL string `json:"ws_url"` + Error string `json:"error"` +} + +// RunLogin performs the device OAuth login flow. +func RunLogin(opts LoginOptions) error { + baseURL := opts.BaseURL + if baseURL == "" { + baseURL = defaultBaseURL + } + + deviceName := opts.DeviceName + if deviceName == "" { + hostname, err := os.Hostname() + if err != nil { + deviceName = "unknown" + } else { + deviceName = hostname + } + } + + // Setup token mode: exchange token for credentials directly + if opts.SetupToken != "" { + return exchangeSetupToken(baseURL, opts.SetupToken, deviceName) + } + + // Check if already logged in + existing, err := auth.LoadCredentials() + if err == nil && existing != nil { + fmt.Printf("Already logged in as %s (%s).\n", existing.User, existing.DeviceName) + fmt.Print("Do you want to log in again? This will replace your existing credentials. [y/N] ") + reader := bufio.NewReader(os.Stdin) + answer, _ := reader.ReadString('\n') + answer = strings.TrimSpace(strings.ToLower(answer)) + if answer != "y" && answer != "yes" { + fmt.Println("Login cancelled.") + return nil + } + } + + // Request device code + codeReqBody, _ := json.Marshal(map[string]string{ + "device_name": deviceName, + }) + + resp, err := http.Post(baseURL+"/api/oauth/device/code", "application/json", bytes.NewReader(codeReqBody)) + if err != nil { + return fmt.Errorf("requesting device code: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("requesting device code: server returned %s", resp.Status) + } + + var codeResp deviceCodeResponse + if err := json.NewDecoder(resp.Body).Decode(&codeResp); err != nil { + return fmt.Errorf("parsing device code response: %w", err) + } + + // Display the user code and URL + deviceURL := baseURL + "/oauth/device" + fmt.Println() + fmt.Println("To authenticate, open this URL in your browser:") + fmt.Printf("\n %s\n\n", deviceURL) + fmt.Printf("And enter this code: %s\n\n", codeResp.UserCode) + + // Try to open browser automatically + openBrowserFunc(deviceURL) + + fmt.Println("Waiting for authorization...") + + // Poll for token + creds, err := pollForToken(baseURL, codeResp.DeviceCode, deviceName) + if err != nil { + return err + } + + // Save credentials + if err := auth.SaveCredentials(creds); err != nil { + return fmt.Errorf("saving credentials: %w", err) + } + + fmt.Printf("\nLogged in as %s (%s)\n", creds.User, creds.DeviceName) + return nil +} + +// exchangeSetupToken exchanges a one-time setup token for credentials. +// This is used during automated VPS provisioning via cloud-init. +func exchangeSetupToken(baseURL, token, deviceName string) error { + reqBody, _ := json.Marshal(map[string]string{ + "setup_token": token, + "device_name": deviceName, + }) + + client := &http.Client{Timeout: 10 * time.Second} + resp, err := client.Post(baseURL+"/api/oauth/device/exchange", "application/json", bytes.NewReader(reqBody)) + if err != nil { + return fmt.Errorf("exchanging setup token: %w", err) + } + defer resp.Body.Close() + + var tokenResp tokenResponse + if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { + return fmt.Errorf("parsing setup token response: %w", err) + } + + if resp.StatusCode != http.StatusOK || tokenResp.Error != "" { + errMsg := tokenResp.Error + if errMsg == "" { + errMsg = resp.Status + } + fmt.Fprintf(os.Stderr, "Setup token exchange failed: %s\n", errMsg) + fmt.Fprintf(os.Stderr, "Please authenticate manually by running: chief login\n") + return fmt.Errorf("setup token exchange failed: %s", errMsg) + } + + creds := &auth.Credentials{ + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + ExpiresAt: time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second), + DeviceName: deviceName, + User: tokenResp.User, + WSURL: tokenResp.WSURL, + } + + if err := auth.SaveCredentials(creds); err != nil { + return fmt.Errorf("saving credentials: %w", err) + } + + fmt.Printf("Logged in as %s (%s)\n", creds.User, creds.DeviceName) + return nil +} + +// pollForToken polls the token endpoint until authorization is granted or timeout. +func pollForToken(baseURL, deviceCode, deviceName string) (*auth.Credentials, error) { + deadline := time.Now().Add(loginTimeout) + client := &http.Client{Timeout: 10 * time.Second} + + for { + if time.Now().After(deadline) { + return nil, errors.New("login timed out — you did not authorize the device within 5 minutes") + } + + time.Sleep(pollInterval) + + reqBody, _ := json.Marshal(map[string]string{ + "device_code": deviceCode, + }) + + resp, err := client.Post(baseURL+"/api/oauth/device/token", "application/json", bytes.NewReader(reqBody)) + if err != nil { + fmt.Fprintf(os.Stderr, "Network error while polling (will retry): %v\n", err) + continue + } + + var tokenResp tokenResponse + if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { + resp.Body.Close() + fmt.Fprintf(os.Stderr, "Error parsing token response (will retry): %v\n", err) + continue + } + resp.Body.Close() + + // Check for pending authorization + if tokenResp.Error == "authorization_pending" { + continue + } + + // Check for other errors + if tokenResp.Error != "" { + return nil, fmt.Errorf("authorization failed: %s", tokenResp.Error) + } + + // Check for successful token response + if resp.StatusCode == http.StatusOK && tokenResp.AccessToken != "" { + return &auth.Credentials{ + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + ExpiresAt: time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second), + DeviceName: deviceName, + User: tokenResp.User, + WSURL: tokenResp.WSURL, + }, nil + } + + // Non-200 status without a recognized error + if resp.StatusCode != http.StatusOK { + fmt.Fprintf(os.Stderr, "Unexpected status %s (will retry)\n", resp.Status) + continue + } + } +} + +// openBrowserFunc is the function used to open URLs in the browser. +// It can be replaced in tests to prevent actual browser launches. +var openBrowserFunc = openBrowserDefault + +// openBrowserDefault attempts to open the given URL in the default browser. +func openBrowserDefault(url string) { + var cmd *exec.Cmd + switch runtime.GOOS { + case "darwin": + cmd = exec.Command("open", url) + case "linux": + cmd = exec.Command("xdg-open", url) + case "windows": + cmd = exec.Command("rundll32", "url.dll,FileProtocolHandler", url) + default: + return + } + // Ignore errors — browser open is best-effort + cmd.Start() +} diff --git a/internal/cmd/login_test.go b/internal/cmd/login_test.go new file mode 100644 index 0000000..92f1dd7 --- /dev/null +++ b/internal/cmd/login_test.go @@ -0,0 +1,476 @@ +package cmd + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "strings" + "sync/atomic" + "testing" + + "github.com/minicodemonkey/chief/internal/auth" +) + +func init() { + // Prevent tests from opening actual browser windows + openBrowserFunc = func(url string) {} +} + +func setTestHome(t *testing.T, dir string) { + t.Helper() + t.Setenv("HOME", dir) +} + +func TestRunLogin_Success(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + + var pollCount atomic.Int32 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/oauth/device/code": + json.NewEncoder(w).Encode(deviceCodeResponse{ + DeviceCode: "test-device-code", + UserCode: "ABCD-1234", + }) + case "/api/oauth/device/token": + count := pollCount.Add(1) + if count < 2 { + // First poll: authorization pending + json.NewEncoder(w).Encode(tokenResponse{ + Error: "authorization_pending", + }) + return + } + // Second poll: success + json.NewEncoder(w).Encode(tokenResponse{ + AccessToken: "test-access-token", + RefreshToken: "test-refresh-token", + ExpiresIn: 3600, + User: "testuser@example.com", + WSURL: "wss://ws-test-reverb.laravel.cloud/ws/server", + }) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + // Override stdin to avoid blocking on "already logged in" prompt + oldStdin := os.Stdin + defer func() { os.Stdin = oldStdin }() + r, w, _ := os.Pipe() + os.Stdin = r + w.Close() + + err := RunLogin(LoginOptions{ + DeviceName: "test-device", + BaseURL: server.URL, + }) + if err != nil { + t.Fatalf("RunLogin failed: %v", err) + } + + // Verify credentials were saved + creds, err := auth.LoadCredentials() + if err != nil { + t.Fatalf("LoadCredentials after login failed: %v", err) + } + if creds.AccessToken != "test-access-token" { + t.Errorf("expected access_token %q, got %q", "test-access-token", creds.AccessToken) + } + if creds.RefreshToken != "test-refresh-token" { + t.Errorf("expected refresh_token %q, got %q", "test-refresh-token", creds.RefreshToken) + } + if creds.User != "testuser@example.com" { + t.Errorf("expected user %q, got %q", "testuser@example.com", creds.User) + } + if creds.DeviceName != "test-device" { + t.Errorf("expected device_name %q, got %q", "test-device", creds.DeviceName) + } + if creds.WSURL != "wss://ws-test-reverb.laravel.cloud/ws/server" { + t.Errorf("expected ws_url %q, got %q", "wss://ws-test-reverb.laravel.cloud/ws/server", creds.WSURL) + } +} + +func TestRunLogin_DeviceCodeError(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + err := RunLogin(LoginOptions{ + DeviceName: "test-device", + BaseURL: server.URL, + }) + if err == nil { + t.Fatal("expected error for server error response") + } +} + +func TestRunLogin_AuthorizationDenied(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/oauth/device/code": + json.NewEncoder(w).Encode(deviceCodeResponse{ + DeviceCode: "test-device-code", + UserCode: "ABCD-1234", + }) + case "/api/oauth/device/token": + json.NewEncoder(w).Encode(tokenResponse{ + Error: "access_denied", + }) + } + })) + defer server.Close() + + err := RunLogin(LoginOptions{ + DeviceName: "test-device", + BaseURL: server.URL, + }) + if err == nil { + t.Fatal("expected error for denied authorization") + } +} + +func TestRunLogin_DefaultDeviceName(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + + var receivedDeviceName string + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/oauth/device/code": + var body map[string]string + json.NewDecoder(r.Body).Decode(&body) + receivedDeviceName = body["device_name"] + json.NewEncoder(w).Encode(deviceCodeResponse{ + DeviceCode: "test-device-code", + UserCode: "TEST-CODE", + }) + case "/api/oauth/device/token": + json.NewEncoder(w).Encode(tokenResponse{ + AccessToken: "token", + RefreshToken: "refresh", + ExpiresIn: 3600, + User: "user", + }) + } + })) + defer server.Close() + + err := RunLogin(LoginOptions{ + BaseURL: server.URL, + // DeviceName left empty — should default to hostname + }) + if err != nil { + t.Fatalf("RunLogin failed: %v", err) + } + + hostname, _ := os.Hostname() + if receivedDeviceName != hostname { + t.Errorf("expected device name %q (hostname), got %q", hostname, receivedDeviceName) + } +} + +func TestRunLogin_AlreadyLoggedIn_Decline(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + + // Save existing credentials + existing := &auth.Credentials{ + AccessToken: "existing-token", + User: "existing-user", + DeviceName: "existing-device", + } + if err := auth.SaveCredentials(existing); err != nil { + t.Fatalf("SaveCredentials failed: %v", err) + } + + // Pipe "n\n" to stdin to decline + oldStdin := os.Stdin + defer func() { os.Stdin = oldStdin }() + r, w, _ := os.Pipe() + w.Write([]byte("n\n")) + w.Close() + os.Stdin = r + + // Server should not be called at all when declining + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Error("server should not be called when login is declined") + })) + defer server.Close() + + err := RunLogin(LoginOptions{ + DeviceName: "new-device", + BaseURL: server.URL, + }) + if err != nil { + t.Fatalf("RunLogin should not error when declining: %v", err) + } + + // Credentials should remain unchanged + creds, err := auth.LoadCredentials() + if err != nil { + t.Fatalf("LoadCredentials failed: %v", err) + } + if creds.AccessToken != "existing-token" { + t.Errorf("credentials should not have changed, got access_token %q", creds.AccessToken) + } +} + +func TestRunLogin_SetupToken_Success(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + + var receivedToken string + var receivedDeviceName string + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/oauth/device/exchange" { + http.NotFound(w, r) + return + } + var body map[string]string + json.NewDecoder(r.Body).Decode(&body) + receivedToken = body["setup_token"] + receivedDeviceName = body["device_name"] + json.NewEncoder(w).Encode(tokenResponse{ + AccessToken: "setup-access-token", + RefreshToken: "setup-refresh-token", + ExpiresIn: 3600, + User: "setupuser@example.com", + WSURL: "wss://ws-setup-reverb.laravel.cloud/ws/server", + }) + })) + defer server.Close() + + err := RunLogin(LoginOptions{ + DeviceName: "setup-device", + BaseURL: server.URL, + SetupToken: "test-setup-token-abc123", + }) + if err != nil { + t.Fatalf("RunLogin with setup token failed: %v", err) + } + + if receivedToken != "test-setup-token-abc123" { + t.Errorf("expected setup_token %q, got %q", "test-setup-token-abc123", receivedToken) + } + if receivedDeviceName != "setup-device" { + t.Errorf("expected device_name %q, got %q", "setup-device", receivedDeviceName) + } + + // Verify credentials were saved + creds, err := auth.LoadCredentials() + if err != nil { + t.Fatalf("LoadCredentials after setup token login failed: %v", err) + } + if creds.AccessToken != "setup-access-token" { + t.Errorf("expected access_token %q, got %q", "setup-access-token", creds.AccessToken) + } + if creds.RefreshToken != "setup-refresh-token" { + t.Errorf("expected refresh_token %q, got %q", "setup-refresh-token", creds.RefreshToken) + } + if creds.User != "setupuser@example.com" { + t.Errorf("expected user %q, got %q", "setupuser@example.com", creds.User) + } + if creds.DeviceName != "setup-device" { + t.Errorf("expected device_name %q, got %q", "setup-device", creds.DeviceName) + } + if creds.WSURL != "wss://ws-setup-reverb.laravel.cloud/ws/server" { + t.Errorf("expected ws_url %q, got %q", "wss://ws-setup-reverb.laravel.cloud/ws/server", creds.WSURL) + } +} + +func TestRunLogin_WSURLNotReturned(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + + var pollCount atomic.Int32 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/oauth/device/code": + json.NewEncoder(w).Encode(deviceCodeResponse{ + DeviceCode: "test-device-code", + UserCode: "ABCD-1234", + }) + case "/api/oauth/device/token": + count := pollCount.Add(1) + if count < 2 { + json.NewEncoder(w).Encode(tokenResponse{ + Error: "authorization_pending", + }) + return + } + // No ws_url in response + json.NewEncoder(w).Encode(tokenResponse{ + AccessToken: "test-access-token", + RefreshToken: "test-refresh-token", + ExpiresIn: 3600, + User: "testuser@example.com", + }) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + oldStdin := os.Stdin + defer func() { os.Stdin = oldStdin }() + r, w, _ := os.Pipe() + os.Stdin = r + w.Close() + + err := RunLogin(LoginOptions{ + DeviceName: "test-device", + BaseURL: server.URL, + }) + if err != nil { + t.Fatalf("RunLogin failed: %v", err) + } + + creds, err := auth.LoadCredentials() + if err != nil { + t.Fatalf("LoadCredentials after login failed: %v", err) + } + if creds.WSURL != "" { + t.Errorf("expected empty ws_url, got %q", creds.WSURL) + } +} + +func TestRunLogin_SetupToken_InvalidToken(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/oauth/device/exchange" { + http.NotFound(w, r) + return + } + w.WriteHeader(http.StatusUnauthorized) + json.NewEncoder(w).Encode(tokenResponse{ + Error: "invalid_token", + }) + })) + defer server.Close() + + err := RunLogin(LoginOptions{ + DeviceName: "setup-device", + BaseURL: server.URL, + SetupToken: "expired-token", + }) + if err == nil { + t.Fatal("expected error for invalid setup token") + } + if !strings.Contains(err.Error(), "invalid_token") { + t.Errorf("expected error to mention invalid_token, got: %v", err) + } + + // Verify no credentials were saved + _, err = auth.LoadCredentials() + if err == nil { + t.Error("credentials should not have been saved for invalid token") + } +} + +func TestRunLogin_SetupToken_ExpiredToken(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/oauth/device/exchange" { + http.NotFound(w, r) + return + } + w.WriteHeader(http.StatusGone) + json.NewEncoder(w).Encode(tokenResponse{ + Error: "token_expired", + }) + })) + defer server.Close() + + err := RunLogin(LoginOptions{ + DeviceName: "setup-device", + BaseURL: server.URL, + SetupToken: "expired-token", + }) + if err == nil { + t.Fatal("expected error for expired setup token") + } + if !strings.Contains(err.Error(), "token_expired") { + t.Errorf("expected error to mention token_expired, got: %v", err) + } +} + +func TestRunLogin_SetupToken_ServerError(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + err := RunLogin(LoginOptions{ + DeviceName: "setup-device", + BaseURL: server.URL, + SetupToken: "some-token", + }) + if err == nil { + t.Fatal("expected error for server error") + } +} + +func TestRunLogin_SetupToken_DefaultDeviceName(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + + var receivedDeviceName string + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/oauth/device/exchange" { + http.NotFound(w, r) + return + } + var body map[string]string + json.NewDecoder(r.Body).Decode(&body) + receivedDeviceName = body["device_name"] + json.NewEncoder(w).Encode(tokenResponse{ + AccessToken: "token", + RefreshToken: "refresh", + ExpiresIn: 3600, + User: "user", + }) + })) + defer server.Close() + + err := RunLogin(LoginOptions{ + BaseURL: server.URL, + SetupToken: "test-token", + // DeviceName left empty — should default to hostname + }) + if err != nil { + t.Fatalf("RunLogin failed: %v", err) + } + + hostname, _ := os.Hostname() + if receivedDeviceName != hostname { + t.Errorf("expected device name %q (hostname), got %q", hostname, receivedDeviceName) + } +} + +func TestOpenBrowser(t *testing.T) { + // Just verifying it doesn't panic — browser open is best-effort + openBrowserFunc("https://example.com") +} diff --git a/internal/cmd/logout.go b/internal/cmd/logout.go new file mode 100644 index 0000000..4174d45 --- /dev/null +++ b/internal/cmd/logout.go @@ -0,0 +1,42 @@ +package cmd + +import ( + "errors" + "fmt" + + "github.com/minicodemonkey/chief/internal/auth" +) + +// LogoutOptions contains configuration for the logout command. +type LogoutOptions struct { + BaseURL string // Override base URL (for testing) +} + +// RunLogout performs the device logout flow. +func RunLogout(opts LogoutOptions) error { + // Load credentials to get device name and access token + creds, err := auth.LoadCredentials() + if err != nil { + if errors.Is(err, auth.ErrNotLoggedIn) { + fmt.Println("Not logged in.") + return nil + } + return err + } + + deviceName := creds.DeviceName + + // Call revocation endpoint + if err := auth.RevokeDevice(creds.AccessToken, opts.BaseURL); err != nil { + fmt.Printf("Warning: could not deauthorize device server-side: %v\n", err) + fmt.Println("Local credentials will still be removed.") + } + + // Delete local credentials + if err := auth.DeleteCredentials(); err != nil { + return fmt.Errorf("removing local credentials: %w", err) + } + + fmt.Printf("Logged out. Device %q has been deauthorized.\n", deviceName) + return nil +} diff --git a/internal/cmd/logout_test.go b/internal/cmd/logout_test.go new file mode 100644 index 0000000..fa8c2bd --- /dev/null +++ b/internal/cmd/logout_test.go @@ -0,0 +1,100 @@ +package cmd + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/minicodemonkey/chief/internal/auth" +) + +func TestRunLogout_Success(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + + // Save credentials + creds := &auth.Credentials{ + AccessToken: "test-token", + RefreshToken: "test-refresh", + ExpiresAt: time.Now().Add(time.Hour), + DeviceName: "test-device", + User: "user@example.com", + } + if err := auth.SaveCredentials(creds); err != nil { + t.Fatalf("SaveCredentials failed: %v", err) + } + + var revokedToken string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/oauth/revoke" { + var body map[string]string + json.NewDecoder(r.Body).Decode(&body) + revokedToken = body["access_token"] + w.WriteHeader(http.StatusOK) + return + } + http.NotFound(w, r) + })) + defer server.Close() + + err := RunLogout(LogoutOptions{BaseURL: server.URL}) + if err != nil { + t.Fatalf("RunLogout failed: %v", err) + } + + // Verify revocation was called with correct token + if revokedToken != "test-token" { + t.Errorf("expected revoked token %q, got %q", "test-token", revokedToken) + } + + // Verify credentials are deleted + _, err = auth.LoadCredentials() + if err != auth.ErrNotLoggedIn { + t.Fatalf("expected ErrNotLoggedIn after logout, got %v", err) + } +} + +func TestRunLogout_NotLoggedIn(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + + err := RunLogout(LogoutOptions{}) + if err != nil { + t.Fatalf("RunLogout should not error when not logged in: %v", err) + } +} + +func TestRunLogout_RevocationFails_StillDeletesCredentials(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + + creds := &auth.Credentials{ + AccessToken: "test-token", + RefreshToken: "test-refresh", + ExpiresAt: time.Now().Add(time.Hour), + DeviceName: "test-device", + User: "user@example.com", + } + if err := auth.SaveCredentials(creds); err != nil { + t.Fatalf("SaveCredentials failed: %v", err) + } + + // Server returns error on revocation + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + err := RunLogout(LogoutOptions{BaseURL: server.URL}) + if err != nil { + t.Fatalf("RunLogout should not error when revocation fails: %v", err) + } + + // Credentials should still be deleted + _, err = auth.LoadCredentials() + if err != auth.ErrNotLoggedIn { + t.Fatalf("expected ErrNotLoggedIn after logout, got %v", err) + } +} diff --git a/internal/cmd/logs.go b/internal/cmd/logs.go new file mode 100644 index 0000000..761de5d --- /dev/null +++ b/internal/cmd/logs.go @@ -0,0 +1,210 @@ +package cmd + +import ( + "bufio" + "encoding/json" + "fmt" + "log" + "os" + "path/filepath" + "strings" + "sync" + + "github.com/minicodemonkey/chief/internal/ws" +) + +// storyLogger manages per-story log files during Ralph loop runs. +type storyLogger struct { + mu sync.Mutex + logDir string // .chief/prds/<id>/logs/ + files map[string]*os.File // story_id -> open file +} + +// newStoryLogger creates a story logger for a given PRD. +// It creates the logs directory and removes any previous log files (V1 simplicity). +func newStoryLogger(prdPath string) (*storyLogger, error) { + prdDir := filepath.Dir(prdPath) + logDir := filepath.Join(prdDir, "logs") + + // Remove previous logs (overwrite on new run) + os.RemoveAll(logDir) + + if err := os.MkdirAll(logDir, 0o755); err != nil { + return nil, fmt.Errorf("creating log directory: %w", err) + } + + return &storyLogger{ + logDir: logDir, + files: make(map[string]*os.File), + }, nil +} + +// WriteLog writes a line to the log file for the given story. +func (sl *storyLogger) WriteLog(storyID, line string) { + if storyID == "" { + return + } + + sl.mu.Lock() + defer sl.mu.Unlock() + + f, ok := sl.files[storyID] + if !ok { + var err error + logPath := filepath.Join(sl.logDir, storyID+".log") + f, err = os.OpenFile(logPath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644) + if err != nil { + log.Printf("Error opening story log file %s: %v", logPath, err) + return + } + sl.files[storyID] = f + } + + f.WriteString(line + "\n") +} + +// Close closes all open log files. +func (sl *storyLogger) Close() { + sl.mu.Lock() + defer sl.mu.Unlock() + + for _, f := range sl.files { + f.Close() + } + sl.files = make(map[string]*os.File) +} + +// handleGetLogs handles a get_logs request. +func handleGetLogs(sender messageSender, finder projectFinder, msg ws.Message) { + var req ws.GetLogsMessage + if err := json.Unmarshal(msg.Raw, &req); err != nil { + log.Printf("Error parsing get_logs message: %v", err) + return + } + + project, found := finder.FindProject(req.Project) + if !found { + sendError(sender, ws.ErrCodeProjectNotFound, + fmt.Sprintf("Project %q not found", req.Project), msg.ID) + return + } + + prdDir := filepath.Join(project.Path, ".chief", "prds", req.PRDID) + if _, err := os.Stat(prdDir); os.IsNotExist(err) { + sendError(sender, ws.ErrCodePRDNotFound, + fmt.Sprintf("PRD %q not found in project %q", req.PRDID, req.Project), msg.ID) + return + } + + logDir := filepath.Join(prdDir, "logs") + + // If story_id is provided, return that specific story's logs + if req.StoryID != "" { + lines, err := readLogFile(filepath.Join(logDir, req.StoryID+".log"), req.Lines) + if err != nil { + sendError(sender, ws.ErrCodeFilesystemError, + fmt.Sprintf("Failed to read logs for story %q: %v", req.StoryID, err), msg.ID) + return + } + + sendLogLines(sender, req.Project, req.PRDID, req.StoryID, lines) + return + } + + // If story_id is omitted, return the most recent log activity for the PRD + storyID, lines, err := readMostRecentLog(logDir, req.Lines) + if err != nil { + sendError(sender, ws.ErrCodeFilesystemError, + fmt.Sprintf("Failed to read logs: %v", err), msg.ID) + return + } + + sendLogLines(sender, req.Project, req.PRDID, storyID, lines) +} + +// readLogFile reads lines from a log file. If maxLines is 0, reads all lines. +func readLogFile(path string, maxLines int) ([]string, error) { + f, err := os.Open(path) + if err != nil { + if os.IsNotExist(err) { + return []string{}, nil + } + return nil, err + } + defer f.Close() + + var lines []string + scanner := bufio.NewScanner(f) + // Increase buffer size for long lines + buf := make([]byte, 0, 64*1024) + scanner.Buffer(buf, 1024*1024) + + for scanner.Scan() { + lines = append(lines, scanner.Text()) + } + + if maxLines > 0 && len(lines) > maxLines { + lines = lines[len(lines)-maxLines:] + } + + return lines, scanner.Err() +} + +// readMostRecentLog finds the most recently modified log file and reads it. +func readMostRecentLog(logDir string, maxLines int) (string, []string, error) { + entries, err := os.ReadDir(logDir) + if err != nil { + if os.IsNotExist(err) { + return "", []string{}, nil + } + return "", nil, err + } + + // Find the most recently modified .log file + var mostRecent string + var mostRecentTime int64 + + for _, entry := range entries { + if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".log") { + continue + } + info, err := entry.Info() + if err != nil { + continue + } + if info.ModTime().UnixNano() > mostRecentTime { + mostRecentTime = info.ModTime().UnixNano() + mostRecent = entry.Name() + } + } + + if mostRecent == "" { + return "", []string{}, nil + } + + storyID := strings.TrimSuffix(mostRecent, ".log") + lines, err := readLogFile(filepath.Join(logDir, mostRecent), maxLines) + return storyID, lines, err +} + +// sendLogLines sends a log_lines message over WebSocket. +func sendLogLines(sender messageSender, project, prdID, storyID string, lines []string) { + if lines == nil { + lines = []string{} + } + + envelope := ws.NewMessage(ws.TypeLogLines) + msg := ws.LogLinesMessage{ + Type: envelope.Type, + ID: envelope.ID, + Timestamp: envelope.Timestamp, + Project: project, + PRDID: prdID, + StoryID: storyID, + Lines: lines, + Level: "info", + } + if err := sender.Send(msg); err != nil { + log.Printf("Error sending log_lines: %v", err) + } +} diff --git a/internal/cmd/logs_test.go b/internal/cmd/logs_test.go new file mode 100644 index 0000000..bb1bdf8 --- /dev/null +++ b/internal/cmd/logs_test.go @@ -0,0 +1,757 @@ +package cmd + +import ( + "encoding/json" + "os" + "path/filepath" + "sync" + "testing" + "time" + + "github.com/minicodemonkey/chief/internal/engine" +) + +func TestStoryLogger_WriteAndRead(t *testing.T) { + tmpDir := t.TempDir() + prdDir := filepath.Join(tmpDir, ".chief", "prds", "feature") + if err := os.MkdirAll(prdDir, 0o755); err != nil { + t.Fatal(err) + } + + prdPath := filepath.Join(prdDir, "prd.json") + if err := os.WriteFile(prdPath, []byte(`{}`), 0o644); err != nil { + t.Fatal(err) + } + + sl, err := newStoryLogger(prdPath) + if err != nil { + t.Fatalf("newStoryLogger failed: %v", err) + } + defer sl.Close() + + // Write some log lines + sl.WriteLog("US-001", "Starting story US-001") + sl.WriteLog("US-001", "Working on implementation") + sl.WriteLog("US-001", "Story complete") + + sl.WriteLog("US-002", "Starting story US-002") + sl.WriteLog("US-002", "Done") + + // Close to flush + sl.Close() + + // Read the log files + logDir := filepath.Join(prdDir, "logs") + + lines, err := readLogFile(filepath.Join(logDir, "US-001.log"), 0) + if err != nil { + t.Fatalf("readLogFile failed: %v", err) + } + if len(lines) != 3 { + t.Errorf("expected 3 lines for US-001, got %d", len(lines)) + } + if lines[0] != "Starting story US-001" { + t.Errorf("expected first line 'Starting story US-001', got %q", lines[0]) + } + + lines, err = readLogFile(filepath.Join(logDir, "US-002.log"), 0) + if err != nil { + t.Fatalf("readLogFile failed: %v", err) + } + if len(lines) != 2 { + t.Errorf("expected 2 lines for US-002, got %d", len(lines)) + } +} + +func TestStoryLogger_WriteEmptyStoryID(t *testing.T) { + tmpDir := t.TempDir() + prdDir := filepath.Join(tmpDir, ".chief", "prds", "feature") + if err := os.MkdirAll(prdDir, 0o755); err != nil { + t.Fatal(err) + } + + prdPath := filepath.Join(prdDir, "prd.json") + if err := os.WriteFile(prdPath, []byte(`{}`), 0o644); err != nil { + t.Fatal(err) + } + + sl, err := newStoryLogger(prdPath) + if err != nil { + t.Fatalf("newStoryLogger failed: %v", err) + } + defer sl.Close() + + // Writing with empty story ID should be a no-op + sl.WriteLog("", "This should not be written") + sl.Close() + + // Verify no files were created + logDir := filepath.Join(prdDir, "logs") + entries, err := os.ReadDir(logDir) + if err != nil { + t.Fatalf("ReadDir failed: %v", err) + } + if len(entries) != 0 { + t.Errorf("expected no log files, got %d", len(entries)) + } +} + +func TestStoryLogger_OverwriteOnNewRun(t *testing.T) { + tmpDir := t.TempDir() + prdDir := filepath.Join(tmpDir, ".chief", "prds", "feature") + if err := os.MkdirAll(prdDir, 0o755); err != nil { + t.Fatal(err) + } + + prdPath := filepath.Join(prdDir, "prd.json") + if err := os.WriteFile(prdPath, []byte(`{}`), 0o644); err != nil { + t.Fatal(err) + } + + // Create first logger and write some logs + sl1, err := newStoryLogger(prdPath) + if err != nil { + t.Fatal(err) + } + sl1.WriteLog("US-001", "First run output") + sl1.Close() + + // Create second logger — should overwrite previous logs + sl2, err := newStoryLogger(prdPath) + if err != nil { + t.Fatal(err) + } + sl2.WriteLog("US-001", "Second run output") + sl2.Close() + + // Read the log — should only have second run's content + logDir := filepath.Join(prdDir, "logs") + lines, err := readLogFile(filepath.Join(logDir, "US-001.log"), 0) + if err != nil { + t.Fatalf("readLogFile failed: %v", err) + } + if len(lines) != 1 { + t.Errorf("expected 1 line (overwritten), got %d: %v", len(lines), lines) + } + if len(lines) > 0 && lines[0] != "Second run output" { + t.Errorf("expected 'Second run output', got %q", lines[0]) + } +} + +func TestReadLogFile_WithLineLimit(t *testing.T) { + tmpDir := t.TempDir() + logPath := filepath.Join(tmpDir, "test.log") + + // Write 10 lines + var content string + for i := 1; i <= 10; i++ { + content += "Line " + string(rune('0'+i)) + "\n" + } + if err := os.WriteFile(logPath, []byte(content), 0o644); err != nil { + t.Fatal(err) + } + + // Read with limit of 3 — should get last 3 lines + lines, err := readLogFile(logPath, 3) + if err != nil { + t.Fatalf("readLogFile failed: %v", err) + } + if len(lines) != 3 { + t.Errorf("expected 3 lines, got %d", len(lines)) + } +} + +func TestReadLogFile_Nonexistent(t *testing.T) { + lines, err := readLogFile("/nonexistent/path/test.log", 0) + if err != nil { + t.Fatalf("expected no error for nonexistent file, got: %v", err) + } + if len(lines) != 0 { + t.Errorf("expected empty lines for nonexistent file, got %d", len(lines)) + } +} + +func TestReadMostRecentLog(t *testing.T) { + tmpDir := t.TempDir() + logDir := filepath.Join(tmpDir, "logs") + if err := os.MkdirAll(logDir, 0o755); err != nil { + t.Fatal(err) + } + + // Write two log files with different mod times + if err := os.WriteFile(filepath.Join(logDir, "US-001.log"), []byte("old log\n"), 0o644); err != nil { + t.Fatal(err) + } + + // Ensure US-002 is newer + time.Sleep(10 * time.Millisecond) + if err := os.WriteFile(filepath.Join(logDir, "US-002.log"), []byte("new log line 1\nnew log line 2\n"), 0o644); err != nil { + t.Fatal(err) + } + + storyID, lines, err := readMostRecentLog(logDir, 0) + if err != nil { + t.Fatalf("readMostRecentLog failed: %v", err) + } + if storyID != "US-002" { + t.Errorf("expected most recent story 'US-002', got %q", storyID) + } + if len(lines) != 2 { + t.Errorf("expected 2 lines, got %d", len(lines)) + } +} + +func TestReadMostRecentLog_EmptyDir(t *testing.T) { + tmpDir := t.TempDir() + logDir := filepath.Join(tmpDir, "logs") + if err := os.MkdirAll(logDir, 0o755); err != nil { + t.Fatal(err) + } + + storyID, lines, err := readMostRecentLog(logDir, 0) + if err != nil { + t.Fatalf("readMostRecentLog failed: %v", err) + } + if storyID != "" { + t.Errorf("expected empty story ID, got %q", storyID) + } + if len(lines) != 0 { + t.Errorf("expected no lines, got %d", len(lines)) + } +} + +func TestReadMostRecentLog_NonexistentDir(t *testing.T) { + storyID, lines, err := readMostRecentLog("/nonexistent/logs", 0) + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + if storyID != "" || len(lines) != 0 { + t.Errorf("expected empty results, got storyID=%q lines=%v", storyID, lines) + } +} + +func TestRunManager_StoryLogWriting(t *testing.T) { + eng := engine.New(5) + defer eng.Shutdown() + + rm := newRunManager(eng, nil) + + // Create a temp project with a PRD + projectDir := t.TempDir() + prdDir := filepath.Join(projectDir, ".chief", "prds", "feature") + if err := os.MkdirAll(prdDir, 0o755); err != nil { + t.Fatal(err) + } + + prdState := `{"project": "Test", "userStories": [{"id": "US-001", "title": "Story", "passes": false}]}` + prdPath := filepath.Join(prdDir, "prd.json") + if err := os.WriteFile(prdPath, []byte(prdState), 0o644); err != nil { + t.Fatal(err) + } + + // Start a run (creates logger) + err := rm.startRun("myproject", "feature", projectDir) + if err != nil { + t.Fatalf("startRun failed: %v", err) + } + + // Verify logger was created + rm.mu.RLock() + _, hasLogger := rm.loggers["myproject/feature"] + rm.mu.RUnlock() + if !hasLogger { + t.Fatal("expected logger to be created for the run") + } + + // Write some story logs + rm.writeStoryLog("myproject/feature", "US-001", "Hello from story log") + rm.writeStoryLog("myproject/feature", "US-001", "Another line") + + // Stop and cleanup + rm.stopAll() + + // Read the log file + logPath := filepath.Join(prdDir, "logs", "US-001.log") + lines, err := readLogFile(logPath, 0) + if err != nil { + t.Fatalf("readLogFile failed: %v", err) + } + if len(lines) != 2 { + t.Errorf("expected 2 lines, got %d", len(lines)) + } + if len(lines) > 0 && lines[0] != "Hello from story log" { + t.Errorf("expected 'Hello from story log', got %q", lines[0]) + } +} + +func TestRunServe_GetLogs(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + // Create a git repo with a PRD that has log files + projectDir := filepath.Join(workspaceDir, "myproject") + createGitRepo(t, projectDir) + + prdDir := filepath.Join(projectDir, ".chief", "prds", "feature") + logDir := filepath.Join(prdDir, "logs") + if err := os.MkdirAll(logDir, 0o755); err != nil { + t.Fatal(err) + } + + prdState := `{"project": "My Feature", "userStories": [{"id": "US-001", "title": "Test Story", "passes": true}]}` + if err := os.WriteFile(filepath.Join(prdDir, "prd.json"), []byte(prdState), 0o644); err != nil { + t.Fatal(err) + } + + // Write a log file for US-001 + logContent := "Starting story US-001\nWorking on implementation\nStory complete\n" + if err := os.WriteFile(filepath.Join(logDir, "US-001.log"), []byte(logContent), 0o644); err != nil { + t.Fatal(err) + } + + var response map[string]interface{} + var mu sync.Mutex + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + // Send get_logs request with story_id + getLogsReq := map[string]interface{}{ + "type": "get_logs", + "id": "req-1", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "project": "myproject", + "prd_id": "feature", + "story_id": "US-001", + } + ms.sendCommand(getLogsReq) + + raw, err := ms.waitForMessageType("log_lines", 5*time.Second) + if err == nil { + mu.Lock() + json.Unmarshal(raw, &response) + mu.Unlock() + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + mu.Lock() + defer mu.Unlock() + + if response == nil { + t.Fatal("response was not received") + } + if response["type"] != "log_lines" { + t.Errorf("expected type 'log_lines', got %v", response["type"]) + } + if response["project"] != "myproject" { + t.Errorf("expected project 'myproject', got %v", response["project"]) + } + if response["prd_id"] != "feature" { + t.Errorf("expected prd_id 'feature', got %v", response["prd_id"]) + } + if response["story_id"] != "US-001" { + t.Errorf("expected story_id 'US-001', got %v", response["story_id"]) + } + + lines, ok := response["lines"].([]interface{}) + if !ok { + t.Fatal("expected lines to be an array") + } + if len(lines) != 3 { + t.Errorf("expected 3 lines, got %d", len(lines)) + } +} + +func TestRunServe_GetLogsNoStoryID(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + projectDir := filepath.Join(workspaceDir, "myproject") + createGitRepo(t, projectDir) + + prdDir := filepath.Join(projectDir, ".chief", "prds", "feature") + logDir := filepath.Join(prdDir, "logs") + if err := os.MkdirAll(logDir, 0o755); err != nil { + t.Fatal(err) + } + + prdState := `{"project": "My Feature", "userStories": [{"id": "US-001", "title": "Test Story", "passes": true}]}` + if err := os.WriteFile(filepath.Join(prdDir, "prd.json"), []byte(prdState), 0o644); err != nil { + t.Fatal(err) + } + + // Write a log file — when no story_id provided, should return most recent + if err := os.WriteFile(filepath.Join(logDir, "US-001.log"), []byte("recent log line\n"), 0o644); err != nil { + t.Fatal(err) + } + + var response map[string]interface{} + var mu sync.Mutex + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + getLogsReq := map[string]interface{}{ + "type": "get_logs", + "id": "req-2", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "project": "myproject", + "prd_id": "feature", + } + ms.sendCommand(getLogsReq) + + raw, err := ms.waitForMessageType("log_lines", 5*time.Second) + if err == nil { + mu.Lock() + json.Unmarshal(raw, &response) + mu.Unlock() + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + mu.Lock() + defer mu.Unlock() + + if response == nil { + t.Fatal("response was not received") + } + if response["type"] != "log_lines" { + t.Errorf("expected type 'log_lines', got %v", response["type"]) + } + if response["story_id"] != "US-001" { + t.Errorf("expected story_id 'US-001', got %v", response["story_id"]) + } +} + +func TestRunServe_GetLogsProjectNotFound(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + var response map[string]interface{} + var mu sync.Mutex + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + getLogsReq := map[string]interface{}{ + "type": "get_logs", + "id": "req-3", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "project": "nonexistent", + "prd_id": "feature", + "story_id": "US-001", + } + ms.sendCommand(getLogsReq) + + raw, err := ms.waitForMessageType("error", 5*time.Second) + if err == nil { + mu.Lock() + json.Unmarshal(raw, &response) + mu.Unlock() + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + mu.Lock() + defer mu.Unlock() + + if response == nil { + t.Fatal("response was not received") + } + if response["type"] != "error" { + t.Errorf("expected type 'error', got %v", response["type"]) + } + if response["code"] != "PROJECT_NOT_FOUND" { + t.Errorf("expected code 'PROJECT_NOT_FOUND', got %v", response["code"]) + } +} + +func TestRunServe_GetLogsPRDNotFound(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + projectDir := filepath.Join(workspaceDir, "myproject") + createGitRepo(t, projectDir) + + var response map[string]interface{} + var mu sync.Mutex + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + getLogsReq := map[string]interface{}{ + "type": "get_logs", + "id": "req-4", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "project": "myproject", + "prd_id": "nonexistent", + "story_id": "US-001", + } + ms.sendCommand(getLogsReq) + + raw, err := ms.waitForMessageType("error", 5*time.Second) + if err == nil { + mu.Lock() + json.Unmarshal(raw, &response) + mu.Unlock() + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + mu.Lock() + defer mu.Unlock() + + if response == nil { + t.Fatal("response was not received") + } + if response["type"] != "error" { + t.Errorf("expected type 'error', got %v", response["type"]) + } + if response["code"] != "PRD_NOT_FOUND" { + t.Errorf("expected code 'PRD_NOT_FOUND', got %v", response["code"]) + } +} + +func TestRunServe_GetLogsWithLineLimit(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + projectDir := filepath.Join(workspaceDir, "myproject") + createGitRepo(t, projectDir) + + prdDir := filepath.Join(projectDir, ".chief", "prds", "feature") + logDir := filepath.Join(prdDir, "logs") + if err := os.MkdirAll(logDir, 0o755); err != nil { + t.Fatal(err) + } + + prdState := `{"project": "My Feature", "userStories": [{"id": "US-001", "title": "Test Story", "passes": true}]}` + if err := os.WriteFile(filepath.Join(prdDir, "prd.json"), []byte(prdState), 0o644); err != nil { + t.Fatal(err) + } + + // Write 5 log lines + logContent := "Line 1\nLine 2\nLine 3\nLine 4\nLine 5\n" + if err := os.WriteFile(filepath.Join(logDir, "US-001.log"), []byte(logContent), 0o644); err != nil { + t.Fatal(err) + } + + var response map[string]interface{} + var mu sync.Mutex + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + // Request only 2 lines + getLogsReq := map[string]interface{}{ + "type": "get_logs", + "id": "req-5", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "project": "myproject", + "prd_id": "feature", + "story_id": "US-001", + "lines": 2, + } + ms.sendCommand(getLogsReq) + + raw, err := ms.waitForMessageType("log_lines", 5*time.Second) + if err == nil { + mu.Lock() + json.Unmarshal(raw, &response) + mu.Unlock() + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + mu.Lock() + defer mu.Unlock() + + if response == nil { + t.Fatal("response was not received") + } + if response["type"] != "log_lines" { + t.Errorf("expected type 'log_lines', got %v", response["type"]) + } + + lines, ok := response["lines"].([]interface{}) + if !ok { + t.Fatal("expected lines to be an array") + } + if len(lines) != 2 { + t.Errorf("expected 2 lines (limited), got %d", len(lines)) + } + // Should return the last 2 lines + if len(lines) >= 2 { + if lines[0] != "Line 4" { + t.Errorf("expected 'Line 4', got %v", lines[0]) + } + if lines[1] != "Line 5" { + t.Errorf("expected 'Line 5', got %v", lines[1]) + } + } +} + +func TestRunServe_LoggingIntegration(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + projectDir := filepath.Join(workspaceDir, "myproject") + createGitRepo(t, projectDir) + + prdDir := filepath.Join(projectDir, ".chief", "prds", "feature") + if err := os.MkdirAll(prdDir, 0o755); err != nil { + t.Fatal(err) + } + + prdState := `{"project": "My Feature", "userStories": [{"id": "US-001", "title": "Test Story", "passes": false}]}` + if err := os.WriteFile(filepath.Join(prdDir, "prd.json"), []byte(prdState), 0o644); err != nil { + t.Fatal(err) + } + + // Create a mock claude that outputs stream-json + mockDir := t.TempDir() + mockScript := `#!/bin/sh +echo '{"type":"system","subtype":"init"}' +echo '{"type":"assistant","message":{"content":[{"type":"text","text":"Working on <ralph-status>US-001</ralph-status>"}]}}' +echo '{"type":"assistant","message":{"content":[{"type":"text","text":"Implementing feature"}]}}' +echo '{"type":"result"}' +exit 0 +` + if err := os.WriteFile(filepath.Join(mockDir, "claude"), []byte(mockScript), 0o755); err != nil { + t.Fatal(err) + } + origPath := os.Getenv("PATH") + t.Setenv("PATH", mockDir+":"+origPath) + + var messages []map[string]interface{} + var mu sync.Mutex + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + // Send start_run request + startReq := map[string]string{ + "type": "start_run", + "id": "req-1", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "project": "myproject", + "prd_id": "feature", + } + ms.sendCommand(startReq) + + // Wait for multiple messages (expecting at least some stream messages) + rawMessages, err := ms.waitForMessages(15, 5*time.Second) + if err == nil { + mu.Lock() + for _, raw := range rawMessages { + var msg map[string]interface{} + if json.Unmarshal(raw, &msg) == nil { + messages = append(messages, msg) + } + } + mu.Unlock() + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + // Verify that log files were created in .chief/prds/feature/logs/ + logDir := filepath.Join(prdDir, "logs") + if _, err := os.Stat(logDir); os.IsNotExist(err) { + t.Error("expected logs directory to be created") + } + + // The log file should exist for US-001 (the story that was started) + logFile := filepath.Join(logDir, "US-001.log") + if _, err := os.Stat(logFile); os.IsNotExist(err) { + t.Error("expected US-001.log to be created") + } else { + lines, err := readLogFile(logFile, 0) + if err != nil { + t.Fatalf("readLogFile failed: %v", err) + } + // Should have at least some log content + if len(lines) == 0 { + t.Error("expected log file to have content") + } + } +} + +func TestRunManager_CleanupClosesLogger(t *testing.T) { + eng := engine.New(5) + defer eng.Shutdown() + + rm := newRunManager(eng, nil) + + tmpDir := t.TempDir() + prdDir := filepath.Join(tmpDir, ".chief", "prds", "feature") + if err := os.MkdirAll(prdDir, 0o755); err != nil { + t.Fatal(err) + } + + prdPath := filepath.Join(prdDir, "prd.json") + prdState := `{"project": "Test", "userStories": [{"id": "US-001", "title": "Story", "passes": false}]}` + if err := os.WriteFile(prdPath, []byte(prdState), 0o644); err != nil { + t.Fatal(err) + } + + if err := rm.startRun("myproject", "feature", tmpDir); err != nil { + t.Fatalf("startRun failed: %v", err) + } + + key := "myproject/feature" + + // Verify logger exists + rm.mu.RLock() + _, hasLogger := rm.loggers[key] + rm.mu.RUnlock() + if !hasLogger { + t.Fatal("expected logger to be created") + } + + // Cleanup should remove the logger + rm.cleanup(key) + + rm.mu.RLock() + _, hasLogger = rm.loggers[key] + rm.mu.RUnlock() + if hasLogger { + t.Error("expected logger to be removed after cleanup") + } +} diff --git a/internal/cmd/prds.go b/internal/cmd/prds.go new file mode 100644 index 0000000..f9345d0 --- /dev/null +++ b/internal/cmd/prds.go @@ -0,0 +1,64 @@ +package cmd + +import ( + "encoding/json" + "fmt" + "log" + "strings" + + "github.com/minicodemonkey/chief/internal/ws" +) + +// handleGetPRDs handles a get_prds request. +func handleGetPRDs(sender messageSender, finder projectFinder, msg ws.Message) { + var req ws.GetPRDsMessage + if err := json.Unmarshal(msg.Raw, &req); err != nil { + log.Printf("Error parsing get_prds message: %v", err) + return + } + + project, found := finder.FindProject(req.Project) + if !found { + sendError(sender, ws.ErrCodeProjectNotFound, + fmt.Sprintf("Project %q not found", req.Project), msg.ID) + return + } + + items := make([]ws.PRDItem, 0, len(project.PRDs)) + for _, prd := range project.PRDs { + items = append(items, ws.PRDItem{ + ID: prd.ID, + Name: prd.Name, + StoryCount: prd.StoryCount, + Status: mapCompletionStatus(prd.CompletionStatus), + }) + } + + resp := ws.PRDsResponseMessage{ + Type: ws.TypePRDsResponse, + Payload: ws.PRDsResponsePayload{ + Project: req.Project, + PRDs: items, + }, + } + if err := sender.Send(resp); err != nil { + log.Printf("Error sending prds_response: %v", err) + } +} + +// mapCompletionStatus converts a "passed/total" completion status to a +// browser-friendly status string: "draft", "active", or "done". +func mapCompletionStatus(status string) string { + parts := strings.SplitN(status, "/", 2) + if len(parts) != 2 { + return "draft" + } + passed, total := parts[0], parts[1] + if total == "0" { + return "draft" + } + if passed == total { + return "done" + } + return "active" +} diff --git a/internal/cmd/prds_test.go b/internal/cmd/prds_test.go new file mode 100644 index 0000000..14d4126 --- /dev/null +++ b/internal/cmd/prds_test.go @@ -0,0 +1,260 @@ +package cmd + +import ( + "encoding/json" + "os" + "path/filepath" + "sync" + "testing" + "time" +) + +func TestRunServe_GetPRDs(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + projectDir := filepath.Join(workspaceDir, "myproject") + createGitRepo(t, projectDir) + + // Create .chief/prds with two PRDs + prd1Dir := filepath.Join(projectDir, ".chief", "prds", "feature-auth") + prd2Dir := filepath.Join(projectDir, ".chief", "prds", "feature-dashboard") + if err := os.MkdirAll(prd1Dir, 0o755); err != nil { + t.Fatal(err) + } + if err := os.MkdirAll(prd2Dir, 0o755); err != nil { + t.Fatal(err) + } + + // PRD with 2/3 stories passing → "active" + prd1JSON := `{"project": "Auth System", "userStories": [ + {"id": "US-001", "passes": true}, + {"id": "US-002", "passes": true}, + {"id": "US-003", "passes": false} + ]}` + if err := os.WriteFile(filepath.Join(prd1Dir, "prd.json"), []byte(prd1JSON), 0o644); err != nil { + t.Fatal(err) + } + + // PRD with 2/2 stories passing → "done" + prd2JSON := `{"project": "Dashboard", "userStories": [ + {"id": "US-010", "passes": true}, + {"id": "US-011", "passes": true} + ]}` + if err := os.WriteFile(filepath.Join(prd2Dir, "prd.json"), []byte(prd2JSON), 0o644); err != nil { + t.Fatal(err) + } + + var response map[string]interface{} + var mu sync.Mutex + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + req := map[string]interface{}{ + "type": "get_prds", + "id": "req-1", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "project": "myproject", + } + if err := ms.sendCommand(req); err != nil { + t.Errorf("sendCommand failed: %v", err) + return + } + + raw, err := ms.waitForMessageType("prds_response", 5*time.Second) + if err == nil { + mu.Lock() + json.Unmarshal(raw, &response) + mu.Unlock() + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + mu.Lock() + defer mu.Unlock() + + if response == nil { + t.Fatal("prds_response was not received") + } + if response["type"] != "prds_response" { + t.Errorf("expected type 'prds_response', got %v", response["type"]) + } + + payload, ok := response["payload"].(map[string]interface{}) + if !ok { + t.Fatal("expected payload to be an object") + } + if payload["project"] != "myproject" { + t.Errorf("expected project 'myproject', got %v", payload["project"]) + } + + prds, ok := payload["prds"].([]interface{}) + if !ok { + t.Fatal("expected prds to be an array") + } + if len(prds) != 2 { + t.Fatalf("expected 2 PRDs, got %d", len(prds)) + } + + // Build a map by ID for easier assertions + prdMap := make(map[string]map[string]interface{}) + for _, p := range prds { + prd := p.(map[string]interface{}) + prdMap[prd["id"].(string)] = prd + } + + // feature-auth: 2/3 passing → "active" + auth := prdMap["feature-auth"] + if auth == nil { + t.Fatal("expected feature-auth PRD") + } + if auth["name"] != "Auth System" { + t.Errorf("expected name 'Auth System', got %v", auth["name"]) + } + if int(auth["story_count"].(float64)) != 3 { + t.Errorf("expected story_count 3, got %v", auth["story_count"]) + } + if auth["status"] != "active" { + t.Errorf("expected status 'active', got %v", auth["status"]) + } + + // feature-dashboard: 2/2 passing → "done" + dash := prdMap["feature-dashboard"] + if dash == nil { + t.Fatal("expected feature-dashboard PRD") + } + if dash["status"] != "done" { + t.Errorf("expected status 'done', got %v", dash["status"]) + } +} + +func TestRunServe_GetPRDs_ProjectNotFound(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + var response map[string]interface{} + var mu sync.Mutex + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + req := map[string]interface{}{ + "type": "get_prds", + "id": "req-1", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "project": "nonexistent", + } + if err := ms.sendCommand(req); err != nil { + t.Errorf("sendCommand failed: %v", err) + return + } + + raw, err := ms.waitForMessageType("error", 5*time.Second) + if err == nil { + mu.Lock() + json.Unmarshal(raw, &response) + mu.Unlock() + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + mu.Lock() + defer mu.Unlock() + + if response == nil { + t.Fatal("error message was not received") + } + if response["code"] != "PROJECT_NOT_FOUND" { + t.Errorf("expected code 'PROJECT_NOT_FOUND', got %v", response["code"]) + } +} + +func TestRunServe_GetPRDs_EmptyProject(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + // Project with no .chief directory → empty PRD list + createGitRepo(t, filepath.Join(workspaceDir, "myproject")) + + var response map[string]interface{} + var mu sync.Mutex + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + req := map[string]interface{}{ + "type": "get_prds", + "id": "req-1", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "project": "myproject", + } + if err := ms.sendCommand(req); err != nil { + t.Errorf("sendCommand failed: %v", err) + return + } + + raw, err := ms.waitForMessageType("prds_response", 5*time.Second) + if err == nil { + mu.Lock() + json.Unmarshal(raw, &response) + mu.Unlock() + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + mu.Lock() + defer mu.Unlock() + + if response == nil { + t.Fatal("prds_response was not received") + } + + payload := response["payload"].(map[string]interface{}) + prds, ok := payload["prds"].([]interface{}) + if !ok { + t.Fatal("expected prds to be an array") + } + if len(prds) != 0 { + t.Errorf("expected 0 PRDs for project without .chief, got %d", len(prds)) + } +} + +func TestMapCompletionStatus(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"0/0", "draft"}, + {"0/5", "active"}, + {"3/5", "active"}, + {"5/5", "done"}, + {"", "draft"}, + {"invalid", "draft"}, + } + + for _, tc := range tests { + result := mapCompletionStatus(tc.input) + if result != tc.expected { + t.Errorf("mapCompletionStatus(%q) = %q, want %q", tc.input, result, tc.expected) + } + } +} diff --git a/internal/cmd/remote_update.go b/internal/cmd/remote_update.go new file mode 100644 index 0000000..d577e90 --- /dev/null +++ b/internal/cmd/remote_update.go @@ -0,0 +1,77 @@ +package cmd + +import ( + "fmt" + "log" + "strings" + + "github.com/minicodemonkey/chief/internal/update" + "github.com/minicodemonkey/chief/internal/ws" +) + +// handleTriggerUpdate handles a trigger_update request from the web app. +// It downloads and installs the latest binary, sends confirmation, +// and returns true if the process should exit (so systemd Restart=always picks up the new binary). +func handleTriggerUpdate(sender messageSender, msg ws.Message, version, releasesURL string) bool { + log.Println("Received trigger_update request") + + // Check for update + result, err := update.CheckForUpdate(version, update.Options{ + ReleasesURL: releasesURL, + }) + if err != nil { + sendError(sender, ws.ErrCodeUpdateFailed, + fmt.Sprintf("checking for updates: %v", err), msg.ID) + return false + } + + if !result.UpdateAvailable { + // Already on latest — send informational message (not an error) + envelope := ws.NewMessage(ws.TypeUpdateAvailable) + infoMsg := ws.UpdateAvailableMessage{ + Type: envelope.Type, + ID: envelope.ID, + Timestamp: envelope.Timestamp, + CurrentVersion: result.CurrentVersion, + LatestVersion: result.LatestVersion, + } + if err := sender.Send(infoMsg); err != nil { + log.Printf("Error sending update_available: %v", err) + } + log.Printf("Already on latest version (v%s)", result.CurrentVersion) + return false + } + + // Perform the update + log.Printf("Downloading v%s (current: v%s)...", result.LatestVersion, result.CurrentVersion) + _, err = update.PerformUpdate(version, update.Options{ + ReleasesURL: releasesURL, + }) + if err != nil { + errMsg := err.Error() + if strings.Contains(errMsg, "Permission denied") { + sendError(sender, ws.ErrCodeUpdateFailed, + "Permission denied. The chief binary is not writable. Ensure the service user has write permissions to the binary path.", msg.ID) + } else { + sendError(sender, ws.ErrCodeUpdateFailed, + fmt.Sprintf("update failed: %v", err), msg.ID) + } + return false + } + + // Send confirmation before exiting + log.Printf("Updated to v%s. Exiting for restart.", result.LatestVersion) + envelope := ws.NewMessage(ws.TypeUpdateAvailable) + confirmMsg := ws.UpdateAvailableMessage{ + Type: envelope.Type, + ID: envelope.ID, + Timestamp: envelope.Timestamp, + CurrentVersion: result.CurrentVersion, + LatestVersion: result.LatestVersion, + } + if err := sender.Send(confirmMsg); err != nil { + log.Printf("Error sending update confirmation: %v", err) + } + + return true +} diff --git a/internal/cmd/remote_update_test.go b/internal/cmd/remote_update_test.go new file mode 100644 index 0000000..126aa80 --- /dev/null +++ b/internal/cmd/remote_update_test.go @@ -0,0 +1,266 @@ +package cmd + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "sync" + "testing" + "time" + + "github.com/minicodemonkey/chief/internal/update" + "github.com/minicodemonkey/chief/internal/ws" +) + +func TestHandleTriggerUpdate_AlreadyLatest(t *testing.T) { + // Mock GitHub releases API — same version + release := update.Release{TagName: "v1.0.0"} + releaseSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(release) + })) + defer releaseSrv.Close() + + sender := &captureSender{} + + msg := ws.Message{ + Type: ws.TypeTriggerUpdate, + ID: "req-1", + } + + shouldExit := handleTriggerUpdate(sender, msg, "1.0.0", releaseSrv.URL) + + if shouldExit { + t.Error("should not exit when already on latest version") + } + + msgs := sender.getMessages() + if len(msgs) == 0 { + t.Fatal("expected update_available message to be sent") + } + + var receivedMsg map[string]interface{} + for _, m := range msgs { + if m["type"] == "update_available" { + receivedMsg = m + break + } + } + + if receivedMsg == nil { + t.Fatal("expected update_available message to be sent") + } + if receivedMsg["current_version"] != "1.0.0" { + t.Errorf("expected current_version '1.0.0', got %v", receivedMsg["current_version"]) + } + if receivedMsg["latest_version"] != "1.0.0" { + t.Errorf("expected latest_version '1.0.0', got %v", receivedMsg["latest_version"]) + } +} + +func TestHandleTriggerUpdate_APIError(t *testing.T) { + // Mock GitHub releases API — error + releaseSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer releaseSrv.Close() + + sender := &captureSender{} + + msg := ws.Message{ + Type: ws.TypeTriggerUpdate, + ID: "req-1", + } + + shouldExit := handleTriggerUpdate(sender, msg, "1.0.0", releaseSrv.URL) + + if shouldExit { + t.Error("should not exit on API error") + } + + msgs := sender.getMessages() + if len(msgs) == 0 { + t.Fatal("expected error message to be sent") + } + + var receivedMsg map[string]interface{} + for _, m := range msgs { + if m["type"] == "error" { + receivedMsg = m + break + } + } + + if receivedMsg == nil { + t.Fatal("expected error message to be sent") + } + if receivedMsg["code"] != "UPDATE_FAILED" { + t.Errorf("expected code 'UPDATE_FAILED', got %v", receivedMsg["code"]) + } + if receivedMsg["request_id"] != "req-1" { + t.Errorf("expected request_id 'req-1', got %v", receivedMsg["request_id"]) + } +} + +func TestRunServe_TriggerUpdateAlreadyLatest(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + // Mock releases API — same version + release := update.Release{TagName: "v1.0.0"} + releaseSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(release) + })) + defer releaseSrv.Close() + + var responseReceived map[string]interface{} + var mu sync.Mutex + + ctx, cancel := context.WithCancel(context.Background()) + ms := newMockUplinkServer(t) + + go func() { + if err := ms.waitForPusherSubscribe(10 * time.Second); err != nil { + t.Logf("waitForPusherSubscribe: %v", err) + cancel() + return + } + + // Wait for initial state_snapshot + if _, err := ms.waitForMessageType("state_snapshot", 5*time.Second); err != nil { + t.Logf("waitForMessageType(state_snapshot): %v", err) + cancel() + return + } + + // Send trigger_update command via Pusher + triggerReq := map[string]string{ + "type": "trigger_update", + "id": "req-1", + "timestamp": time.Now().UTC().Format(time.RFC3339), + } + ms.sendCommand(triggerReq) + + // Wait for update_available response + raw, err := ms.waitForMessageType("update_available", 5*time.Second) + if err == nil { + mu.Lock() + json.Unmarshal(raw, &responseReceived) + mu.Unlock() + } + + cancel() + }() + + err := RunServe(ServeOptions{ + Workspace: workspaceDir, + ServerURL: ms.httpSrv.URL, + Version: "1.0.0", + ReleasesURL: releaseSrv.URL, + Ctx: ctx, + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + mu.Lock() + defer mu.Unlock() + + if responseReceived == nil { + t.Fatal("expected response message") + } + if responseReceived["type"] != "update_available" { + t.Errorf("expected type 'update_available', got %v", responseReceived["type"]) + } + if responseReceived["current_version"] != "1.0.0" { + t.Errorf("expected current_version '1.0.0', got %v", responseReceived["current_version"]) + } +} + +func TestRunServe_TriggerUpdateAPIError(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + // Mock releases API — error + releaseSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer releaseSrv.Close() + + var errorReceived map[string]interface{} + var mu sync.Mutex + + ctx, cancel := context.WithCancel(context.Background()) + ms := newMockUplinkServer(t) + + go func() { + if err := ms.waitForPusherSubscribe(10 * time.Second); err != nil { + t.Logf("waitForPusherSubscribe: %v", err) + cancel() + return + } + + // Wait for initial state_snapshot + if _, err := ms.waitForMessageType("state_snapshot", 5*time.Second); err != nil { + t.Logf("waitForMessageType(state_snapshot): %v", err) + cancel() + return + } + + // Send trigger_update command via Pusher + triggerReq := map[string]string{ + "type": "trigger_update", + "id": "req-1", + "timestamp": time.Now().UTC().Format(time.RFC3339), + } + ms.sendCommand(triggerReq) + + // Wait for error response + raw, err := ms.waitForMessageType("error", 5*time.Second) + if err == nil { + mu.Lock() + json.Unmarshal(raw, &errorReceived) + mu.Unlock() + } + + cancel() + }() + + err := RunServe(ServeOptions{ + Workspace: workspaceDir, + ServerURL: ms.httpSrv.URL, + Version: "1.0.0", + ReleasesURL: releaseSrv.URL, + Ctx: ctx, + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + mu.Lock() + defer mu.Unlock() + + if errorReceived == nil { + t.Fatal("expected error message") + } + if errorReceived["type"] != "error" { + t.Errorf("expected type 'error', got %v", errorReceived["type"]) + } + if errorReceived["code"] != "UPDATE_FAILED" { + t.Errorf("expected code 'UPDATE_FAILED', got %v", errorReceived["code"]) + } +} diff --git a/internal/cmd/runs.go b/internal/cmd/runs.go new file mode 100644 index 0000000..08594cc --- /dev/null +++ b/internal/cmd/runs.go @@ -0,0 +1,632 @@ +package cmd + +import ( + "context" + "encoding/json" + "fmt" + "log" + "path/filepath" + "sync" + "time" + + "github.com/minicodemonkey/chief/internal/engine" + "github.com/minicodemonkey/chief/internal/loop" + "github.com/minicodemonkey/chief/internal/prd" + "github.com/minicodemonkey/chief/internal/ws" +) + +// runManager manages Ralph loop runs driven by WebSocket commands. +type runManager struct { + mu sync.RWMutex + eng *engine.Engine + sender messageSender + // tracks which engine registration key maps to which project/prd + runs map[string]*runInfo + loggers map[string]*storyLogger +} + +// runInfo tracks metadata about a registered run. +type runInfo struct { + project string + prdID string + prdPath string // absolute path to prd.json + startTime time.Time + storyID string // currently active story ID +} + +// runKey returns the engine registration key for a project/PRD combination. +func runKey(project, prdID string) string { + return project + "/" + prdID +} + +// newRunManager creates a new run manager. +func newRunManager(eng *engine.Engine, sender messageSender) *runManager { + return &runManager{ + eng: eng, + sender: sender, + runs: make(map[string]*runInfo), + loggers: make(map[string]*storyLogger), + } +} + +// startEventMonitor subscribes to engine events and handles progress streaming, +// run completion, Claude output streaming, and quota exhaustion. +// It runs until the context is cancelled. +func (rm *runManager) startEventMonitor(ctx context.Context) { + eventCh, unsub := rm.eng.Subscribe() + go func() { + defer unsub() + for { + select { + case <-ctx.Done(): + return + case event, ok := <-eventCh: + if !ok { + return + } + rm.handleEvent(event) + } + } + }() +} + +// handleEvent routes an engine event to the appropriate handler. +func (rm *runManager) handleEvent(event engine.ManagerEvent) { + rm.mu.RLock() + info, exists := rm.runs[event.PRDName] + rm.mu.RUnlock() + + if !exists { + // Events from runs we don't track (e.g., TUI-driven runs) + return + } + + switch event.Event.Type { + case loop.EventQuotaExhausted: + rm.handleQuotaExhausted(event.PRDName) + + case loop.EventIterationStart: + rm.sendRunProgress(info, "iteration_started", event.Event) + + case loop.EventStoryStarted: + // Track the current story ID + rm.mu.Lock() + info.storyID = event.Event.StoryID + rm.mu.Unlock() + rm.sendRunProgress(info, "story_started", event.Event) + + case loop.EventStoryCompleted: + rm.sendRunProgress(info, "story_completed", event.Event) + rm.sendStoryDiff(info, event.Event) + + case loop.EventComplete: + rm.sendRunProgress(info, "complete", event.Event) + rm.sendRunComplete(info, event.PRDName) + + case loop.EventMaxIterationsReached: + rm.sendRunProgress(info, "max_iterations_reached", event.Event) + rm.sendRunComplete(info, event.PRDName) + + case loop.EventRetrying: + rm.sendRunProgress(info, "retrying", event.Event) + + case loop.EventAssistantText: + rm.writeStoryLog(event.PRDName, info.storyID, event.Event.Text) + rm.sendClaudeOutput(info, event.Event.Text, false) + + case loop.EventToolStart: + text := fmt.Sprintf("[tool_use] %s", event.Event.Tool) + rm.writeStoryLog(event.PRDName, info.storyID, text) + rm.sendClaudeOutput(info, text, false) + + case loop.EventToolResult: + rm.writeStoryLog(event.PRDName, info.storyID, event.Event.Text) + rm.sendClaudeOutput(info, event.Event.Text, false) + + case loop.EventError: + errText := "" + if event.Event.Err != nil { + errText = event.Event.Err.Error() + } + rm.writeStoryLog(event.PRDName, info.storyID, errText) + rm.sendClaudeOutput(info, errText, true) + } +} + +// sendRunProgress sends a run_progress message over WebSocket. +func (rm *runManager) sendRunProgress(info *runInfo, status string, event loop.Event) { + if rm.sender == nil { + return + } + + rm.mu.RLock() + storyID := info.storyID + rm.mu.RUnlock() + + // Use the event's story ID if available, otherwise use tracked story ID + if event.StoryID != "" { + storyID = event.StoryID + } + + envelope := ws.NewMessage(ws.TypeRunProgress) + msg := ws.RunProgressMessage{ + Type: envelope.Type, + ID: envelope.ID, + Timestamp: envelope.Timestamp, + Project: info.project, + PRDID: info.prdID, + StoryID: storyID, + Status: status, + Iteration: event.Iteration, + Attempt: event.RetryCount, + } + if err := rm.sender.Send(msg); err != nil { + log.Printf("Error sending run_progress: %v", err) + } +} + +// sendRunComplete sends a run_complete message over WebSocket. +func (rm *runManager) sendRunComplete(info *runInfo, prdName string) { + if rm.sender == nil { + return + } + + // Calculate duration + rm.mu.RLock() + duration := time.Since(info.startTime) + rm.mu.RUnlock() + + // Load PRD to get pass/fail counts + var passCount, failCount, storiesCompleted int + p, err := prd.LoadPRD(info.prdPath) + if err == nil { + for _, s := range p.UserStories { + if s.Passes { + passCount++ + storiesCompleted++ + } else { + failCount++ + } + } + } + + envelope := ws.NewMessage(ws.TypeRunComplete) + msg := ws.RunCompleteMessage{ + Type: envelope.Type, + ID: envelope.ID, + Timestamp: envelope.Timestamp, + Project: info.project, + PRDID: info.prdID, + StoriesCompleted: storiesCompleted, + Duration: duration.Round(time.Second).String(), + PassCount: passCount, + FailCount: failCount, + } + if err := rm.sender.Send(msg); err != nil { + log.Printf("Error sending run_complete: %v", err) + } +} + +// sendClaudeOutput sends a claude_output message for an active run over WebSocket. +func (rm *runManager) sendClaudeOutput(info *runInfo, data string, done bool) { + if rm.sender == nil { + return + } + + rm.mu.RLock() + storyID := info.storyID + rm.mu.RUnlock() + + envelope := ws.NewMessage(ws.TypeClaudeOutput) + msg := ws.ClaudeOutputMessage{ + Type: envelope.Type, + ID: envelope.ID, + Timestamp: envelope.Timestamp, + Project: info.project, + PRDID: info.prdID, + StoryID: storyID, + Data: data, + Done: done, + } + if err := rm.sender.Send(msg); err != nil { + log.Printf("Error sending claude_output: %v", err) + } +} + +// sendStoryDiff sends a proactive diff message when a story completes during a run. +func (rm *runManager) sendStoryDiff(info *runInfo, event loop.Event) { + if rm.sender == nil { + return + } + + storyID := event.StoryID + if storyID == "" { + rm.mu.RLock() + storyID = info.storyID + rm.mu.RUnlock() + } + if storyID == "" { + return + } + + // Get the project path from the PRD path + // prdPath is like /path/to/project/.chief/prds/<id>/prd.json + projectPath := filepath.Dir(filepath.Dir(filepath.Dir(filepath.Dir(info.prdPath)))) + + diffText, files, err := getStoryDiff(projectPath, storyID) + if err != nil { + log.Printf("Could not get diff for story %s: %v", storyID, err) + return + } + + sendDiffMessage(rm.sender, info.project, info.prdID, storyID, files, diffText) +} + +// handleQuotaExhausted handles a quota exhaustion event for a specific run. +func (rm *runManager) handleQuotaExhausted(prdName string) { + rm.mu.RLock() + info, exists := rm.runs[prdName] + rm.mu.RUnlock() + + if !exists { + log.Printf("Quota exhausted for unknown run key: %s", prdName) + return + } + + log.Printf("Quota exhausted for %s/%s, auto-pausing", info.project, info.prdID) + + if rm.sender == nil { + return + } + + // Send run_paused with reason quota_exhausted + envelope := ws.NewMessage(ws.TypeRunPaused) + pausedMsg := ws.RunPausedMessage{ + Type: envelope.Type, + ID: envelope.ID, + Timestamp: envelope.Timestamp, + Project: info.project, + PRDID: info.prdID, + Reason: "quota_exhausted", + } + if err := rm.sender.Send(pausedMsg); err != nil { + log.Printf("Error sending run_paused: %v", err) + } + + // Send quota_exhausted message listing affected runs + rm.sendQuotaExhausted(info.project, info.prdID) +} + +// sendQuotaExhausted sends a quota_exhausted message over WebSocket. +func (rm *runManager) sendQuotaExhausted(project, prdID string) { + if rm.sender == nil { + return + } + envelope := ws.NewMessage(ws.TypeQuotaExhausted) + msg := ws.QuotaExhaustedMessage{ + Type: envelope.Type, + ID: envelope.ID, + Timestamp: envelope.Timestamp, + Runs: []string{runKey(project, prdID)}, + Sessions: []string{}, + } + if err := rm.sender.Send(msg); err != nil { + log.Printf("Error sending quota_exhausted: %v", err) + } +} + +// activeRuns returns the list of active runs for state snapshots. +func (rm *runManager) activeRuns() []ws.RunState { + rm.mu.RLock() + defer rm.mu.RUnlock() + + var runs []ws.RunState + for key := range rm.runs { + info := rm.runs[key] + instance := rm.eng.GetInstance(key) + if instance == nil { + continue + } + + status := loopStateToString(instance.State) + runs = append(runs, ws.RunState{ + Project: info.project, + PRDID: info.prdID, + Status: status, + Iteration: instance.Iteration, + }) + } + return runs +} + +// startRun starts a Ralph loop for a project/PRD. +func (rm *runManager) startRun(project, prdID, projectPath string) error { + key := runKey(project, prdID) + + // Check if already running + if instance := rm.eng.GetInstance(key); instance != nil { + if instance.State == loop.LoopStateRunning { + return fmt.Errorf("RUN_ALREADY_ACTIVE") + } + } + + prdPath := filepath.Join(projectPath, ".chief", "prds", prdID, "prd.json") + + // Register if not already registered + if instance := rm.eng.GetInstance(key); instance == nil { + if err := rm.eng.Register(key, prdPath); err != nil { + return fmt.Errorf("failed to register PRD: %w", err) + } + } + + // Create per-story logger (removes previous logs for this PRD) + sl, err := newStoryLogger(prdPath) + if err != nil { + log.Printf("Warning: could not create story logger: %v", err) + } + + rm.mu.Lock() + rm.runs[key] = &runInfo{ + project: project, + prdID: prdID, + prdPath: prdPath, + startTime: time.Now(), + } + if sl != nil { + rm.loggers[key] = sl + } + rm.mu.Unlock() + + if err := rm.eng.Start(key); err != nil { + return fmt.Errorf("failed to start run: %w", err) + } + + return nil +} + +// pauseRun pauses a running loop. +func (rm *runManager) pauseRun(project, prdID string) error { + key := runKey(project, prdID) + + instance := rm.eng.GetInstance(key) + if instance == nil || instance.State != loop.LoopStateRunning { + return fmt.Errorf("RUN_NOT_ACTIVE") + } + + if err := rm.eng.Pause(key); err != nil { + return fmt.Errorf("failed to pause run: %w", err) + } + + return nil +} + +// resumeRun resumes a paused loop by starting it again. +func (rm *runManager) resumeRun(project, prdID string) error { + key := runKey(project, prdID) + + instance := rm.eng.GetInstance(key) + if instance == nil || instance.State != loop.LoopStatePaused { + return fmt.Errorf("RUN_NOT_ACTIVE") + } + + // Start creates a fresh Loop that picks up from the next unfinished story + if err := rm.eng.Start(key); err != nil { + return fmt.Errorf("failed to resume run: %w", err) + } + + return nil +} + +// stopRun stops a running or paused loop immediately. +func (rm *runManager) stopRun(project, prdID string) error { + key := runKey(project, prdID) + + instance := rm.eng.GetInstance(key) + if instance == nil || (instance.State != loop.LoopStateRunning && instance.State != loop.LoopStatePaused) { + return fmt.Errorf("RUN_NOT_ACTIVE") + } + + if err := rm.eng.Stop(key); err != nil { + return fmt.Errorf("failed to stop run: %w", err) + } + + return nil +} + +// writeStoryLog writes a line to the per-story log file. +func (rm *runManager) writeStoryLog(runKey, storyID, text string) { + rm.mu.RLock() + sl := rm.loggers[runKey] + rm.mu.RUnlock() + + if sl != nil { + sl.WriteLog(storyID, text) + } +} + +// cleanup removes tracking for a completed/stopped run. +func (rm *runManager) cleanup(key string) { + rm.mu.Lock() + if sl, ok := rm.loggers[key]; ok { + sl.Close() + delete(rm.loggers, key) + } + delete(rm.runs, key) + rm.mu.Unlock() +} + +// markInterruptedStories marks any in-progress stories as interrupted in prd.json +// so that the next run resumes from where it left off. +func (rm *runManager) markInterruptedStories() { + rm.mu.RLock() + runs := make([]*runInfo, 0, len(rm.runs)) + for _, info := range rm.runs { + runs = append(runs, info) + } + rm.mu.RUnlock() + + for _, info := range runs { + if info.storyID == "" { + continue + } + p, err := prd.LoadPRD(info.prdPath) + if err != nil { + log.Printf("Warning: could not load PRD %s to mark interrupted story: %v", info.prdPath, err) + continue + } + for i := range p.UserStories { + if p.UserStories[i].ID == info.storyID && !p.UserStories[i].Passes { + p.UserStories[i].InProgress = true + if err := p.Save(info.prdPath); err != nil { + log.Printf("Warning: could not save PRD %s: %v", info.prdPath, err) + } + break + } + } + } +} + +// activeRunCount returns the number of currently tracked runs. +func (rm *runManager) activeRunCount() int { + rm.mu.RLock() + defer rm.mu.RUnlock() + return len(rm.runs) +} + +// stopAll stops all active runs (for shutdown). +func (rm *runManager) stopAll() { + rm.eng.StopAll() + + // Close all story loggers + rm.mu.Lock() + for key, sl := range rm.loggers { + sl.Close() + delete(rm.loggers, key) + } + rm.mu.Unlock() +} + +// loopStateToString converts a LoopState to a string for WebSocket messages. +func loopStateToString(state loop.LoopState) string { + switch state { + case loop.LoopStateReady: + return "ready" + case loop.LoopStateRunning: + return "running" + case loop.LoopStatePaused: + return "paused" + case loop.LoopStateStopped: + return "stopped" + case loop.LoopStateComplete: + return "complete" + case loop.LoopStateError: + return "error" + default: + return "unknown" + } +} + +// handleStartRun handles a start_run WebSocket message. +func handleStartRun(sender messageSender, scanner projectFinder, runs *runManager, watcher activator, msg ws.Message) { + var req ws.StartRunMessage + if err := json.Unmarshal(msg.Raw, &req); err != nil { + log.Printf("Error parsing start_run message: %v", err) + return + } + + project, found := scanner.FindProject(req.Project) + if !found { + sendError(sender, ws.ErrCodeProjectNotFound, + fmt.Sprintf("Project %q not found", req.Project), msg.ID) + return + } + + if err := runs.startRun(req.Project, req.PRDID, project.Path); err != nil { + if err.Error() == "RUN_ALREADY_ACTIVE" { + sendError(sender, ws.ErrCodeRunAlreadyActive, + fmt.Sprintf("Run already active for %s/%s", req.Project, req.PRDID), msg.ID) + } else { + sendError(sender, ws.ErrCodeClaudeError, + fmt.Sprintf("Failed to start run: %v", err), msg.ID) + } + return + } + + // Activate file watching for the project + if watcher != nil { + watcher.Activate(req.Project) + } + + log.Printf("Started run for %s/%s", req.Project, req.PRDID) +} + +// handlePauseRun handles a pause_run WebSocket message. +func handlePauseRun(sender messageSender, runs *runManager, msg ws.Message) { + var req ws.PauseRunMessage + if err := json.Unmarshal(msg.Raw, &req); err != nil { + log.Printf("Error parsing pause_run message: %v", err) + return + } + + if err := runs.pauseRun(req.Project, req.PRDID); err != nil { + if err.Error() == "RUN_NOT_ACTIVE" { + sendError(sender, ws.ErrCodeRunNotActive, + fmt.Sprintf("No active run for %s/%s", req.Project, req.PRDID), msg.ID) + } else { + sendError(sender, ws.ErrCodeClaudeError, + fmt.Sprintf("Failed to pause run: %v", err), msg.ID) + } + return + } + + log.Printf("Paused run for %s/%s", req.Project, req.PRDID) +} + +// handleResumeRun handles a resume_run WebSocket message. +func handleResumeRun(sender messageSender, runs *runManager, msg ws.Message) { + var req ws.ResumeRunMessage + if err := json.Unmarshal(msg.Raw, &req); err != nil { + log.Printf("Error parsing resume_run message: %v", err) + return + } + + if err := runs.resumeRun(req.Project, req.PRDID); err != nil { + if err.Error() == "RUN_NOT_ACTIVE" { + sendError(sender, ws.ErrCodeRunNotActive, + fmt.Sprintf("No paused run for %s/%s", req.Project, req.PRDID), msg.ID) + } else { + sendError(sender, ws.ErrCodeClaudeError, + fmt.Sprintf("Failed to resume run: %v", err), msg.ID) + } + return + } + + log.Printf("Resumed run for %s/%s", req.Project, req.PRDID) +} + +// handleStopRun handles a stop_run WebSocket message. +func handleStopRun(sender messageSender, runs *runManager, msg ws.Message) { + var req ws.StopRunMessage + if err := json.Unmarshal(msg.Raw, &req); err != nil { + log.Printf("Error parsing stop_run message: %v", err) + return + } + + if err := runs.stopRun(req.Project, req.PRDID); err != nil { + if err.Error() == "RUN_NOT_ACTIVE" { + sendError(sender, ws.ErrCodeRunNotActive, + fmt.Sprintf("No active run for %s/%s", req.Project, req.PRDID), msg.ID) + } else { + sendError(sender, ws.ErrCodeClaudeError, + fmt.Sprintf("Failed to stop run: %v", err), msg.ID) + } + return + } + + log.Printf("Stopped run for %s/%s", req.Project, req.PRDID) +} + +// activator is an interface for activating file watching (for testability). +type activator interface { + Activate(name string) +} diff --git a/internal/cmd/runs_test.go b/internal/cmd/runs_test.go new file mode 100644 index 0000000..f7dabe0 --- /dev/null +++ b/internal/cmd/runs_test.go @@ -0,0 +1,1001 @@ +package cmd + +import ( + "context" + "encoding/json" + "os" + "path/filepath" + "sync" + "testing" + "time" + + "github.com/minicodemonkey/chief/internal/engine" + "github.com/minicodemonkey/chief/internal/loop" + "github.com/minicodemonkey/chief/internal/ws" +) + +func TestRunServe_StartRun(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + // Create a git repo with a PRD + projectDir := filepath.Join(workspaceDir, "myproject") + createGitRepo(t, projectDir) + + prdDir := filepath.Join(projectDir, ".chief", "prds", "feature") + if err := os.MkdirAll(prdDir, 0o755); err != nil { + t.Fatal(err) + } + + // Write a minimal prd.json with one story + prdState := `{"project": "My Feature", "userStories": [{"id": "US-001", "title": "Test Story", "passes": false}]}` + if err := os.WriteFile(filepath.Join(prdDir, "prd.json"), []byte(prdState), 0o644); err != nil { + t.Fatal(err) + } + + var responseReceived map[string]interface{} + var mu sync.Mutex + gotError := false + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + // Send start_run request + startReq := map[string]string{ + "type": "start_run", + "id": "req-1", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "project": "myproject", + "prd_id": "feature", + } + ms.sendCommand(startReq) + + // Wait a moment for the run to start, then check if error was returned + raw, err := ms.waitForMessages(1, 2*time.Second) + if err == nil && len(raw) > 0 { + mu.Lock() + json.Unmarshal(raw[0], &responseReceived) + // If it's an error, it means the run couldn't start (expected in test env without claude) + if responseReceived["type"] == "error" { + gotError = true + } + mu.Unlock() + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + // In a test environment without a real claude binary, the engine.Start() call + // will succeed (registers + starts the loop) but the loop itself will fail + // quickly since there's no claude. We verify the handler routed correctly + // by checking that we didn't get a PROJECT_NOT_FOUND error. + mu.Lock() + defer mu.Unlock() + + if responseReceived != nil && gotError { + // If we got an error, it should NOT be PROJECT_NOT_FOUND + code, _ := responseReceived["code"].(string) + if code == "PROJECT_NOT_FOUND" { + t.Errorf("should not have gotten PROJECT_NOT_FOUND for existing project") + } + } +} + +func TestRunServe_StartRunProjectNotFound(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + var errorReceived map[string]interface{} + var mu sync.Mutex + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + startReq := map[string]string{ + "type": "start_run", + "id": "req-1", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "project": "nonexistent", + "prd_id": "feature", + } + ms.sendCommand(startReq) + + raw, err := ms.waitForMessageType("error", 2*time.Second) + if err == nil { + mu.Lock() + json.Unmarshal(raw, &errorReceived) + mu.Unlock() + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + mu.Lock() + defer mu.Unlock() + + if errorReceived == nil { + t.Fatal("error message was not received") + } + if errorReceived["type"] != "error" { + t.Errorf("expected type 'error', got %v", errorReceived["type"]) + } + if errorReceived["code"] != "PROJECT_NOT_FOUND" { + t.Errorf("expected code 'PROJECT_NOT_FOUND', got %v", errorReceived["code"]) + } +} + +func TestRunServe_PauseRunNotActive(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + createGitRepo(t, filepath.Join(workspaceDir, "myproject")) + + var errorReceived map[string]interface{} + var mu sync.Mutex + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + pauseReq := map[string]string{ + "type": "pause_run", + "id": "req-1", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "project": "myproject", + "prd_id": "feature", + } + ms.sendCommand(pauseReq) + + raw, err := ms.waitForMessageType("error", 2*time.Second) + if err == nil { + mu.Lock() + json.Unmarshal(raw, &errorReceived) + mu.Unlock() + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + mu.Lock() + defer mu.Unlock() + + if errorReceived == nil { + t.Fatal("error message was not received") + } + if errorReceived["type"] != "error" { + t.Errorf("expected type 'error', got %v", errorReceived["type"]) + } + if errorReceived["code"] != "RUN_NOT_ACTIVE" { + t.Errorf("expected code 'RUN_NOT_ACTIVE', got %v", errorReceived["code"]) + } +} + +func TestRunServe_ResumeRunNotActive(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + createGitRepo(t, filepath.Join(workspaceDir, "myproject")) + + var errorReceived map[string]interface{} + var mu sync.Mutex + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + resumeReq := map[string]string{ + "type": "resume_run", + "id": "req-1", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "project": "myproject", + "prd_id": "feature", + } + ms.sendCommand(resumeReq) + + raw, err := ms.waitForMessageType("error", 2*time.Second) + if err == nil { + mu.Lock() + json.Unmarshal(raw, &errorReceived) + mu.Unlock() + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + mu.Lock() + defer mu.Unlock() + + if errorReceived == nil { + t.Fatal("error message was not received") + } + if errorReceived["type"] != "error" { + t.Errorf("expected type 'error', got %v", errorReceived["type"]) + } + if errorReceived["code"] != "RUN_NOT_ACTIVE" { + t.Errorf("expected code 'RUN_NOT_ACTIVE', got %v", errorReceived["code"]) + } +} + +func TestRunServe_StopRunNotActive(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + createGitRepo(t, filepath.Join(workspaceDir, "myproject")) + + var errorReceived map[string]interface{} + var mu sync.Mutex + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + stopReq := map[string]string{ + "type": "stop_run", + "id": "req-1", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "project": "myproject", + "prd_id": "feature", + } + ms.sendCommand(stopReq) + + raw, err := ms.waitForMessageType("error", 2*time.Second) + if err == nil { + mu.Lock() + json.Unmarshal(raw, &errorReceived) + mu.Unlock() + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + mu.Lock() + defer mu.Unlock() + + if errorReceived == nil { + t.Fatal("error message was not received") + } + if errorReceived["type"] != "error" { + t.Errorf("expected type 'error', got %v", errorReceived["type"]) + } + if errorReceived["code"] != "RUN_NOT_ACTIVE" { + t.Errorf("expected code 'RUN_NOT_ACTIVE', got %v", errorReceived["code"]) + } +} + +func TestRunManager_StartAndAlreadyActive(t *testing.T) { + eng := engine.New(5) + defer eng.Shutdown() + + rm := newRunManager(eng, nil) + + // Create a temp project with a PRD + projectDir := t.TempDir() + prdDir := filepath.Join(projectDir, ".chief", "prds", "feature") + if err := os.MkdirAll(prdDir, 0o755); err != nil { + t.Fatal(err) + } + + // Write a minimal prd.json + prdState := `{"project": "Test", "userStories": [{"id": "US-001", "title": "Story", "passes": false}]}` + if err := os.WriteFile(filepath.Join(prdDir, "prd.json"), []byte(prdState), 0o644); err != nil { + t.Fatal(err) + } + + // Start a run + err := rm.startRun("myproject", "feature", projectDir) + if err != nil { + t.Fatalf("startRun failed: %v", err) + } + + // Wait briefly for the engine to register the run as running + time.Sleep(100 * time.Millisecond) + + // Try to start the same run again — should get RUN_ALREADY_ACTIVE + err = rm.startRun("myproject", "feature", projectDir) + if err == nil { + t.Fatal("expected error for already active run") + } + if err.Error() != "RUN_ALREADY_ACTIVE" { + t.Errorf("expected 'RUN_ALREADY_ACTIVE', got: %v", err) + } + + // Clean up + rm.stopAll() +} + +func TestRunManager_PauseAndResume(t *testing.T) { + eng := engine.New(5) + defer eng.Shutdown() + + rm := newRunManager(eng, nil) + + // Trying to pause when nothing is running + err := rm.pauseRun("myproject", "feature") + if err == nil || err.Error() != "RUN_NOT_ACTIVE" { + t.Errorf("expected RUN_NOT_ACTIVE, got: %v", err) + } + + // Trying to resume when nothing is paused + err = rm.resumeRun("myproject", "feature") + if err == nil || err.Error() != "RUN_NOT_ACTIVE" { + t.Errorf("expected RUN_NOT_ACTIVE, got: %v", err) + } +} + +func TestRunManager_StopNotActive(t *testing.T) { + eng := engine.New(5) + defer eng.Shutdown() + + rm := newRunManager(eng, nil) + + err := rm.stopRun("myproject", "feature") + if err == nil || err.Error() != "RUN_NOT_ACTIVE" { + t.Errorf("expected RUN_NOT_ACTIVE, got: %v", err) + } +} + +func TestRunManager_ActiveRuns(t *testing.T) { + eng := engine.New(5) + defer eng.Shutdown() + + rm := newRunManager(eng, nil) + + // No active runs initially + runs := rm.activeRuns() + if runs != nil && len(runs) != 0 { + t.Errorf("expected no active runs, got %d", len(runs)) + } + + // Create a temp project with a PRD + projectDir := t.TempDir() + prdDir := filepath.Join(projectDir, ".chief", "prds", "feature") + if err := os.MkdirAll(prdDir, 0o755); err != nil { + t.Fatal(err) + } + + prdState := `{"project": "Test", "userStories": [{"id": "US-001", "title": "Story", "passes": false}]}` + if err := os.WriteFile(filepath.Join(prdDir, "prd.json"), []byte(prdState), 0o644); err != nil { + t.Fatal(err) + } + + // Start a run + if err := rm.startRun("myproject", "feature", projectDir); err != nil { + t.Fatalf("startRun failed: %v", err) + } + + // Wait briefly for the engine to start + time.Sleep(100 * time.Millisecond) + + // Should have one active run + runs = rm.activeRuns() + if len(runs) != 1 { + t.Fatalf("expected 1 active run, got %d", len(runs)) + } + + if runs[0].Project != "myproject" { + t.Errorf("expected project 'myproject', got %q", runs[0].Project) + } + if runs[0].PRDID != "feature" { + t.Errorf("expected prd_id 'feature', got %q", runs[0].PRDID) + } + + rm.stopAll() +} + +func TestRunManager_MultipleConcurrentProjects(t *testing.T) { + eng := engine.New(5) + defer eng.Shutdown() + + rm := newRunManager(eng, nil) + + // Create two projects with PRDs + for _, name := range []string{"project-a", "project-b"} { + projectDir := filepath.Join(t.TempDir(), name) + prdDir := filepath.Join(projectDir, ".chief", "prds", "feature") + if err := os.MkdirAll(prdDir, 0o755); err != nil { + t.Fatal(err) + } + + prdState := `{"project": "Test", "userStories": [{"id": "US-001", "title": "Story", "passes": false}]}` + if err := os.WriteFile(filepath.Join(prdDir, "prd.json"), []byte(prdState), 0o644); err != nil { + t.Fatal(err) + } + + if err := rm.startRun(name, "feature", projectDir); err != nil { + t.Fatalf("startRun %s failed: %v", name, err) + } + } + + // Wait briefly + time.Sleep(100 * time.Millisecond) + + // Should have two active runs + runs := rm.activeRuns() + if len(runs) != 2 { + t.Errorf("expected 2 active runs, got %d", len(runs)) + } + + rm.stopAll() +} + +func TestRunManager_LoopStateToString(t *testing.T) { + tests := []struct { + state ws.RunState + expected string + }{ + {ws.RunState{Status: "running"}, "running"}, + {ws.RunState{Status: "paused"}, "paused"}, + {ws.RunState{Status: "stopped"}, "stopped"}, + } + + for _, tt := range tests { + if tt.state.Status != tt.expected { + t.Errorf("expected %q, got %q", tt.expected, tt.state.Status) + } + } +} + +func TestRunManager_HandleQuotaExhausted(t *testing.T) { + eng := engine.New(5) + defer eng.Shutdown() + + // Create a mock WS client to capture sent messages + rm := newRunManager(eng, nil) + + // Add a run to the runs map + rm.mu.Lock() + rm.runs["myproject/feature"] = &runInfo{ + project: "myproject", + prdID: "feature", + prdPath: "/tmp/test/prd.json", + } + rm.mu.Unlock() + + // handleQuotaExhausted should not panic even with nil client + // (it logs errors but continues) + rm.handleQuotaExhausted("myproject/feature") +} + +func TestRunManager_HandleQuotaExhaustedUnknownRun(t *testing.T) { + eng := engine.New(5) + defer eng.Shutdown() + + rm := newRunManager(eng, nil) + + // Should not panic for unknown run + rm.handleQuotaExhausted("unknown/run") +} + +func TestRunManager_EventMonitorQuotaDetection(t *testing.T) { + eng := engine.New(5) + defer eng.Shutdown() + + rm := newRunManager(eng, nil) + + // Set up run tracking + rm.mu.Lock() + rm.runs["test/feature"] = &runInfo{ + project: "test", + prdID: "feature", + prdPath: "/tmp/test/prd.json", + } + rm.mu.Unlock() + + // Start the event monitor + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + rm.startEventMonitor(ctx) + + // Give the goroutine time to start + time.Sleep(50 * time.Millisecond) + + // Cancel context to stop the monitor + cancel() + time.Sleep(50 * time.Millisecond) +} + +func TestRunManager_QuotaExhaustedWebSocket(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + // Create a git repo with a PRD that uses a mock claude that simulates quota error + projectDir := filepath.Join(workspaceDir, "myproject") + createGitRepo(t, projectDir) + + prdDir := filepath.Join(projectDir, ".chief", "prds", "feature") + if err := os.MkdirAll(prdDir, 0o755); err != nil { + t.Fatal(err) + } + + prdState := `{"project": "My Feature", "userStories": [{"id": "US-001", "title": "Test Story", "passes": false}]}` + if err := os.WriteFile(filepath.Join(prdDir, "prd.json"), []byte(prdState), 0o644); err != nil { + t.Fatal(err) + } + + // Create a mock claude that outputs a quota error on stderr and exits with non-zero + mockDir := t.TempDir() + mockScript := `#!/bin/sh +echo "rate limit exceeded" >&2 +exit 1 +` + if err := os.WriteFile(filepath.Join(mockDir, "claude"), []byte(mockScript), 0o755); err != nil { + t.Fatal(err) + } + origPath := os.Getenv("PATH") + t.Setenv("PATH", mockDir+":"+origPath) + + var messages []map[string]interface{} + var mu sync.Mutex + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + // Send start_run request + startReq := map[string]string{ + "type": "start_run", + "id": "req-1", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "project": "myproject", + "prd_id": "feature", + } + ms.sendCommand(startReq) + + // Read messages — we expect run_paused with reason "quota_exhausted" + // and a quota_exhausted message + raws, err := ms.waitForMessages(5, 5*time.Second) + if err == nil { + for _, data := range raws { + var msg map[string]interface{} + if json.Unmarshal(data, &msg) == nil { + mu.Lock() + messages = append(messages, msg) + mu.Unlock() + } + } + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + mu.Lock() + defer mu.Unlock() + + // Check that we received a run_paused with reason quota_exhausted + foundRunPaused := false + foundQuotaExhausted := false + for _, msg := range messages { + if msg["type"] == "run_paused" { + if msg["reason"] == "quota_exhausted" { + foundRunPaused = true + } + } + if msg["type"] == "quota_exhausted" { + foundQuotaExhausted = true + } + } + + if !foundRunPaused { + t.Errorf("expected run_paused with reason quota_exhausted, got messages: %v", messages) + } + if !foundQuotaExhausted { + t.Errorf("expected quota_exhausted message, got messages: %v", messages) + } +} + +func TestIsQuotaErrorIntegration(t *testing.T) { + // Test that the loop package's IsQuotaError function correctly identifies quota errors + tests := []struct { + stderr string + expected bool + }{ + {"Error: rate limit exceeded for model", true}, + {"HTTP 429 Too Many Requests", true}, + {"quota has been exceeded", true}, + {"normal crash: segfault", false}, + } + + for _, tt := range tests { + got := loop.IsQuotaError(tt.stderr) + if got != tt.expected { + t.Errorf("IsQuotaError(%q) = %v, want %v", tt.stderr, got, tt.expected) + } + } +} + +func TestRunManager_HandleEventRunProgress(t *testing.T) { + eng := engine.New(5) + defer eng.Shutdown() + + rm := newRunManager(eng, nil) // nil client — sendRunProgress guards against nil + + // Add a run to the runs map + rm.mu.Lock() + rm.runs["myproject/feature"] = &runInfo{ + project: "myproject", + prdID: "feature", + prdPath: "/tmp/test/prd.json", + startTime: time.Now(), + } + rm.mu.Unlock() + + // Test that handleEvent does not panic for each event type with nil client + eventTypes := []loop.EventType{ + loop.EventIterationStart, + loop.EventStoryStarted, + loop.EventStoryCompleted, + loop.EventComplete, + loop.EventMaxIterationsReached, + loop.EventRetrying, + loop.EventAssistantText, + loop.EventToolStart, + loop.EventToolResult, + loop.EventError, + } + + for _, et := range eventTypes { + event := engine.ManagerEvent{ + PRDName: "myproject/feature", + Event: loop.Event{ + Type: et, + Iteration: 1, + StoryID: "US-001", + Text: "test text", + Tool: "TestTool", + }, + } + rm.handleEvent(event) // should not panic with nil client + } +} + +func TestRunManager_HandleEventUnknownRun(t *testing.T) { + eng := engine.New(5) + defer eng.Shutdown() + + rm := newRunManager(eng, nil) + + // Events for unknown runs should be silently ignored + event := engine.ManagerEvent{ + PRDName: "unknown/run", + Event: loop.Event{ + Type: loop.EventIterationStart, + Iteration: 1, + }, + } + rm.handleEvent(event) // should not panic +} + +func TestRunManager_HandleEventStoryTracking(t *testing.T) { + eng := engine.New(5) + defer eng.Shutdown() + + rm := newRunManager(eng, nil) + + rm.mu.Lock() + rm.runs["myproject/feature"] = &runInfo{ + project: "myproject", + prdID: "feature", + prdPath: "/tmp/test/prd.json", + startTime: time.Now(), + } + rm.mu.Unlock() + + // Send a StoryStarted event — should update the tracked storyID + event := engine.ManagerEvent{ + PRDName: "myproject/feature", + Event: loop.Event{ + Type: loop.EventStoryStarted, + StoryID: "US-042", + }, + } + rm.handleEvent(event) + + rm.mu.RLock() + storyID := rm.runs["myproject/feature"].storyID + rm.mu.RUnlock() + + if storyID != "US-042" { + t.Errorf("expected storyID 'US-042', got %q", storyID) + } +} + +func TestRunManager_SendRunComplete(t *testing.T) { + eng := engine.New(5) + defer eng.Shutdown() + + rm := newRunManager(eng, nil) // nil client — guards against nil + + // Create a temp PRD with known pass/fail counts + tmpDir := t.TempDir() + prdDir := filepath.Join(tmpDir, ".chief", "prds", "feature") + if err := os.MkdirAll(prdDir, 0o755); err != nil { + t.Fatal(err) + } + + prdJSON := `{"project": "Test", "userStories": [{"id": "US-001", "passes": true}, {"id": "US-002", "passes": true}, {"id": "US-003", "passes": false}]}` + prdPath := filepath.Join(prdDir, "prd.json") + if err := os.WriteFile(prdPath, []byte(prdJSON), 0o644); err != nil { + t.Fatal(err) + } + + info := &runInfo{ + project: "myproject", + prdID: "feature", + prdPath: prdPath, + startTime: time.Now().Add(-5 * time.Minute), + } + + // Should not panic with nil client + rm.sendRunComplete(info, "myproject/feature") +} + +func TestRunManager_MarkInterruptedStories(t *testing.T) { + eng := engine.New(5) + defer eng.Shutdown() + + rm := newRunManager(eng, nil) + + // Create a temp project with a PRD + projectDir := t.TempDir() + prdDir := filepath.Join(projectDir, ".chief", "prds", "feature") + if err := os.MkdirAll(prdDir, 0o755); err != nil { + t.Fatal(err) + } + + prdPath := filepath.Join(prdDir, "prd.json") + prdState := `{"project": "Test", "userStories": [{"id": "US-001", "title": "Story 1", "passes": false}, {"id": "US-002", "title": "Story 2", "passes": true}]}` + if err := os.WriteFile(prdPath, []byte(prdState), 0o644); err != nil { + t.Fatal(err) + } + + // Add a run with an active story + rm.mu.Lock() + rm.runs["test/feature"] = &runInfo{ + project: "test", + prdID: "feature", + prdPath: prdPath, + startTime: time.Now(), + storyID: "US-001", + } + rm.mu.Unlock() + + // Mark interrupted stories + rm.markInterruptedStories() + + // Verify the PRD was updated + data, err := os.ReadFile(prdPath) + if err != nil { + t.Fatalf("failed to read PRD: %v", err) + } + + var result map[string]interface{} + if err := json.Unmarshal(data, &result); err != nil { + t.Fatalf("failed to parse PRD: %v", err) + } + + stories := result["userStories"].([]interface{}) + story1 := stories[0].(map[string]interface{}) + if story1["inProgress"] != true { + t.Errorf("expected US-001 to have inProgress=true, got %v", story1["inProgress"]) + } + + // US-002 is already passing, should NOT be marked as inProgress + story2 := stories[1].(map[string]interface{}) + if _, hasInProgress := story2["inProgress"]; hasInProgress && story2["inProgress"] == true { + t.Error("expected US-002 to NOT have inProgress=true (already passes)") + } +} + +func TestRunManager_MarkInterruptedStoriesNoStoryID(t *testing.T) { + eng := engine.New(5) + defer eng.Shutdown() + + rm := newRunManager(eng, nil) + + // Create a temp project with a PRD + projectDir := t.TempDir() + prdDir := filepath.Join(projectDir, ".chief", "prds", "feature") + if err := os.MkdirAll(prdDir, 0o755); err != nil { + t.Fatal(err) + } + + prdPath := filepath.Join(prdDir, "prd.json") + prdState := `{"project": "Test", "userStories": [{"id": "US-001", "title": "Story 1", "passes": false}]}` + if err := os.WriteFile(prdPath, []byte(prdState), 0o644); err != nil { + t.Fatal(err) + } + + // Add a run WITHOUT a story ID (no story started yet) + rm.mu.Lock() + rm.runs["test/feature"] = &runInfo{ + project: "test", + prdID: "feature", + prdPath: prdPath, + startTime: time.Now(), + storyID: "", // no story started + } + rm.mu.Unlock() + + // Mark interrupted stories — should be a no-op + rm.markInterruptedStories() + + // Verify the PRD was NOT modified + data, err := os.ReadFile(prdPath) + if err != nil { + t.Fatalf("failed to read PRD: %v", err) + } + + var result map[string]interface{} + if err := json.Unmarshal(data, &result); err != nil { + t.Fatalf("failed to parse PRD: %v", err) + } + + stories := result["userStories"].([]interface{}) + story1 := stories[0].(map[string]interface{}) + if _, hasInProgress := story1["inProgress"]; hasInProgress && story1["inProgress"] == true { + t.Error("expected US-001 to NOT have inProgress=true when no story was started") + } +} + +func TestRunManager_ActiveRunCount(t *testing.T) { + eng := engine.New(5) + defer eng.Shutdown() + + rm := newRunManager(eng, nil) + + if rm.activeRunCount() != 0 { + t.Errorf("expected 0 active runs, got %d", rm.activeRunCount()) + } + + rm.mu.Lock() + rm.runs["test/feature"] = &runInfo{ + project: "test", + prdID: "feature", + } + rm.mu.Unlock() + + if rm.activeRunCount() != 1 { + t.Errorf("expected 1 active run, got %d", rm.activeRunCount()) + } +} + +func TestRunServe_RunProgressStreaming(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + // Create a git repo with a PRD + projectDir := filepath.Join(workspaceDir, "myproject") + createGitRepo(t, projectDir) + + prdDir := filepath.Join(projectDir, ".chief", "prds", "feature") + if err := os.MkdirAll(prdDir, 0o755); err != nil { + t.Fatal(err) + } + + prdState := `{"project": "My Feature", "userStories": [{"id": "US-001", "title": "Test Story", "passes": false}]}` + if err := os.WriteFile(filepath.Join(prdDir, "prd.json"), []byte(prdState), 0o644); err != nil { + t.Fatal(err) + } + + // Create a mock claude that outputs stream-json with a story start marker then exits successfully + mockDir := t.TempDir() + mockScript := `#!/bin/sh +echo '{"type":"system","subtype":"init"}' +echo '{"type":"assistant","message":{"content":[{"type":"text","text":"Working on <ralph-status>US-001</ralph-status>"}]}}' +echo '{"type":"assistant","message":{"content":[{"type":"text","text":"Hello from Claude"}]}}' +echo '{"type":"result"}' +exit 0 +` + if err := os.WriteFile(filepath.Join(mockDir, "claude"), []byte(mockScript), 0o755); err != nil { + t.Fatal(err) + } + origPath := os.Getenv("PATH") + t.Setenv("PATH", mockDir+":"+origPath) + + var messages []map[string]interface{} + var mu sync.Mutex + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + // Send start_run request + startReq := map[string]string{ + "type": "start_run", + "id": "req-1", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "project": "myproject", + "prd_id": "feature", + } + ms.sendCommand(startReq) + + // Read messages — expect run_progress and claude_output messages + raws, err := ms.waitForMessages(15, 5*time.Second) + if err == nil { + for _, data := range raws { + var msg map[string]interface{} + if json.Unmarshal(data, &msg) == nil { + mu.Lock() + messages = append(messages, msg) + mu.Unlock() + } + } + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + mu.Lock() + defer mu.Unlock() + + // Check for run_progress messages + foundIterationStart := false + foundStoryStarted := false + foundClaudeOutput := false + for _, msg := range messages { + if msg["type"] == "run_progress" { + status, _ := msg["status"].(string) + if status == "iteration_started" { + foundIterationStart = true + } + if status == "story_started" { + foundStoryStarted = true + if msg["story_id"] != "US-001" { + t.Errorf("expected story_id 'US-001', got %v", msg["story_id"]) + } + if msg["project"] != "myproject" { + t.Errorf("expected project 'myproject', got %v", msg["project"]) + } + if msg["prd_id"] != "feature" { + t.Errorf("expected prd_id 'feature', got %v", msg["prd_id"]) + } + } + } + if msg["type"] == "claude_output" { + foundClaudeOutput = true + if msg["project"] != "myproject" { + t.Errorf("expected project 'myproject', got %v", msg["project"]) + } + } + } + + if !foundIterationStart { + t.Errorf("expected run_progress with status 'iteration_started', messages: %v", messages) + } + if !foundStoryStarted { + t.Errorf("expected run_progress with status 'story_started', messages: %v", messages) + } + if !foundClaudeOutput { + t.Errorf("expected claude_output messages, messages: %v", messages) + } +} diff --git a/internal/cmd/sender.go b/internal/cmd/sender.go new file mode 100644 index 0000000..5b7f08e --- /dev/null +++ b/internal/cmd/sender.go @@ -0,0 +1,45 @@ +package cmd + +import ( + "encoding/json" + "fmt" + "log" + + "github.com/minicodemonkey/chief/internal/uplink" +) + +// messageSender is an interface for sending messages to the server. +// The uplink adapter satisfies this interface. +type messageSender interface { + Send(msg interface{}) error +} + +// uplinkSender adapts *uplink.Uplink to the messageSender interface. +// It JSON-marshals the message, extracts the "type" field for the batcher's +// priority tier classification, and enqueues via Uplink.Send(). +type uplinkSender struct { + uplink *uplink.Uplink +} + +func newUplinkSender(u *uplink.Uplink) *uplinkSender { + return &uplinkSender{uplink: u} +} + +func (s *uplinkSender) Send(msg interface{}) error { + data, err := json.Marshal(msg) + if err != nil { + return fmt.Errorf("marshaling message: %w", err) + } + + // Extract the "type" field for batcher tier classification. + var envelope struct { + Type string `json:"type"` + } + if err := json.Unmarshal(data, &envelope); err != nil { + log.Printf("uplinkSender: could not extract message type: %v", err) + envelope.Type = "unknown" + } + + s.uplink.Send(data, envelope.Type) + return nil +} diff --git a/internal/cmd/serve.go b/internal/cmd/serve.go new file mode 100644 index 0000000..f9b2a8c --- /dev/null +++ b/internal/cmd/serve.go @@ -0,0 +1,655 @@ +package cmd + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log" + "os" + "os/signal" + "path/filepath" + "syscall" + "time" + + "github.com/minicodemonkey/chief/internal/auth" + "github.com/minicodemonkey/chief/internal/engine" + "github.com/minicodemonkey/chief/internal/uplink" + "github.com/minicodemonkey/chief/internal/update" + "github.com/minicodemonkey/chief/internal/workspace" + "github.com/minicodemonkey/chief/internal/ws" +) + +const ( + // defaultServerURL is the default HTTP base URL for the chief server. + defaultServerURL = "https://chiefloop.com" +) + +// ServeOptions contains configuration for the serve command. +type ServeOptions struct { + Workspace string // Path to workspace directory + DeviceName string // Override device name (default: from credentials) + LogFile string // Path to log file (default: stdout) + BaseURL string // Override base URL (for testing) + ServerURL string // Override server URL for uplink (for testing/dev) + Version string // Chief version string + ReleasesURL string // Override GitHub releases URL (for testing) + Ctx context.Context // Optional context for cancellation (for testing) + +} + +// RunServe starts the headless serve daemon. +func RunServe(opts ServeOptions) error { + // Validate workspace directory exists + if opts.Workspace == "" { + opts.Workspace = "." + } + absWorkspace, err := filepath.Abs(opts.Workspace) + if err != nil { + return fmt.Errorf("resolving workspace path: %w", err) + } + opts.Workspace = absWorkspace + info, err := os.Stat(opts.Workspace) + if err != nil { + if os.IsNotExist(err) { + return fmt.Errorf("workspace directory does not exist: %s", opts.Workspace) + } + return fmt.Errorf("checking workspace directory: %w", err) + } + if !info.IsDir() { + return fmt.Errorf("workspace path is not a directory: %s", opts.Workspace) + } + + // Set up logging + var logFile *os.File + if opts.LogFile != "" { + f, err := os.OpenFile(opts.LogFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644) + if err != nil { + return fmt.Errorf("opening log file: %w", err) + } + logFile = f + defer func() { + logFile.Sync() + logFile.Close() + }() + log.SetOutput(f) + } + + // Check for credentials + creds, err := auth.LoadCredentials() + if err != nil { + if errors.Is(err, auth.ErrNotLoggedIn) { + return fmt.Errorf("Not logged in. Run 'chief login' first.") + } + return fmt.Errorf("loading credentials: %w", err) + } + + // Refresh token if near-expiry + if creds.IsNearExpiry(5 * time.Minute) { + log.Println("Access token near expiry, refreshing...") + creds, err = auth.RefreshToken(opts.BaseURL) + if err != nil { + return fmt.Errorf("refreshing token: %w", err) + } + log.Println("Token refreshed successfully") + } + + // Determine device name + deviceName := opts.DeviceName + if deviceName == "" { + deviceName = creds.DeviceName + } + + // Determine server URL (precedence: ServerURL flag > env > default) + serverURL := opts.ServerURL + if serverURL == "" { + serverURL = os.Getenv("CHIEF_SERVER_URL") + } + if serverURL == "" { + serverURL = defaultServerURL + } + + log.Printf("Starting chief serve (workspace: %s, device: %s)", opts.Workspace, deviceName) + log.Printf("Connecting to %s", serverURL) + + // Set up context with cancellation for clean shutdown + ctx := opts.Ctx + if ctx == nil { + ctx = context.Background() + } + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + // Determine version string + version := opts.Version + if version == "" { + version = "dev" + } + + // Start workspace scanner (before connect so initial scan is ready) + scanner := workspace.New(opts.Workspace, nil) // sender set after connect + scanner.ScanAndUpdate() + + // Create engine for Ralph loop runs (default 5 max iterations) + eng := engine.New(5) + + // Create rate limiter for incoming messages + rateLimiter := ws.NewRateLimiter() + + // Create the uplink HTTP client + httpClient, err := uplink.New(serverURL, creds.AccessToken, + uplink.WithDeviceName(deviceName), + uplink.WithChiefVersion(version), + ) + if err != nil { + return fmt.Errorf("creating uplink client: %w", err) + } + + // Create the uplink with reconnect handler that re-sends state + var sender messageSender + var sessions *sessionManager + var runs *runManager + var ul *uplink.Uplink + + ul = uplink.NewUplink(httpClient, + uplink.WithOnReconnect(func() { + log.Println("Uplink reconnected, re-sending state snapshot") + rateLimiter.Reset() + sendStateSnapshot(sender, scanner, sessions, runs) + }), + uplink.WithOnAuthFailure(func() error { + log.Println("Auth failed during reconnection, refreshing token...") + newCreds, err := auth.RefreshToken(opts.BaseURL) + if err != nil { + return fmt.Errorf("token refresh failed: %w", err) + } + ul.SetAccessToken(newCreds.AccessToken) + log.Println("Token refreshed successfully during reconnection") + return nil + }), + ) + + // Create the sender adapter that wraps the uplink + sender = newUplinkSender(ul) + + // Set scanner's sender now that it exists + scanner.SetSender(sender) + + // Create session manager for Claude PRD sessions + sessions = newSessionManager(sender) + + // Create run manager for Ralph loop runs + runs = newRunManager(eng, sender) + + // Start engine event monitor for quota detection + runs.startEventMonitor(ctx) + + // Connect to server (HTTP connect + Pusher subscribe + batcher start) + if err := ul.Connect(ctx); err != nil { + if errors.Is(err, uplink.ErrAuthFailed) { + return fmt.Errorf("Device deauthorized. Run 'chief login' to re-authenticate.") + } + if errors.Is(err, uplink.ErrDeviceRevoked) { + return fmt.Errorf("Device deauthorized. Run 'chief login' to re-authenticate.") + } + return fmt.Errorf("connecting to server: %w", err) + } + log.Println("Connected to server") + + // Send initial state snapshot after successful connect + sendStateSnapshot(sender, scanner, sessions, runs) + + // Start periodic scanning loop + go scanner.Run(ctx) + log.Println("Workspace scanner started") + + // Start file watcher + watcher, err := workspace.NewWatcher(opts.Workspace, scanner, sender) + if err != nil { + log.Printf("Warning: could not start file watcher: %v", err) + } else { + go watcher.Run(ctx) + log.Println("File watcher started") + } + + // Start periodic version check (every 24 hours) + go runVersionChecker(ctx, sender, opts.Version, opts.ReleasesURL) + + // Set up signal handling + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGINT) + defer signal.Stop(sigCh) + + log.Println("Serve is running. Press Ctrl+C to stop.") + + // Main event loop — commands arrive from Pusher via uplink.Receive() + for { + select { + case <-ctx.Done(): + log.Println("Context cancelled, shutting down...") + return serveShutdown(ul, watcher, sessions, runs, eng) + + case sig := <-sigCh: + log.Printf("Received signal %s, shutting down...", sig) + return serveShutdown(ul, watcher, sessions, runs, eng) + + case raw, ok := <-ul.Receive(): + if !ok { + // Channel closed, connection lost permanently + log.Println("Uplink connection closed permanently") + return serveShutdown(ul, watcher, sessions, runs, eng) + } + + // Parse the raw JSON into a ws.Message for dispatch + var msg ws.Message + if err := json.Unmarshal(raw, &msg); err != nil { + log.Printf("Ignoring unparseable command: %v", err) + continue + } + msg.Raw = raw + + // Extract payload wrapper if present. + // The CommandRelayController sends {"type": "...", "payload": {...}} + // but handlers expect fields at the top level of msg.Raw. + var env struct { + Type string `json:"type"` + Payload json.RawMessage `json:"payload,omitempty"` + } + if err := json.Unmarshal(raw, &env); err == nil && len(env.Payload) > 0 { + msg.Raw = env.Payload + } + + // Check rate limit before processing + if result := rateLimiter.Allow(msg.Type); !result.Allowed { + log.Printf("Rate limited message type=%s, retry after %s", msg.Type, ws.FormatRetryAfter(result.RetryAfter)) + sendError(sender, ws.ErrCodeRateLimited, + fmt.Sprintf("Rate limited. Try again in %s.", ws.FormatRetryAfter(result.RetryAfter)), + msg.ID) + continue + } + + if shouldExit := handleMessage(sender, scanner, watcher, sessions, runs, msg, version, opts.ReleasesURL); shouldExit { + log.Println("Update installed, exiting for restart...") + serveShutdown(ul, watcher, sessions, runs, eng) + return nil + } + } + } +} + +// sendStateSnapshot sends a full state snapshot via the uplink. +func sendStateSnapshot(sender messageSender, scanner *workspace.Scanner, sessions *sessionManager, runs *runManager) { + projects := scanner.Projects() + envelope := ws.NewMessage(ws.TypeStateSnapshot) + + var activeSessions []ws.SessionState + if sessions != nil { + activeSessions = sessions.activeSessions() + } + if activeSessions == nil { + activeSessions = []ws.SessionState{} + } + + activeRuns := []ws.RunState{} + if runs != nil { + if r := runs.activeRuns(); r != nil { + activeRuns = r + } + } + + snapshot := ws.StateSnapshotMessage{ + Type: envelope.Type, + ID: envelope.ID, + Timestamp: envelope.Timestamp, + Projects: projects, + Runs: activeRuns, + Sessions: activeSessions, + } + if err := sender.Send(snapshot); err != nil { + log.Printf("Error sending state_snapshot: %v", err) + } else { + log.Printf("Sent state_snapshot with %d projects", len(projects)) + } +} + +// sendError sends an error message. +func sendError(sender messageSender, code, message, requestID string) { + envelope := ws.NewMessage(ws.TypeError) + errMsg := ws.ErrorMessage{ + Type: envelope.Type, + ID: envelope.ID, + Timestamp: envelope.Timestamp, + Code: code, + Message: message, + RequestID: requestID, + } + if err := sender.Send(errMsg); err != nil { + log.Printf("Error sending error message: %v", err) + } +} + +// handleMessage routes incoming commands. +// Returns true if the serve loop should exit (e.g., after a successful remote update). +func handleMessage(sender messageSender, scanner *workspace.Scanner, watcher *workspace.Watcher, sessions *sessionManager, runs *runManager, msg ws.Message, version, releasesURL string) bool { + log.Printf("Received command type=%s id=%s", msg.Type, msg.ID) + + switch msg.Type { + case ws.TypePing: + pong := ws.NewMessage(ws.TypePong) + if err := sender.Send(pong); err != nil { + log.Printf("Error sending pong: %v", err) + } + + case ws.TypeListProjects: + handleListProjects(sender, scanner) + + case ws.TypeGetProject: + handleGetProject(sender, scanner, watcher, msg) + + case ws.TypeGetPRD: + handleGetPRD(sender, scanner, msg) + + case ws.TypeGetPRDs: + handleGetPRDs(sender, scanner, msg) + + case ws.TypeNewPRD: + handleNewPRD(sender, scanner, sessions, msg) + + case ws.TypeRefinePRD: + handleRefinePRD(sender, scanner, sessions, msg) + + case ws.TypePRDMessage: + handlePRDMessage(sender, sessions, msg) + + case ws.TypeClosePRDSession: + handleClosePRDSession(sender, sessions, msg) + + case ws.TypeStartRun: + handleStartRun(sender, scanner, runs, watcher, msg) + + case ws.TypePauseRun: + handlePauseRun(sender, runs, msg) + + case ws.TypeResumeRun: + handleResumeRun(sender, runs, msg) + + case ws.TypeStopRun: + handleStopRun(sender, runs, msg) + + case ws.TypeGetDiff: + handleGetDiff(sender, scanner, msg) + + case ws.TypeGetDiffs: + handleGetDiffs(sender, scanner, msg) + + case ws.TypeGetLogs: + handleGetLogs(sender, scanner, msg) + + case ws.TypeGetSettings: + handleGetSettings(sender, scanner, msg) + + case ws.TypeUpdateSettings: + handleUpdateSettings(sender, scanner, msg) + + case ws.TypeCloneRepo: + handleCloneRepo(sender, scanner, msg) + + case ws.TypeCreateProject: + handleCreateProject(sender, scanner, msg) + + case ws.TypeTriggerUpdate: + return handleTriggerUpdate(sender, msg, version, releasesURL) + + default: + log.Printf("Received message type: %s", msg.Type) + } + return false +} + +// handleListProjects handles a list_projects request. +func handleListProjects(sender messageSender, scanner *workspace.Scanner) { + projects := scanner.Projects() + envelope := ws.NewMessage(ws.TypeProjectList) + plMsg := ws.ProjectListMessage{ + Type: envelope.Type, + ID: envelope.ID, + Timestamp: envelope.Timestamp, + Projects: projects, + } + if err := sender.Send(plMsg); err != nil { + log.Printf("Error sending project_list: %v", err) + } +} + +// handleGetProject handles a get_project request. +func handleGetProject(sender messageSender, scanner *workspace.Scanner, watcher *workspace.Watcher, msg ws.Message) { + var req ws.GetProjectMessage + if err := json.Unmarshal(msg.Raw, &req); err != nil { + log.Printf("Error parsing get_project message: %v", err) + return + } + + project, found := scanner.FindProject(req.Project) + if !found { + sendError(sender, ws.ErrCodeProjectNotFound, + fmt.Sprintf("Project %q not found", req.Project), msg.ID) + return + } + + // Activate file watching for the requested project + if watcher != nil { + watcher.Activate(req.Project) + } + + envelope := ws.NewMessage(ws.TypeProjectState) + psMsg := ws.ProjectStateMessage{ + Type: envelope.Type, + ID: envelope.ID, + Timestamp: envelope.Timestamp, + Project: project, + } + if err := sender.Send(psMsg); err != nil { + log.Printf("Error sending project_state: %v", err) + } +} + +// handleGetPRD handles a get_prd request. +func handleGetPRD(sender messageSender, scanner *workspace.Scanner, msg ws.Message) { + var req ws.GetPRDMessage + if err := json.Unmarshal(msg.Raw, &req); err != nil { + log.Printf("Error parsing get_prd message: %v", err) + return + } + + project, found := scanner.FindProject(req.Project) + if !found { + sendError(sender, ws.ErrCodeProjectNotFound, + fmt.Sprintf("Project %q not found", req.Project), msg.ID) + return + } + + // Read PRD markdown content + prdDir := filepath.Join(project.Path, ".chief", "prds", req.PRDID) + prdMD := filepath.Join(prdDir, "prd.md") + prdJSON := filepath.Join(prdDir, "prd.json") + + // Check that the PRD directory exists + if _, err := os.Stat(prdDir); os.IsNotExist(err) { + sendError(sender, ws.ErrCodePRDNotFound, + fmt.Sprintf("PRD %q not found in project %q", req.PRDID, req.Project), msg.ID) + return + } + + // Read markdown content (optional — may not exist yet) + var content string + if data, err := os.ReadFile(prdMD); err == nil { + content = string(data) + } + + // Read prd.json state + var state interface{} + if data, err := os.ReadFile(prdJSON); err == nil { + var parsed interface{} + if json.Unmarshal(data, &parsed) == nil { + state = parsed + } + } + + envelope := ws.NewMessage(ws.TypePRDContent) + prdMsg := ws.PRDContentMessage{ + Type: envelope.Type, + ID: envelope.ID, + Timestamp: envelope.Timestamp, + Project: req.Project, + PRDID: req.PRDID, + Content: content, + State: state, + } + if err := sender.Send(prdMsg); err != nil { + log.Printf("Error sending prd_content: %v", err) + } +} + +// runVersionChecker periodically checks for updates and sends update_available. +func runVersionChecker(ctx context.Context, sender messageSender, version, releasesURL string) { + // Check immediately on startup + checkAndNotify(sender, version, releasesURL) + + ticker := time.NewTicker(24 * time.Hour) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + checkAndNotify(sender, version, releasesURL) + } + } +} + +// checkAndNotify performs a version check and sends update_available if needed. +func checkAndNotify(sender messageSender, version, releasesURL string) { + result, err := update.CheckForUpdate(version, update.Options{ + ReleasesURL: releasesURL, + }) + if err != nil { + log.Printf("Version check failed: %v", err) + return + } + if result.UpdateAvailable { + log.Printf("Update available: v%s (current: v%s)", result.LatestVersion, result.CurrentVersion) + envelope := ws.NewMessage(ws.TypeUpdateAvailable) + msg := ws.UpdateAvailableMessage{ + Type: envelope.Type, + ID: envelope.ID, + Timestamp: envelope.Timestamp, + CurrentVersion: result.CurrentVersion, + LatestVersion: result.LatestVersion, + } + if err := sender.Send(msg); err != nil { + log.Printf("Error sending update_available: %v", err) + } + } +} + +// uplinkCloser is an interface for closing the uplink connection. +type uplinkCloser interface { + Close() error + CloseWithTimeout(timeout time.Duration) error +} + +// shutdownTimeout is the maximum time allowed for the entire shutdown sequence. +// This covers: process killing, batcher flush, Pusher close, and HTTP disconnect. +const shutdownTimeout = 10 * time.Second + +// processKillTimeout is the maximum time to wait for Claude processes to exit gracefully. +const processKillTimeout = 5 * time.Second + +// serveShutdown performs clean shutdown of the serve command. +// It kills all child Claude processes, marks interrupted stories, closes the uplink, +// and flushes log files. Processes are force-killed after 5 seconds if they haven't +// exited gracefully. The entire shutdown completes within 10 seconds. +func serveShutdown(closer uplinkCloser, watcher *workspace.Watcher, sessions *sessionManager, runs *runManager, eng *engine.Engine) error { + log.Println("Shutting down...") + + // Enforce an overall shutdown deadline. + shutdownDone := make(chan struct{}) + go func() { + defer close(shutdownDone) + doShutdown(closer, watcher, sessions, runs, eng) + }() + + select { + case <-shutdownDone: + // Normal shutdown completed within the timeout. + case <-time.After(shutdownTimeout): + log.Printf("Shutdown timed out after %s — forcing exit", shutdownTimeout) + } + + log.Println("Goodbye.") + return nil +} + +// doShutdown performs the actual shutdown sequence. +func doShutdown(closer uplinkCloser, watcher *workspace.Watcher, sessions *sessionManager, runs *runManager, eng *engine.Engine) { + // Count processes before shutdown. + processCount := 0 + if sessions != nil { + processCount += sessions.sessionCount() + } + if runs != nil { + processCount += runs.activeRunCount() + } + + // Mark any in-progress stories as interrupted in prd.json. + if runs != nil { + runs.markInterruptedStories() + } + + // Use a channel to track when graceful process shutdown completes. + done := make(chan struct{}) + go func() { + // Stop all active Ralph loop runs. + if runs != nil { + runs.stopAll() + } + + // Kill all active Claude sessions. + if sessions != nil { + sessions.killAll() + } + + // Shut down the engine (stops event forwarding goroutine). + if eng != nil { + eng.Shutdown() + } + + close(done) + }() + + // Wait for graceful shutdown or force-kill after 5 seconds. + select { + case <-done: + // Graceful shutdown completed. + case <-time.After(processKillTimeout): + log.Println("Force-killing hung processes after 5 second timeout") + } + + if processCount > 0 { + log.Printf("Killed %d processes", processCount) + } + + // Close file watcher. + if watcher != nil { + if err := watcher.Close(); err != nil { + log.Printf("Error closing file watcher: %v", err) + } + } + + // Close the uplink with a timeout to prevent hanging on unreachable servers. + // The batcher flush + Pusher close + HTTP disconnect must complete within this window. + if err := closer.CloseWithTimeout(5 * time.Second); err != nil { + log.Printf("Error closing connection: %v", err) + } +} diff --git a/internal/cmd/serve_test.go b/internal/cmd/serve_test.go new file mode 100644 index 0000000..8f40b64 --- /dev/null +++ b/internal/cmd/serve_test.go @@ -0,0 +1,1663 @@ +package cmd + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "os" + "os/exec" + "path/filepath" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/minicodemonkey/chief/internal/auth" +) + +// serveWsURL converts an httptest.Server URL to a WebSocket URL. +// Kept for use by session_test.go and remote_update_test.go. +func serveWsURL(s *httptest.Server) string { + return "ws" + strings.TrimPrefix(s.URL, "http") +} + +func setupServeCredentials(t *testing.T) { + t.Helper() + creds := &auth.Credentials{ + AccessToken: "test-token", + RefreshToken: "test-refresh", + ExpiresAt: time.Now().Add(time.Hour), + DeviceName: "test-device", + User: "user@example.com", + } + if err := auth.SaveCredentials(creds); err != nil { + t.Fatalf("SaveCredentials failed: %v", err) + } +} + +func TestRunServe_WorkspaceDefaultsToCwd(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + + setupServeCredentials(t) + + // Empty workspace should default to "." and resolve to an absolute path. + // Cancel immediately — we only care that workspace validation passes. + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := RunServe(ServeOptions{Ctx: ctx}) + if err != nil && strings.Contains(err.Error(), "does not exist") { + t.Errorf("empty workspace should default to cwd, got: %v", err) + } +} + +func TestRunServe_WorkspaceDoesNotExist(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + + setupServeCredentials(t) + + err := RunServe(ServeOptions{ + Workspace: "/nonexistent/path", + }) + if err == nil { + t.Fatal("expected error for nonexistent workspace") + } + if !strings.Contains(err.Error(), "does not exist") { + t.Errorf("expected 'does not exist' error, got: %v", err) + } +} + +func TestRunServe_WorkspaceIsFile(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + + setupServeCredentials(t) + + // Create a file instead of directory + filePath := filepath.Join(home, "not-a-dir") + if err := os.WriteFile(filePath, []byte("not a directory"), 0o644); err != nil { + t.Fatal(err) + } + + err := RunServe(ServeOptions{ + Workspace: filePath, + }) + if err == nil { + t.Fatal("expected error for file workspace") + } + if !strings.Contains(err.Error(), "not a directory") { + t.Errorf("expected 'not a directory' error, got: %v", err) + } +} + +func TestRunServe_NotLoggedIn(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + + workspace := filepath.Join(home, "projects") + if err := os.MkdirAll(workspace, 0o755); err != nil { + t.Fatal(err) + } + + err := RunServe(ServeOptions{ + Workspace: workspace, + }) + if err == nil { + t.Fatal("expected error for missing credentials") + } + if !strings.Contains(err.Error(), "Not logged in") { + t.Errorf("expected 'Not logged in' error, got: %v", err) + } +} + +func TestRunServe_ConnectsAndHandshakes(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspace := filepath.Join(home, "projects") + if err := os.MkdirAll(workspace, 0o755); err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithCancel(context.Background()) + ms := newMockUplinkServer(t) + + go func() { + if err := ms.waitForPusherSubscribe(10 * time.Second); err != nil { + t.Logf("waitForPusherSubscribe: %v", err) + cancel() + return + } + cancel() + }() + + err := RunServe(ServeOptions{ + Workspace: workspace, + ServerURL: ms.httpSrv.URL, + Version: "1.0.0", + Ctx: ctx, + }) + + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + // Check the connect body for expected metadata + connectBody := ms.getConnectBody() + if connectBody == nil { + t.Fatal("connect request body was not received") + } + + if connectBody["chief_version"] != "1.0.0" { + t.Errorf("expected chief_version '1.0.0', got %v", connectBody["chief_version"]) + } + if connectBody["device_name"] != "test-device" { + t.Errorf("expected device_name 'test-device', got %v", connectBody["device_name"]) + } +} + +func TestRunServe_DeviceNameOverride(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspace := filepath.Join(home, "projects") + if err := os.MkdirAll(workspace, 0o755); err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithCancel(context.Background()) + ms := newMockUplinkServer(t) + + go func() { + if err := ms.waitForPusherSubscribe(10 * time.Second); err != nil { + t.Logf("waitForPusherSubscribe: %v", err) + cancel() + return + } + cancel() + }() + + err := RunServe(ServeOptions{ + Workspace: workspace, + DeviceName: "my-custom-device", + ServerURL: ms.httpSrv.URL, + Version: "1.0.0", + Ctx: ctx, + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + connectBody := ms.getConnectBody() + if connectBody == nil { + t.Fatal("connect request body was not received") + } + + if connectBody["device_name"] != "my-custom-device" { + t.Errorf("expected device name 'my-custom-device', got %q", connectBody["device_name"]) + } +} + +func TestRunServe_AuthFailed(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspace := filepath.Join(home, "projects") + if err := os.MkdirAll(workspace, 0o755); err != nil { + t.Fatal(err) + } + + ms := newMockUplinkServer(t) + // Set connect endpoint to return 401 + ms.connectStatus.Store(401) + + err := RunServe(ServeOptions{ + Workspace: workspace, + ServerURL: ms.httpSrv.URL, + Version: "1.0.0", + }) + if err == nil { + t.Fatal("expected error for auth failure") + } + if !strings.Contains(err.Error(), "deauthorized") { + t.Errorf("expected 'deauthorized' error, got: %v", err) + } +} + +func TestRunServe_LogFile(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspace := filepath.Join(home, "projects") + if err := os.MkdirAll(workspace, 0o755); err != nil { + t.Fatal(err) + } + + logFile := filepath.Join(home, "chief.log") + + ctx, cancel := context.WithCancel(context.Background()) + ms := newMockUplinkServer(t) + + go func() { + if err := ms.waitForPusherSubscribe(10 * time.Second); err != nil { + t.Logf("waitForPusherSubscribe: %v", err) + cancel() + return + } + cancel() + }() + + err := RunServe(ServeOptions{ + Workspace: workspace, + ServerURL: ms.httpSrv.URL, + LogFile: logFile, + Version: "1.0.0", + Ctx: ctx, + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + // Verify log file was created and has content + data, err := os.ReadFile(logFile) + if err != nil { + t.Fatalf("failed to read log file: %v", err) + } + if len(data) == 0 { + t.Error("log file is empty") + } + content := string(data) + if !strings.Contains(content, "Starting chief serve") { + t.Errorf("log file missing startup message, got: %s", content) + } +} + +func TestRunServe_PingPong(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspace := filepath.Join(home, "projects") + if err := os.MkdirAll(workspace, 0o755); err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithCancel(context.Background()) + ms := newMockUplinkServer(t) + + go func() { + // Wait for connection + if err := ms.waitForPusherSubscribe(10 * time.Second); err != nil { + t.Logf("waitForPusherSubscribe: %v", err) + cancel() + return + } + + // Wait for initial state_snapshot + if _, err := ms.waitForMessageType("state_snapshot", 5*time.Second); err != nil { + t.Logf("waitForMessageType(state_snapshot): %v", err) + cancel() + return + } + + // Send ping command via Pusher + pingReq := map[string]string{ + "type": "ping", + "id": "ping-1", + "timestamp": time.Now().UTC().Format(time.RFC3339), + } + if err := ms.sendCommand(pingReq); err != nil { + t.Logf("sendCommand(ping): %v", err) + cancel() + return + } + + // Wait for pong response via HTTP messages + if _, err := ms.waitForMessageType("pong", 5*time.Second); err != nil { + t.Logf("waitForMessageType(pong): %v", err) + } + + cancel() + }() + + err := RunServe(ServeOptions{ + Workspace: workspace, + ServerURL: ms.httpSrv.URL, + Version: "1.0.0", + Ctx: ctx, + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + // Verify pong was received + pong, err := ms.waitForMessageType("pong", time.Second) + if err != nil { + t.Error("expected pong response to be received by server") + } else { + var msg map[string]interface{} + json.Unmarshal(pong, &msg) + if msg["type"] != "pong" { + t.Errorf("expected type 'pong', got %v", msg["type"]) + } + } +} + +func TestRunServe_TokenRefresh(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + + // Save credentials that are near expiry + creds := &auth.Credentials{ + AccessToken: "old-token", + RefreshToken: "test-refresh", + ExpiresAt: time.Now().Add(2 * time.Minute), // Within 5 min threshold + DeviceName: "test-device", + User: "user@example.com", + } + if err := auth.SaveCredentials(creds); err != nil { + t.Fatalf("SaveCredentials failed: %v", err) + } + + workspace := filepath.Join(home, "projects") + if err := os.MkdirAll(workspace, 0o755); err != nil { + t.Fatal(err) + } + + var tokenRefreshed bool + var mu sync.Mutex + + ctx, cancel := context.WithCancel(context.Background()) + + // Create a mock uplink server for the uplink endpoints + ms := newMockUplinkServer(t) + + // Create a mux that combines token refresh with uplink server endpoints. + // Token refresh goes to BaseURL, uplink goes to ServerURL. We use + // a separate mux server as the BaseURL for token refresh. + tokenSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/oauth/token" { + mu.Lock() + tokenRefreshed = true + mu.Unlock() + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "access_token": "new-refreshed-token", + "refresh_token": "new-refresh", + "expires_in": 3600, + }) + return + } + http.NotFound(w, r) + })) + defer tokenSrv.Close() + + go func() { + if err := ms.waitForPusherSubscribe(10 * time.Second); err != nil { + t.Logf("waitForPusherSubscribe: %v", err) + cancel() + return + } + cancel() + }() + + err := RunServe(ServeOptions{ + Workspace: workspace, + ServerURL: ms.httpSrv.URL, + BaseURL: tokenSrv.URL, + Version: "1.0.0", + Ctx: ctx, + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + mu.Lock() + defer mu.Unlock() + + if !tokenRefreshed { + t.Error("expected token refresh to be called for near-expiry credentials") + } +} + +// createGitRepo creates a minimal git repository for testing. +func createGitRepo(t *testing.T, dir string) { + t.Helper() + if err := os.MkdirAll(dir, 0o755); err != nil { + t.Fatal(err) + } + // Initialize git repo + cmd := exec.Command("git", "init", dir) + cmd.Env = append(os.Environ(), "GIT_CONFIG_GLOBAL=/dev/null", "GIT_CONFIG_SYSTEM=/dev/null") + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("git init failed: %v\n%s", err, out) + } + // Configure user for the repo + for _, args := range [][]string{ + {"config", "user.email", "test@test.com"}, + {"config", "user.name", "Test"}, + } { + cmd := exec.Command("git", args...) + cmd.Dir = dir + cmd.Env = append(os.Environ(), "GIT_CONFIG_GLOBAL=/dev/null", "GIT_CONFIG_SYSTEM=/dev/null") + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("git config failed: %v\n%s", err, out) + } + } + // Create initial commit + readmePath := filepath.Join(dir, "README.md") + if err := os.WriteFile(readmePath, []byte("# Test\n"), 0o644); err != nil { + t.Fatal(err) + } + cmd = exec.Command("git", "add", ".") + cmd.Dir = dir + cmd.Env = append(os.Environ(), "GIT_CONFIG_GLOBAL=/dev/null", "GIT_CONFIG_SYSTEM=/dev/null") + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("git add failed: %v\n%s", err, out) + } + cmd = exec.Command("git", "commit", "-m", "initial commit") + cmd.Dir = dir + cmd.Env = append(os.Environ(), "GIT_CONFIG_GLOBAL=/dev/null", "GIT_CONFIG_SYSTEM=/dev/null") + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("git commit failed: %v\n%s", err, out) + } +} + +// serveTestHelper sets up a serve test with a mock uplink server. +// The serverFn receives the mock uplink server after the CLI has connected +// (HTTP connect + Pusher subscribe) and sent the initial state_snapshot. +// The serverFn should send commands via ms.sendCommand() and read responses +// via ms.waitForMessageType() or ms.getMessages(). +func serveTestHelper(t *testing.T, workspacePath string, serverFn func(ms *mockUplinkServer)) error { + t.Helper() + + ctx, cancel := context.WithCancel(context.Background()) + ms := newMockUplinkServer(t) + + // Wait for the CLI to connect and send state_snapshot, then run test logic. + go func() { + // Wait for Pusher subscription (indicates full connection). + if err := ms.waitForPusherSubscribe(10 * time.Second); err != nil { + t.Logf("serveTestHelper: %v", err) + cancel() + return + } + + // Wait for initial state_snapshot to arrive. + if _, err := ms.waitForMessageType("state_snapshot", 5*time.Second); err != nil { + t.Logf("serveTestHelper: %v", err) + cancel() + return + } + + // Run test-specific server logic. + serverFn(ms) + + // Cancel context to stop serve loop. + cancel() + }() + + return RunServe(ServeOptions{ + Workspace: workspacePath, + ServerURL: ms.httpSrv.URL, + Version: "1.0.0", + Ctx: ctx, + }) +} + +func TestRunServe_StateSnapshotOnConnect(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + // Create a git repo in the workspace + projectDir := filepath.Join(workspaceDir, "myproject") + createGitRepo(t, projectDir) + + ctx, cancel := context.WithCancel(context.Background()) + ms := newMockUplinkServer(t) + + go func() { + if err := ms.waitForPusherSubscribe(10 * time.Second); err != nil { + t.Logf("waitForPusherSubscribe: %v", err) + cancel() + return + } + + // Wait for state_snapshot + if _, err := ms.waitForMessageType("state_snapshot", 5*time.Second); err != nil { + t.Logf("waitForMessageType(state_snapshot): %v", err) + } + + cancel() + }() + + err := RunServe(ServeOptions{ + Workspace: workspaceDir, + ServerURL: ms.httpSrv.URL, + Version: "1.0.0", + Ctx: ctx, + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + // Retrieve the state_snapshot message + raw, err := ms.waitForMessageType("state_snapshot", time.Second) + if err != nil { + t.Fatal("state_snapshot was not received") + } + + var snapshotReceived map[string]interface{} + json.Unmarshal(raw, &snapshotReceived) + + if snapshotReceived["type"] != "state_snapshot" { + t.Errorf("expected type 'state_snapshot', got %v", snapshotReceived["type"]) + } + + // Verify projects are included + projects, ok := snapshotReceived["projects"].([]interface{}) + if !ok { + t.Fatal("expected projects array in state_snapshot") + } + if len(projects) != 1 { + t.Fatalf("expected 1 project, got %d", len(projects)) + } + proj := projects[0].(map[string]interface{}) + if proj["name"] != "myproject" { + t.Errorf("expected project name 'myproject', got %v", proj["name"]) + } + + // Verify runs and sessions are empty arrays + runs, ok := snapshotReceived["runs"].([]interface{}) + if !ok { + t.Fatal("expected runs array in state_snapshot") + } + if len(runs) != 0 { + t.Errorf("expected 0 runs, got %d", len(runs)) + } + sessions, ok := snapshotReceived["sessions"].([]interface{}) + if !ok { + t.Fatal("expected sessions array in state_snapshot") + } + if len(sessions) != 0 { + t.Errorf("expected 0 sessions, got %d", len(sessions)) + } +} + +func TestRunServe_ListProjects(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + // Create two git repos + createGitRepo(t, filepath.Join(workspaceDir, "alpha")) + createGitRepo(t, filepath.Join(workspaceDir, "beta")) + + var projectListReceived map[string]interface{} + var mu sync.Mutex + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + // Send list_projects request + listReq := map[string]string{ + "type": "list_projects", + "id": "req-1", + "timestamp": time.Now().UTC().Format(time.RFC3339), + } + ms.sendCommand(listReq) + + // Read project_list response + raw, err := ms.waitForMessageType("project_list", 5*time.Second) + if err == nil { + mu.Lock() + json.Unmarshal(raw, &projectListReceived) + mu.Unlock() + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + mu.Lock() + defer mu.Unlock() + + if projectListReceived == nil { + t.Fatal("project_list was not received") + } + if projectListReceived["type"] != "project_list" { + t.Errorf("expected type 'project_list', got %v", projectListReceived["type"]) + } + + projects, ok := projectListReceived["projects"].([]interface{}) + if !ok { + t.Fatal("expected projects array") + } + if len(projects) != 2 { + t.Fatalf("expected 2 projects, got %d", len(projects)) + } + + // Collect project names + names := make(map[string]bool) + for _, p := range projects { + proj := p.(map[string]interface{}) + names[proj["name"].(string)] = true + } + if !names["alpha"] || !names["beta"] { + t.Errorf("expected projects alpha and beta, got %v", names) + } +} + +func TestRunServe_GetProject(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + createGitRepo(t, filepath.Join(workspaceDir, "myproject")) + + var projectStateReceived map[string]interface{} + var mu sync.Mutex + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + // Send get_project request + getReq := map[string]string{ + "type": "get_project", + "id": "req-1", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "project": "myproject", + } + ms.sendCommand(getReq) + + // Read project_state response + raw, err := ms.waitForMessageType("project_state", 5*time.Second) + if err == nil { + mu.Lock() + json.Unmarshal(raw, &projectStateReceived) + mu.Unlock() + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + mu.Lock() + defer mu.Unlock() + + if projectStateReceived == nil { + t.Fatal("project_state was not received") + } + if projectStateReceived["type"] != "project_state" { + t.Errorf("expected type 'project_state', got %v", projectStateReceived["type"]) + } + + project, ok := projectStateReceived["project"].(map[string]interface{}) + if !ok { + t.Fatal("expected project object in project_state") + } + if project["name"] != "myproject" { + t.Errorf("expected project name 'myproject', got %v", project["name"]) + } +} + +func TestRunServe_GetProjectNotFound(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + var errorReceived map[string]interface{} + var mu sync.Mutex + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + // Send get_project for nonexistent project + getReq := map[string]string{ + "type": "get_project", + "id": "req-1", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "project": "nonexistent", + } + ms.sendCommand(getReq) + + // Read error response + raw, err := ms.waitForMessageType("error", 5*time.Second) + if err == nil { + mu.Lock() + json.Unmarshal(raw, &errorReceived) + mu.Unlock() + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + mu.Lock() + defer mu.Unlock() + + if errorReceived == nil { + t.Fatal("error message was not received") + } + if errorReceived["type"] != "error" { + t.Errorf("expected type 'error', got %v", errorReceived["type"]) + } + if errorReceived["code"] != "PROJECT_NOT_FOUND" { + t.Errorf("expected code 'PROJECT_NOT_FOUND', got %v", errorReceived["code"]) + } +} + +func TestRunServe_GetPRD(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + // Create a git repo with .chief/prds/feature/ + projectDir := filepath.Join(workspaceDir, "myproject") + createGitRepo(t, projectDir) + + prdDir := filepath.Join(projectDir, ".chief", "prds", "feature") + if err := os.MkdirAll(prdDir, 0o755); err != nil { + t.Fatal(err) + } + + // Write prd.md + prdMD := "# My Feature\nThis is a feature PRD." + if err := os.WriteFile(filepath.Join(prdDir, "prd.md"), []byte(prdMD), 0o644); err != nil { + t.Fatal(err) + } + + // Write prd.json + prdState := `{"project": "My Feature", "userStories": [{"id": "US-001", "passes": true}]}` + if err := os.WriteFile(filepath.Join(prdDir, "prd.json"), []byte(prdState), 0o644); err != nil { + t.Fatal(err) + } + + var prdContentReceived map[string]interface{} + var mu sync.Mutex + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + // Send get_prd request + getReq := map[string]string{ + "type": "get_prd", + "id": "req-1", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "project": "myproject", + "prd_id": "feature", + } + ms.sendCommand(getReq) + + // Read prd_content response + raw, err := ms.waitForMessageType("prd_content", 5*time.Second) + if err == nil { + mu.Lock() + json.Unmarshal(raw, &prdContentReceived) + mu.Unlock() + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + mu.Lock() + defer mu.Unlock() + + if prdContentReceived == nil { + t.Fatal("prd_content was not received") + } + if prdContentReceived["type"] != "prd_content" { + t.Errorf("expected type 'prd_content', got %v", prdContentReceived["type"]) + } + if prdContentReceived["project"] != "myproject" { + t.Errorf("expected project 'myproject', got %v", prdContentReceived["project"]) + } + if prdContentReceived["prd_id"] != "feature" { + t.Errorf("expected prd_id 'feature', got %v", prdContentReceived["prd_id"]) + } + if prdContentReceived["content"] != prdMD { + t.Errorf("expected content %q, got %v", prdMD, prdContentReceived["content"]) + } + + // Verify state is present and contains expected data + state, ok := prdContentReceived["state"].(map[string]interface{}) + if !ok { + t.Fatal("expected state object in prd_content") + } + if state["project"] != "My Feature" { + t.Errorf("expected state.project 'My Feature', got %v", state["project"]) + } +} + +func TestRunServe_GetPRDNotFound(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + // Create a git repo without any PRDs + createGitRepo(t, filepath.Join(workspaceDir, "myproject")) + + var errorReceived map[string]interface{} + var mu sync.Mutex + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + // Send get_prd for nonexistent PRD + getReq := map[string]string{ + "type": "get_prd", + "id": "req-1", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "project": "myproject", + "prd_id": "nonexistent", + } + ms.sendCommand(getReq) + + // Read error response + raw, err := ms.waitForMessageType("error", 5*time.Second) + if err == nil { + mu.Lock() + json.Unmarshal(raw, &errorReceived) + mu.Unlock() + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + mu.Lock() + defer mu.Unlock() + + if errorReceived == nil { + t.Fatal("error message was not received") + } + if errorReceived["type"] != "error" { + t.Errorf("expected type 'error', got %v", errorReceived["type"]) + } + if errorReceived["code"] != "PRD_NOT_FOUND" { + t.Errorf("expected code 'PRD_NOT_FOUND', got %v", errorReceived["code"]) + } +} + +func TestRunServe_GetPRDProjectNotFound(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + var errorReceived map[string]interface{} + var mu sync.Mutex + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + // Send get_prd for nonexistent project + getReq := map[string]string{ + "type": "get_prd", + "id": "req-1", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "project": "nonexistent", + "prd_id": "feature", + } + ms.sendCommand(getReq) + + // Read error response + raw, err := ms.waitForMessageType("error", 5*time.Second) + if err == nil { + mu.Lock() + json.Unmarshal(raw, &errorReceived) + mu.Unlock() + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + mu.Lock() + defer mu.Unlock() + + if errorReceived == nil { + t.Fatal("error message was not received") + } + if errorReceived["type"] != "error" { + t.Errorf("expected type 'error', got %v", errorReceived["type"]) + } + if errorReceived["code"] != "PROJECT_NOT_FOUND" { + t.Errorf("expected code 'PROJECT_NOT_FOUND', got %v", errorReceived["code"]) + } +} + +func TestRunServe_RateLimitGlobal(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + var rateLimitReceived bool + var mu sync.Mutex + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + // Send more than globalBurst (30) messages rapidly to trigger rate limiting + for i := 0; i < 35; i++ { + msg := map[string]string{ + "type": "list_projects", + "id": fmt.Sprintf("req-%d", i), + "timestamp": time.Now().UTC().Format(time.RFC3339), + } + ms.sendCommand(msg) + } + + // Wait for a RATE_LIMITED error response + deadline := time.After(5 * time.Second) + for { + msgs := ms.getMessages() + for _, raw := range msgs { + var resp map[string]interface{} + json.Unmarshal(raw, &resp) + if resp["type"] == "error" && resp["code"] == "RATE_LIMITED" { + mu.Lock() + rateLimitReceived = true + mu.Unlock() + return + } + } + select { + case <-deadline: + return + case <-time.After(50 * time.Millisecond): + } + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + mu.Lock() + defer mu.Unlock() + + if !rateLimitReceived { + t.Error("expected RATE_LIMITED error after burst exhaustion") + } +} + +func TestRunServe_RateLimitPingExempt(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + var pongReceived bool + var rateLimitSeen bool + var mu sync.Mutex + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + // Exhaust the global rate limit with normal messages + for i := 0; i < 35; i++ { + msg := map[string]string{ + "type": "list_projects", + "id": fmt.Sprintf("req-%d", i), + "timestamp": time.Now().UTC().Format(time.RFC3339), + } + ms.sendCommand(msg) + } + + // Now immediately send a ping — should bypass rate limiting + ping := map[string]string{ + "type": "ping", + "id": "ping-1", + "timestamp": time.Now().UTC().Format(time.RFC3339), + } + ms.sendCommand(ping) + + // Wait for both RATE_LIMITED and pong responses + deadline := time.After(5 * time.Second) + for { + msgs := ms.getMessages() + for _, raw := range msgs { + var resp map[string]interface{} + json.Unmarshal(raw, &resp) + if resp["type"] == "pong" { + mu.Lock() + pongReceived = true + mu.Unlock() + } + if resp["type"] == "error" && resp["code"] == "RATE_LIMITED" { + mu.Lock() + rateLimitSeen = true + mu.Unlock() + } + } + mu.Lock() + done := pongReceived && rateLimitSeen + mu.Unlock() + if done { + return + } + select { + case <-deadline: + return + case <-time.After(50 * time.Millisecond): + } + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + mu.Lock() + defer mu.Unlock() + + if !rateLimitSeen { + t.Error("expected RATE_LIMITED error to confirm rate limiting was active") + } + if !pongReceived { + t.Error("expected pong response even after rate limit exhaustion — ping should be exempt") + } +} + +func TestRunServe_RateLimitExpensiveOps(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + var rateLimitReceived bool + var mu sync.Mutex + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + // Send 3 start_run messages (limit is 2/minute) + for i := 0; i < 3; i++ { + msg := map[string]interface{}{ + "type": "start_run", + "id": fmt.Sprintf("req-%d", i), + "timestamp": time.Now().UTC().Format(time.RFC3339), + "project": "nonexistent", + "prd_id": "test", + } + ms.sendCommand(msg) + } + + // Wait for a RATE_LIMITED error response + deadline := time.After(5 * time.Second) + for { + msgs := ms.getMessages() + for _, raw := range msgs { + var resp map[string]interface{} + json.Unmarshal(raw, &resp) + if resp["type"] == "error" && resp["code"] == "RATE_LIMITED" { + mu.Lock() + rateLimitReceived = true + mu.Unlock() + return + } + } + select { + case <-deadline: + return + case <-time.After(50 * time.Millisecond): + } + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + mu.Lock() + defer mu.Unlock() + + if !rateLimitReceived { + t.Error("expected RATE_LIMITED error for excessive expensive operations") + } +} + +func TestRunServe_ShutdownLogsSequence(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspace := filepath.Join(home, "projects") + if err := os.MkdirAll(workspace, 0o755); err != nil { + t.Fatal(err) + } + + logFile := filepath.Join(home, "chief-shutdown.log") + + ctx, cancel := context.WithCancel(context.Background()) + ms := newMockUplinkServer(t) + + go func() { + if err := ms.waitForPusherSubscribe(10 * time.Second); err != nil { + t.Logf("waitForPusherSubscribe: %v", err) + cancel() + return + } + + // Wait for state_snapshot to ensure connection is fully established + if _, err := ms.waitForMessageType("state_snapshot", 5*time.Second); err != nil { + t.Logf("waitForMessageType(state_snapshot): %v", err) + } + + // Cancel to trigger shutdown + cancel() + }() + + err := RunServe(ServeOptions{ + Workspace: workspace, + ServerURL: ms.httpSrv.URL, + LogFile: logFile, + Version: "1.0.0", + Ctx: ctx, + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + // Verify log file contains shutdown sequence + data, err := os.ReadFile(logFile) + if err != nil { + t.Fatalf("failed to read log file: %v", err) + } + content := string(data) + + if !strings.Contains(content, "Shutting down...") { + t.Error("log file missing 'Shutting down...' message") + } + if !strings.Contains(content, "Goodbye.") { + t.Error("log file missing 'Goodbye.' message") + } +} + +func TestRunServe_ShutdownMarksInterruptedStories(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + // Create a git repo with a PRD + projectDir := filepath.Join(workspaceDir, "myproject") + createGitRepo(t, projectDir) + + prdDir := filepath.Join(projectDir, ".chief", "prds", "feature") + if err := os.MkdirAll(prdDir, 0o755); err != nil { + t.Fatal(err) + } + + prdPath := filepath.Join(prdDir, "prd.json") + prdState := `{"project": "My Feature", "userStories": [{"id": "US-001", "title": "Test Story", "passes": false}, {"id": "US-002", "title": "Done Story", "passes": true}]}` + if err := os.WriteFile(prdPath, []byte(prdState), 0o644); err != nil { + t.Fatal(err) + } + + // Create a mock claude that hangs until killed (simulates in-progress work) + mockDir := t.TempDir() + mockScript := `#!/bin/sh +echo '{"type":"system","subtype":"init"}' +echo '{"type":"assistant","message":{"content":[{"type":"text","text":"Working on <ralph-status>US-001</ralph-status>"}]}}' +sleep 300 +` + if err := os.WriteFile(filepath.Join(mockDir, "claude"), []byte(mockScript), 0o755); err != nil { + t.Fatal(err) + } + origPath := os.Getenv("PATH") + t.Setenv("PATH", mockDir+":"+origPath) + + ctx, cancel := context.WithCancel(context.Background()) + ms := newMockUplinkServer(t) + + go func() { + if err := ms.waitForPusherSubscribe(10 * time.Second); err != nil { + t.Logf("waitForPusherSubscribe: %v", err) + cancel() + return + } + + // Wait for state_snapshot + if _, err := ms.waitForMessageType("state_snapshot", 5*time.Second); err != nil { + t.Logf("waitForMessageType(state_snapshot): %v", err) + cancel() + return + } + + // Send start_run to get a story in-progress + startReq := map[string]string{ + "type": "start_run", + "id": "req-1", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "project": "myproject", + "prd_id": "feature", + } + ms.sendCommand(startReq) + + // Wait for run_progress with story_started so we know US-001 is tracked + deadline := time.After(10 * time.Second) + for { + msgs := ms.getMessages() + for _, raw := range msgs { + var msg map[string]interface{} + json.Unmarshal(raw, &msg) + if msg["type"] == "run_progress" { + status, _ := msg["status"].(string) + if status == "story_started" { + // Now cancel to trigger shutdown while story is in-progress + cancel() + return + } + } + } + select { + case <-deadline: + t.Logf("timeout waiting for story_started") + cancel() + return + case <-time.After(50 * time.Millisecond): + } + } + }() + + err := RunServe(ServeOptions{ + Workspace: workspaceDir, + ServerURL: ms.httpSrv.URL, + Version: "1.0.0", + Ctx: ctx, + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + // Verify the PRD was updated with inProgress: true for US-001 + data, err := os.ReadFile(prdPath) + if err != nil { + t.Fatalf("failed to read PRD: %v", err) + } + + var result map[string]interface{} + if err := json.Unmarshal(data, &result); err != nil { + t.Fatalf("failed to parse PRD: %v", err) + } + + stories := result["userStories"].([]interface{}) + story1 := stories[0].(map[string]interface{}) + if story1["inProgress"] != true { + t.Errorf("expected US-001 to have inProgress=true after shutdown, got %v", story1["inProgress"]) + } + + // US-002 is already passing, should NOT be marked as inProgress + story2 := stories[1].(map[string]interface{}) + if _, hasInProgress := story2["inProgress"]; hasInProgress && story2["inProgress"] == true { + t.Error("expected US-002 to NOT have inProgress=true (already passes)") + } +} + +func TestRunServe_ShutdownLogFileFlush(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspace := filepath.Join(home, "projects") + if err := os.MkdirAll(workspace, 0o755); err != nil { + t.Fatal(err) + } + + logFile := filepath.Join(home, "chief-flush.log") + + ctx, cancel := context.WithCancel(context.Background()) + ms := newMockUplinkServer(t) + + go func() { + if err := ms.waitForPusherSubscribe(10 * time.Second); err != nil { + t.Logf("waitForPusherSubscribe: %v", err) + cancel() + return + } + + // Wait for state_snapshot + if _, err := ms.waitForMessageType("state_snapshot", 5*time.Second); err != nil { + t.Logf("waitForMessageType(state_snapshot): %v", err) + } + + cancel() + }() + + err := RunServe(ServeOptions{ + Workspace: workspace, + ServerURL: ms.httpSrv.URL, + LogFile: logFile, + Version: "1.0.0", + Ctx: ctx, + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + // Verify the log file was flushed and contains the final "Goodbye." message + data, err := os.ReadFile(logFile) + if err != nil { + t.Fatalf("failed to read log file: %v", err) + } + content := string(data) + if !strings.Contains(content, "Goodbye.") { + t.Error("log file missing 'Goodbye.' — may not have been flushed properly") + } +} + +func TestSessionManager_SessionCount(t *testing.T) { + sm := &sessionManager{ + sessions: make(map[string]*claudeSession), + stopTimeout: make(chan struct{}), + } + // Manually close stopTimeout since we're not starting the timeout checker + close(sm.stopTimeout) + + if sm.sessionCount() != 0 { + t.Errorf("expected 0 sessions, got %d", sm.sessionCount()) + } + + sm.sessions["sess1"] = &claudeSession{sessionID: "sess1"} + sm.sessions["sess2"] = &claudeSession{sessionID: "sess2"} + + if sm.sessionCount() != 2 { + t.Errorf("expected 2 sessions, got %d", sm.sessionCount()) + } +} + +func TestRunServe_ServerURLFromEnvVar(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspace := filepath.Join(home, "projects") + if err := os.MkdirAll(workspace, 0o755); err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithCancel(context.Background()) + ms := newMockUplinkServer(t) + + go func() { + if err := ms.waitForPusherSubscribe(10 * time.Second); err != nil { + t.Logf("waitForPusherSubscribe: %v", err) + cancel() + return + } + cancel() + }() + + // Set env var to point at our test server (no ServerURL in ServeOptions) + t.Setenv("CHIEF_SERVER_URL", ms.httpSrv.URL) + + err := RunServe(ServeOptions{ + Workspace: workspace, + Version: "1.0.0", + Ctx: ctx, + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } +} + +func TestRunServe_ServerURLPrecedence_FlagOverridesEnv(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspace := filepath.Join(home, "projects") + if err := os.MkdirAll(workspace, 0o755); err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithCancel(context.Background()) + ms := newMockUplinkServer(t) + + go func() { + if err := ms.waitForPusherSubscribe(10 * time.Second); err != nil { + t.Logf("waitForPusherSubscribe: %v", err) + cancel() + return + } + cancel() + }() + + // Set env var to a bad URL — flag should override it + t.Setenv("CHIEF_SERVER_URL", "http://bad-url-that-should-not-be-used:9999") + + err := RunServe(ServeOptions{ + Workspace: workspace, + ServerURL: ms.httpSrv.URL, // Flag value — should take precedence + Version: "1.0.0", + Ctx: ctx, + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } +} + +func TestRunServe_ServerURLLoggedOnStartup(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspace := filepath.Join(home, "projects") + if err := os.MkdirAll(workspace, 0o755); err != nil { + t.Fatal(err) + } + + logFile := filepath.Join(home, "serve.log") + + ctx, cancel := context.WithCancel(context.Background()) + ms := newMockUplinkServer(t) + + go func() { + if err := ms.waitForPusherSubscribe(10 * time.Second); err != nil { + t.Logf("waitForPusherSubscribe: %v", err) + cancel() + return + } + cancel() + }() + + serverURL := ms.httpSrv.URL + err := RunServe(ServeOptions{ + Workspace: workspace, + ServerURL: serverURL, + LogFile: logFile, + Version: "1.0.0", + Ctx: ctx, + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + data, err := os.ReadFile(logFile) + if err != nil { + t.Fatalf("failed to read log file: %v", err) + } + content := string(data) + if !strings.Contains(content, "Connecting to "+serverURL) { + t.Errorf("expected log to contain 'Connecting to %s', got: %s", serverURL, content) + } +} + +// --- serveShutdown unit tests --- + +// mockCloser implements uplinkCloser for testing serveShutdown directly. +type mockCloser struct { + closeCalled atomic.Int32 + closeDelay time.Duration // Simulates a slow Close operation. + closeErr error +} + +func (m *mockCloser) Close() error { + m.closeCalled.Add(1) + if m.closeDelay > 0 { + time.Sleep(m.closeDelay) + } + return m.closeErr +} + +func (m *mockCloser) CloseWithTimeout(timeout time.Duration) error { + done := make(chan error, 1) + go func() { + done <- m.Close() + }() + + select { + case err := <-done: + return err + case <-time.After(timeout): + return nil + } +} + +func TestServeShutdown_CompletesWithinTimeout(t *testing.T) { + closer := &mockCloser{} + + start := time.Now() + err := serveShutdown(closer, nil, nil, nil, nil) + elapsed := time.Since(start) + + if err != nil { + t.Errorf("serveShutdown returned error: %v", err) + } + + // Should complete quickly. + if elapsed > 2*time.Second { + t.Errorf("serveShutdown took %s, expected < 2s", elapsed) + } + + // Close should have been called. + if got := closer.closeCalled.Load(); got != 1 { + t.Errorf("close called %d times, want 1", got) + } +} + +func TestServeShutdown_TimesOutWithHangingClose(t *testing.T) { + // Create a closer that hangs longer than the shutdown timeout. + closer := &mockCloser{closeDelay: 30 * time.Second} + + start := time.Now() + err := serveShutdown(closer, nil, nil, nil, nil) + elapsed := time.Since(start) + + if err != nil { + t.Errorf("serveShutdown returned error: %v", err) + } + + // Should complete within the shutdown timeout (10s) + small buffer, + // not hang for the full 30s close delay. + if elapsed > 15*time.Second { + t.Errorf("serveShutdown took %s, expected < 15s (shutdown timeout is 10s)", elapsed) + } + + t.Logf("serveShutdown completed in %s", elapsed.Round(time.Millisecond)) +} + +func TestServeShutdown_DisconnectFailureLoggedNotBlocking(t *testing.T) { + closer := &mockCloser{closeErr: fmt.Errorf("connection refused")} + + err := serveShutdown(closer, nil, nil, nil, nil) + + // serveShutdown should not propagate the close error. + if err != nil { + t.Errorf("serveShutdown returned error: %v, want nil", err) + } + + // Close should have been called. + if got := closer.closeCalled.Load(); got != 1 { + t.Errorf("close called %d times, want 1", got) + } +} + +func TestServeShutdown_CallsDisconnectOnServer(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspace := filepath.Join(home, "projects") + if err := os.MkdirAll(workspace, 0o755); err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithCancel(context.Background()) + ms := newMockUplinkServer(t) + + go func() { + if err := ms.waitForPusherSubscribe(10 * time.Second); err != nil { + t.Logf("waitForPusherSubscribe: %v", err) + cancel() + return + } + + // Wait for state_snapshot to ensure connection is fully established. + if _, err := ms.waitForMessageType("state_snapshot", 5*time.Second); err != nil { + t.Logf("waitForMessageType(state_snapshot): %v", err) + } + + // Cancel to trigger shutdown. + cancel() + }() + + err := RunServe(ServeOptions{ + Workspace: workspace, + ServerURL: ms.httpSrv.URL, + Version: "1.0.0", + Ctx: ctx, + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + // Verify disconnect was called during shutdown. + if got := ms.disconnectCalls.Load(); got != 1 { + t.Errorf("disconnect calls = %d, want 1", got) + } +} diff --git a/internal/cmd/serve_test_helper.go b/internal/cmd/serve_test_helper.go new file mode 100644 index 0000000..563a63b --- /dev/null +++ b/internal/cmd/serve_test_helper.go @@ -0,0 +1,423 @@ +package cmd + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/gorilla/websocket" +) + +// mockUplinkServer is a combined HTTP API + Pusher WebSocket server for testing +// the serve command with the uplink transport. It replaces the old WebSocket-only +// test server used before the uplink refactor. +type mockUplinkServer struct { + httpSrv *httptest.Server + pusherSrv *mockPusherServer + + mu sync.Mutex + messageBatches []mockMessageBatch + connectBody map[string]interface{} + + connectCalls atomic.Int32 + disconnectCalls atomic.Int32 + heartbeatCalls atomic.Int32 + messagesCalls atomic.Int32 + + // connectStatus controls the HTTP status returned by /api/device/connect. + // 0 means success (200). + connectStatus atomic.Int32 +} + +type mockMessageBatch struct { + BatchID string `json:"batch_id"` + Messages []json.RawMessage `json:"messages"` +} + +// newMockUplinkServer creates a new combined test server. +func newMockUplinkServer(t *testing.T) *mockUplinkServer { + t.Helper() + + ps := newMockPusherServer(t) + + ms := &mockUplinkServer{ + pusherSrv: ps, + } + + reverbCfg := ps.reverbConfig() + + ms.httpSrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ms.handleHTTP(w, r, reverbCfg) + })) + t.Cleanup(func() { ms.httpSrv.Close() }) + + return ms +} + +func (ms *mockUplinkServer) handleHTTP(w http.ResponseWriter, r *http.Request, reverbCfg mockReverbConfig) { + // Check auth header. + auth := r.Header.Get("Authorization") + if !strings.HasPrefix(auth, "Bearer ") { + w.WriteHeader(http.StatusUnauthorized) + json.NewEncoder(w).Encode(map[string]string{"error": "missing token"}) + return + } + + // Check for simulated auth failure. + if r.URL.Path == "/api/device/connect" { + status := int(ms.connectStatus.Load()) + if status >= 400 { + w.WriteHeader(status) + json.NewEncoder(w).Encode(map[string]string{"error": "auth failed"}) + return + } + } + + w.Header().Set("Content-Type", "application/json") + + switch r.URL.Path { + case "/api/device/connect": + ms.connectCalls.Add(1) + + var body map[string]interface{} + json.NewDecoder(r.Body).Decode(&body) + ms.mu.Lock() + ms.connectBody = body + ms.mu.Unlock() + + json.NewEncoder(w).Encode(map[string]interface{}{ + "type": "welcome", + "protocol_version": 1, + "device_id": 42, + "session_id": "test-session-1", + "reverb": map[string]interface{}{ + "key": reverbCfg.Key, + "host": reverbCfg.Host, + "port": reverbCfg.Port, + "scheme": reverbCfg.Scheme, + }, + }) + + case "/api/device/disconnect": + ms.disconnectCalls.Add(1) + json.NewEncoder(w).Encode(map[string]string{"status": "disconnected"}) + + case "/api/device/heartbeat": + ms.heartbeatCalls.Add(1) + json.NewEncoder(w).Encode(map[string]string{"status": "ok"}) + + case "/api/device/messages": + ms.messagesCalls.Add(1) + + var req struct { + BatchID string `json:"batch_id"` + Messages []json.RawMessage `json:"messages"` + } + json.NewDecoder(r.Body).Decode(&req) + + ms.mu.Lock() + ms.messageBatches = append(ms.messageBatches, mockMessageBatch{ + BatchID: req.BatchID, + Messages: req.Messages, + }) + ms.mu.Unlock() + + json.NewEncoder(w).Encode(map[string]interface{}{ + "accepted": len(req.Messages), + "batch_id": req.BatchID, + "session_id": "test-session-1", + }) + + case "/api/device/broadcasting/auth": + var body struct { + SocketID string `json:"socket_id"` + ChannelName string `json:"channel_name"` + } + json.NewDecoder(r.Body).Decode(&body) + + sig := generateTestAuthSignature( + ms.pusherSrv.appKey, + ms.pusherSrv.appSecret, + body.SocketID, + body.ChannelName, + ) + json.NewEncoder(w).Encode(map[string]string{"auth": sig}) + + default: + http.NotFound(w, r) + } +} + +// getConnectBody returns the last connect request body. +func (ms *mockUplinkServer) getConnectBody() map[string]interface{} { + ms.mu.Lock() + defer ms.mu.Unlock() + return ms.connectBody +} + +// getMessages returns all messages received across all batches, flattened. +func (ms *mockUplinkServer) getMessages() []json.RawMessage { + ms.mu.Lock() + defer ms.mu.Unlock() + var msgs []json.RawMessage + for _, b := range ms.messageBatches { + msgs = append(msgs, b.Messages...) + } + return msgs +} + +// waitForMessageType waits for a message of the given type to arrive. +// Returns the first matching message or an error on timeout. +func (ms *mockUplinkServer) waitForMessageType(msgType string, timeout time.Duration) (json.RawMessage, error) { + deadline := time.After(timeout) + for { + msgs := ms.getMessages() + for _, raw := range msgs { + var envelope struct { + Type string `json:"type"` + } + if json.Unmarshal(raw, &envelope) == nil && envelope.Type == msgType { + return raw, nil + } + } + select { + case <-deadline: + return nil, fmt.Errorf("timeout waiting for message type %q (got %d messages total)", msgType, len(msgs)) + case <-time.After(10 * time.Millisecond): + } + } +} + +// waitForMessages waits until at least n messages have been received. +func (ms *mockUplinkServer) waitForMessages(n int, timeout time.Duration) ([]json.RawMessage, error) { + deadline := time.After(timeout) + for { + msgs := ms.getMessages() + if len(msgs) >= n { + return msgs, nil + } + select { + case <-deadline: + return msgs, fmt.Errorf("timeout waiting for %d messages (got %d)", n, len(msgs)) + case <-time.After(10 * time.Millisecond): + } + } +} + +// sendCommand sends a command to the CLI via the Pusher server. +// Commands are wrapped in a {"type": ..., "payload": {...}} envelope +// to match the real CommandRelayController format. +func (ms *mockUplinkServer) sendCommand(command interface{}) error { + data, err := json.Marshal(command) + if err != nil { + return fmt.Errorf("marshaling command: %w", err) + } + + // Wrap in payload envelope: extract "type", put everything else under "payload". + var flat map[string]json.RawMessage + if err := json.Unmarshal(data, &flat); err == nil { + cmdType := flat["type"] + delete(flat, "type") + + payload, _ := json.Marshal(flat) + wrapped, _ := json.Marshal(map[string]json.RawMessage{ + "type": cmdType, + "payload": payload, + }) + data = wrapped + } + + channel := fmt.Sprintf("private-chief-server.%d", 42) // device ID 42 + return ms.pusherSrv.sendCommand(channel, data) +} + +// waitForPusherSubscribe waits for the CLI to subscribe to its Pusher channel. +func (ms *mockUplinkServer) waitForPusherSubscribe(timeout time.Duration) error { + select { + case <-ms.pusherSrv.onSubscribe: + return nil + case <-time.After(timeout): + return fmt.Errorf("timeout waiting for Pusher subscription") + } +} + +// generateTestAuthSignature generates a Pusher auth signature for testing. +func generateTestAuthSignature(appKey, appSecret, socketID, channelName string) string { + toSign := socketID + ":" + channelName + mac := hmac.New(sha256.New, []byte(appSecret)) + mac.Write([]byte(toSign)) + sig := hex.EncodeToString(mac.Sum(nil)) + return appKey + ":" + sig +} + +// mockPusherServer is a minimal Pusher protocol WebSocket server for testing. +type mockPusherServer struct { + srv *httptest.Server + upgrader websocket.Upgrader + + mu sync.Mutex + conn *websocket.Conn + + appKey string + appSecret string + socketID string + activityTimeout int + + onSubscribe chan string +} + +type mockReverbConfig struct { + Key string `json:"key"` + Host string `json:"host"` + Port int `json:"port"` + Scheme string `json:"scheme"` +} + +func newMockPusherServer(t *testing.T) *mockPusherServer { + t.Helper() + + ps := &mockPusherServer{ + appKey: "test-app-key", + appSecret: "test-app-secret", + socketID: "123456.7890", + activityTimeout: 120, + onSubscribe: make(chan string, 10), + upgrader: websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { return true }, + }, + } + + ps.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ps.handleWS(t, w, r) + })) + + t.Cleanup(func() { + ps.mu.Lock() + if ps.conn != nil { + ps.conn.Close() + } + ps.mu.Unlock() + ps.srv.Close() + }) + + return ps +} + +type pusherMsg struct { + Event string `json:"event"` + Data json.RawMessage `json:"data,omitempty"` + Channel string `json:"channel,omitempty"` +} + +type pusherConnData struct { + SocketID string `json:"socket_id"` + ActivityTimeout int `json:"activity_timeout"` +} + +func (ps *mockPusherServer) handleWS(t *testing.T, w http.ResponseWriter, r *http.Request) { + t.Helper() + + expectedPath := fmt.Sprintf("/app/%s", ps.appKey) + if !strings.HasPrefix(r.URL.Path, expectedPath) { + http.Error(w, "invalid path", http.StatusNotFound) + return + } + + conn, err := ps.upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + + ps.mu.Lock() + ps.conn = conn + ps.mu.Unlock() + + // Send connection_established. + // Real Pusher protocol double-encodes the data: the data field is a JSON string + // containing the connection info, not an embedded object. + connDataInner, _ := json.Marshal(pusherConnData{ + SocketID: ps.socketID, + ActivityTimeout: ps.activityTimeout, + }) + connDataStr, _ := json.Marshal(string(connDataInner)) + conn.WriteJSON(pusherMsg{ + Event: "pusher:connection_established", + Data: connDataStr, + }) + + // Read loop. + for { + _, data, err := conn.ReadMessage() + if err != nil { + return + } + + var msg pusherMsg + if json.Unmarshal(data, &msg) != nil { + continue + } + + switch msg.Event { + case "pusher:subscribe": + var subData map[string]string + json.Unmarshal(msg.Data, &subData) + channel := subData["channel"] + + select { + case ps.onSubscribe <- channel: + default: + } + + conn.WriteJSON(pusherMsg{ + Event: "pusher_internal:subscription_succeeded", + Channel: channel, + Data: json.RawMessage("{}"), + }) + + case "pusher:pong": + // Ignore pong responses. + } + } +} + +// sendCommand sends a chief.command event to the connected client. +func (ps *mockPusherServer) sendCommand(channel string, command json.RawMessage) error { + ps.mu.Lock() + conn := ps.conn + ps.mu.Unlock() + + if conn == nil { + return fmt.Errorf("no client connected") + } + + return conn.WriteJSON(pusherMsg{ + Event: "chief.command", + Channel: channel, + Data: command, + }) +} + +// reverbConfig returns a config pointing at the test server. +func (ps *mockPusherServer) reverbConfig() mockReverbConfig { + addr := ps.srv.Listener.Addr().String() + parts := strings.Split(addr, ":") + host := parts[0] + port := 0 + fmt.Sscanf(parts[1], "%d", &port) + + return mockReverbConfig{ + Key: ps.appKey, + Host: host, + Port: port, + Scheme: "http", + } +} diff --git a/internal/cmd/session.go b/internal/cmd/session.go new file mode 100644 index 0000000..552304d --- /dev/null +++ b/internal/cmd/session.go @@ -0,0 +1,663 @@ +package cmd + +import ( + "bufio" + "encoding/json" + "fmt" + "io" + "log" + "os" + "os/exec" + "path/filepath" + "sync" + "time" + + "github.com/minicodemonkey/chief/embed" + "github.com/minicodemonkey/chief/internal/prd" + "github.com/minicodemonkey/chief/internal/ws" +) + +// Default session timeout configuration. +const ( + defaultSessionTimeout = 30 * time.Minute +) + +// Default warning thresholds (minutes of inactivity at which to warn). +var defaultWarningThresholds = []int{20, 25, 29} + +// claudeSession tracks a single Claude PRD session process. +type claudeSession struct { + sessionID string + project string + projectPath string + cmd *exec.Cmd + stdin io.WriteCloser + done chan struct{} // closed when the process exits + lastActive time.Time // last time a prd_message was received + activeMu sync.Mutex // protects lastActive +} + +// resetActivity updates the last active time for this session. +func (s *claudeSession) resetActivity() { + s.activeMu.Lock() + s.lastActive = time.Now() + s.activeMu.Unlock() +} + +// inactiveDuration returns how long the session has been inactive. +func (s *claudeSession) inactiveDuration() time.Duration { + s.activeMu.Lock() + defer s.activeMu.Unlock() + return time.Since(s.lastActive) +} + +// sessionManager manages Claude PRD sessions spawned via WebSocket. +type sessionManager struct { + mu sync.RWMutex + sessions map[string]*claudeSession + sender messageSender + timeout time.Duration // session inactivity timeout + warningThresholds []int // minutes of inactivity at which to send warnings + checkInterval time.Duration // how often to check for timeouts (configurable for tests) + stopTimeout chan struct{} // closed to stop the timeout checker +} + +// newSessionManager creates a new session manager. +func newSessionManager(sender messageSender) *sessionManager { + sm := &sessionManager{ + sessions: make(map[string]*claudeSession), + sender: sender, + timeout: defaultSessionTimeout, + warningThresholds: defaultWarningThresholds, + checkInterval: 30 * time.Second, + stopTimeout: make(chan struct{}), + } + go sm.runTimeoutChecker(sm.stopTimeout) + return sm +} + +// sessionCount returns the number of active sessions. +func (sm *sessionManager) sessionCount() int { + sm.mu.RLock() + defer sm.mu.RUnlock() + return len(sm.sessions) +} + +// getSession returns a session by ID, or nil if not found. +func (sm *sessionManager) getSession(sessionID string) *claudeSession { + sm.mu.RLock() + defer sm.mu.RUnlock() + return sm.sessions[sessionID] +} + +// activeSessions returns a list of active session states for state snapshots. +func (sm *sessionManager) activeSessions() []ws.SessionState { + sm.mu.RLock() + defer sm.mu.RUnlock() + var sessions []ws.SessionState + for _, s := range sm.sessions { + sessions = append(sessions, ws.SessionState{ + SessionID: s.sessionID, + Project: s.project, + }) + } + return sessions +} + +// newPRD spawns a new Claude PRD session. +func (sm *sessionManager) newPRD(projectPath, projectName, sessionID, initialMessage string) error { + sm.mu.Lock() + if _, exists := sm.sessions[sessionID]; exists { + sm.mu.Unlock() + return fmt.Errorf("session %s already exists", sessionID) + } + sm.mu.Unlock() + + // Ensure .chief/prds directory structure exists + prdsDir := filepath.Join(projectPath, ".chief", "prds") + if err := os.MkdirAll(prdsDir, 0o755); err != nil { + return fmt.Errorf("failed to create prds directory: %w", err) + } + + // Build prompt from init_prompt.txt template + // Use a temp PRD dir name based on session ID — Claude will create the actual + // directory when it writes prd.md (the init prompt instructs it to). + // We pass the prds base dir so the prompt has the right context. + prompt := embed.GetInitPrompt(prdsDir, initialMessage) + + // Spawn claude in print mode for non-interactive piped I/O + cmd := exec.Command(claudeBinary(), "-p", "--dangerously-skip-permissions", prompt) + cmd.Dir = projectPath + cmd.Env = filterEnv(os.Environ(), "CLAUDECODE") + + stdinPipe, err := cmd.StdinPipe() + if err != nil { + return fmt.Errorf("failed to create stdin pipe: %w", err) + } + + stdoutPipe, err := cmd.StdoutPipe() + if err != nil { + stdinPipe.Close() + return fmt.Errorf("failed to create stdout pipe: %w", err) + } + + stderrPipe, err := cmd.StderrPipe() + if err != nil { + stdinPipe.Close() + return fmt.Errorf("failed to create stderr pipe: %w", err) + } + + if err := cmd.Start(); err != nil { + stdinPipe.Close() + return fmt.Errorf("failed to start Claude: %w", err) + } + + sess := &claudeSession{ + sessionID: sessionID, + project: projectName, + projectPath: projectPath, + cmd: cmd, + stdin: stdinPipe, + done: make(chan struct{}), + lastActive: time.Now(), + } + + sm.mu.Lock() + sm.sessions[sessionID] = sess + sm.mu.Unlock() + + // Stream stdout in a goroutine + go sm.streamOutput(sessionID, stdoutPipe) + + // Stream stderr in a goroutine (merged into same claude_output) + go sm.streamOutput(sessionID, stderrPipe) + + // Wait for process to exit + go func() { + err := cmd.Wait() + if err != nil { + log.Printf("Claude session %s exited with error: %v", sessionID, err) + } else { + log.Printf("Claude session %s exited normally", sessionID) + } + + // Send prd_response_complete to signal the PRD session is done + completeMsg := ws.PRDResponseCompleteMessage{ + Type: ws.TypePRDResponseComplete, + Payload: ws.PRDResponseCompletePayload{ + SessionID: sessionID, + Project: projectName, + }, + } + if sendErr := sm.sender.Send(completeMsg); sendErr != nil { + log.Printf("Error sending prd_response_complete: %v", sendErr) + } + + // Auto-convert prd.md to prd.json if prd.md was created + sm.autoConvert(projectPath) + + close(sess.done) + + sm.mu.Lock() + delete(sm.sessions, sessionID) + sm.mu.Unlock() + }() + + return nil +} + +// refinePRD spawns a Claude PRD session to edit an existing PRD. +func (sm *sessionManager) refinePRD(projectPath, projectName, sessionID, prdID, message string) error { + sm.mu.Lock() + if _, exists := sm.sessions[sessionID]; exists { + sm.mu.Unlock() + return fmt.Errorf("session %s already exists", sessionID) + } + sm.mu.Unlock() + + // Verify the PRD directory exists + prdDir := filepath.Join(projectPath, ".chief", "prds", prdID) + if _, err := os.Stat(prdDir); os.IsNotExist(err) { + return fmt.Errorf("PRD %q not found in project", prdID) + } + + // Build prompt from edit_prompt.txt template + prompt := embed.GetEditPrompt(prdDir) + + // Spawn claude in print mode for non-interactive piped I/O + cmd := exec.Command(claudeBinary(), "-p", "--dangerously-skip-permissions", prompt) + cmd.Dir = projectPath + cmd.Env = filterEnv(os.Environ(), "CLAUDECODE") + + stdinPipe, err := cmd.StdinPipe() + if err != nil { + return fmt.Errorf("failed to create stdin pipe: %w", err) + } + + stdoutPipe, err := cmd.StdoutPipe() + if err != nil { + stdinPipe.Close() + return fmt.Errorf("failed to create stdout pipe: %w", err) + } + + stderrPipe, err := cmd.StderrPipe() + if err != nil { + stdinPipe.Close() + return fmt.Errorf("failed to create stderr pipe: %w", err) + } + + if err := cmd.Start(); err != nil { + stdinPipe.Close() + return fmt.Errorf("failed to start Claude: %w", err) + } + + sess := &claudeSession{ + sessionID: sessionID, + project: projectName, + projectPath: projectPath, + cmd: cmd, + stdin: stdinPipe, + done: make(chan struct{}), + lastActive: time.Now(), + } + + sm.mu.Lock() + sm.sessions[sessionID] = sess + sm.mu.Unlock() + + // Stream stdout in a goroutine + go sm.streamOutput(sessionID, stdoutPipe) + + // Stream stderr in a goroutine (merged into same prd_output) + go sm.streamOutput(sessionID, stderrPipe) + + // Send the user's message as the first input to Claude + go func() { + // Small delay to let Claude process the initial prompt + time.Sleep(100 * time.Millisecond) + fmt.Fprintf(stdinPipe, "%s\n", message) + }() + + // Wait for process to exit + go func() { + err := cmd.Wait() + if err != nil { + log.Printf("Claude session %s exited with error: %v", sessionID, err) + } else { + log.Printf("Claude session %s exited normally", sessionID) + } + + // Send prd_response_complete to signal the PRD session is done + completeMsg := ws.PRDResponseCompleteMessage{ + Type: ws.TypePRDResponseComplete, + Payload: ws.PRDResponseCompletePayload{ + SessionID: sessionID, + Project: projectName, + }, + } + if sendErr := sm.sender.Send(completeMsg); sendErr != nil { + log.Printf("Error sending prd_response_complete: %v", sendErr) + } + + // Auto-convert prd.md to prd.json if prd.md was updated + sm.autoConvert(projectPath) + + close(sess.done) + + sm.mu.Lock() + delete(sm.sessions, sessionID) + sm.mu.Unlock() + }() + + return nil +} + +// streamOutput reads from an io.Reader and sends each chunk as a prd_output message. +func (sm *sessionManager) streamOutput(sessionID string, r io.Reader) { + sm.mu.RLock() + sess := sm.sessions[sessionID] + sm.mu.RUnlock() + if sess == nil { + return + } + + scanner := bufio.NewScanner(r) + scanner.Buffer(make([]byte, 64*1024), 1024*1024) + for scanner.Scan() { + line := scanner.Text() + msg := ws.PRDOutputMessage{ + Type: ws.TypePRDOutput, + Payload: ws.PRDOutputPayload{ + Content: line + "\n", + SessionID: sessionID, + Project: sess.project, + }, + } + if err := sm.sender.Send(msg); err != nil { + log.Printf("Error sending prd_output: %v", err) + return + } + } +} + +// sendMessage writes a user message to an active session's stdin. +func (sm *sessionManager) sendMessage(sessionID, content string) error { + sess := sm.getSession(sessionID) + if sess == nil { + return fmt.Errorf("session not found") + } + + // Reset the inactivity timer + sess.resetActivity() + + // Write the message followed by a newline to the Claude process stdin + _, err := fmt.Fprintf(sess.stdin, "%s\n", content) + if err != nil { + return fmt.Errorf("failed to write to Claude stdin: %w", err) + } + return nil +} + +// closeSession closes a PRD session. If save is true, waits for Claude to finish. +// If save is false, kills immediately. +func (sm *sessionManager) closeSession(sessionID string, save bool) error { + sess := sm.getSession(sessionID) + if sess == nil { + return fmt.Errorf("session not found") + } + + if save { + // Close stdin to signal EOF to Claude, then wait for it to finish + sess.stdin.Close() + <-sess.done + } else { + // Kill immediately + if sess.cmd.Process != nil { + sess.cmd.Process.Kill() + } + <-sess.done + } + + return nil +} + +// killAll kills all active sessions (used during shutdown). +func (sm *sessionManager) killAll() { + // Stop the timeout checker + select { + case <-sm.stopTimeout: + // Already closed + default: + close(sm.stopTimeout) + } + + sm.mu.RLock() + sessions := make([]*claudeSession, 0, len(sm.sessions)) + for _, s := range sm.sessions { + sessions = append(sessions, s) + } + sm.mu.RUnlock() + + for _, s := range sessions { + if s.cmd.Process != nil { + s.cmd.Process.Kill() + } + } + + // Wait for all to finish + for _, s := range sessions { + <-s.done + } +} + +// runTimeoutChecker periodically checks all sessions for inactivity and sends +// warnings at the configured thresholds. When the timeout is reached, the session +// is expired: state is saved to disk, the process is killed, and session_expired is sent. +func (sm *sessionManager) runTimeoutChecker(stopCh <-chan struct{}) { + ticker := time.NewTicker(sm.checkInterval) + defer ticker.Stop() + + // Track which warnings have been sent for each session to avoid duplicates. + // Key: sessionID, Value: set of warning minutes already sent. + sentWarnings := make(map[string]map[int]bool) + + for { + select { + case <-stopCh: + return + case <-ticker.C: + sm.mu.RLock() + sessions := make([]*claudeSession, 0, len(sm.sessions)) + for _, s := range sm.sessions { + sessions = append(sessions, s) + } + sm.mu.RUnlock() + + for _, sess := range sessions { + inactive := sess.inactiveDuration() + inactiveMinutes := int(inactive.Minutes()) + + // Check if session should be expired + if inactive >= sm.timeout { + log.Printf("Session %s timed out after %v of inactivity", sess.sessionID, sm.timeout) + sm.expireSession(sess) + delete(sentWarnings, sess.sessionID) + continue + } + + // Check warning thresholds + if _, ok := sentWarnings[sess.sessionID]; !ok { + sentWarnings[sess.sessionID] = make(map[int]bool) + } + + for _, threshold := range sm.warningThresholds { + if inactiveMinutes >= threshold && !sentWarnings[sess.sessionID][threshold] { + timeoutMinutes := int(sm.timeout.Minutes()) + remaining := timeoutMinutes - threshold + log.Printf("Session %s: sending timeout warning (%d minutes remaining)", sess.sessionID, remaining) + sm.sendTimeoutWarning(sess.sessionID, remaining) + sentWarnings[sess.sessionID][threshold] = true + } + } + } + + // Clean up sentWarnings for sessions that no longer exist + sm.mu.RLock() + for sid := range sentWarnings { + if _, exists := sm.sessions[sid]; !exists { + delete(sentWarnings, sid) + } + } + sm.mu.RUnlock() + } + } +} + +// sendTimeoutWarning sends a session_timeout_warning message over WebSocket. +func (sm *sessionManager) sendTimeoutWarning(sessionID string, minutesRemaining int) { + envelope := ws.NewMessage(ws.TypeSessionTimeoutWarning) + msg := ws.SessionTimeoutWarningMessage{ + Type: envelope.Type, + ID: envelope.ID, + Timestamp: envelope.Timestamp, + SessionID: sessionID, + MinutesRemaining: minutesRemaining, + } + if err := sm.sender.Send(msg); err != nil { + log.Printf("Error sending session_timeout_warning: %v", err) + } +} + +// expireSession saves whatever PRD state exists, kills the Claude process, +// and sends a session_expired message. +func (sm *sessionManager) expireSession(sess *claudeSession) { + // Close stdin to let Claude finish writing, then kill after a brief grace period + sess.stdin.Close() + + // Give Claude 2 seconds to finish writing + select { + case <-sess.done: + // Process exited cleanly + case <-time.After(2 * time.Second): + // Force kill + if sess.cmd.Process != nil { + sess.cmd.Process.Kill() + } + <-sess.done + } + + // Send session_expired message + envelope := ws.NewMessage(ws.TypeSessionExpired) + expiredMsg := ws.SessionExpiredMessage{ + Type: envelope.Type, + ID: envelope.ID, + Timestamp: envelope.Timestamp, + SessionID: sess.sessionID, + } + if err := sm.sender.Send(expiredMsg); err != nil { + log.Printf("Error sending session_expired: %v", err) + } + + log.Printf("Session %s expired and cleaned up", sess.sessionID) +} + +// autoConvert scans for any prd.md files that need conversion and converts them. +func (sm *sessionManager) autoConvert(projectPath string) { + prdsDir := filepath.Join(projectPath, ".chief", "prds") + entries, err := os.ReadDir(prdsDir) + if err != nil { + return + } + + for _, entry := range entries { + if !entry.IsDir() { + continue + } + prdDir := filepath.Join(prdsDir, entry.Name()) + needs, err := prd.NeedsConversion(prdDir) + if err != nil { + log.Printf("Error checking conversion for %s: %v", prdDir, err) + continue + } + if needs { + log.Printf("Auto-converting PRD in %s", prdDir) + if err := prd.Convert(prd.ConvertOptions{PRDDir: prdDir}); err != nil { + log.Printf("Auto-conversion failed for %s: %v", prdDir, err) + } else { + log.Printf("Auto-conversion succeeded for %s", prdDir) + } + } + } +} + +// handleNewPRD handles a new_prd WebSocket message. +func handleNewPRD(sender messageSender, scanner projectFinder, sessions *sessionManager, msg ws.Message) { + var req ws.NewPRDMessage + if err := json.Unmarshal(msg.Raw, &req); err != nil { + log.Printf("Error parsing new_prd message: %v", err) + return + } + + project, found := scanner.FindProject(req.Project) + if !found { + sendError(sender, ws.ErrCodeProjectNotFound, + fmt.Sprintf("Project %q not found", req.Project), msg.ID) + return + } + + if err := sessions.newPRD(project.Path, req.Project, req.SessionID, req.Message); err != nil { + sendError(sender, ws.ErrCodeClaudeError, + fmt.Sprintf("Failed to start Claude session: %v", err), msg.ID) + return + } + + log.Printf("Started Claude PRD session %s for project %s", req.SessionID, req.Project) +} + +// handleRefinePRD handles a refine_prd WebSocket message. +func handleRefinePRD(sender messageSender, scanner projectFinder, sessions *sessionManager, msg ws.Message) { + var req ws.RefinePRDMessage + if err := json.Unmarshal(msg.Raw, &req); err != nil { + log.Printf("Error parsing refine_prd message: %v", err) + return + } + + project, found := scanner.FindProject(req.Project) + if !found { + sendError(sender, ws.ErrCodeProjectNotFound, + fmt.Sprintf("Project %q not found", req.Project), msg.ID) + return + } + + if err := sessions.refinePRD(project.Path, req.Project, req.SessionID, req.PRDID, req.Message); err != nil { + sendError(sender, ws.ErrCodeClaudeError, + fmt.Sprintf("Failed to start Claude session: %v", err), msg.ID) + return + } + + log.Printf("Started Claude PRD refine session %s for project %s (prd: %s)", req.SessionID, req.Project, req.PRDID) +} + +// handlePRDMessage handles a prd_message WebSocket message. +func handlePRDMessage(sender messageSender, sessions *sessionManager, msg ws.Message) { + var req ws.PRDMessageMessage + if err := json.Unmarshal(msg.Raw, &req); err != nil { + log.Printf("Error parsing prd_message: %v", err) + return + } + + if err := sessions.sendMessage(req.SessionID, req.Message); err != nil { + sendError(sender, ws.ErrCodeSessionNotFound, + fmt.Sprintf("Session %q not found", req.SessionID), msg.ID) + return + } +} + +// handleClosePRDSession handles a close_prd_session WebSocket message. +func handleClosePRDSession(sender messageSender, sessions *sessionManager, msg ws.Message) { + var req ws.ClosePRDSessionMessage + if err := json.Unmarshal(msg.Raw, &req); err != nil { + log.Printf("Error parsing close_prd_session: %v", err) + return + } + + if err := sessions.closeSession(req.SessionID, req.Save); err != nil { + sendError(sender, ws.ErrCodeSessionNotFound, + fmt.Sprintf("Session %q not found", req.SessionID), msg.ID) + return + } + + log.Printf("Closed Claude PRD session %s (save=%v)", req.SessionID, req.Save) +} + +// claudeBinary returns the path to the claude CLI binary. +// It checks the CHIEF_CLAUDE_BINARY environment variable first, falling back to "claude". +func claudeBinary() string { + if bin := os.Getenv("CHIEF_CLAUDE_BINARY"); bin != "" { + return bin + } + return "claude" +} + +// filterEnv returns a copy of env with the named variables removed. +func filterEnv(env []string, keys ...string) []string { + filtered := make([]string, 0, len(env)) + for _, e := range env { + skip := false + for _, key := range keys { + if len(e) > len(key) && e[:len(key)+1] == key+"=" { + skip = true + break + } + } + if !skip { + filtered = append(filtered, e) + } + } + return filtered +} + +// projectFinder is an interface for finding projects (for testability). +type projectFinder interface { + FindProject(name string) (ws.ProjectSummary, bool) +} diff --git a/internal/cmd/session_test.go b/internal/cmd/session_test.go new file mode 100644 index 0000000..afa3142 --- /dev/null +++ b/internal/cmd/session_test.go @@ -0,0 +1,1469 @@ +package cmd + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + "sync" + "testing" + "time" + + "github.com/minicodemonkey/chief/internal/ws" +) + +// captureSender is a mock messageSender that captures sent messages. +type captureSender struct { + mu sync.Mutex + messages []map[string]interface{} +} + +func (c *captureSender) Send(msg interface{}) error { + data, err := json.Marshal(msg) + if err != nil { + return err + } + var m map[string]interface{} + if err := json.Unmarshal(data, &m); err != nil { + return err + } + c.mu.Lock() + c.messages = append(c.messages, m) + c.mu.Unlock() + return nil +} + +func (c *captureSender) getMessages() []map[string]interface{} { + c.mu.Lock() + defer c.mu.Unlock() + cp := make([]map[string]interface{}, len(c.messages)) + copy(cp, c.messages) + return cp +} + +// discardSender is a mock messageSender that discards all messages. +type discardSender struct{} + +func (d *discardSender) Send(msg interface{}) error { return nil } + +// mockProjectFinder implements projectFinder for tests. +type mockProjectFinder struct { + projects map[string]ws.ProjectSummary +} + +func (m *mockProjectFinder) FindProject(name string) (ws.ProjectSummary, bool) { + p, ok := m.projects[name] + return p, ok +} + +func TestSessionManager_NewPRD(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + projectDir := filepath.Join(workspaceDir, "myproject") + createGitRepo(t, projectDir) + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + // Send new_prd request + newPRDReq := map[string]string{ + "type": "new_prd", + "id": "req-1", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "project": "myproject", + "session_id": "sess-123", + "message": "Build a todo app", + } + ms.sendCommand(newPRDReq) + + // We should receive prd_output messages. + // Since we can't actually run claude in tests, expect an error response + // (claude binary not available in test) — this tests the error path. + // Wait for first message with 3 second timeout. + raw, err := ms.waitForMessageType("error", 3*time.Second) + if err != nil { + // If not an error, might be prd_output + raw, err = ms.waitForMessageType("prd_output", 3*time.Second) + if err != nil { + t.Fatal("expected error or prd_output message") + } + } + + var msg map[string]interface{} + if err := json.Unmarshal(raw, &msg); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + // Check that we got some kind of response + msgType := msg["type"].(string) + if msgType != "error" && msgType != "prd_output" { + t.Errorf("expected error or prd_output message, got %s", msgType) + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } +} + +func TestSessionManager_NewPRD_ProjectNotFound(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + // Send new_prd for nonexistent project + newPRDReq := map[string]string{ + "type": "new_prd", + "id": "req-1", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "project": "nonexistent", + "session_id": "sess-123", + "message": "Build a todo app", + } + ms.sendCommand(newPRDReq) + + // Read error response + raw, err := ms.waitForMessageType("error", 2*time.Second) + if err != nil { + t.Fatalf("expected error message: %v", err) + } + + var errorReceived map[string]interface{} + if err := json.Unmarshal(raw, &errorReceived); err != nil { + t.Fatalf("failed to unmarshal error: %v", err) + } + + if errorReceived["type"] != "error" { + t.Errorf("expected type 'error', got %v", errorReceived["type"]) + } + if errorReceived["code"] != "PROJECT_NOT_FOUND" { + t.Errorf("expected code 'PROJECT_NOT_FOUND', got %v", errorReceived["code"]) + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } +} + +func TestSessionManager_PRDMessage_SessionNotFound(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + // Send prd_message for nonexistent session + prdMsg := map[string]string{ + "type": "prd_message", + "id": "req-1", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "session_id": "nonexistent-session", + "message": "hello", + } + ms.sendCommand(prdMsg) + + // Read error response + raw, err := ms.waitForMessageType("error", 2*time.Second) + if err != nil { + t.Fatalf("expected error message: %v", err) + } + + var errorReceived map[string]interface{} + if err := json.Unmarshal(raw, &errorReceived); err != nil { + t.Fatalf("failed to unmarshal error: %v", err) + } + + if errorReceived["type"] != "error" { + t.Errorf("expected type 'error', got %v", errorReceived["type"]) + } + if errorReceived["code"] != "SESSION_NOT_FOUND" { + t.Errorf("expected code 'SESSION_NOT_FOUND', got %v", errorReceived["code"]) + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } +} + +func TestSessionManager_ClosePRDSession_SessionNotFound(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + // Send close_prd_session for nonexistent session + closeMsg := map[string]interface{}{ + "type": "close_prd_session", + "id": "req-1", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "session_id": "nonexistent-session", + "save": false, + } + ms.sendCommand(closeMsg) + + // Read error response + raw, err := ms.waitForMessageType("error", 2*time.Second) + if err != nil { + t.Fatalf("expected error message: %v", err) + } + + var errorReceived map[string]interface{} + if err := json.Unmarshal(raw, &errorReceived); err != nil { + t.Fatalf("failed to unmarshal error: %v", err) + } + + if errorReceived["type"] != "error" { + t.Errorf("expected type 'error', got %v", errorReceived["type"]) + } + if errorReceived["code"] != "SESSION_NOT_FOUND" { + t.Errorf("expected code 'SESSION_NOT_FOUND', got %v", errorReceived["code"]) + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } +} + +// TestSessionManager_WithMockClaude uses a shell script to simulate Claude, +// testing the full session lifecycle: spawn, stream output, send message, close. +func TestSessionManager_WithMockClaude(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + projectDir := filepath.Join(workspaceDir, "myproject") + createGitRepo(t, projectDir) + + // Create a mock "claude" script that echoes input + mockClaudeBin := filepath.Join(home, "claude") + mockScript := `#!/bin/sh +echo "Claude PRD session started" +echo "Processing: $1" +# Read from stdin and echo back +while IFS= read -r line; do + echo "Received: $line" +done +echo "Session complete" +` + if err := os.WriteFile(mockClaudeBin, []byte(mockScript), 0o755); err != nil { + t.Fatal(err) + } + + // Add mock claude to PATH + origPath := os.Getenv("PATH") + t.Setenv("PATH", home+":"+origPath) + + ctx, cancel := context.WithCancel(context.Background()) + ms := newMockUplinkServer(t) + + go func() { + if err := ms.waitForPusherSubscribe(10 * time.Second); err != nil { + t.Logf("waitForPusherSubscribe: %v", err) + cancel() + return + } + + // Wait for initial state_snapshot + if _, err := ms.waitForMessageType("state_snapshot", 5*time.Second); err != nil { + t.Logf("waitForMessageType(state_snapshot): %v", err) + cancel() + return + } + + // Send new_prd request via Pusher + newPRDReq := map[string]string{ + "type": "new_prd", + "id": "req-1", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "project": "myproject", + "session_id": "sess-mock-1", + "message": "Build a todo app", + } + ms.sendCommand(newPRDReq) + + // Wait a bit for process to start and produce output + time.Sleep(500 * time.Millisecond) + + // Send a prd_message + prdMsg := map[string]string{ + "type": "prd_message", + "id": "req-2", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "session_id": "sess-mock-1", + "message": "Add user authentication", + } + ms.sendCommand(prdMsg) + + // Wait for output + time.Sleep(500 * time.Millisecond) + + // Close the session (save=false, kill immediately) + closeMsg := map[string]interface{}{ + "type": "close_prd_session", + "id": "req-3", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "session_id": "sess-mock-1", + "save": false, + } + ms.sendCommand(closeMsg) + + // Wait for a prd_response_complete message + deadline := time.After(5 * time.Second) + for { + msgs := ms.getMessages() + for _, raw := range msgs { + var msg map[string]interface{} + json.Unmarshal(raw, &msg) + if msg["type"] == "prd_response_complete" { + cancel() + return + } + } + select { + case <-deadline: + cancel() + return + case <-time.After(50 * time.Millisecond): + } + } + }() + + err := RunServe(ServeOptions{ + Workspace: workspaceDir, + ServerURL: ms.httpSrv.URL, + Version: "1.0.0", + Ctx: ctx, + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + // Collect all prd_output messages + allMsgs := ms.getMessages() + var prdOutputs []map[string]interface{} + for _, raw := range allMsgs { + var msg map[string]interface{} + if json.Unmarshal(raw, &msg) == nil && msg["type"] == "prd_output" { + prdOutputs = append(prdOutputs, msg) + } + } + + if len(prdOutputs) == 0 { + t.Fatal("expected at least one prd_output message") + } + + // Verify session_id and project are set on all prd_output messages (inside payload) + for _, co := range prdOutputs { + payload, _ := co["payload"].(map[string]interface{}) + if payload == nil { + t.Error("expected prd_output to have a payload field") + continue + } + if payload["session_id"] != "sess-mock-1" { + t.Errorf("expected payload.session_id 'sess-mock-1', got %v", payload["session_id"]) + } + if payload["project"] != "myproject" { + t.Errorf("expected payload.project 'myproject', got %v", payload["project"]) + } + } + + // Verify we got a prd_response_complete message + hasComplete := false + for _, raw := range allMsgs { + var msg map[string]interface{} + if json.Unmarshal(raw, &msg) == nil && msg["type"] == "prd_response_complete" { + hasComplete = true + payload, _ := msg["payload"].(map[string]interface{}) + if payload == nil { + t.Error("expected prd_response_complete to have a payload field") + } else if payload["session_id"] != "sess-mock-1" { + t.Errorf("expected payload.session_id 'sess-mock-1' on prd_response_complete, got %v", payload["session_id"]) + } + break + } + } + if !hasComplete { + t.Error("expected a prd_response_complete message") + } + + // Verify we received some actual content + hasContent := false + for _, co := range prdOutputs { + payload, _ := co["payload"].(map[string]interface{}) + if payload != nil { + if content, ok := payload["content"].(string); ok && strings.TrimSpace(content) != "" { + hasContent = true + break + } + } + } + if !hasContent { + t.Error("expected at least one prd_output with non-empty content") + } +} + +// TestSessionManager_WithMockClaude_SaveClose tests save=true close behavior. +func TestSessionManager_WithMockClaude_SaveClose(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + projectDir := filepath.Join(workspaceDir, "myproject") + createGitRepo(t, projectDir) + + // Create a mock "claude" script that exits on EOF + mockClaudeBin := filepath.Join(home, "claude") + mockScript := `#!/bin/sh +echo "Session started" +# Read until EOF (stdin closed) +while IFS= read -r line; do + echo "Got: $line" +done +echo "Saving PRD..." +exit 0 +` + if err := os.WriteFile(mockClaudeBin, []byte(mockScript), 0o755); err != nil { + t.Fatal(err) + } + + origPath := os.Getenv("PATH") + t.Setenv("PATH", home+":"+origPath) + + ctx, cancel := context.WithCancel(context.Background()) + ms := newMockUplinkServer(t) + + go func() { + if err := ms.waitForPusherSubscribe(10 * time.Second); err != nil { + t.Logf("waitForPusherSubscribe: %v", err) + cancel() + return + } + + // Wait for initial state_snapshot + if _, err := ms.waitForMessageType("state_snapshot", 5*time.Second); err != nil { + t.Logf("waitForMessageType(state_snapshot): %v", err) + cancel() + return + } + + // Send new_prd via Pusher + newPRDReq := map[string]string{ + "type": "new_prd", + "id": "req-1", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "project": "myproject", + "session_id": "sess-save-1", + "message": "Build an API", + } + ms.sendCommand(newPRDReq) + + time.Sleep(500 * time.Millisecond) + + // Close with save=true (waits for Claude to finish) + closeMsg := map[string]interface{}{ + "type": "close_prd_session", + "id": "req-2", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "session_id": "sess-save-1", + "save": true, + } + ms.sendCommand(closeMsg) + + // Wait for a prd_response_complete message + deadline := time.After(5 * time.Second) + for { + msgs := ms.getMessages() + for _, raw := range msgs { + var msg map[string]interface{} + json.Unmarshal(raw, &msg) + if msg["type"] == "prd_response_complete" { + cancel() + return + } + } + select { + case <-deadline: + cancel() + return + case <-time.After(50 * time.Millisecond): + } + } + }() + + err := RunServe(ServeOptions{ + Workspace: workspaceDir, + ServerURL: ms.httpSrv.URL, + Version: "1.0.0", + Ctx: ctx, + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + // Verify we received a prd_response_complete message + allMsgs := ms.getMessages() + hasComplete := false + for _, raw := range allMsgs { + var msg map[string]interface{} + if json.Unmarshal(raw, &msg) == nil && msg["type"] == "prd_response_complete" { + hasComplete = true + break + } + } + if !hasComplete { + t.Error("expected a prd_response_complete message after save close") + } +} + +func TestSessionManager_ActiveSessions(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + + // Create a mock "claude" script that stays alive + mockClaudeBin := filepath.Join(home, "claude") + mockScript := `#!/bin/sh +while IFS= read -r line; do + echo "$line" +done +` + if err := os.WriteFile(mockClaudeBin, []byte(mockScript), 0o755); err != nil { + t.Fatal(err) + } + origPath := os.Getenv("PATH") + t.Setenv("PATH", home+":"+origPath) + + sm := newSessionManager(&discardSender{}) + + // Initially no active sessions + sessions := sm.activeSessions() + if len(sessions) != 0 { + t.Errorf("expected 0 active sessions, got %d", len(sessions)) + } + + // Create a project dir for the session + projectDir := filepath.Join(home, "testproject") + if err := os.MkdirAll(projectDir, 0o755); err != nil { + t.Fatal(err) + } + + // Start a session + err := sm.newPRD(projectDir, "testproject", "sess-1", "test message") + if err != nil { + t.Fatalf("newPRD failed: %v", err) + } + + // Now should have 1 active session + sessions = sm.activeSessions() + if len(sessions) != 1 { + t.Fatalf("expected 1 active session, got %d", len(sessions)) + } + if sessions[0].SessionID != "sess-1" { + t.Errorf("expected session_id 'sess-1', got %q", sessions[0].SessionID) + } + if sessions[0].Project != "testproject" { + t.Errorf("expected project 'testproject', got %q", sessions[0].Project) + } + + // Kill all sessions + sm.killAll() + + // Now should have 0 active sessions + sessions = sm.activeSessions() + if len(sessions) != 0 { + t.Errorf("expected 0 active sessions after killAll, got %d", len(sessions)) + } +} + +func TestSessionManager_SendMessage(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + + // Create a mock "claude" script that echoes input + mockClaudeBin := filepath.Join(home, "claude") + mockScript := `#!/bin/sh +echo "ready" +while IFS= read -r line; do + echo "echo: $line" +done +` + if err := os.WriteFile(mockClaudeBin, []byte(mockScript), 0o755); err != nil { + t.Fatal(err) + } + origPath := os.Getenv("PATH") + t.Setenv("PATH", home+":"+origPath) + + sender := &captureSender{} + sm := newSessionManager(sender) + + projectDir := filepath.Join(home, "testproject") + if err := os.MkdirAll(projectDir, 0o755); err != nil { + t.Fatal(err) + } + + err := sm.newPRD(projectDir, "testproject", "sess-msg-1", "test") + if err != nil { + t.Fatalf("newPRD failed: %v", err) + } + + // Wait for process to start + time.Sleep(300 * time.Millisecond) + + // Send a message + if err := sm.sendMessage("sess-msg-1", "hello world"); err != nil { + t.Fatalf("sendMessage failed: %v", err) + } + + // Wait for echo + time.Sleep(500 * time.Millisecond) + + // Verify error on nonexistent session + if err := sm.sendMessage("nonexistent", "test"); err == nil { + t.Error("expected error for nonexistent session") + } + + // Check that we received the echoed message via captureSender + msgs := sender.getMessages() + hasEcho := false + for _, msg := range msgs { + if msg["type"] == "prd_output" { + payload, _ := msg["payload"].(map[string]interface{}) + if payload != nil { + if content, ok := payload["content"].(string); ok && strings.Contains(content, "echo: hello world") { + hasEcho = true + break + } + } + } + } + if !hasEcho { + t.Errorf("expected echoed message 'echo: hello world' in captured messages: %v", msgs) + } + + // Clean up + sm.killAll() +} + +func TestSessionManager_CloseSession_Errors(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + + sm := newSessionManager(&discardSender{}) + + // Close nonexistent session + err := sm.closeSession("nonexistent", false) + if err == nil { + t.Error("expected error for nonexistent session") + } + if !strings.Contains(err.Error(), "session not found") { + t.Errorf("expected 'session not found' error, got: %v", err) + } +} + +func TestSessionManager_DuplicateSession(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + + // Create a mock "claude" script + mockClaudeBin := filepath.Join(home, "claude") + mockScript := fmt.Sprintf("#!/bin/sh\nwhile IFS= read -r line; do echo \"$line\"; done") + if err := os.WriteFile(mockClaudeBin, []byte(mockScript), 0o755); err != nil { + t.Fatal(err) + } + origPath := os.Getenv("PATH") + t.Setenv("PATH", home+":"+origPath) + + sm := newSessionManager(&discardSender{}) + + projectDir := filepath.Join(home, "testproject") + if err := os.MkdirAll(projectDir, 0o755); err != nil { + t.Fatal(err) + } + + // Start first session + err := sm.newPRD(projectDir, "testproject", "sess-dup", "test") + if err != nil { + t.Fatalf("first newPRD failed: %v", err) + } + + // Try to start duplicate session + err = sm.newPRD(projectDir, "testproject", "sess-dup", "test") + if err == nil { + t.Error("expected error for duplicate session_id") + } + if !strings.Contains(err.Error(), "already exists") { + t.Errorf("expected 'already exists' error, got: %v", err) + } + + sm.killAll() +} + +func TestSessionManager_RefinePRD(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + projectDir := filepath.Join(workspaceDir, "myproject") + createGitRepo(t, projectDir) + + // Create existing PRD directory with prd.md + prdDir := filepath.Join(projectDir, ".chief", "prds", "feature-auth") + if err := os.MkdirAll(prdDir, 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(prdDir, "prd.md"), []byte("# Auth PRD\n"), 0o644); err != nil { + t.Fatal(err) + } + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + // Send refine_prd request + refinePRDReq := map[string]string{ + "type": "refine_prd", + "id": "req-1", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "project": "myproject", + "session_id": "sess-refine-1", + "prd_id": "feature-auth", + "message": "Add OAuth support", + } + ms.sendCommand(refinePRDReq) + + // Since we can't actually run claude in tests, expect an error response + // (claude binary not available in test) — this tests the error path. + raw, err := ms.waitForMessageType("error", 3*time.Second) + if err != nil { + raw, err = ms.waitForMessageType("prd_output", 3*time.Second) + if err != nil { + t.Fatal("expected error or prd_output message") + } + } + + var msg map[string]interface{} + if err := json.Unmarshal(raw, &msg); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + msgType := msg["type"].(string) + if msgType != "error" && msgType != "prd_output" { + t.Errorf("expected error or prd_output message, got %s", msgType) + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } +} + +func TestSessionManager_RefinePRD_ProjectNotFound(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + refinePRDReq := map[string]string{ + "type": "refine_prd", + "id": "req-1", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "project": "nonexistent", + "session_id": "sess-refine-1", + "prd_id": "feature-auth", + "message": "Add OAuth support", + } + ms.sendCommand(refinePRDReq) + + raw, err := ms.waitForMessageType("error", 2*time.Second) + if err != nil { + t.Fatalf("expected error message: %v", err) + } + + var errorReceived map[string]interface{} + if err := json.Unmarshal(raw, &errorReceived); err != nil { + t.Fatalf("failed to unmarshal error: %v", err) + } + + if errorReceived["type"] != "error" { + t.Errorf("expected type 'error', got %v", errorReceived["type"]) + } + if errorReceived["code"] != "PROJECT_NOT_FOUND" { + t.Errorf("expected code 'PROJECT_NOT_FOUND', got %v", errorReceived["code"]) + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } +} + +func TestSessionManager_RefinePRD_PRDNotFound(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + projectDir := filepath.Join(workspaceDir, "myproject") + createGitRepo(t, projectDir) + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + refinePRDReq := map[string]string{ + "type": "refine_prd", + "id": "req-1", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "project": "myproject", + "session_id": "sess-refine-1", + "prd_id": "nonexistent-prd", + "message": "Add OAuth support", + } + ms.sendCommand(refinePRDReq) + + raw, err := ms.waitForMessageType("error", 2*time.Second) + if err != nil { + t.Fatalf("expected error message: %v", err) + } + + var errorReceived map[string]interface{} + if err := json.Unmarshal(raw, &errorReceived); err != nil { + t.Fatalf("failed to unmarshal error: %v", err) + } + + if errorReceived["type"] != "error" { + t.Errorf("expected type 'error', got %v", errorReceived["type"]) + } + if errorReceived["code"] != "CLAUDE_ERROR" { + t.Errorf("expected code 'CLAUDE_ERROR', got %v", errorReceived["code"]) + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } +} + +func TestSessionManager_WithMockClaude_RefinePRD(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + projectDir := filepath.Join(workspaceDir, "myproject") + createGitRepo(t, projectDir) + + // Create existing PRD directory with prd.md + prdDir := filepath.Join(projectDir, ".chief", "prds", "feature-auth") + if err := os.MkdirAll(prdDir, 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(prdDir, "prd.md"), []byte("# Auth PRD\n"), 0o644); err != nil { + t.Fatal(err) + } + + // Create a mock "claude" script that echoes input + mockClaudeBin := filepath.Join(home, "claude") + mockScript := `#!/bin/sh +echo "Claude PRD edit session started" +echo "Processing: $1" +while IFS= read -r line; do + echo "Received: $line" +done +echo "Edit complete" +` + if err := os.WriteFile(mockClaudeBin, []byte(mockScript), 0o755); err != nil { + t.Fatal(err) + } + + origPath := os.Getenv("PATH") + t.Setenv("PATH", home+":"+origPath) + + ctx, cancel := context.WithCancel(context.Background()) + ms := newMockUplinkServer(t) + + go func() { + if err := ms.waitForPusherSubscribe(10 * time.Second); err != nil { + t.Logf("waitForPusherSubscribe: %v", err) + cancel() + return + } + + if _, err := ms.waitForMessageType("state_snapshot", 5*time.Second); err != nil { + t.Logf("waitForMessageType(state_snapshot): %v", err) + cancel() + return + } + + // Send refine_prd request via Pusher + refinePRDReq := map[string]string{ + "type": "refine_prd", + "id": "req-1", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "project": "myproject", + "session_id": "sess-refine-mock-1", + "prd_id": "feature-auth", + "message": "Add OAuth support", + } + ms.sendCommand(refinePRDReq) + + // Wait for output + time.Sleep(500 * time.Millisecond) + + // Send a follow-up message + prdMsg := map[string]string{ + "type": "prd_message", + "id": "req-2", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "session_id": "sess-refine-mock-1", + "message": "Also add RBAC", + } + ms.sendCommand(prdMsg) + + // Wait for output + time.Sleep(500 * time.Millisecond) + + // Close the session + closeMsg := map[string]interface{}{ + "type": "close_prd_session", + "id": "req-3", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "session_id": "sess-refine-mock-1", + "save": false, + } + ms.sendCommand(closeMsg) + + // Wait for prd_response_complete + deadline := time.After(5 * time.Second) + for { + msgs := ms.getMessages() + for _, raw := range msgs { + var msg map[string]interface{} + json.Unmarshal(raw, &msg) + if msg["type"] == "prd_response_complete" { + cancel() + return + } + } + select { + case <-deadline: + cancel() + return + case <-time.After(50 * time.Millisecond): + } + } + }() + + err := RunServe(ServeOptions{ + Workspace: workspaceDir, + ServerURL: ms.httpSrv.URL, + Version: "1.0.0", + Ctx: ctx, + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + // Collect all prd_output messages + allMsgs := ms.getMessages() + var prdOutputs []map[string]interface{} + for _, raw := range allMsgs { + var msg map[string]interface{} + if json.Unmarshal(raw, &msg) == nil && msg["type"] == "prd_output" { + prdOutputs = append(prdOutputs, msg) + } + } + + if len(prdOutputs) == 0 { + t.Fatal("expected at least one prd_output message") + } + + // Verify session_id and project are set on all prd_output messages (inside payload) + for _, co := range prdOutputs { + payload, _ := co["payload"].(map[string]interface{}) + if payload == nil { + t.Error("expected prd_output to have a payload field") + continue + } + if payload["session_id"] != "sess-refine-mock-1" { + t.Errorf("expected payload.session_id 'sess-refine-mock-1', got %v", payload["session_id"]) + } + if payload["project"] != "myproject" { + t.Errorf("expected payload.project 'myproject', got %v", payload["project"]) + } + } + + // Verify we got a prd_response_complete message + hasComplete := false + for _, raw := range allMsgs { + var msg map[string]interface{} + if json.Unmarshal(raw, &msg) == nil && msg["type"] == "prd_response_complete" { + hasComplete = true + payload, _ := msg["payload"].(map[string]interface{}) + if payload == nil { + t.Error("expected prd_response_complete to have a payload field") + } else if payload["session_id"] != "sess-refine-mock-1" { + t.Errorf("expected payload.session_id 'sess-refine-mock-1' on prd_response_complete, got %v", payload["session_id"]) + } + break + } + } + if !hasComplete { + t.Error("expected a prd_response_complete message") + } + + // Verify the user's message was received by Claude (should appear in prd_output) + hasUserMessage := false + for _, co := range prdOutputs { + payload, _ := co["payload"].(map[string]interface{}) + if payload != nil { + if content, ok := payload["content"].(string); ok && strings.Contains(content, "Add OAuth support") { + hasUserMessage = true + break + } + } + } + if !hasUserMessage { + t.Error("expected user's refine message 'Add OAuth support' to appear in prd_output") + } +} + +// newTestSessionManager creates a session manager with configurable timeouts for testing. +// It does NOT start the timeout checker goroutine automatically. +func newTestSessionManager(t *testing.T, timeout time.Duration, warningThresholds []int, checkInterval time.Duration) (*sessionManager, *captureSender, func()) { + t.Helper() + + home := t.TempDir() + setTestHome(t, home) + + // Create a mock "claude" script that stays alive + mockClaudeBin := filepath.Join(home, "claude") + mockScript := `#!/bin/sh +while IFS= read -r line; do + echo "$line" +done +` + if err := os.WriteFile(mockClaudeBin, []byte(mockScript), 0o755); err != nil { + t.Fatal(err) + } + origPath := os.Getenv("PATH") + t.Setenv("PATH", home+":"+origPath) + + sender := &captureSender{} + sm := &sessionManager{ + sessions: make(map[string]*claudeSession), + sender: sender, + timeout: timeout, + warningThresholds: warningThresholds, + checkInterval: checkInterval, + stopTimeout: make(chan struct{}), + } + + cleanup := func() { + sm.killAll() + } + + return sm, sender, cleanup +} + +func TestSessionManager_TimeoutExpiration(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + + // Create mock claude + mockClaudeBin := filepath.Join(home, "claude") + mockScript := `#!/bin/sh +while IFS= read -r line; do + echo "$line" +done +` + if err := os.WriteFile(mockClaudeBin, []byte(mockScript), 0o755); err != nil { + t.Fatal(err) + } + origPath := os.Getenv("PATH") + t.Setenv("PATH", home+":"+origPath) + + sender := &captureSender{} + sm := &sessionManager{ + sessions: make(map[string]*claudeSession), + sender: sender, + timeout: 200 * time.Millisecond, // Very short for testing + warningThresholds: []int{}, // No warnings, just test expiry + checkInterval: 50 * time.Millisecond, // Check frequently + stopTimeout: make(chan struct{}), + } + go sm.runTimeoutChecker(sm.stopTimeout) + + projectDir := filepath.Join(home, "testproject") + if err := os.MkdirAll(projectDir, 0o755); err != nil { + t.Fatal(err) + } + + err := sm.newPRD(projectDir, "testproject", "sess-timeout-1", "test") + if err != nil { + t.Fatalf("newPRD failed: %v", err) + } + + // Session should be active + if len(sm.activeSessions()) != 1 { + t.Fatal("expected 1 active session") + } + + // Wait for timeout to expire + some buffer + time.Sleep(500 * time.Millisecond) + + // Session should be expired and removed + if len(sm.activeSessions()) != 0 { + t.Errorf("expected 0 active sessions after timeout, got %d", len(sm.activeSessions())) + } + + // Check that session_expired message was sent + msgs := sender.getMessages() + hasExpired := false + for _, msg := range msgs { + if msg["type"] == "session_expired" && msg["session_id"] == "sess-timeout-1" { + hasExpired = true + break + } + } + if !hasExpired { + t.Error("expected session_expired message to be sent") + } + + sm.killAll() +} + +func TestSessionManager_TimeoutWarnings(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + + // Create mock claude + mockClaudeBin := filepath.Join(home, "claude") + mockScript := `#!/bin/sh +while IFS= read -r line; do + echo "$line" +done +` + if err := os.WriteFile(mockClaudeBin, []byte(mockScript), 0o755); err != nil { + t.Fatal(err) + } + origPath := os.Getenv("PATH") + t.Setenv("PATH", home+":"+origPath) + + sender := &captureSender{} + + // Use a 3-minute timeout with thresholds at 1 and 2 minutes. + // We simulate time by setting lastActive in the past. + sm := &sessionManager{ + sessions: make(map[string]*claudeSession), + sender: sender, + timeout: 3 * time.Minute, + warningThresholds: []int{1, 2}, // Warn at 1min and 2min of inactivity + checkInterval: 50 * time.Millisecond, // Check frequently + stopTimeout: make(chan struct{}), + } + go sm.runTimeoutChecker(sm.stopTimeout) + + projectDir := filepath.Join(home, "testproject") + if err := os.MkdirAll(projectDir, 0o755); err != nil { + t.Fatal(err) + } + + err := sm.newPRD(projectDir, "testproject", "sess-warn-1", "test") + if err != nil { + t.Fatalf("newPRD failed: %v", err) + } + + // Simulate 90 seconds of inactivity by setting lastActive in the past + sess := sm.getSession("sess-warn-1") + if sess == nil { + t.Fatal("session not found") + } + sess.activeMu.Lock() + sess.lastActive = time.Now().Add(-90 * time.Second) + sess.activeMu.Unlock() + + // Wait for the checker to pick it up + time.Sleep(200 * time.Millisecond) + + // Should have the 1-minute warning (2 remaining) + msgs := sender.getMessages() + var warningMessages []map[string]interface{} + for _, msg := range msgs { + if msg["type"] == "session_timeout_warning" { + warningMessages = append(warningMessages, msg) + } + } + + if len(warningMessages) != 1 { + t.Fatalf("expected 1 warning message, got %d", len(warningMessages)) + } + + // The warning at 1 min means 3-1 = 2 minutes remaining + if warningMessages[0]["minutes_remaining"] != float64(2) { + t.Errorf("expected minutes_remaining=2, got %v", warningMessages[0]["minutes_remaining"]) + } + + // Now simulate 2.5 minutes of inactivity + sess.activeMu.Lock() + sess.lastActive = time.Now().Add(-150 * time.Second) + sess.activeMu.Unlock() + + time.Sleep(200 * time.Millisecond) + + msgs = sender.getMessages() + warningMessages = nil + for _, msg := range msgs { + if msg["type"] == "session_timeout_warning" { + warningMessages = append(warningMessages, msg) + } + } + + // Should now have 2 warnings (1 min and 2 min thresholds) + if len(warningMessages) != 2 { + t.Fatalf("expected 2 warning messages, got %d", len(warningMessages)) + } + + sm.killAll() +} + +func TestSessionManager_TimeoutResetOnMessage(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + + // Create mock claude + mockClaudeBin := filepath.Join(home, "claude") + mockScript := `#!/bin/sh +while IFS= read -r line; do + echo "$line" +done +` + if err := os.WriteFile(mockClaudeBin, []byte(mockScript), 0o755); err != nil { + t.Fatal(err) + } + origPath := os.Getenv("PATH") + t.Setenv("PATH", home+":"+origPath) + + sm := &sessionManager{ + sessions: make(map[string]*claudeSession), + sender: &discardSender{}, + timeout: 300 * time.Millisecond, + warningThresholds: []int{}, + checkInterval: 50 * time.Millisecond, + stopTimeout: make(chan struct{}), + } + go sm.runTimeoutChecker(sm.stopTimeout) + + projectDir := filepath.Join(home, "testproject") + if err := os.MkdirAll(projectDir, 0o755); err != nil { + t.Fatal(err) + } + + err := sm.newPRD(projectDir, "testproject", "sess-reset-1", "test") + if err != nil { + t.Fatalf("newPRD failed: %v", err) + } + + // Wait 200ms (timeout is 300ms) + time.Sleep(200 * time.Millisecond) + + // Session should still be active + if len(sm.activeSessions()) != 1 { + t.Fatal("expected session to still be active before timeout") + } + + // Send a message to reset the timer + if err := sm.sendMessage("sess-reset-1", "keep alive"); err != nil { + t.Fatalf("sendMessage failed: %v", err) + } + + // Wait another 200ms (total 400ms since start, but only 200ms since last activity) + time.Sleep(200 * time.Millisecond) + + // Session should still be active because we reset the timer + if len(sm.activeSessions()) != 1 { + t.Error("expected session to still be active after timer reset") + } + + // Wait for the full timeout from last activity (another 200ms) + time.Sleep(200 * time.Millisecond) + + // Now it should have timed out + if len(sm.activeSessions()) != 0 { + t.Errorf("expected 0 active sessions after timeout, got %d", len(sm.activeSessions())) + } + + sm.killAll() +} + +func TestSessionManager_IndependentTimers(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + + // Create mock claude + mockClaudeBin := filepath.Join(home, "claude") + mockScript := `#!/bin/sh +while IFS= read -r line; do + echo "$line" +done +` + if err := os.WriteFile(mockClaudeBin, []byte(mockScript), 0o755); err != nil { + t.Fatal(err) + } + origPath := os.Getenv("PATH") + t.Setenv("PATH", home+":"+origPath) + + sender := &captureSender{} + sm := &sessionManager{ + sessions: make(map[string]*claudeSession), + sender: sender, + timeout: 300 * time.Millisecond, + warningThresholds: []int{}, + checkInterval: 50 * time.Millisecond, + stopTimeout: make(chan struct{}), + } + go sm.runTimeoutChecker(sm.stopTimeout) + + projectDir1 := filepath.Join(home, "project1") + projectDir2 := filepath.Join(home, "project2") + os.MkdirAll(projectDir1, 0o755) + os.MkdirAll(projectDir2, 0o755) + + // Start two sessions + if err := sm.newPRD(projectDir1, "project1", "sess-a", "test"); err != nil { + t.Fatalf("newPRD a failed: %v", err) + } + if err := sm.newPRD(projectDir2, "project2", "sess-b", "test"); err != nil { + t.Fatalf("newPRD b failed: %v", err) + } + + // Both should be active + if len(sm.activeSessions()) != 2 { + t.Fatalf("expected 2 active sessions, got %d", len(sm.activeSessions())) + } + + // Keep session B alive by sending a message after 200ms + time.Sleep(200 * time.Millisecond) + if err := sm.sendMessage("sess-b", "keep alive"); err != nil { + t.Fatalf("sendMessage failed: %v", err) + } + + // Wait for session A to expire (another 200ms) + time.Sleep(200 * time.Millisecond) + + // Session A should be expired, session B should still be active + sessions := sm.activeSessions() + if len(sessions) != 1 { + t.Fatalf("expected 1 active session, got %d", len(sessions)) + } + if sessions[0].SessionID != "sess-b" { + t.Errorf("expected session 'sess-b' to survive, got %q", sessions[0].SessionID) + } + + // Verify session_expired was sent for sess-a + msgs := sender.getMessages() + hasExpiredA := false + for _, msg := range msgs { + if msg["type"] == "session_expired" && msg["session_id"] == "sess-a" { + hasExpiredA = true + break + } + } + + if !hasExpiredA { + t.Error("expected session_expired for sess-a") + } + + sm.killAll() +} + +func TestSessionManager_TimeoutCheckerGoroutineSafe(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + + // Create mock claude + mockClaudeBin := filepath.Join(home, "claude") + mockScript := `#!/bin/sh +while IFS= read -r line; do + echo "$line" +done +` + if err := os.WriteFile(mockClaudeBin, []byte(mockScript), 0o755); err != nil { + t.Fatal(err) + } + origPath := os.Getenv("PATH") + t.Setenv("PATH", home+":"+origPath) + + sm := &sessionManager{ + sessions: make(map[string]*claudeSession), + sender: &discardSender{}, + timeout: 500 * time.Millisecond, + warningThresholds: []int{}, + checkInterval: 50 * time.Millisecond, + stopTimeout: make(chan struct{}), + } + go sm.runTimeoutChecker(sm.stopTimeout) + + projectDir := filepath.Join(home, "testproject") + os.MkdirAll(projectDir, 0o755) + + // Concurrently create sessions and send messages while timeout checker runs + var wg sync.WaitGroup + for i := 0; i < 5; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + sid := fmt.Sprintf("sess-conc-%d", idx) + dir := filepath.Join(home, fmt.Sprintf("proj-%d", idx)) + os.MkdirAll(dir, 0o755) + + if err := sm.newPRD(dir, fmt.Sprintf("proj-%d", idx), sid, "test"); err != nil { + t.Errorf("newPRD %s failed: %v", sid, err) + return + } + + // Send some messages + for j := 0; j < 3; j++ { + time.Sleep(50 * time.Millisecond) + sm.sendMessage(sid, fmt.Sprintf("msg-%d", j)) + } + }(i) + } + wg.Wait() + + // No crash = goroutine-safe. Wait for all to expire. + time.Sleep(700 * time.Millisecond) + + if len(sm.activeSessions()) != 0 { + t.Errorf("expected all sessions to expire, got %d active", len(sm.activeSessions())) + } + + sm.killAll() +} diff --git a/internal/cmd/settings.go b/internal/cmd/settings.go new file mode 100644 index 0000000..09c01bf --- /dev/null +++ b/internal/cmd/settings.go @@ -0,0 +1,119 @@ +package cmd + +import ( + "encoding/json" + "fmt" + "log" + + "github.com/minicodemonkey/chief/internal/config" + "github.com/minicodemonkey/chief/internal/ws" +) + +// handleGetSettings handles a get_settings request. +func handleGetSettings(sender messageSender, finder projectFinder, msg ws.Message) { + var req ws.GetSettingsMessage + if err := json.Unmarshal(msg.Raw, &req); err != nil { + log.Printf("Error parsing get_settings message: %v", err) + return + } + + project, found := finder.FindProject(req.Project) + if !found { + sendError(sender, ws.ErrCodeProjectNotFound, + fmt.Sprintf("Project %q not found", req.Project), msg.ID) + return + } + + cfg, err := config.Load(project.Path) + if err != nil { + sendError(sender, ws.ErrCodeFilesystemError, + fmt.Sprintf("Failed to load settings: %v", err), msg.ID) + return + } + + resp := ws.SettingsResponseMessage{ + Type: ws.TypeSettingsResponse, + Payload: ws.SettingsResponsePayload{ + Project: req.Project, + Settings: ws.SettingsData{ + MaxIterations: cfg.EffectiveMaxIterations(), + AutoCommit: cfg.EffectiveAutoCommit(), + CommitPrefix: cfg.CommitPrefix, + ClaudeModel: cfg.ClaudeModel, + TestCommand: cfg.TestCommand, + }, + }, + } + if err := sender.Send(resp); err != nil { + log.Printf("Error sending settings_response: %v", err) + } +} + +// handleUpdateSettings handles an update_settings request. +func handleUpdateSettings(sender messageSender, finder projectFinder, msg ws.Message) { + var req ws.UpdateSettingsMessage + if err := json.Unmarshal(msg.Raw, &req); err != nil { + log.Printf("Error parsing update_settings message: %v", err) + return + } + + project, found := finder.FindProject(req.Project) + if !found { + sendError(sender, ws.ErrCodeProjectNotFound, + fmt.Sprintf("Project %q not found", req.Project), msg.ID) + return + } + + cfg, err := config.Load(project.Path) + if err != nil { + sendError(sender, ws.ErrCodeFilesystemError, + fmt.Sprintf("Failed to load settings: %v", err), msg.ID) + return + } + + // Merge provided fields + if req.MaxIterations != nil { + if *req.MaxIterations < 1 { + sendError(sender, ws.ErrCodeFilesystemError, + "max_iterations must be at least 1", msg.ID) + return + } + cfg.MaxIterations = *req.MaxIterations + } + if req.AutoCommit != nil { + cfg.AutoCommit = req.AutoCommit + } + if req.CommitPrefix != nil { + cfg.CommitPrefix = *req.CommitPrefix + } + if req.ClaudeModel != nil { + cfg.ClaudeModel = *req.ClaudeModel + } + if req.TestCommand != nil { + cfg.TestCommand = *req.TestCommand + } + + if err := config.Save(project.Path, cfg); err != nil { + sendError(sender, ws.ErrCodeFilesystemError, + fmt.Sprintf("Failed to save settings: %v", err), msg.ID) + return + } + + // Echo back full updated settings + resp := ws.SettingsResponseMessage{ + Type: ws.TypeSettingsUpdated, + Payload: ws.SettingsResponsePayload{ + Project: req.Project, + Settings: ws.SettingsData{ + MaxIterations: cfg.EffectiveMaxIterations(), + AutoCommit: cfg.EffectiveAutoCommit(), + CommitPrefix: cfg.CommitPrefix, + ClaudeModel: cfg.ClaudeModel, + TestCommand: cfg.TestCommand, + }, + }, + } + if err := sender.Send(resp); err != nil { + log.Printf("Error sending settings_updated: %v", err) + } +} diff --git a/internal/cmd/settings_test.go b/internal/cmd/settings_test.go new file mode 100644 index 0000000..bb70377 --- /dev/null +++ b/internal/cmd/settings_test.go @@ -0,0 +1,488 @@ +package cmd + +import ( + "encoding/json" + "os" + "path/filepath" + "sync" + "testing" + "time" + + "github.com/minicodemonkey/chief/internal/config" +) + +func TestRunServe_GetSettings_Defaults(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + createGitRepo(t, filepath.Join(workspaceDir, "myproject")) + + var settingsReceived map[string]interface{} + var mu sync.Mutex + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + req := map[string]string{ + "type": "get_settings", + "id": "req-1", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "project": "myproject", + } + if err := ms.sendCommand(req); err != nil { + t.Errorf("sendCommand failed: %v", err) + return + } + + raw, err := ms.waitForMessageType("settings_response", 5*time.Second) + if err == nil { + mu.Lock() + json.Unmarshal(raw, &settingsReceived) + mu.Unlock() + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + mu.Lock() + defer mu.Unlock() + + if settingsReceived == nil { + t.Fatal("settings_response was not received") + } + if settingsReceived["type"] != "settings_response" { + t.Errorf("expected type 'settings_response', got %v", settingsReceived["type"]) + } + payload, ok := settingsReceived["payload"].(map[string]interface{}) + if !ok { + t.Fatal("expected payload to be an object") + } + if payload["project"] != "myproject" { + t.Errorf("expected project 'myproject', got %v", payload["project"]) + } + settings, ok := payload["settings"].(map[string]interface{}) + if !ok { + t.Fatal("expected settings to be an object") + } + // Default max_iterations should be 5 + if maxIter, ok := settings["max_iterations"].(float64); !ok || int(maxIter) != 5 { + t.Errorf("expected max_iterations 5, got %v", settings["max_iterations"]) + } + // Default auto_commit should be true + if autoCommit, ok := settings["auto_commit"].(bool); !ok || !autoCommit { + t.Errorf("expected auto_commit true, got %v", settings["auto_commit"]) + } + // Other fields should be empty strings + if settings["commit_prefix"] != "" { + t.Errorf("expected empty commit_prefix, got %v", settings["commit_prefix"]) + } + if settings["claude_model"] != "" { + t.Errorf("expected empty claude_model, got %v", settings["claude_model"]) + } + if settings["test_command"] != "" { + t.Errorf("expected empty test_command, got %v", settings["test_command"]) + } +} + +func TestRunServe_GetSettings_ProjectNotFound(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + var errorReceived map[string]interface{} + var mu sync.Mutex + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + req := map[string]string{ + "type": "get_settings", + "id": "req-1", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "project": "nonexistent", + } + if err := ms.sendCommand(req); err != nil { + t.Errorf("sendCommand failed: %v", err) + return + } + + raw, err := ms.waitForMessageType("error", 5*time.Second) + if err == nil { + mu.Lock() + json.Unmarshal(raw, &errorReceived) + mu.Unlock() + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + mu.Lock() + defer mu.Unlock() + + if errorReceived == nil { + t.Fatal("error message was not received") + } + if errorReceived["type"] != "error" { + t.Errorf("expected type 'error', got %v", errorReceived["type"]) + } + if errorReceived["code"] != "PROJECT_NOT_FOUND" { + t.Errorf("expected code 'PROJECT_NOT_FOUND', got %v", errorReceived["code"]) + } +} + +func TestRunServe_GetSettings_WithExistingConfig(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + projectDir := filepath.Join(workspaceDir, "myproject") + createGitRepo(t, projectDir) + + // Write existing config + autoCommit := false + cfg := &config.Config{ + MaxIterations: 10, + AutoCommit: &autoCommit, + CommitPrefix: "fix:", + ClaudeModel: "claude-sonnet-4-5-20250929", + TestCommand: "npm test", + } + if err := config.Save(projectDir, cfg); err != nil { + t.Fatal(err) + } + + var settingsReceived map[string]interface{} + var mu sync.Mutex + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + req := map[string]string{ + "type": "get_settings", + "id": "req-1", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "project": "myproject", + } + if err := ms.sendCommand(req); err != nil { + t.Errorf("sendCommand failed: %v", err) + return + } + + raw, err := ms.waitForMessageType("settings_response", 5*time.Second) + if err == nil { + mu.Lock() + json.Unmarshal(raw, &settingsReceived) + mu.Unlock() + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + mu.Lock() + defer mu.Unlock() + + if settingsReceived == nil { + t.Fatal("settings_response was not received") + } + payload := settingsReceived["payload"].(map[string]interface{}) + settings := payload["settings"].(map[string]interface{}) + if maxIter, ok := settings["max_iterations"].(float64); !ok || int(maxIter) != 10 { + t.Errorf("expected max_iterations 10, got %v", settings["max_iterations"]) + } + if autoCommitVal, ok := settings["auto_commit"].(bool); !ok || autoCommitVal { + t.Errorf("expected auto_commit false, got %v", settings["auto_commit"]) + } + if settings["commit_prefix"] != "fix:" { + t.Errorf("expected commit_prefix 'fix:', got %v", settings["commit_prefix"]) + } + if settings["claude_model"] != "claude-sonnet-4-5-20250929" { + t.Errorf("expected claude_model 'claude-sonnet-4-5-20250929', got %v", settings["claude_model"]) + } + if settings["test_command"] != "npm test" { + t.Errorf("expected test_command 'npm test', got %v", settings["test_command"]) + } +} + +func TestRunServe_UpdateSettings(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + projectDir := filepath.Join(workspaceDir, "myproject") + createGitRepo(t, projectDir) + + var settingsReceived map[string]interface{} + var mu sync.Mutex + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + maxIter := 8 + autoCommit := false + commitPrefix := "chore:" + claudeModel := "claude-sonnet-4-5-20250929" + testCommand := "go test ./..." + + req := map[string]interface{}{ + "type": "update_settings", + "id": "req-1", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "project": "myproject", + "max_iterations": maxIter, + "auto_commit": autoCommit, + "commit_prefix": commitPrefix, + "claude_model": claudeModel, + "test_command": testCommand, + } + if err := ms.sendCommand(req); err != nil { + t.Errorf("sendCommand failed: %v", err) + return + } + + raw, err := ms.waitForMessageType("settings_updated", 5*time.Second) + if err == nil { + mu.Lock() + json.Unmarshal(raw, &settingsReceived) + mu.Unlock() + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + mu.Lock() + defer mu.Unlock() + + if settingsReceived == nil { + t.Fatal("settings_updated was not received") + } + if settingsReceived["type"] != "settings_updated" { + t.Errorf("expected type 'settings_updated', got %v", settingsReceived["type"]) + } + payload := settingsReceived["payload"].(map[string]interface{}) + settings := payload["settings"].(map[string]interface{}) + if maxIter, ok := settings["max_iterations"].(float64); !ok || int(maxIter) != 8 { + t.Errorf("expected max_iterations 8, got %v", settings["max_iterations"]) + } + if autoCommitVal, ok := settings["auto_commit"].(bool); !ok || autoCommitVal { + t.Errorf("expected auto_commit false, got %v", settings["auto_commit"]) + } + if settings["commit_prefix"] != "chore:" { + t.Errorf("expected commit_prefix 'chore:', got %v", settings["commit_prefix"]) + } + if settings["claude_model"] != "claude-sonnet-4-5-20250929" { + t.Errorf("expected claude_model 'claude-sonnet-4-5-20250929', got %v", settings["claude_model"]) + } + if settings["test_command"] != "go test ./..." { + t.Errorf("expected test_command 'go test ./...', got %v", settings["test_command"]) + } + + // Verify the config was persisted to disk + cfg, err := config.Load(filepath.Join(workspaceDir, "myproject")) + if err != nil { + t.Fatalf("config.Load failed: %v", err) + } + if cfg.MaxIterations != 8 { + t.Errorf("expected saved max_iterations 8, got %d", cfg.MaxIterations) + } + if cfg.AutoCommit == nil || *cfg.AutoCommit != false { + t.Errorf("expected saved auto_commit false, got %v", cfg.AutoCommit) + } +} + +func TestRunServe_UpdateSettings_PartialUpdate(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + projectDir := filepath.Join(workspaceDir, "myproject") + createGitRepo(t, projectDir) + + // Set initial config + autoCommit := false + cfg := &config.Config{ + MaxIterations: 10, + AutoCommit: &autoCommit, + CommitPrefix: "fix:", + TestCommand: "npm test", + } + if err := config.Save(projectDir, cfg); err != nil { + t.Fatal(err) + } + + var settingsReceived map[string]interface{} + var mu sync.Mutex + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + // Only update test_command — other fields should be preserved + req := map[string]interface{}{ + "type": "update_settings", + "id": "req-1", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "project": "myproject", + "test_command": "go test ./...", + } + if err := ms.sendCommand(req); err != nil { + t.Errorf("sendCommand failed: %v", err) + return + } + + raw, err := ms.waitForMessageType("settings_updated", 5*time.Second) + if err == nil { + mu.Lock() + json.Unmarshal(raw, &settingsReceived) + mu.Unlock() + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + mu.Lock() + defer mu.Unlock() + + if settingsReceived == nil { + t.Fatal("settings_updated was not received") + } + payload := settingsReceived["payload"].(map[string]interface{}) + settings := payload["settings"].(map[string]interface{}) + // Existing values should be preserved + if maxIter, ok := settings["max_iterations"].(float64); !ok || int(maxIter) != 10 { + t.Errorf("expected max_iterations 10 preserved, got %v", settings["max_iterations"]) + } + if autoCommitVal, ok := settings["auto_commit"].(bool); !ok || autoCommitVal { + t.Errorf("expected auto_commit false preserved, got %v", settings["auto_commit"]) + } + if settings["commit_prefix"] != "fix:" { + t.Errorf("expected commit_prefix 'fix:' preserved, got %v", settings["commit_prefix"]) + } + // Updated value + if settings["test_command"] != "go test ./..." { + t.Errorf("expected test_command 'go test ./...', got %v", settings["test_command"]) + } +} + +func TestRunServe_UpdateSettings_ProjectNotFound(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + var errorReceived map[string]interface{} + var mu sync.Mutex + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + req := map[string]interface{}{ + "type": "update_settings", + "id": "req-1", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "project": "nonexistent", + "max_iterations": 3, + } + if err := ms.sendCommand(req); err != nil { + t.Errorf("sendCommand failed: %v", err) + return + } + + raw, err := ms.waitForMessageType("error", 5*time.Second) + if err == nil { + mu.Lock() + json.Unmarshal(raw, &errorReceived) + mu.Unlock() + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + mu.Lock() + defer mu.Unlock() + + if errorReceived == nil { + t.Fatal("error message was not received") + } + if errorReceived["type"] != "error" { + t.Errorf("expected type 'error', got %v", errorReceived["type"]) + } + if errorReceived["code"] != "PROJECT_NOT_FOUND" { + t.Errorf("expected code 'PROJECT_NOT_FOUND', got %v", errorReceived["code"]) + } +} + +func TestRunServe_UpdateSettings_InvalidMaxIterations(t *testing.T) { + home := t.TempDir() + setTestHome(t, home) + setupServeCredentials(t) + + workspaceDir := filepath.Join(home, "projects") + if err := os.MkdirAll(workspaceDir, 0o755); err != nil { + t.Fatal(err) + } + + createGitRepo(t, filepath.Join(workspaceDir, "myproject")) + + var errorReceived map[string]interface{} + var mu sync.Mutex + + err := serveTestHelper(t, workspaceDir, func(ms *mockUplinkServer) { + req := map[string]interface{}{ + "type": "update_settings", + "id": "req-1", + "timestamp": time.Now().UTC().Format(time.RFC3339), + "project": "myproject", + "max_iterations": 0, + } + if err := ms.sendCommand(req); err != nil { + t.Errorf("sendCommand failed: %v", err) + return + } + + raw, err := ms.waitForMessageType("error", 5*time.Second) + if err == nil { + mu.Lock() + json.Unmarshal(raw, &errorReceived) + mu.Unlock() + } + }) + if err != nil { + t.Fatalf("RunServe returned error: %v", err) + } + + mu.Lock() + defer mu.Unlock() + + if errorReceived == nil { + t.Fatal("error message was not received") + } + if errorReceived["type"] != "error" { + t.Errorf("expected type 'error', got %v", errorReceived["type"]) + } + if errorReceived["code"] != "FILESYSTEM_ERROR" { + t.Errorf("expected code 'FILESYSTEM_ERROR', got %v", errorReceived["code"]) + } +} diff --git a/internal/cmd/update.go b/internal/cmd/update.go new file mode 100644 index 0000000..458a551 --- /dev/null +++ b/internal/cmd/update.go @@ -0,0 +1,72 @@ +package cmd + +import ( + "fmt" + "log" + + "github.com/minicodemonkey/chief/internal/update" +) + +// UpdateOptions holds configuration for the update command. +type UpdateOptions struct { + Version string // Current version (from build ldflags) + ReleasesURL string // Override GitHub API URL (for testing) +} + +// RunUpdate downloads and installs the latest version of Chief. +func RunUpdate(opts UpdateOptions) error { + fmt.Println("Checking for updates...") + + // First check if an update is available + result, err := update.CheckForUpdate(opts.Version, update.Options{ + ReleasesURL: opts.ReleasesURL, + }) + if err != nil { + return fmt.Errorf("checking for updates: %w", err) + } + + if !result.UpdateAvailable { + fmt.Printf("Already on latest version (v%s).\n", result.CurrentVersion) + return nil + } + + fmt.Printf("Downloading v%s (you have v%s)...\n", result.LatestVersion, result.CurrentVersion) + + // Perform the update + if _, err := update.PerformUpdate(opts.Version, update.Options{ + ReleasesURL: opts.ReleasesURL, + }); err != nil { + return err + } + + fmt.Printf("Updated to v%s. Restart 'chief serve' to apply.\n", result.LatestVersion) + return nil +} + +// CheckVersionOnStartup performs a non-blocking version check and prints a message if an update is available. +// This is called on startup for interactive CLI commands. +func CheckVersionOnStartup(version string) { + go func() { + result, err := update.CheckForUpdate(version, update.Options{}) + if err != nil { + // Silently fail — version check is best-effort + return + } + if result.UpdateAvailable { + fmt.Printf("Chief v%s available (you have v%s). Run 'chief update' to upgrade.\n", + result.LatestVersion, result.CurrentVersion) + } + }() +} + +// CheckVersionForServe performs a version check and returns the result for use by the serve command. +func CheckVersionForServe(version, releasesURL string) *update.CheckResult { + result, err := update.CheckForUpdate(version, update.Options{ + ReleasesURL: releasesURL, + }) + if err != nil { + log.Printf("Version check failed: %v", err) + return nil + } + return result +} diff --git a/internal/cmd/update_test.go b/internal/cmd/update_test.go new file mode 100644 index 0000000..37dbf0a --- /dev/null +++ b/internal/cmd/update_test.go @@ -0,0 +1,104 @@ +package cmd + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/minicodemonkey/chief/internal/update" +) + +func TestRunUpdate_AlreadyLatest(t *testing.T) { + release := update.Release{TagName: "v0.5.0"} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(release) + })) + defer srv.Close() + + err := RunUpdate(UpdateOptions{ + Version: "0.5.0", + ReleasesURL: srv.URL, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestRunUpdate_APIError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer srv.Close() + + err := RunUpdate(UpdateOptions{ + Version: "0.5.0", + ReleasesURL: srv.URL, + }) + if err == nil { + t.Error("expected error for API failure") + } +} + +func TestCheckVersionForServe_UpdateAvailable(t *testing.T) { + release := update.Release{TagName: "v0.6.0"} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(release) + })) + defer srv.Close() + + result := CheckVersionForServe("0.5.0", srv.URL) + if result == nil { + t.Fatal("expected non-nil result") + } + if !result.UpdateAvailable { + t.Error("expected update to be available") + } + if result.LatestVersion != "0.6.0" { + t.Errorf("expected latest version 0.6.0, got %s", result.LatestVersion) + } +} + +func TestCheckVersionForServe_NoUpdate(t *testing.T) { + release := update.Release{TagName: "v0.5.0"} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(release) + })) + defer srv.Close() + + result := CheckVersionForServe("0.5.0", srv.URL) + if result == nil { + t.Fatal("expected non-nil result") + } + if result.UpdateAvailable { + t.Error("expected no update available") + } +} + +func TestCheckVersionForServe_APIFailure(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer srv.Close() + + result := CheckVersionForServe("0.5.0", srv.URL) + if result != nil { + t.Error("expected nil result on API failure") + } +} + +func TestCheckVersionForServe_DevVersion(t *testing.T) { + release := update.Release{TagName: "v1.0.0"} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(release) + })) + defer srv.Close() + + result := CheckVersionForServe("dev", srv.URL) + if result == nil { + t.Fatal("expected non-nil result") + } + if result.UpdateAvailable { + t.Error("dev version should not report update available") + } +} diff --git a/internal/config/config.go b/internal/config/config.go index 6f43ec0..84084f9 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,6 +1,7 @@ package config import ( + "fmt" "os" "path/filepath" @@ -9,10 +10,45 @@ import ( const configFile = ".chief/config.yaml" +// UserConfig holds user-level settings from ~/.chief/config.yaml. +type UserConfig struct { + WSURL string `yaml:"ws_url,omitempty"` +} + +// LoadUserConfig reads the user-level config from ~/.chief/config.yaml. +// Returns an empty UserConfig when the file doesn't exist (no error). +func LoadUserConfig() (*UserConfig, error) { + home, err := os.UserHomeDir() + if err != nil { + return &UserConfig{}, fmt.Errorf("determining home directory: %w", err) + } + + path := filepath.Join(home, ".chief", "config.yaml") + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return &UserConfig{}, nil + } + return &UserConfig{}, err + } + + cfg := &UserConfig{} + if err := yaml.Unmarshal(data, cfg); err != nil { + return &UserConfig{}, err + } + + return cfg, nil +} + // Config holds project-level settings for Chief. type Config struct { - Worktree WorktreeConfig `yaml:"worktree"` - OnComplete OnCompleteConfig `yaml:"onComplete"` + Worktree WorktreeConfig `yaml:"worktree"` + OnComplete OnCompleteConfig `yaml:"onComplete"` + MaxIterations int `yaml:"maxIterations,omitempty"` + AutoCommit *bool `yaml:"autoCommit,omitempty"` + CommitPrefix string `yaml:"commitPrefix,omitempty"` + ClaudeModel string `yaml:"claudeModel,omitempty"` + TestCommand string `yaml:"testCommand,omitempty"` } // WorktreeConfig holds worktree-related settings. @@ -26,11 +62,30 @@ type OnCompleteConfig struct { CreatePR bool `yaml:"createPR"` } +// DefaultMaxIterations is the default value for MaxIterations when not set. +const DefaultMaxIterations = 5 + // Default returns a Config with zero-value defaults. func Default() *Config { return &Config{} } +// EffectiveMaxIterations returns MaxIterations or the default if not set. +func (c *Config) EffectiveMaxIterations() int { + if c.MaxIterations > 0 { + return c.MaxIterations + } + return DefaultMaxIterations +} + +// EffectiveAutoCommit returns AutoCommit or true if not set. +func (c *Config) EffectiveAutoCommit() bool { + if c.AutoCommit != nil { + return *c.AutoCommit + } + return true +} + // configPath returns the full path to the config file. func configPath(baseDir string) string { return filepath.Join(baseDir, configFile) diff --git a/internal/config/config_test.go b/internal/config/config_test.go index dacee39..7a67823 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -62,6 +62,121 @@ func TestSaveAndLoad(t *testing.T) { } } +func TestSaveAndLoadSettingsFields(t *testing.T) { + dir := t.TempDir() + autoCommit := false + cfg := &Config{ + MaxIterations: 10, + AutoCommit: &autoCommit, + CommitPrefix: "feat:", + ClaudeModel: "claude-sonnet-4-5-20250929", + TestCommand: "go test ./...", + } + + if err := Save(dir, cfg); err != nil { + t.Fatalf("Save failed: %v", err) + } + + loaded, err := Load(dir) + if err != nil { + t.Fatalf("Load failed: %v", err) + } + + if loaded.MaxIterations != 10 { + t.Errorf("expected MaxIterations 10, got %d", loaded.MaxIterations) + } + if loaded.AutoCommit == nil || *loaded.AutoCommit != false { + t.Errorf("expected AutoCommit false, got %v", loaded.AutoCommit) + } + if loaded.CommitPrefix != "feat:" { + t.Errorf("expected CommitPrefix %q, got %q", "feat:", loaded.CommitPrefix) + } + if loaded.ClaudeModel != "claude-sonnet-4-5-20250929" { + t.Errorf("expected ClaudeModel %q, got %q", "claude-sonnet-4-5-20250929", loaded.ClaudeModel) + } + if loaded.TestCommand != "go test ./..." { + t.Errorf("expected TestCommand %q, got %q", "go test ./...", loaded.TestCommand) + } +} + +func TestEffectiveDefaults(t *testing.T) { + cfg := Default() + + if cfg.EffectiveMaxIterations() != 5 { + t.Errorf("expected EffectiveMaxIterations 5, got %d", cfg.EffectiveMaxIterations()) + } + if !cfg.EffectiveAutoCommit() { + t.Error("expected EffectiveAutoCommit true") + } + + // With explicit values + cfg.MaxIterations = 3 + autoCommit := false + cfg.AutoCommit = &autoCommit + + if cfg.EffectiveMaxIterations() != 3 { + t.Errorf("expected EffectiveMaxIterations 3, got %d", cfg.EffectiveMaxIterations()) + } + if cfg.EffectiveAutoCommit() { + t.Error("expected EffectiveAutoCommit false") + } +} + +func TestLoadUserConfig_NonExistent(t *testing.T) { + home := t.TempDir() + t.Setenv("HOME", home) + + cfg, err := LoadUserConfig() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cfg.WSURL != "" { + t.Errorf("expected empty WSURL, got %q", cfg.WSURL) + } +} + +func TestLoadUserConfig_WithWSURL(t *testing.T) { + home := t.TempDir() + t.Setenv("HOME", home) + + chiefDir := filepath.Join(home, ".chief") + if err := os.MkdirAll(chiefDir, 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(chiefDir, "config.yaml"), []byte("ws_url: ws://localhost:8080/ws/server\n"), 0o644); err != nil { + t.Fatal(err) + } + + cfg, err := LoadUserConfig() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cfg.WSURL != "ws://localhost:8080/ws/server" { + t.Errorf("expected ws://localhost:8080/ws/server, got %q", cfg.WSURL) + } +} + +func TestLoadUserConfig_EmptyWSURL(t *testing.T) { + home := t.TempDir() + t.Setenv("HOME", home) + + chiefDir := filepath.Join(home, ".chief") + if err := os.MkdirAll(chiefDir, 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(chiefDir, "config.yaml"), []byte("{}\n"), 0o644); err != nil { + t.Fatal(err) + } + + cfg, err := LoadUserConfig() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cfg.WSURL != "" { + t.Errorf("expected empty WSURL, got %q", cfg.WSURL) + } +} + func TestExists(t *testing.T) { dir := t.TempDir() diff --git a/internal/contract/contract_test.go b/internal/contract/contract_test.go new file mode 100644 index 0000000..b3d9a88 --- /dev/null +++ b/internal/contract/contract_test.go @@ -0,0 +1,558 @@ +package contract + +import ( + "encoding/json" + "os" + "path/filepath" + "runtime" + "testing" + + "github.com/minicodemonkey/chief/internal/uplink" + "github.com/minicodemonkey/chief/internal/ws" +) + +// fixturesDir returns the absolute path to contract/fixtures relative to the repo root. +func fixturesDir(t *testing.T) string { + t.Helper() + // Determine repo root from this test file's location: + // internal/contract/contract_test.go → ../../contract/fixtures + _, thisFile, _, ok := runtime.Caller(0) + if !ok { + t.Fatal("cannot determine test file location") + } + return filepath.Join(filepath.Dir(thisFile), "..", "..", "contract", "fixtures") +} + +func loadFixture(t *testing.T, relPath string) []byte { + t.Helper() + data, err := os.ReadFile(filepath.Join(fixturesDir(t), relPath)) + if err != nil { + t.Fatalf("loading fixture %s: %v", relPath, err) + } + return data +} + +// --- server-to-cli fixtures --- + +func TestWelcomeResponse_Deserialize(t *testing.T) { + data := loadFixture(t, "server-to-cli/welcome_response.json") + + var welcome uplink.WelcomeResponse + if err := json.Unmarshal(data, &welcome); err != nil { + t.Fatalf("failed to unmarshal welcome_response.json: %v", err) + } + + if welcome.Type != "welcome" { + t.Errorf("type = %q, want %q", welcome.Type, "welcome") + } + if welcome.ProtocolVersion != 1 { + t.Errorf("protocol_version = %d, want 1", welcome.ProtocolVersion) + } + if welcome.DeviceID != 42 { + t.Errorf("device_id = %d, want 42", welcome.DeviceID) + } + if welcome.SessionID != "550e8400-e29b-41d4-a716-446655440000" { + t.Errorf("session_id = %q, want UUID", welcome.SessionID) + } + + // Reverb config — port MUST be an int, not a string + if welcome.Reverb.Port != 8080 { + t.Errorf("reverb.port = %d, want 8080", welcome.Reverb.Port) + } + if welcome.Reverb.Key != "test-app-key" { + t.Errorf("reverb.key = %q, want %q", welcome.Reverb.Key, "test-app-key") + } + if welcome.Reverb.Host != "127.0.0.1" { + t.Errorf("reverb.host = %q, want %q", welcome.Reverb.Host, "127.0.0.1") + } + if welcome.Reverb.Scheme != "https" { + t.Errorf("reverb.scheme = %q, want %q", welcome.Reverb.Scheme, "https") + } +} + +func TestWelcomeResponse_PortIsInt(t *testing.T) { + // Regression: PHP env() returns strings — verify port decodes as int. + data := loadFixture(t, "server-to-cli/welcome_response.json") + + var raw map[string]json.RawMessage + if err := json.Unmarshal(data, &raw); err != nil { + t.Fatal(err) + } + + var reverb map[string]json.RawMessage + json.Unmarshal(raw["reverb"], &reverb) + + // Verify port is a JSON number, not a string + portStr := string(reverb["port"]) + if portStr == `"8080"` { + t.Fatal("reverb.port is a JSON string — must be a number") + } + if portStr != "8080" { + t.Errorf("reverb.port raw JSON = %s, want 8080", portStr) + } +} + +func TestCommandCreateProject_PayloadWrapper(t *testing.T) { + data := loadFixture(t, "server-to-cli/command_create_project.json") + + // Verify the envelope has type + payload + var env struct { + Type string `json:"type"` + Payload json.RawMessage `json:"payload,omitempty"` + } + if err := json.Unmarshal(data, &env); err != nil { + t.Fatalf("failed to unmarshal command envelope: %v", err) + } + + if env.Type != "create_project" { + t.Errorf("envelope type = %q, want %q", env.Type, "create_project") + } + if len(env.Payload) == 0 { + t.Fatal("envelope payload is empty — commands must have payload wrapper") + } + + // The payload itself should parse into CreateProjectMessage fields + var req ws.CreateProjectMessage + if err := json.Unmarshal(env.Payload, &req); err != nil { + t.Fatalf("failed to unmarshal payload into CreateProjectMessage: %v", err) + } + + if req.Name != "new-project" { + t.Errorf("payload.name = %q, want %q", req.Name, "new-project") + } + if !req.GitInit { + t.Error("payload.git_init = false, want true") + } +} + +func TestCommandStartRun_PayloadWrapper(t *testing.T) { + data := loadFixture(t, "server-to-cli/command_start_run.json") + + var env struct { + Type string `json:"type"` + Payload json.RawMessage `json:"payload,omitempty"` + } + if err := json.Unmarshal(data, &env); err != nil { + t.Fatalf("failed to unmarshal command envelope: %v", err) + } + + if env.Type != "start_run" { + t.Errorf("envelope type = %q, want %q", env.Type, "start_run") + } + + var req ws.StartRunMessage + if err := json.Unmarshal(env.Payload, &req); err != nil { + t.Fatalf("failed to unmarshal payload into StartRunMessage: %v", err) + } + + if req.Project != "my-project" { + t.Errorf("payload.project = %q, want %q", req.Project, "my-project") + } + if req.PRDID != "feature-auth" { + t.Errorf("payload.prd_id = %q, want %q", req.PRDID, "feature-auth") + } +} + +func TestCommandListProjects_PayloadWrapper(t *testing.T) { + data := loadFixture(t, "server-to-cli/command_list_projects.json") + + var env struct { + Type string `json:"type"` + Payload json.RawMessage `json:"payload,omitempty"` + } + if err := json.Unmarshal(data, &env); err != nil { + t.Fatalf("failed to unmarshal command envelope: %v", err) + } + + if env.Type != "list_projects" { + t.Errorf("envelope type = %q, want %q", env.Type, "list_projects") + } +} + +// --- cli-to-server fixtures --- + +func TestStateSnapshot_Roundtrip(t *testing.T) { + data := loadFixture(t, "cli-to-server/state_snapshot.json") + + // Unmarshal into the Go struct + var snapshot ws.StateSnapshotMessage + if err := json.Unmarshal(data, &snapshot); err != nil { + t.Fatalf("failed to unmarshal state_snapshot.json: %v", err) + } + + if snapshot.Type != "state_snapshot" { + t.Errorf("type = %q, want %q", snapshot.Type, "state_snapshot") + } + if len(snapshot.Projects) != 1 { + t.Fatalf("projects count = %d, want 1", len(snapshot.Projects)) + } + + // Verify project uses "name" field, not "project_slug" + proj := snapshot.Projects[0] + if proj.Name != "my-project" { + t.Errorf("project.name = %q, want %q", proj.Name, "my-project") + } + if proj.Path != "/home/user/projects/my-project" { + t.Errorf("project.path = %q", proj.Path) + } + if !proj.HasChief { + t.Error("project.has_chief = false, want true") + } + if proj.Branch != "main" { + t.Errorf("project.branch = %q, want %q", proj.Branch, "main") + } + if proj.Commit.Hash != "abc1234" { + t.Errorf("project.commit.hash = %q, want %q", proj.Commit.Hash, "abc1234") + } + + // Re-marshal and verify it round-trips cleanly + remarshaled, err := json.Marshal(snapshot) + if err != nil { + t.Fatalf("failed to re-marshal: %v", err) + } + + var roundtrip ws.StateSnapshotMessage + if err := json.Unmarshal(remarshaled, &roundtrip); err != nil { + t.Fatalf("failed to unmarshal round-trip: %v", err) + } + if roundtrip.Projects[0].Name != "my-project" { + t.Errorf("round-trip project.name = %q, want %q", roundtrip.Projects[0].Name, "my-project") + } +} + +func TestStateSnapshot_NameFieldNotProjectSlug(t *testing.T) { + // Regression: CLI sends "name", not "project_slug". + data := loadFixture(t, "cli-to-server/state_snapshot.json") + + var raw map[string]json.RawMessage + json.Unmarshal(data, &raw) + + var projects []map[string]json.RawMessage + json.Unmarshal(raw["projects"], &projects) + + if len(projects) == 0 { + t.Fatal("no projects in fixture") + } + + proj := projects[0] + if _, hasName := proj["name"]; !hasName { + t.Error("project should have 'name' field") + } + if _, hasSlug := proj["project_slug"]; hasSlug { + t.Error("project should NOT have 'project_slug' field — CLI uses 'name'") + } +} + +func TestConnectRequest_Deserialize(t *testing.T) { + data := loadFixture(t, "cli-to-server/connect_request.json") + + var req struct { + ChiefVersion string `json:"chief_version"` + DeviceName string `json:"device_name"` + OS string `json:"os"` + Arch string `json:"arch"` + ProtocolVersion int `json:"protocol_version"` + } + if err := json.Unmarshal(data, &req); err != nil { + t.Fatalf("failed to unmarshal connect_request.json: %v", err) + } + + if req.ChiefVersion != "1.0.0" { + t.Errorf("chief_version = %q, want %q", req.ChiefVersion, "1.0.0") + } + if req.ProtocolVersion != 1 { + t.Errorf("protocol_version = %d, want 1", req.ProtocolVersion) + } + if req.OS == "" { + t.Error("os should not be empty") + } +} + +func TestCommandGetPRDs_PayloadWrapper(t *testing.T) { + data := loadFixture(t, "server-to-cli/command_get_prds.json") + + var env struct { + Type string `json:"type"` + Payload json.RawMessage `json:"payload,omitempty"` + } + if err := json.Unmarshal(data, &env); err != nil { + t.Fatalf("failed to unmarshal command envelope: %v", err) + } + + if env.Type != "get_prds" { + t.Errorf("envelope type = %q, want %q", env.Type, "get_prds") + } + + var req ws.GetPRDsMessage + if err := json.Unmarshal(env.Payload, &req); err != nil { + t.Fatalf("failed to unmarshal payload into GetPRDsMessage: %v", err) + } + + if req.Project != "my-project" { + t.Errorf("payload.project = %q, want %q", req.Project, "my-project") + } +} + +func TestCommandGetSettings_PayloadWrapper(t *testing.T) { + data := loadFixture(t, "server-to-cli/command_get_settings.json") + + var env struct { + Type string `json:"type"` + Payload json.RawMessage `json:"payload,omitempty"` + } + if err := json.Unmarshal(data, &env); err != nil { + t.Fatalf("failed to unmarshal command envelope: %v", err) + } + + if env.Type != "get_settings" { + t.Errorf("envelope type = %q, want %q", env.Type, "get_settings") + } + + var req ws.GetSettingsMessage + if err := json.Unmarshal(env.Payload, &req); err != nil { + t.Fatalf("failed to unmarshal payload into GetSettingsMessage: %v", err) + } + + if req.Project != "my-project" { + t.Errorf("payload.project = %q, want %q", req.Project, "my-project") + } +} + +func TestCommandGetDiffs_PayloadWrapper(t *testing.T) { + data := loadFixture(t, "server-to-cli/command_get_diffs.json") + + var env struct { + Type string `json:"type"` + Payload json.RawMessage `json:"payload,omitempty"` + } + if err := json.Unmarshal(data, &env); err != nil { + t.Fatalf("failed to unmarshal command envelope: %v", err) + } + + if env.Type != "get_diffs" { + t.Errorf("envelope type = %q, want %q", env.Type, "get_diffs") + } + + var req ws.GetDiffsMessage + if err := json.Unmarshal(env.Payload, &req); err != nil { + t.Fatalf("failed to unmarshal payload into GetDiffsMessage: %v", err) + } + + if req.Project != "my-project" { + t.Errorf("payload.project = %q, want %q", req.Project, "my-project") + } + if req.StoryID != "US-001" { + t.Errorf("payload.story_id = %q, want %q", req.StoryID, "US-001") + } +} + +func TestCommandNewPRD_PayloadWrapper(t *testing.T) { + data := loadFixture(t, "server-to-cli/command_new_prd.json") + + var env struct { + Type string `json:"type"` + Payload json.RawMessage `json:"payload,omitempty"` + } + if err := json.Unmarshal(data, &env); err != nil { + t.Fatalf("failed to unmarshal command envelope: %v", err) + } + + if env.Type != "new_prd" { + t.Errorf("envelope type = %q, want %q", env.Type, "new_prd") + } + + var req ws.NewPRDMessage + if err := json.Unmarshal(env.Payload, &req); err != nil { + t.Fatalf("failed to unmarshal payload into NewPRDMessage: %v", err) + } + + if req.Project != "my-project" { + t.Errorf("payload.project = %q, want %q", req.Project, "my-project") + } + if req.SessionID != "session-abc" { + t.Errorf("payload.session_id = %q, want %q", req.SessionID, "session-abc") + } + if req.Message != "Build an authentication system" { + t.Errorf("payload.message = %q, want %q", req.Message, "Build an authentication system") + } +} + +func TestCommandPRDMessage_PayloadWrapper(t *testing.T) { + data := loadFixture(t, "server-to-cli/command_prd_message.json") + + var env struct { + Type string `json:"type"` + Payload json.RawMessage `json:"payload,omitempty"` + } + if err := json.Unmarshal(data, &env); err != nil { + t.Fatalf("failed to unmarshal command envelope: %v", err) + } + + if env.Type != "prd_message" { + t.Errorf("envelope type = %q, want %q", env.Type, "prd_message") + } + + var req ws.PRDMessageMessage + if err := json.Unmarshal(env.Payload, &req); err != nil { + t.Fatalf("failed to unmarshal payload into PRDMessageMessage: %v", err) + } + + if req.Project != "my-project" { + t.Errorf("payload.project = %q, want %q", req.Project, "my-project") + } + if req.SessionID != "session-abc" { + t.Errorf("payload.session_id = %q, want %q", req.SessionID, "session-abc") + } + if req.Message != "Add OAuth support to the PRD" { + t.Errorf("payload.message = %q, want %q", req.Message, "Add OAuth support to the PRD") + } +} + +func TestCommandRefinePRD_PayloadWrapper(t *testing.T) { + data := loadFixture(t, "server-to-cli/command_refine_prd.json") + + var env struct { + Type string `json:"type"` + Payload json.RawMessage `json:"payload,omitempty"` + } + if err := json.Unmarshal(data, &env); err != nil { + t.Fatalf("failed to unmarshal command envelope: %v", err) + } + + if env.Type != "refine_prd" { + t.Errorf("envelope type = %q, want %q", env.Type, "refine_prd") + } + + var req ws.RefinePRDMessage + if err := json.Unmarshal(env.Payload, &req); err != nil { + t.Fatalf("failed to unmarshal payload into RefinePRDMessage: %v", err) + } + + if req.Project != "my-project" { + t.Errorf("payload.project = %q, want %q", req.Project, "my-project") + } + if req.SessionID != "session-abc" { + t.Errorf("payload.session_id = %q, want %q", req.SessionID, "session-abc") + } + if req.PRDID != "feature-auth" { + t.Errorf("payload.prd_id = %q, want %q", req.PRDID, "feature-auth") + } + if req.Message != "Add OAuth support to the PRD" { + t.Errorf("payload.message = %q, want %q", req.Message, "Add OAuth support to the PRD") + } +} + +// --- cli-to-server response fixtures --- + +func TestPRDsResponse_Roundtrip(t *testing.T) { + data := loadFixture(t, "cli-to-server/prds_response.json") + + var resp ws.PRDsResponseMessage + if err := json.Unmarshal(data, &resp); err != nil { + t.Fatalf("failed to unmarshal prds_response.json: %v", err) + } + + if resp.Type != "prds_response" { + t.Errorf("type = %q, want %q", resp.Type, "prds_response") + } + if resp.Payload.Project != "my-project" { + t.Errorf("payload.project = %q, want %q", resp.Payload.Project, "my-project") + } + if len(resp.Payload.PRDs) != 2 { + t.Fatalf("prds count = %d, want 2", len(resp.Payload.PRDs)) + } + + prd := resp.Payload.PRDs[0] + if prd.ID != "feature-auth" { + t.Errorf("prds[0].id = %q, want %q", prd.ID, "feature-auth") + } + if prd.Status != "active" { + t.Errorf("prds[0].status = %q, want %q", prd.Status, "active") + } + if prd.StoryCount != 5 { + t.Errorf("prds[0].story_count = %d, want 5", prd.StoryCount) + } +} + +func TestSettingsResponse_Roundtrip(t *testing.T) { + data := loadFixture(t, "cli-to-server/settings_response.json") + + var resp ws.SettingsResponseMessage + if err := json.Unmarshal(data, &resp); err != nil { + t.Fatalf("failed to unmarshal settings_response.json: %v", err) + } + + if resp.Type != "settings_response" { + t.Errorf("type = %q, want %q", resp.Type, "settings_response") + } + if resp.Payload.Project != "my-project" { + t.Errorf("payload.project = %q, want %q", resp.Payload.Project, "my-project") + } + if resp.Payload.Settings.MaxIterations != 5 { + t.Errorf("settings.max_iterations = %d, want 5", resp.Payload.Settings.MaxIterations) + } + if !resp.Payload.Settings.AutoCommit { + t.Error("settings.auto_commit = false, want true") + } +} + +func TestDiffsResponse_Roundtrip(t *testing.T) { + data := loadFixture(t, "cli-to-server/diffs_response.json") + + var resp ws.DiffsResponseMessage + if err := json.Unmarshal(data, &resp); err != nil { + t.Fatalf("failed to unmarshal diffs_response.json: %v", err) + } + + if resp.Type != "diffs_response" { + t.Errorf("type = %q, want %q", resp.Type, "diffs_response") + } + if resp.Payload.Project != "my-project" { + t.Errorf("payload.project = %q, want %q", resp.Payload.Project, "my-project") + } + if resp.Payload.StoryID != "US-001" { + t.Errorf("payload.story_id = %q, want %q", resp.Payload.StoryID, "US-001") + } + if len(resp.Payload.Files) != 1 { + t.Fatalf("files count = %d, want 1", len(resp.Payload.Files)) + } + + file := resp.Payload.Files[0] + if file.Filename != "src/auth.go" { + t.Errorf("files[0].filename = %q, want %q", file.Filename, "src/auth.go") + } + if file.Additions != 25 { + t.Errorf("files[0].additions = %d, want 25", file.Additions) + } + if file.Deletions != 3 { + t.Errorf("files[0].deletions = %d, want 3", file.Deletions) + } +} + +func TestMessagesBatch_Deserialize(t *testing.T) { + data := loadFixture(t, "cli-to-server/messages_batch.json") + + var batch struct { + BatchID string `json:"batch_id"` + Messages []json.RawMessage `json:"messages"` + } + if err := json.Unmarshal(data, &batch); err != nil { + t.Fatalf("failed to unmarshal messages_batch.json: %v", err) + } + + if batch.BatchID == "" { + t.Error("batch_id should not be empty") + } + if len(batch.Messages) != 1 { + t.Fatalf("messages count = %d, want 1", len(batch.Messages)) + } + + // First message should be a state_snapshot + var msg ws.StateSnapshotMessage + if err := json.Unmarshal(batch.Messages[0], &msg); err != nil { + t.Fatalf("failed to unmarshal first message: %v", err) + } + if msg.Type != "state_snapshot" { + t.Errorf("first message type = %q, want %q", msg.Type, "state_snapshot") + } +} diff --git a/internal/engine/engine.go b/internal/engine/engine.go new file mode 100644 index 0000000..cd45457 --- /dev/null +++ b/internal/engine/engine.go @@ -0,0 +1,248 @@ +// Package engine provides a shared orchestration layer on top of loop.Manager +// that both the TUI and the serve command (WebSocket handler) can consume. +// It supports multiple concurrent event consumers via fan-out subscription. +package engine + +import ( + "sync" + + "github.com/minicodemonkey/chief/internal/config" + "github.com/minicodemonkey/chief/internal/loop" + "github.com/minicodemonkey/chief/internal/prd" +) + +// Engine wraps loop.Manager to provide a shared interface for driving Ralph loops +// and Claude sessions. It fans out events to multiple consumers. +type Engine struct { + manager *loop.Manager + + // Fan-out event distribution + subscribers map[int]chan ManagerEvent + nextID int + subMu sync.RWMutex + + // Forwarding goroutine lifecycle + stopForward chan struct{} + forwarding bool + forwardMu sync.Mutex +} + +// ManagerEvent wraps a loop.ManagerEvent for engine consumers. +// It mirrors loop.ManagerEvent to avoid exposing the loop package directly. +type ManagerEvent = loop.ManagerEvent + +// New creates a new Engine with the given max iterations. +func New(maxIter int) *Engine { + e := &Engine{ + manager: loop.NewManager(maxIter), + subscribers: make(map[int]chan ManagerEvent), + stopForward: make(chan struct{}), + } + e.startForwarding() + return e +} + +// startForwarding starts the goroutine that reads from the manager's event +// channel and fans out to all subscribers. +func (e *Engine) startForwarding() { + e.forwardMu.Lock() + defer e.forwardMu.Unlock() + + if e.forwarding { + return + } + e.forwarding = true + + go func() { + for { + select { + case event, ok := <-e.manager.Events(): + if !ok { + return + } + e.subMu.RLock() + for _, ch := range e.subscribers { + // Non-blocking send: drop events for slow consumers + select { + case ch <- event: + default: + } + } + e.subMu.RUnlock() + + case <-e.stopForward: + return + } + } + }() +} + +// Subscribe creates a new event subscription and returns a channel and an +// unsubscribe function. The channel is buffered (100 events). The caller must +// call the returned function when done to avoid resource leaks. +func (e *Engine) Subscribe() (<-chan ManagerEvent, func()) { + ch := make(chan ManagerEvent, 100) + + e.subMu.Lock() + id := e.nextID + e.nextID++ + e.subscribers[id] = ch + e.subMu.Unlock() + + unsub := func() { + e.subMu.Lock() + delete(e.subscribers, id) + e.subMu.Unlock() + } + + return ch, unsub +} + +// Manager returns the underlying loop.Manager for direct access when needed. +// This is useful for operations like Register, UpdateWorktreeInfo, etc. +// that don't need to be abstracted by the engine. +func (e *Engine) Manager() *loop.Manager { + return e.manager +} + +// --- Delegated Manager methods --- + +// Register registers a PRD with the engine (does not start it). +func (e *Engine) Register(name, prdPath string) error { + return e.manager.Register(name, prdPath) +} + +// RegisterWithWorktree registers a PRD with worktree metadata. +func (e *Engine) RegisterWithWorktree(name, prdPath, worktreeDir, branch string) error { + return e.manager.RegisterWithWorktree(name, prdPath, worktreeDir, branch) +} + +// Unregister removes a PRD from the engine. +func (e *Engine) Unregister(name string) error { + return e.manager.Unregister(name) +} + +// Start starts the loop for a specific PRD. +func (e *Engine) Start(name string) error { + return e.manager.Start(name) +} + +// Pause pauses the loop for a specific PRD. +func (e *Engine) Pause(name string) error { + return e.manager.Pause(name) +} + +// Stop stops the loop for a specific PRD immediately. +func (e *Engine) Stop(name string) error { + return e.manager.Stop(name) +} + +// StopAll stops all running loops and waits for completion. +func (e *Engine) StopAll() { + e.manager.StopAll() +} + +// GetState returns the state of a specific PRD loop. +func (e *Engine) GetState(name string) (loop.LoopState, int, error) { + return e.manager.GetState(name) +} + +// GetInstance returns a copy of the loop instance for a specific PRD. +func (e *Engine) GetInstance(name string) *loop.LoopInstance { + return e.manager.GetInstance(name) +} + +// GetAllInstances returns a snapshot of all loop instances. +func (e *Engine) GetAllInstances() []*loop.LoopInstance { + return e.manager.GetAllInstances() +} + +// GetRunningPRDs returns the names of all currently running PRDs. +func (e *Engine) GetRunningPRDs() []string { + return e.manager.GetRunningPRDs() +} + +// GetRunningCount returns the number of currently running loops. +func (e *Engine) GetRunningCount() int { + return e.manager.GetRunningCount() +} + +// IsAnyRunning returns true if any loop is currently running. +func (e *Engine) IsAnyRunning() bool { + return e.manager.IsAnyRunning() +} + +// SetMaxIterations updates the default max iterations for new loops. +func (e *Engine) SetMaxIterations(maxIter int) { + e.manager.SetMaxIterations(maxIter) +} + +// MaxIterations returns the current default max iterations. +func (e *Engine) MaxIterations() int { + return e.manager.MaxIterations() +} + +// SetMaxIterationsForInstance updates max iterations for a running loop. +func (e *Engine) SetMaxIterationsForInstance(name string, maxIter int) error { + return e.manager.SetMaxIterationsForInstance(name, maxIter) +} + +// SetRetryConfig sets the retry configuration for new loops. +func (e *Engine) SetRetryConfig(cfg loop.RetryConfig) { + e.manager.SetRetryConfig(cfg) +} + +// DisableRetry disables automatic retry for new loops. +func (e *Engine) DisableRetry() { + e.manager.DisableRetry() +} + +// SetCompletionCallback sets a callback for when any PRD completes. +func (e *Engine) SetCompletionCallback(fn func(prdName string)) { + e.manager.SetCompletionCallback(fn) +} + +// SetPostCompleteCallback sets a callback for post-completion actions. +func (e *Engine) SetPostCompleteCallback(fn func(prdName, branch, workDir string)) { + e.manager.SetPostCompleteCallback(fn) +} + +// SetConfig sets the project config. +func (e *Engine) SetConfig(cfg *config.Config) { + e.manager.SetConfig(cfg) +} + +// Config returns the current project config. +func (e *Engine) Config() *config.Config { + return e.manager.Config() +} + +// UpdateWorktreeInfo updates the worktree directory and branch for a PRD. +func (e *Engine) UpdateWorktreeInfo(name, worktreeDir, branch string) error { + return e.manager.UpdateWorktreeInfo(name, worktreeDir, branch) +} + +// ClearWorktreeInfo clears the worktree directory and optionally branch. +func (e *Engine) ClearWorktreeInfo(name string, clearBranch bool) error { + return e.manager.ClearWorktreeInfo(name, clearBranch) +} + +// --- Project state queries --- + +// LoadPRD loads and returns a PRD from the given path. +func (e *Engine) LoadPRD(prdPath string) (*prd.PRD, error) { + return prd.LoadPRD(prdPath) +} + +// Shutdown stops all loops and the event forwarding goroutine. +func (e *Engine) Shutdown() { + e.manager.StopAll() + + e.forwardMu.Lock() + defer e.forwardMu.Unlock() + + if e.forwarding { + close(e.stopForward) + e.forwarding = false + } +} diff --git a/internal/engine/engine_test.go b/internal/engine/engine_test.go new file mode 100644 index 0000000..65f6f75 --- /dev/null +++ b/internal/engine/engine_test.go @@ -0,0 +1,504 @@ +package engine + +import ( + "os" + "path/filepath" + "sync" + "testing" + "time" + + "github.com/minicodemonkey/chief/internal/config" + "github.com/minicodemonkey/chief/internal/loop" +) + +// createTestPRD creates a minimal test PRD file and returns its path. +func createTestPRD(t *testing.T, dir, name string) string { + t.Helper() + prdDir := filepath.Join(dir, name) + if err := os.MkdirAll(prdDir, 0755); err != nil { + t.Fatal(err) + } + prdPath := filepath.Join(prdDir, "prd.json") + content := `{ + "project": "Test PRD", + "description": "Test", + "userStories": [ + {"id": "US-001", "title": "Test Story", "description": "Test", "priority": 1, "passes": false} + ] + }` + if err := os.WriteFile(prdPath, []byte(content), 0644); err != nil { + t.Fatal(err) + } + return prdPath +} + +func TestNew(t *testing.T) { + e := New(10) + defer e.Shutdown() + + if e == nil { + t.Fatal("expected non-nil engine") + } + if e.manager == nil { + t.Fatal("expected non-nil manager") + } + if e.MaxIterations() != 10 { + t.Errorf("expected maxIter 10, got %d", e.MaxIterations()) + } +} + +func TestRegisterAndGetInstance(t *testing.T) { + tmpDir := t.TempDir() + prdPath := createTestPRD(t, tmpDir, "test-prd") + + e := New(10) + defer e.Shutdown() + + if err := e.Register("test-prd", prdPath); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + instance := e.GetInstance("test-prd") + if instance == nil { + t.Fatal("expected instance to be registered") + } + if instance.Name != "test-prd" { + t.Errorf("expected name 'test-prd', got '%s'", instance.Name) + } + if instance.State != loop.LoopStateReady { + t.Errorf("expected state Ready, got %v", instance.State) + } +} + +func TestRegisterDuplicate(t *testing.T) { + tmpDir := t.TempDir() + prdPath := createTestPRD(t, tmpDir, "test-prd") + + e := New(10) + defer e.Shutdown() + + if err := e.Register("test-prd", prdPath); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + err := e.Register("test-prd", prdPath) + if err == nil { + t.Error("expected error when registering duplicate PRD") + } +} + +func TestUnregister(t *testing.T) { + tmpDir := t.TempDir() + prdPath := createTestPRD(t, tmpDir, "test-prd") + + e := New(10) + defer e.Shutdown() + + e.Register("test-prd", prdPath) + if err := e.Unregister("test-prd"); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if inst := e.GetInstance("test-prd"); inst != nil { + t.Error("expected instance to be removed") + } +} + +func TestSubscribeAndUnsubscribe(t *testing.T) { + e := New(10) + defer e.Shutdown() + + ch1, unsub1 := e.Subscribe() + ch2, unsub2 := e.Subscribe() + + if ch1 == nil || ch2 == nil { + t.Fatal("expected non-nil channels") + } + + e.subMu.RLock() + count := len(e.subscribers) + e.subMu.RUnlock() + if count != 2 { + t.Errorf("expected 2 subscribers, got %d", count) + } + + unsub1() + + e.subMu.RLock() + count = len(e.subscribers) + e.subMu.RUnlock() + if count != 1 { + t.Errorf("expected 1 subscriber after unsub, got %d", count) + } + + unsub2() + + e.subMu.RLock() + count = len(e.subscribers) + e.subMu.RUnlock() + if count != 0 { + t.Errorf("expected 0 subscribers after unsub, got %d", count) + } +} + +func TestMultipleSubscribersReceiveEvents(t *testing.T) { + e := New(10) + defer e.Shutdown() + + ch1, unsub1 := e.Subscribe() + defer unsub1() + ch2, unsub2 := e.Subscribe() + defer unsub2() + + // Inject an event directly into the manager's events channel for testing + // We do this by sending an event through the underlying manager + go func() { + // Send a synthetic event via the manager's event channel + e.manager.Events() + }() + + // Instead of trying to trigger a real loop event, test fan-out by + // directly injecting into the fan-out mechanism + testEvent := ManagerEvent{ + PRDName: "test", + Completed: false, + Event: loop.Event{ + Type: loop.EventIterationStart, + Text: "test event", + }, + } + + // Directly write to subscriber channels to verify wiring + e.subMu.RLock() + for _, ch := range e.subscribers { + ch <- testEvent + } + e.subMu.RUnlock() + + // Both subscribers should receive the event + select { + case ev := <-ch1: + if ev.PRDName != "test" { + t.Errorf("ch1: expected PRDName 'test', got '%s'", ev.PRDName) + } + case <-time.After(time.Second): + t.Error("ch1: timed out waiting for event") + } + + select { + case ev := <-ch2: + if ev.PRDName != "test" { + t.Errorf("ch2: expected PRDName 'test', got '%s'", ev.PRDName) + } + case <-time.After(time.Second): + t.Error("ch2: timed out waiting for event") + } +} + +func TestGetState(t *testing.T) { + tmpDir := t.TempDir() + prdPath := createTestPRD(t, tmpDir, "test-prd") + + e := New(10) + defer e.Shutdown() + + e.Register("test-prd", prdPath) + + state, iteration, err := e.GetState("test-prd") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if state != loop.LoopStateReady { + t.Errorf("expected Ready state, got %v", state) + } + if iteration != 0 { + t.Errorf("expected iteration 0, got %d", iteration) + } +} + +func TestGetAllInstances(t *testing.T) { + tmpDir := t.TempDir() + prd1 := createTestPRD(t, tmpDir, "prd1") + prd2 := createTestPRD(t, tmpDir, "prd2") + + e := New(10) + defer e.Shutdown() + + e.Register("prd1", prd1) + e.Register("prd2", prd2) + + instances := e.GetAllInstances() + if len(instances) != 2 { + t.Errorf("expected 2 instances, got %d", len(instances)) + } +} + +func TestGetRunningPRDs(t *testing.T) { + e := New(10) + defer e.Shutdown() + + running := e.GetRunningPRDs() + if len(running) != 0 { + t.Errorf("expected 0 running PRDs, got %d", len(running)) + } +} + +func TestIsAnyRunning(t *testing.T) { + e := New(10) + defer e.Shutdown() + + if e.IsAnyRunning() { + t.Error("expected no running loops") + } +} + +func TestSetMaxIterations(t *testing.T) { + e := New(10) + defer e.Shutdown() + + e.SetMaxIterations(20) + if e.MaxIterations() != 20 { + t.Errorf("expected 20, got %d", e.MaxIterations()) + } +} + +func TestSetConfig(t *testing.T) { + e := New(10) + defer e.Shutdown() + + if e.Config() != nil { + t.Error("expected nil config initially") + } + + cfg := &config.Config{ + OnComplete: config.OnCompleteConfig{Push: true}, + } + e.SetConfig(cfg) + + got := e.Config() + if got == nil || !got.OnComplete.Push { + t.Error("expected config with Push=true") + } +} + +func TestRetryConfig(t *testing.T) { + e := New(10) + defer e.Shutdown() + + e.SetRetryConfig(loop.RetryConfig{MaxRetries: 5, Enabled: true}) + e.DisableRetry() + // No assertion on internal state; just verify no panic +} + +func TestSetCompletionCallback(t *testing.T) { + e := New(10) + defer e.Shutdown() + + called := false + e.SetCompletionCallback(func(prdName string) { + called = true + }) + + // Manually trigger via manager to verify it's wired + e.manager.SetCompletionCallback(func(prdName string) { + called = true + }) + // The callback is set on the manager, verify it + if called { + t.Error("callback should not be called yet") + } +} + +func TestRegisterWithWorktree(t *testing.T) { + tmpDir := t.TempDir() + prdPath := createTestPRD(t, tmpDir, "test-prd") + + e := New(10) + defer e.Shutdown() + + err := e.RegisterWithWorktree("test-prd", prdPath, "/tmp/wt", "branch") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + inst := e.GetInstance("test-prd") + if inst.WorktreeDir != "/tmp/wt" { + t.Errorf("expected WorktreeDir '/tmp/wt', got '%s'", inst.WorktreeDir) + } + if inst.Branch != "branch" { + t.Errorf("expected Branch 'branch', got '%s'", inst.Branch) + } +} + +func TestUpdateAndClearWorktreeInfo(t *testing.T) { + tmpDir := t.TempDir() + prdPath := createTestPRD(t, tmpDir, "test-prd") + + e := New(10) + defer e.Shutdown() + + e.Register("test-prd", prdPath) + e.UpdateWorktreeInfo("test-prd", "/tmp/wt", "branch") + + inst := e.GetInstance("test-prd") + if inst.WorktreeDir != "/tmp/wt" { + t.Errorf("expected '/tmp/wt', got '%s'", inst.WorktreeDir) + } + + e.ClearWorktreeInfo("test-prd", true) + inst = e.GetInstance("test-prd") + if inst.WorktreeDir != "" || inst.Branch != "" { + t.Error("expected cleared worktree info") + } +} + +func TestManagerAccess(t *testing.T) { + e := New(10) + defer e.Shutdown() + + if e.Manager() == nil { + t.Error("expected non-nil manager") + } +} + +func TestLoadPRD(t *testing.T) { + tmpDir := t.TempDir() + prdPath := createTestPRD(t, tmpDir, "test-prd") + + e := New(10) + defer e.Shutdown() + + p, err := e.LoadPRD(prdPath) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if p.Project != "Test PRD" { + t.Errorf("expected 'Test PRD', got '%s'", p.Project) + } +} + +func TestStopAll(t *testing.T) { + tmpDir := t.TempDir() + prd1 := createTestPRD(t, tmpDir, "prd1") + prd2 := createTestPRD(t, tmpDir, "prd2") + + e := New(10) + defer e.Shutdown() + + e.Register("prd1", prd1) + e.Register("prd2", prd2) + + done := make(chan struct{}) + go func() { + e.StopAll() + close(done) + }() + + select { + case <-done: + case <-time.After(time.Second): + t.Error("StopAll did not complete in time") + } +} + +func TestShutdown(t *testing.T) { + e := New(10) + e.Shutdown() + + // Verify forwarding is stopped + e.forwardMu.Lock() + forwarding := e.forwarding + e.forwardMu.Unlock() + + if forwarding { + t.Error("expected forwarding to be stopped after shutdown") + } +} + +func TestConcurrentSubscribeUnsubscribe(t *testing.T) { + e := New(10) + defer e.Shutdown() + + var wg sync.WaitGroup + for i := 0; i < 50; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _, unsub := e.Subscribe() + time.Sleep(time.Millisecond) + unsub() + }() + } + wg.Wait() + + e.subMu.RLock() + count := len(e.subscribers) + e.subMu.RUnlock() + if count != 0 { + t.Errorf("expected 0 subscribers after all unsubscribed, got %d", count) + } +} + +func TestConcurrentAccess(t *testing.T) { + tmpDir := t.TempDir() + prdPath := createTestPRD(t, tmpDir, "test-prd") + + e := New(10) + defer e.Shutdown() + + e.Register("test-prd", prdPath) + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _ = e.GetInstance("test-prd") + _ = e.GetAllInstances() + _ = e.GetRunningPRDs() + _ = e.GetRunningCount() + _, _, _ = e.GetState("test-prd") + _ = e.IsAnyRunning() + _ = e.MaxIterations() + }() + } + wg.Wait() +} + +func TestStartNonExistent(t *testing.T) { + e := New(10) + defer e.Shutdown() + + err := e.Start("nonexistent") + if err == nil { + t.Error("expected error when starting non-existent PRD") + } +} + +func TestPauseNonRunning(t *testing.T) { + tmpDir := t.TempDir() + prdPath := createTestPRD(t, tmpDir, "test-prd") + + e := New(10) + defer e.Shutdown() + + e.Register("test-prd", prdPath) + err := e.Pause("test-prd") + if err == nil { + t.Error("expected error when pausing non-running PRD") + } +} + +func TestStopNonRunning(t *testing.T) { + tmpDir := t.TempDir() + prdPath := createTestPRD(t, tmpDir, "test-prd") + + e := New(10) + defer e.Shutdown() + + e.Register("test-prd", prdPath) + err := e.Stop("test-prd") + if err != nil { + t.Errorf("stop non-running should not error: %v", err) + } +} diff --git a/internal/loop/loop.go b/internal/loop/loop.go index b7d702e..7768069 100644 --- a/internal/loop/loop.go +++ b/internal/loop/loop.go @@ -6,7 +6,9 @@ package loop import ( "bufio" + "bytes" "context" + "errors" "fmt" "io" "os" @@ -138,9 +140,16 @@ func (l *Loop) Run(ctx context.Context) error { // Run a single iteration with retry logic if err := l.runIterationWithRetry(ctx); err != nil { - l.events <- Event{ - Type: EventError, - Err: err, + if errors.Is(err, ErrQuotaExhausted) { + l.events <- Event{ + Type: EventQuotaExhausted, + Err: err, + } + } else { + l.events <- Event{ + Type: EventError, + Err: err, + } } return err } @@ -250,6 +259,11 @@ func (l *Loop) runIterationWithRetry(ctx context.Context) error { return nil } + // Quota errors should NOT be retried — return immediately + if errors.Is(err, ErrQuotaExhausted) { + return err + } + lastErr = err } @@ -295,10 +309,11 @@ func (l *Loop) runIteration(ctx context.Context) error { l.processOutput(stdout) }() - // Log stderr to the log file + // Capture stderr into a buffer while also logging it + var stderrBuf bytes.Buffer go func() { defer wg.Done() - l.logStream(stderr, "[stderr] ") + l.logAndCaptureStream(stderr, "[stderr] ", &stderrBuf) }() // Wait for output processing to complete @@ -317,6 +332,11 @@ func (l *Loop) runIteration(ctx context.Context) error { if stopped { return nil } + // Check if this is a quota/rate-limit error + stderrText := stderrBuf.String() + if IsQuotaError(stderrText) || IsQuotaError(err.Error()) { + return fmt.Errorf("Claude quota exhausted: %w", ErrQuotaExhausted) + } return fmt.Errorf("Claude exited with error: %w", err) } @@ -358,6 +378,17 @@ func (l *Loop) logStream(r io.Reader, prefix string) { } } +// logAndCaptureStream logs a stream with a prefix and also captures it into a buffer. +func (l *Loop) logAndCaptureStream(r io.Reader, prefix string, buf *bytes.Buffer) { + scanner := bufio.NewScanner(r) + for scanner.Scan() { + text := scanner.Text() + l.logLine(prefix + text) + buf.WriteString(text) + buf.WriteByte('\n') + } +} + // logLine writes a line to the log file. func (l *Loop) logLine(line string) { if l.logFile != nil { diff --git a/internal/loop/manager.go b/internal/loop/manager.go index 004fb0a..3395bfa 100644 --- a/internal/loop/manager.go +++ b/internal/loop/manager.go @@ -2,6 +2,7 @@ package loop import ( "context" + "errors" "fmt" "sync" "time" @@ -296,8 +297,14 @@ func (m *Manager) runLoop(instance *LoopInstance) { // Update state based on result instance.mu.Lock() if err != nil && err != context.Canceled { - instance.State = LoopStateError - instance.Error = err + if errors.Is(err, ErrQuotaExhausted) { + // Quota exhaustion pauses the run (resumable by user) + instance.State = LoopStatePaused + instance.Error = err + } else { + instance.State = LoopStateError + instance.Error = err + } } else if instance.Loop.IsPaused() { instance.State = LoopStatePaused } else if instance.Loop.IsStopped() { diff --git a/internal/loop/parser.go b/internal/loop/parser.go index e73e1a0..610a087 100644 --- a/internal/loop/parser.go +++ b/internal/loop/parser.go @@ -2,6 +2,7 @@ package loop import ( "encoding/json" + "errors" "strings" ) @@ -31,6 +32,8 @@ const ( EventError // EventRetrying is emitted when retrying after a crash. EventRetrying + // EventQuotaExhausted is emitted when Claude exits due to quota/rate-limit errors. + EventQuotaExhausted ) // String returns the string representation of an EventType. @@ -56,6 +59,8 @@ func (e EventType) String() string { return "Error" case EventRetrying: return "Retrying" + case EventQuotaExhausted: + return "QuotaExhausted" default: return "Unknown" } @@ -216,6 +221,31 @@ func parseUserMessage(raw json.RawMessage) *Event { return nil } +// ErrQuotaExhausted is returned when Claude exits due to quota/rate-limit errors. +var ErrQuotaExhausted = errors.New("quota exhausted") + +// quotaPatterns are stderr/error patterns that indicate quota or rate-limit exhaustion. +var quotaPatterns = []string{ + "rate limit", + "rate_limit", + "quota", + "429", + "too many requests", + "resource_exhausted", + "overloaded", +} + +// IsQuotaError checks if an error text contains quota/rate-limit patterns. +func IsQuotaError(errText string) bool { + lower := strings.ToLower(errText) + for _, pattern := range quotaPatterns { + if strings.Contains(lower, pattern) { + return true + } + } + return false +} + // extractStoryID extracts a story ID from text between start and end tags. func extractStoryID(text, startTag, endTag string) string { startIdx := strings.Index(text, startTag) diff --git a/internal/loop/parser_test.go b/internal/loop/parser_test.go index 2dc53c8..19b0926 100644 --- a/internal/loop/parser_test.go +++ b/internal/loop/parser_test.go @@ -20,6 +20,7 @@ func TestEventTypeString(t *testing.T) { {EventMaxIterationsReached, "MaxIterationsReached"}, {EventError, "Error"}, {EventRetrying, "Retrying"}, + {EventQuotaExhausted, "QuotaExhausted"}, } for _, tt := range tests { @@ -305,3 +306,35 @@ func TestParseLineToolUseFirst(t *testing.T) { t.Errorf("event.Tool = %q, want %q", event.Tool, "Write") } } + +func TestIsQuotaError(t *testing.T) { + tests := []struct { + text string + expected bool + }{ + {"rate limit exceeded", true}, + {"Rate Limit Exceeded", true}, + {"rate_limit_error", true}, + {"quota exceeded for this billing period", true}, + {"HTTP 429 Too Many Requests", true}, + {"429", true}, + {"too many requests", true}, + {"Too Many Requests", true}, + {"resource_exhausted", true}, + {"model is overloaded", true}, + {"Overloaded", true}, + {"normal error message", false}, + {"exit status 1", false}, + {"connection refused", false}, + {"", false}, + {"Claude exited with error: exit status 2", false}, + } + + for _, tt := range tests { + t.Run(tt.text, func(t *testing.T) { + if got := IsQuotaError(tt.text); got != tt.expected { + t.Errorf("IsQuotaError(%q) = %v, want %v", tt.text, got, tt.expected) + } + }) + } +} diff --git a/internal/prd/generator.go b/internal/prd/generator.go index cc9080d..757fab8 100644 --- a/internal/prd/generator.go +++ b/internal/prd/generator.go @@ -144,6 +144,7 @@ func runClaudeConversion(absPRDDir string) error { cmd := exec.Command("claude", "--dangerously-skip-permissions", "--output-format", "stream-json", + "--verbose", "-p", prompt, ) cmd.Dir = absPRDDir diff --git a/internal/tui/app.go b/internal/tui/app.go index 95fddc4..31da136 100644 --- a/internal/tui/app.go +++ b/internal/tui/app.go @@ -10,6 +10,7 @@ import ( tea "github.com/charmbracelet/bubbletea" "github.com/minicodemonkey/chief/internal/config" + "github.com/minicodemonkey/chief/internal/engine" "github.com/minicodemonkey/chief/internal/git" "github.com/minicodemonkey/chief/internal/loop" "github.com/minicodemonkey/chief/internal/prd" @@ -150,10 +151,14 @@ type App struct { height int err error - // Loop manager for parallel PRD execution - manager *loop.Manager + // Shared engine for parallel PRD execution (used by both TUI and serve) + eng *engine.Engine maxIter int + // Event subscription from engine + eventCh <-chan engine.ManagerEvent + unsubFn func() + // Activity tracking lastActivity string @@ -220,6 +225,58 @@ func NewApp(prdPath string) (*App, error) { return NewAppWithOptions(prdPath, 10) // default max iterations } +// NewAppWithEngine creates a new App using a pre-existing engine. +// This is used by the serve command to share an engine between the TUI and WebSocket handler. +func NewAppWithEngine(prdPath string, eng *engine.Engine) (*App, error) { + p, err := prd.LoadPRD(prdPath) + if err != nil { + return nil, err + } + + prdName := filepath.Base(filepath.Dir(prdPath)) + if prdName == "." || prdName == "/" { + prdName = filepath.Base(prdPath) + } + + watcher, err := prd.NewWatcher(prdPath) + if err != nil { + return nil, err + } + + baseDir := filepath.Dir(filepath.Dir(filepath.Dir(filepath.Dir(prdPath)))) + if !strings.Contains(prdPath, ".chief/prds/") { + baseDir, _ = os.Getwd() + } + + // Subscribe to engine events for TUI consumption + eventCh, unsubFn := eng.Subscribe() + + return &App{ + prd: p, + prdPath: prdPath, + prdName: prdName, + state: StateReady, + selectedIndex: 0, + maxIter: eng.MaxIterations(), + eng: eng, + eventCh: eventCh, + unsubFn: unsubFn, + watcher: watcher, + viewMode: ViewDashboard, + logViewer: NewLogViewer(), + diffViewer: NewDiffViewer(baseDir), + tabBar: NewTabBar(baseDir, prdName, eng.Manager()), + picker: NewPRDPicker(baseDir, prdName, eng.Manager()), + baseDir: baseDir, + config: eng.Config(), + helpOverlay: NewHelpOverlay(), + branchWarning: NewBranchWarning(), + worktreeSpinner: NewWorktreeSpinner(), + completionScreen: NewCompletionScreen(), + settingsOverlay: NewSettingsOverlay(), + }, nil +} + // NewAppWithOptions creates a new App with the given PRD and options. // If maxIter <= 0, it will be calculated dynamically based on remaining stories. func NewAppWithOptions(prdPath string, maxIter int) (*App, error) { @@ -274,18 +331,21 @@ func NewAppWithOptions(prdPath string, maxIter int) (*App, error) { _ = git.PruneWorktrees(baseDir) } - // Create loop manager for parallel PRD execution - manager := loop.NewManager(maxIter) - manager.SetConfig(cfg) + // Create shared engine for parallel PRD execution + eng := engine.New(maxIter) + eng.SetConfig(cfg) - // Register the initial PRD with the manager - manager.Register(prdName, prdPath) + // Register the initial PRD with the engine + eng.Register(prdName, prdPath) + + // Subscribe to engine events for TUI consumption + eventCh, unsubFn := eng.Subscribe() // Create tab bar for always-visible PRD tabs - tabBar := NewTabBar(baseDir, prdName, manager) + tabBar := NewTabBar(baseDir, prdName, eng.Manager()) // Create picker with manager reference (for creating new PRDs) - picker := NewPRDPicker(baseDir, prdName, manager) + picker := NewPRDPicker(baseDir, prdName, eng.Manager()) return &App{ prd: p, @@ -295,7 +355,9 @@ func NewAppWithOptions(prdPath string, maxIter int) (*App, error) { iteration: 0, selectedIndex: 0, maxIter: maxIter, - manager: manager, + eng: eng, + eventCh: eventCh, + unsubFn: unsubFn, watcher: watcher, viewMode: ViewDashboard, logViewer: NewLogViewer(), @@ -315,8 +377,8 @@ func NewAppWithOptions(prdPath string, maxIter int) (*App, error) { // SetCompletionCallback sets a callback that is called when any PRD completes. func (a *App) SetCompletionCallback(fn func(prdName string)) { a.onCompletion = fn - if a.manager != nil { - a.manager.SetCompletionCallback(fn) + if a.eng != nil { + a.eng.SetCompletionCallback(fn) } } @@ -327,8 +389,8 @@ func (a *App) SetVerbose(v bool) { // DisableRetry disables automatic retry on Claude crashes. func (a *App) DisableRetry() { - if a.manager != nil { - a.manager.DisableRetry() + if a.eng != nil { + a.eng.DisableRetry() } } @@ -349,13 +411,13 @@ func (a App) Init() tea.Cmd { ) } -// listenForManagerEvents listens for events from all managed loops. +// listenForManagerEvents listens for events from the engine's subscription. func (a *App) listenForManagerEvents() tea.Cmd { - if a.manager == nil { + if a.eventCh == nil { return nil } return func() tea.Msg { - event, ok := <-a.manager.Events() + event, ok := <-a.eventCh if !ok { return nil } @@ -693,10 +755,10 @@ func (a App) startLoopForPRD(prdName string) (tea.Model, tea.Cmd) { // isAnotherPRDRunningInSameDir checks if another PRD is running in the project root (no worktree). func (a *App) isAnotherPRDRunningInSameDir(prdName string) bool { - if a.manager == nil { + if a.eng == nil { return false } - for _, inst := range a.manager.GetAllInstances() { + for _, inst := range a.eng.GetAllInstances() { if inst.Name != prdName && inst.State == loop.LoopStateRunning && inst.WorktreeDir == "" { return true } @@ -707,14 +769,14 @@ func (a *App) isAnotherPRDRunningInSameDir(prdName string) bool { // doStartLoop actually starts the loop (after branch check). func (a App) doStartLoop(prdName, prdDir string) (tea.Model, tea.Cmd) { // Check if this PRD is registered, if not register it - if instance := a.manager.GetInstance(prdName); instance == nil { + if instance := a.eng.GetInstance(prdName); instance == nil { // Find the PRD path prdPath := filepath.Join(prdDir, "prd.json") - a.manager.Register(prdName, prdPath) + a.eng.Register(prdName, prdPath) } // Start the loop via manager - if err := a.manager.Start(prdName); err != nil { + if err := a.eng.Start(prdName); err != nil { a.lastActivity = "Error starting loop: " + err.Error() return a, nil } @@ -738,8 +800,8 @@ func (a App) pauseLoop() (tea.Model, tea.Cmd) { // pauseLoopForPRD pauses the loop for a specific PRD. func (a App) pauseLoopForPRD(prdName string) (tea.Model, tea.Cmd) { - if a.manager != nil { - a.manager.Pause(prdName) + if a.eng != nil { + a.eng.Pause(prdName) } if prdName == a.prdName { a.lastActivity = "Pausing after current iteration..." @@ -756,8 +818,8 @@ func (a *App) stopLoop() { // stopLoopForPRD stops the loop for a specific PRD immediately. func (a *App) stopLoopForPRD(prdName string) { - if a.manager != nil { - a.manager.Stop(prdName) + if a.eng != nil { + a.eng.Stop(prdName) } } @@ -778,10 +840,13 @@ func (a App) stopLoopAndUpdateForPRD(prdName string) (tea.Model, tea.Cmd) { return a, nil } -// stopAllLoops stops all running loops. +// stopAllLoops stops all running loops and unsubscribes from events. func (a *App) stopAllLoops() { - if a.manager != nil { - a.manager.StopAll() + if a.eng != nil { + a.eng.StopAll() + } + if a.unsubFn != nil { + a.unsubFn() } } @@ -880,7 +945,7 @@ func (a App) handleLoopFinished(prdName string, err error) (tea.Model, tea.Cmd) // Only update state if this is the current PRD if prdName == a.prdName { // Get the actual state from the manager - if state, _, _ := a.manager.GetState(prdName); state != 0 { + if state, _, _ := a.eng.GetState(prdName); state != 0 { switch state { case loop.LoopStateError: a.state = StateError @@ -1028,8 +1093,8 @@ func (a App) handleBranchWarningKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) { return a, nil } // Track the branch on the manager instance - if instance := a.manager.GetInstance(prdName); instance != nil { - a.manager.UpdateWorktreeInfo(prdName, "", branchName) + if instance := a.eng.GetInstance(prdName); instance != nil { + a.eng.UpdateWorktreeInfo(prdName, "", branchName) } a.lastActivity = "Created branch: " + branchName // Now start the loop @@ -1095,7 +1160,7 @@ func (a *App) showCompletionScreen(prdName string) tea.Cmd { // Get branch from manager branch := "" - if instance := a.manager.GetInstance(prdName); instance != nil { + if instance := a.eng.GetInstance(prdName); instance != nil { branch = instance.Branch } @@ -1139,7 +1204,7 @@ func (a *App) runBackgroundAutoActions(prdName string) tea.Cmd { return nil } - instance := a.manager.GetInstance(prdName) + instance := a.eng.GetInstance(prdName) if instance == nil || instance.Branch == "" { return nil } @@ -1198,7 +1263,7 @@ func (a App) handleBackgroundAutoAction(msg backgroundAutoActionResultMsg) (tea. if msg.action == "push" && a.config != nil && a.config.OnComplete.CreatePR { // Chain PR creation after successful push - instance := a.manager.GetInstance(msg.prdName) + instance := a.eng.GetInstance(msg.prdName) if instance != nil && instance.Branch != "" { prdName := msg.prdName branch := instance.Branch @@ -1225,7 +1290,7 @@ func (a *App) runAutoPush() tea.Cmd { branch := a.completionScreen.Branch() // Use worktree dir if available, otherwise base dir dir := a.baseDir - if instance := a.manager.GetInstance(a.completionScreen.PRDName()); instance != nil && instance.WorktreeDir != "" { + if instance := a.eng.GetInstance(a.completionScreen.PRDName()); instance != nil && instance.WorktreeDir != "" { dir = instance.WorktreeDir } return func() tea.Msg { @@ -1513,10 +1578,10 @@ func (a App) finishWorktreeSetup() (tea.Model, tea.Cmd) { // Register or update with worktree info prdPath := filepath.Join(prdDir, "prd.json") - if instance := a.manager.GetInstance(prdName); instance == nil { - a.manager.RegisterWithWorktree(prdName, prdPath, worktreePath, branchName) + if instance := a.eng.GetInstance(prdName); instance == nil { + a.eng.RegisterWithWorktree(prdName, prdPath, worktreePath, branchName) } else { - a.manager.UpdateWorktreeInfo(prdName, worktreePath, branchName) + a.eng.UpdateWorktreeInfo(prdName, worktreePath, branchName) } a.lastActivity = fmt.Sprintf("Created worktree at %s on branch %s", worktreePath, branchName) @@ -1631,8 +1696,8 @@ func (a App) handleCleanResult(msg cleanResultMsg) (tea.Model, tea.Cmd) { if msg.success { // Clear worktree info from manager - if a.manager != nil { - a.manager.ClearWorktreeInfo(msg.prdName, msg.clearBranch) + if a.eng != nil { + a.eng.ClearWorktreeInfo(msg.prdName, msg.clearBranch) } a.picker.Refresh() a.lastActivity = fmt.Sprintf("Cleaned worktree for %s", msg.prdName) @@ -1823,8 +1888,8 @@ func (a App) switchToPRD(name, prdPath string) (tea.Model, tea.Cmd) { } // Register with manager if not already registered - if instance := a.manager.GetInstance(name); instance == nil { - a.manager.Register(name, prdPath) + if instance := a.eng.GetInstance(name); instance == nil { + a.eng.Register(name, prdPath) } // Create new watcher for the new PRD @@ -1839,7 +1904,7 @@ func (a App) switchToPRD(name, prdPath string) (tea.Model, tea.Cmd) { } // Get the state from the manager for this PRD - loopState, iteration, loopErr := a.manager.GetState(name) + loopState, iteration, loopErr := a.eng.GetState(name) appState := StateReady switch loopState { case loop.LoopStateRunning: @@ -1864,7 +1929,7 @@ func (a App) switchToPRD(name, prdPath string) (tea.Model, tea.Cmd) { a.err = loopErr if appState == StateRunning { // Keep the existing start time if running - if instance := a.manager.GetInstance(name); instance != nil { + if instance := a.eng.GetInstance(name); instance != nil { a.startTime = instance.StartTime } } else { @@ -1957,10 +2022,10 @@ func (a *App) adjustMaxIterations(delta int) { a.maxIter = newMax // Update the manager's default - if a.manager != nil { - a.manager.SetMaxIterations(newMax) + if a.eng != nil { + a.eng.SetMaxIterations(newMax) // Also update any running loop for the current PRD - a.manager.SetMaxIterationsForInstance(a.prdName, newMax) + a.eng.SetMaxIterationsForInstance(a.prdName, newMax) } a.lastActivity = fmt.Sprintf("Max iterations: %d", newMax) diff --git a/internal/tui/dashboard.go b/internal/tui/dashboard.go index 6268363..cea7f6b 100644 --- a/internal/tui/dashboard.go +++ b/internal/tui/dashboard.go @@ -84,10 +84,10 @@ func (a *App) renderStackedDashboard() string { // getWorktreeInfo returns the branch and directory info for the current PRD. // Returns empty strings if no branch is set (backward compatible). func (a *App) getWorktreeInfo() (branch, dir string) { - if a.manager == nil { + if a.eng == nil { return "", "" } - instance := a.manager.GetInstance(a.prdName) + instance := a.eng.GetInstance(a.prdName) if instance == nil || instance.Branch == "" { return "", "" } diff --git a/internal/tui/layout_test.go b/internal/tui/layout_test.go index b82f610..c38ad61 100644 --- a/internal/tui/layout_test.go +++ b/internal/tui/layout_test.go @@ -4,7 +4,7 @@ import ( "strings" "testing" - "github.com/minicodemonkey/chief/internal/loop" + "github.com/minicodemonkey/chief/internal/engine" ) func TestIsNarrowMode(t *testing.T) { @@ -207,20 +207,36 @@ func TestMinMaxHelpers(t *testing.T) { } } +// newTestEngine creates a test engine with the given PRD registered. +func newTestEngine(name, prdPath string) *engine.Engine { + eng := engine.New(10) + if prdPath != "" { + eng.Register(name, prdPath) + } + return eng +} + +// newTestEngineWithWorktree creates a test engine with a worktree-registered PRD. +func newTestEngineWithWorktree(name, prdPath, worktreeDir, branch string) *engine.Engine { + eng := engine.New(10) + eng.RegisterWithWorktree(name, prdPath, worktreeDir, branch) + return eng +} + func TestGetWorktreeInfo_NoBranch(t *testing.T) { - // No manager - should return empty + // No engine - should return empty app := &App{prdName: "auth"} branch, dir := app.getWorktreeInfo() if branch != "" || dir != "" { - t.Errorf("expected empty worktree info without manager, got branch=%q dir=%q", branch, dir) + t.Errorf("expected empty worktree info without engine, got branch=%q dir=%q", branch, dir) } } func TestGetWorktreeInfo_WithBranch(t *testing.T) { - mgr := loop.NewManager(10) - mgr.RegisterWithWorktree("auth", "/tmp/prd.json", "/tmp/.chief/worktrees/auth", "chief/auth") + eng := newTestEngineWithWorktree("auth", "/tmp/prd.json", "/tmp/.chief/worktrees/auth", "chief/auth") + defer eng.Shutdown() - app := &App{prdName: "auth", manager: mgr} + app := &App{prdName: "auth", eng: eng} branch, dir := app.getWorktreeInfo() if branch != "chief/auth" { t.Errorf("branch = %q, want %q", branch, "chief/auth") @@ -232,10 +248,10 @@ func TestGetWorktreeInfo_WithBranch(t *testing.T) { func TestGetWorktreeInfo_WithBranchNoWorktree(t *testing.T) { // Branch set but no worktree dir (branch-only mode) - mgr := loop.NewManager(10) - mgr.RegisterWithWorktree("auth", "/tmp/prd.json", "", "chief/auth") + eng := newTestEngineWithWorktree("auth", "/tmp/prd.json", "", "chief/auth") + defer eng.Shutdown() - app := &App{prdName: "auth", manager: mgr} + app := &App{prdName: "auth", eng: eng} branch, dir := app.getWorktreeInfo() if branch != "chief/auth" { t.Errorf("branch = %q, want %q", branch, "chief/auth") @@ -247,10 +263,10 @@ func TestGetWorktreeInfo_WithBranchNoWorktree(t *testing.T) { func TestGetWorktreeInfo_RegisteredNoBranch(t *testing.T) { // Registered without worktree - should return empty (backward compatible) - mgr := loop.NewManager(10) - mgr.Register("auth", "/tmp/prd.json") + eng := newTestEngine("auth", "/tmp/prd.json") + defer eng.Shutdown() - app := &App{prdName: "auth", manager: mgr} + app := &App{prdName: "auth", eng: eng} branch, dir := app.getWorktreeInfo() if branch != "" || dir != "" { t.Errorf("expected empty worktree info for no-branch PRD, got branch=%q dir=%q", branch, dir) @@ -258,16 +274,17 @@ func TestGetWorktreeInfo_RegisteredNoBranch(t *testing.T) { } func TestHasWorktreeInfo(t *testing.T) { - // No manager + // No engine app := &App{prdName: "auth"} if app.hasWorktreeInfo() { - t.Error("expected hasWorktreeInfo=false without manager") + t.Error("expected hasWorktreeInfo=false without engine") } // With branch - mgr := loop.NewManager(10) - mgr.RegisterWithWorktree("auth", "/tmp/prd.json", "/tmp/.chief/worktrees/auth", "chief/auth") - app.manager = mgr + eng := newTestEngineWithWorktree("auth", "/tmp/prd.json", "/tmp/.chief/worktrees/auth", "chief/auth") + defer eng.Shutdown() + + app.eng = eng if !app.hasWorktreeInfo() { t.Error("expected hasWorktreeInfo=true with branch set") } @@ -281,10 +298,10 @@ func TestEffectiveHeaderHeight_NoBranch(t *testing.T) { } func TestEffectiveHeaderHeight_WithBranch(t *testing.T) { - mgr := loop.NewManager(10) - mgr.RegisterWithWorktree("auth", "/tmp/prd.json", "/tmp/.chief/worktrees/auth", "chief/auth") + eng := newTestEngineWithWorktree("auth", "/tmp/prd.json", "/tmp/.chief/worktrees/auth", "chief/auth") + defer eng.Shutdown() - app := &App{prdName: "auth", manager: mgr} + app := &App{prdName: "auth", eng: eng} if got := app.effectiveHeaderHeight(); got != headerHeight+1 { t.Errorf("effectiveHeaderHeight() = %d, want %d (with branch)", got, headerHeight+1) } @@ -298,10 +315,10 @@ func TestRenderWorktreeInfoLine_NoBranch(t *testing.T) { } func TestRenderWorktreeInfoLine_WithBranch(t *testing.T) { - mgr := loop.NewManager(10) - mgr.RegisterWithWorktree("auth", "/tmp/prd.json", "/tmp/.chief/worktrees/auth", "chief/auth") + eng := newTestEngineWithWorktree("auth", "/tmp/prd.json", "/tmp/.chief/worktrees/auth", "chief/auth") + defer eng.Shutdown() - app := &App{prdName: "auth", manager: mgr} + app := &App{prdName: "auth", eng: eng} got := app.renderWorktreeInfoLine() if got == "" { t.Error("renderWorktreeInfoLine() should not be empty with branch set") @@ -321,10 +338,10 @@ func TestRenderWorktreeInfoLine_WithBranch(t *testing.T) { } func TestRenderWorktreeInfoLine_BranchNoWorktree(t *testing.T) { - mgr := loop.NewManager(10) - mgr.RegisterWithWorktree("auth", "/tmp/prd.json", "", "chief/auth") + eng := newTestEngineWithWorktree("auth", "/tmp/prd.json", "", "chief/auth") + defer eng.Shutdown() - app := &App{prdName: "auth", manager: mgr} + app := &App{prdName: "auth", eng: eng} got := app.renderWorktreeInfoLine() if !strings.Contains(got, "current directory") { t.Errorf("renderWorktreeInfoLine() should contain 'current directory' for branch-only mode, got %q", got) diff --git a/internal/update/update.go b/internal/update/update.go new file mode 100644 index 0000000..dda3931 --- /dev/null +++ b/internal/update/update.go @@ -0,0 +1,293 @@ +package update + +import ( + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "runtime" + "strings" + "time" +) + +const ( + defaultReleasesURL = "https://api.github.com/repos/MiniCodeMonkey/chief/releases/latest" + downloadTimeout = 5 * time.Minute + checkTimeout = 10 * time.Second +) + +// Release represents a GitHub release. +type Release struct { + TagName string `json:"tag_name"` + Assets []Asset `json:"assets"` +} + +// Asset represents a release asset. +type Asset struct { + Name string `json:"name"` + BrowserDownloadURL string `json:"browser_download_url"` +} + +// CheckResult contains the result of a version check. +type CheckResult struct { + CurrentVersion string + LatestVersion string + UpdateAvailable bool +} + +// Options configures the update checker. +type Options struct { + ReleasesURL string // Override GitHub API URL (for testing) +} + +func (o Options) releasesURL() string { + if o.ReleasesURL != "" { + return o.ReleasesURL + } + return defaultReleasesURL +} + +// CheckForUpdate checks if a newer version is available. +func CheckForUpdate(currentVersion string, opts Options) (*CheckResult, error) { + client := &http.Client{Timeout: checkTimeout} + resp, err := client.Get(opts.releasesURL()) + if err != nil { + return nil, fmt.Errorf("fetching latest release: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("GitHub API returned status %d", resp.StatusCode) + } + + var release Release + if err := json.NewDecoder(resp.Body).Decode(&release); err != nil { + return nil, fmt.Errorf("parsing release response: %w", err) + } + + latest := normalizeVersion(release.TagName) + current := normalizeVersion(currentVersion) + + return &CheckResult{ + CurrentVersion: current, + LatestVersion: latest, + UpdateAvailable: baseVersion(current) != latest && current != "dev", + }, nil +} + +// PerformUpdate downloads and installs the latest version. +func PerformUpdate(currentVersion string, opts Options) (*CheckResult, error) { + client := &http.Client{Timeout: checkTimeout} + resp, err := client.Get(opts.releasesURL()) + if err != nil { + return nil, fmt.Errorf("fetching latest release: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("GitHub API returned status %d", resp.StatusCode) + } + + var release Release + if err := json.NewDecoder(resp.Body).Decode(&release); err != nil { + return nil, fmt.Errorf("parsing release response: %w", err) + } + + latest := normalizeVersion(release.TagName) + current := normalizeVersion(currentVersion) + + if baseVersion(current) == latest { + return &CheckResult{ + CurrentVersion: current, + LatestVersion: latest, + UpdateAvailable: false, + }, nil + } + + // Find the binary asset for this OS/arch + binaryAsset, checksumAsset := findAssets(release.Assets, runtime.GOOS, runtime.GOARCH) + if binaryAsset == nil { + return nil, fmt.Errorf("no binary available for %s/%s", runtime.GOOS, runtime.GOARCH) + } + + // Get the current binary path + binaryPath, err := os.Executable() + if err != nil { + return nil, fmt.Errorf("finding current binary path: %w", err) + } + binaryPath, err = filepath.EvalSymlinks(binaryPath) + if err != nil { + return nil, fmt.Errorf("resolving binary path: %w", err) + } + + // Check write permissions + dir := filepath.Dir(binaryPath) + if err := checkWritePermission(dir); err != nil { + return nil, fmt.Errorf("Permission denied. Run 'sudo chief update' to upgrade.") + } + + // Download binary to temp file + tmpFile, err := downloadToTemp(binaryAsset.BrowserDownloadURL, dir) + if err != nil { + return nil, fmt.Errorf("downloading update: %w", err) + } + defer os.Remove(tmpFile) // Clean up on failure + + // Verify checksum if available + if checksumAsset != nil { + if err := verifyChecksum(tmpFile, checksumAsset.BrowserDownloadURL); err != nil { + return nil, fmt.Errorf("checksum verification failed: %w", err) + } + } + + // Make the new binary executable + if err := os.Chmod(tmpFile, 0o755); err != nil { + return nil, fmt.Errorf("setting permissions on new binary: %w", err) + } + + // Atomic rename to replace current binary + if err := os.Rename(tmpFile, binaryPath); err != nil { + return nil, fmt.Errorf("replacing binary: %w", err) + } + + return &CheckResult{ + CurrentVersion: current, + LatestVersion: latest, + UpdateAvailable: true, + }, nil +} + +// normalizeVersion strips the "v" prefix from version strings. +func normalizeVersion(v string) string { + return strings.TrimPrefix(v, "v") +} + +// baseVersion extracts the base semver from a git-describe version string. +// For example, "0.4.0-61-gd06835b" returns "0.4.0". +// Handles dirty builds too: "0.4.0-61-gd06835b-dirty" returns "0.4.0". +// A plain version like "0.4.0" is returned unchanged. +func baseVersion(v string) string { + v = normalizeVersion(v) + // Strip "-dirty" suffix from uncommitted builds. + v = strings.TrimSuffix(v, "-dirty") + // Git describe format: "0.4.0-61-gd06835b" (tag-commits-ghash) + // Strip the "-N-gHASH" suffix to get the base semver. + parts := strings.Split(v, "-") + if len(parts) >= 3 { + last := parts[len(parts)-1] + if strings.HasPrefix(last, "g") { + return strings.Join(parts[:len(parts)-2], "-") + } + } + return v +} + +// CompareVersions returns true if latest is different from the base version of current. +// Dev builds (e.g. "0.4.0-61-gd06835b") are compared by their base tag ("0.4.0"). +func CompareVersions(current, latest string) bool { + current = normalizeVersion(current) + latest = normalizeVersion(latest) + return baseVersion(current) != latest && current != "dev" +} + +// findAssets locates the binary and checksum assets for the given OS/arch. +func findAssets(assets []Asset, goos, goarch string) (*Asset, *Asset) { + binaryName := fmt.Sprintf("chief-%s-%s", goos, goarch) + checksumName := binaryName + ".sha256" + + var binary, checksum *Asset + for i := range assets { + if assets[i].Name == binaryName { + binary = &assets[i] + } + if assets[i].Name == checksumName { + checksum = &assets[i] + } + } + return binary, checksum +} + +// checkWritePermission checks if we can write to the directory. +func checkWritePermission(dir string) error { + tmp, err := os.CreateTemp(dir, ".chief-update-check-*") + if err != nil { + return err + } + tmp.Close() + os.Remove(tmp.Name()) + return nil +} + +// downloadToTemp downloads a URL to a temporary file in the specified directory. +func downloadToTemp(url, dir string) (string, error) { + client := &http.Client{Timeout: downloadTimeout} + resp, err := client.Get(url) + if err != nil { + return "", fmt.Errorf("downloading %s: %w", url, err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("download returned status %d", resp.StatusCode) + } + + tmp, err := os.CreateTemp(dir, ".chief-update-*") + if err != nil { + return "", fmt.Errorf("creating temp file: %w", err) + } + + if _, err := io.Copy(tmp, resp.Body); err != nil { + tmp.Close() + os.Remove(tmp.Name()) + return "", fmt.Errorf("writing download: %w", err) + } + tmp.Close() + + return tmp.Name(), nil +} + +// verifyChecksum downloads the expected SHA256 checksum and verifies the file. +func verifyChecksum(filePath, checksumURL string) error { + // Download checksum file + client := &http.Client{Timeout: checkTimeout} + resp, err := client.Get(checksumURL) + if err != nil { + return fmt.Errorf("downloading checksum: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("checksum download returned status %d", resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("reading checksum: %w", err) + } + + // Parse checksum (format: "hash filename" or just "hash") + expectedHash := strings.Fields(strings.TrimSpace(string(body)))[0] + + // Calculate actual hash + f, err := os.Open(filePath) + if err != nil { + return fmt.Errorf("opening file for checksum: %w", err) + } + defer f.Close() + + h := sha256.New() + if _, err := io.Copy(h, f); err != nil { + return fmt.Errorf("computing checksum: %w", err) + } + actualHash := hex.EncodeToString(h.Sum(nil)) + + if actualHash != expectedHash { + return fmt.Errorf("expected %s, got %s", expectedHash, actualHash) + } + + return nil +} diff --git a/internal/update/update_test.go b/internal/update/update_test.go new file mode 100644 index 0000000..e1a5092 --- /dev/null +++ b/internal/update/update_test.go @@ -0,0 +1,431 @@ +package update + +import ( + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "runtime" + "testing" +) + +func TestCheckForUpdate_UpdateAvailable(t *testing.T) { + release := Release{ + TagName: "v0.5.1", + Assets: []Asset{{Name: "chief-linux-amd64", BrowserDownloadURL: "http://example.com/chief"}}, + } + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(release) + })) + defer srv.Close() + + result, err := CheckForUpdate("0.5.0", Options{ReleasesURL: srv.URL}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !result.UpdateAvailable { + t.Error("expected update to be available") + } + if result.LatestVersion != "0.5.1" { + t.Errorf("expected latest version 0.5.1, got %s", result.LatestVersion) + } + if result.CurrentVersion != "0.5.0" { + t.Errorf("expected current version 0.5.0, got %s", result.CurrentVersion) + } +} + +func TestCheckForUpdate_AlreadyLatest(t *testing.T) { + release := Release{TagName: "v0.5.0"} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(release) + })) + defer srv.Close() + + result, err := CheckForUpdate("v0.5.0", Options{ReleasesURL: srv.URL}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.UpdateAvailable { + t.Error("expected no update available") + } +} + +func TestCheckForUpdate_DevVersion(t *testing.T) { + release := Release{TagName: "v1.0.0"} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(release) + })) + defer srv.Close() + + result, err := CheckForUpdate("dev", Options{ReleasesURL: srv.URL}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.UpdateAvailable { + t.Error("dev version should not report update available") + } +} + +func TestCheckForUpdate_DevBuildSameTag(t *testing.T) { + release := Release{TagName: "v0.4.0"} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(release) + })) + defer srv.Close() + + result, err := CheckForUpdate("v0.4.0-61-gd06835b", Options{ReleasesURL: srv.URL}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.UpdateAvailable { + t.Error("dev build ahead of the same tag should not report update available") + } +} + +func TestCheckForUpdate_DevBuildOlderTag(t *testing.T) { + release := Release{TagName: "v0.5.0"} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(release) + })) + defer srv.Close() + + result, err := CheckForUpdate("v0.4.0-61-gd06835b", Options{ReleasesURL: srv.URL}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !result.UpdateAvailable { + t.Error("dev build should report update available when a newer release exists") + } +} + +func TestCheckForUpdate_APIError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer srv.Close() + + _, err := CheckForUpdate("0.5.0", Options{ReleasesURL: srv.URL}) + if err == nil { + t.Error("expected error for API failure") + } +} + +func TestCheckForUpdate_BadJSON(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("not json")) + })) + defer srv.Close() + + _, err := CheckForUpdate("0.5.0", Options{ReleasesURL: srv.URL}) + if err == nil { + t.Error("expected error for bad JSON") + } +} + +func TestNormalizeVersion(t *testing.T) { + tests := []struct { + input, expected string + }{ + {"v0.5.0", "0.5.0"}, + {"0.5.0", "0.5.0"}, + {"v1.0.0-beta", "1.0.0-beta"}, + {"dev", "dev"}, + } + + for _, tc := range tests { + result := normalizeVersion(tc.input) + if result != tc.expected { + t.Errorf("normalizeVersion(%q) = %q, want %q", tc.input, result, tc.expected) + } + } +} + +func TestBaseVersion(t *testing.T) { + tests := []struct { + input, expected string + }{ + {"v0.4.0-61-gd06835b", "0.4.0"}, + {"0.4.0-61-gd06835b", "0.4.0"}, + {"v0.5.0", "0.5.0"}, + {"0.5.0", "0.5.0"}, + {"1.0.0-beta", "1.0.0-beta"}, + {"1.0.0-beta-3-gabcdef1", "1.0.0-beta"}, + {"dev", "dev"}, + {"v0.4.0-dirty", "0.4.0"}, + {"0.4.0-81-g1d1ebf3-dirty", "0.4.0"}, + } + + for _, tc := range tests { + result := baseVersion(tc.input) + if result != tc.expected { + t.Errorf("baseVersion(%q) = %q, want %q", tc.input, result, tc.expected) + } + } +} + +func TestCompareVersions(t *testing.T) { + tests := []struct { + current, latest string + expected bool + }{ + {"0.5.0", "0.5.1", true}, + {"0.5.0", "0.5.0", false}, + {"v0.5.0", "v0.5.0", false}, + {"dev", "1.0.0", false}, + {"1.0.0", "2.0.0", true}, + // Dev builds ahead of a tag should not report update available + {"0.4.0-61-gd06835b", "0.4.0", false}, + {"v0.4.0-61-gd06835b", "v0.4.0", false}, + // Dev builds should report update when a newer release exists + {"0.4.0-61-gd06835b", "0.5.0", true}, + } + + for _, tc := range tests { + result := CompareVersions(tc.current, tc.latest) + if result != tc.expected { + t.Errorf("CompareVersions(%q, %q) = %v, want %v", tc.current, tc.latest, result, tc.expected) + } + } +} + +func TestFindAssets(t *testing.T) { + assets := []Asset{ + {Name: "chief-linux-amd64", BrowserDownloadURL: "http://example.com/chief-linux-amd64"}, + {Name: "chief-linux-amd64.sha256", BrowserDownloadURL: "http://example.com/chief-linux-amd64.sha256"}, + {Name: "chief-darwin-arm64", BrowserDownloadURL: "http://example.com/chief-darwin-arm64"}, + } + + binary, checksum := findAssets(assets, "linux", "amd64") + if binary == nil { + t.Fatal("expected to find binary asset") + } + if binary.Name != "chief-linux-amd64" { + t.Errorf("expected chief-linux-amd64, got %s", binary.Name) + } + if checksum == nil { + t.Fatal("expected to find checksum asset") + } + if checksum.Name != "chief-linux-amd64.sha256" { + t.Errorf("expected chief-linux-amd64.sha256, got %s", checksum.Name) + } +} + +func TestFindAssets_NoMatch(t *testing.T) { + assets := []Asset{ + {Name: "chief-linux-amd64", BrowserDownloadURL: "http://example.com/chief-linux-amd64"}, + } + + binary, _ := findAssets(assets, "windows", "amd64") + if binary != nil { + t.Error("expected no binary for windows/amd64") + } +} + +func TestFindAssets_NoChecksum(t *testing.T) { + assets := []Asset{ + {Name: "chief-linux-amd64", BrowserDownloadURL: "http://example.com/chief-linux-amd64"}, + } + + binary, checksum := findAssets(assets, "linux", "amd64") + if binary == nil { + t.Fatal("expected to find binary") + } + if checksum != nil { + t.Error("expected no checksum") + } +} + +func TestCheckWritePermission_Success(t *testing.T) { + dir := t.TempDir() + if err := checkWritePermission(dir); err != nil { + t.Errorf("expected write permission check to pass: %v", err) + } +} + +func TestCheckWritePermission_Fail(t *testing.T) { + dir := t.TempDir() + os.Chmod(dir, 0o555) + defer os.Chmod(dir, 0o755) // restore for cleanup + + if err := checkWritePermission(dir); err == nil { + t.Error("expected write permission check to fail") + } +} + +func TestDownloadToTemp(t *testing.T) { + content := "binary content here" + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(content)) + })) + defer srv.Close() + + dir := t.TempDir() + tmpFile, err := downloadToTemp(srv.URL, dir) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer os.Remove(tmpFile) + + data, err := os.ReadFile(tmpFile) + if err != nil { + t.Fatalf("reading temp file: %v", err) + } + if string(data) != content { + t.Errorf("expected %q, got %q", content, string(data)) + } + + // Verify temp file is in the right directory + if filepath.Dir(tmpFile) != dir { + t.Errorf("temp file should be in %s, got %s", dir, filepath.Dir(tmpFile)) + } +} + +func TestDownloadToTemp_ServerError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + defer srv.Close() + + dir := t.TempDir() + _, err := downloadToTemp(srv.URL, dir) + if err == nil { + t.Error("expected error for server error") + } +} + +func TestVerifyChecksum(t *testing.T) { + // Create a temp file with known content + dir := t.TempDir() + filePath := filepath.Join(dir, "binary") + content := []byte("test binary content") + os.WriteFile(filePath, content, 0o644) + + // Calculate expected hash + h := sha256.Sum256(content) + expectedHash := hex.EncodeToString(h[:]) + + // Serve checksum + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "%s binary\n", expectedHash) + })) + defer srv.Close() + + if err := verifyChecksum(filePath, srv.URL); err != nil { + t.Errorf("expected checksum verification to pass: %v", err) + } +} + +func TestVerifyChecksum_Mismatch(t *testing.T) { + dir := t.TempDir() + filePath := filepath.Join(dir, "binary") + os.WriteFile(filePath, []byte("content"), 0o644) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "0000000000000000000000000000000000000000000000000000000000000000 binary\n") + })) + defer srv.Close() + + if err := verifyChecksum(filePath, srv.URL); err == nil { + t.Error("expected checksum verification to fail") + } +} + +func TestPerformUpdate_AlreadyLatest(t *testing.T) { + release := Release{TagName: "v0.5.0"} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(release) + })) + defer srv.Close() + + result, err := PerformUpdate("0.5.0", Options{ReleasesURL: srv.URL}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.UpdateAvailable { + t.Error("expected no update available when already on latest") + } +} + +func TestPerformUpdate_FullFlow(t *testing.T) { + // Create a fake "current binary" + dir := t.TempDir() + binaryPath := filepath.Join(dir, "chief") + os.WriteFile(binaryPath, []byte("old binary"), 0o755) + + // New binary content + newContent := []byte("new binary v0.6.0") + h := sha256.Sum256(newContent) + expectedHash := hex.EncodeToString(h[:]) + + binaryName := fmt.Sprintf("chief-%s-%s", runtime.GOOS, runtime.GOARCH) + + // Set up download server + downloadSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/binary" { + w.Write(newContent) + } else if r.URL.Path == "/checksum" { + fmt.Fprintf(w, "%s %s\n", expectedHash, binaryName) + } + })) + defer downloadSrv.Close() + + release := Release{ + TagName: "v0.6.0", + Assets: []Asset{ + {Name: binaryName, BrowserDownloadURL: downloadSrv.URL + "/binary"}, + {Name: binaryName + ".sha256", BrowserDownloadURL: downloadSrv.URL + "/checksum"}, + }, + } + releaseSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(release) + })) + defer releaseSrv.Close() + + // We can't easily test PerformUpdate because it calls os.Executable() + // Instead, test the components used by PerformUpdate individually + // (CheckForUpdate, findAssets, downloadToTemp, verifyChecksum are all tested above) + + // Test the download + checksum flow manually + tmpFile, err := downloadToTemp(downloadSrv.URL+"/binary", dir) + if err != nil { + t.Fatalf("download failed: %v", err) + } + defer os.Remove(tmpFile) + + if err := verifyChecksum(tmpFile, downloadSrv.URL+"/checksum"); err != nil { + t.Fatalf("checksum failed: %v", err) + } + + // Verify file content + data, err := os.ReadFile(tmpFile) + if err != nil { + t.Fatalf("reading downloaded file: %v", err) + } + if string(data) != string(newContent) { + t.Errorf("downloaded content mismatch") + } +} + +func TestCheckForUpdate_VersionWithVPrefix(t *testing.T) { + release := Release{TagName: "v0.5.1"} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(release) + })) + defer srv.Close() + + // Passing version with v prefix should still work + result, err := CheckForUpdate("v0.5.0", Options{ReleasesURL: srv.URL}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !result.UpdateAvailable { + t.Error("expected update to be available") + } + if result.CurrentVersion != "0.5.0" { + t.Errorf("expected normalized version 0.5.0, got %s", result.CurrentVersion) + } +} diff --git a/internal/uplink/batcher.go b/internal/uplink/batcher.go new file mode 100644 index 0000000..7a9f483 --- /dev/null +++ b/internal/uplink/batcher.go @@ -0,0 +1,271 @@ +package uplink + +import ( + "context" + "crypto/rand" + "encoding/json" + "fmt" + "log" + "sync" + "time" +) + +// Flush tier durations. +const ( + // tierImmediate flushes immediately (0ms delay). + tierImmediate = 0 * time.Millisecond + + // tierStandard flushes after 200ms. + tierStandard = 200 * time.Millisecond + + // tierLowPriority flushes after 1s. + tierLowPriority = 1 * time.Second + + // maxBatchMessages is the maximum number of messages before a forced flush. + maxBatchMessages = 20 + + // maxBufferMessages is the maximum total messages in the buffer before dropping. + maxBufferMessages = 1000 + + // maxBufferBytes is the maximum total payload size (5MB) before dropping. + maxBufferBytes = 5 * 1024 * 1024 +) + +// tier identifies a flush priority tier. +type tier int + +const ( + tierIDImmediate tier = iota + tierIDStandard + tierIDLowPriority +) + +// tierFor returns the tier for a given message type. +func tierFor(msgType string) tier { + switch msgType { + case "run_complete", "run_paused", "error", "clone_complete", "session_expired", "quota_exhausted", "prd_response_complete": + return tierIDImmediate + case "claude_output", "prd_output", "run_progress", "clone_progress": + return tierIDStandard + case "state_snapshot", "project_state", "project_list", "settings", "log_lines": + return tierIDLowPriority + default: + // Unknown types go to standard tier. + return tierIDStandard + } +} + +// tierDelay returns the flush delay for a tier. +func tierDelay(t tier) time.Duration { + switch t { + case tierIDImmediate: + return tierImmediate + case tierIDStandard: + return tierStandard + case tierIDLowPriority: + return tierLowPriority + default: + return tierStandard + } +} + +// bufferedMessage is a message waiting to be flushed. +type bufferedMessage struct { + data json.RawMessage + tier tier + size int +} + +// SendFunc is the function called on flush to send a batch of messages. +// batchID is a unique UUID for idempotency. The function should retry internally if needed. +type SendFunc func(batchID string, messages []json.RawMessage) error + +// Batcher collects outgoing messages and flushes them in batches. +// Messages are assigned to priority tiers that control flush timing. +// Flushes are sequential — the next flush waits for the current one to complete. +type Batcher struct { + sendFn SendFunc + + mu sync.Mutex + messages []bufferedMessage + totalSize int + flushNotify chan struct{} // signals the run loop that a flush may be needed + stopped bool + + // Timer management for tier-based flushing. + standardTimer *time.Timer + lowPriorityTimer *time.Timer + standardActive bool + lowPriorityActive bool +} + +// NewBatcher creates a Batcher that calls sendFn on each flush. +func NewBatcher(sendFn SendFunc) *Batcher { + return &Batcher{ + sendFn: sendFn, + flushNotify: make(chan struct{}, 1), + } +} + +// Enqueue adds a message to the appropriate tier buffer. +// It is safe to call from multiple goroutines. +func (b *Batcher) Enqueue(msg json.RawMessage, msgType string) { + t := tierFor(msgType) + size := len(msg) + + b.mu.Lock() + defer b.mu.Unlock() + + if b.stopped { + return + } + + // Check buffer limits and drop low-priority messages if full. + for b.totalSize+size > maxBufferBytes || len(b.messages)+1 > maxBufferMessages { + if !b.dropLowestPriority() { + // Nothing left to drop — reject this message too. + log.Printf("batcher: buffer full, dropping %s message (%d bytes)", msgType, size) + return + } + } + + b.messages = append(b.messages, bufferedMessage{data: msg, tier: t, size: size}) + b.totalSize += size + + // Determine if we need to flush now. + shouldFlushNow := t == tierIDImmediate || len(b.messages) >= maxBatchMessages + + if shouldFlushNow { + b.notifyFlush() + return + } + + // Start tier timers if not already running. + if t == tierIDStandard && !b.standardActive { + b.standardActive = true + if b.standardTimer == nil { + b.standardTimer = time.AfterFunc(tierStandard, func() { + b.mu.Lock() + b.standardActive = false + b.mu.Unlock() + b.notifyFlush() + }) + } else { + b.standardTimer.Reset(tierStandard) + } + } + if t == tierIDLowPriority && !b.lowPriorityActive { + b.lowPriorityActive = true + if b.lowPriorityTimer == nil { + b.lowPriorityTimer = time.AfterFunc(tierLowPriority, func() { + b.mu.Lock() + b.lowPriorityActive = false + b.mu.Unlock() + b.notifyFlush() + }) + } else { + b.lowPriorityTimer.Reset(tierLowPriority) + } + } +} + +// notifyFlush signals the run loop that a flush should happen. +// Must not be called with b.mu held if blocking is possible, but the channel is buffered. +func (b *Batcher) notifyFlush() { + select { + case b.flushNotify <- struct{}{}: + default: + // Already notified. + } +} + +// dropLowestPriority removes the last low-priority message from the buffer. +// Returns false if there are no low-priority messages to drop. +// Caller must hold b.mu. +func (b *Batcher) dropLowestPriority() bool { + // Search from the end for the lowest-priority message. + // Priority order for dropping: low priority first, then standard. + for priority := tierIDLowPriority; priority >= tierIDStandard; priority-- { + for i := len(b.messages) - 1; i >= 0; i-- { + if b.messages[i].tier == priority { + b.totalSize -= b.messages[i].size + log.Printf("batcher: buffer overflow, dropping message at index %d (tier %d, %d bytes)", i, priority, b.messages[i].size) + b.messages = append(b.messages[:i], b.messages[i+1:]...) + return true + } + } + } + return false +} + +// Run starts the background flush loop. It blocks until ctx is done. +// Flushes are sequential — only one flush runs at a time. +func (b *Batcher) Run(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case <-b.flushNotify: + b.flush() + } + } +} + +// Stop performs a final flush of all remaining messages, then marks the batcher as stopped. +func (b *Batcher) Stop() { + b.mu.Lock() + b.stopped = true + // Stop timers. + if b.standardTimer != nil { + b.standardTimer.Stop() + } + if b.lowPriorityTimer != nil { + b.lowPriorityTimer.Stop() + } + b.mu.Unlock() + + // Final flush. + b.flush() +} + +// flush collects all pending messages and sends them as a single batch. +// It is always called from a single goroutine (the Run loop or Stop), so flushes never overlap. +func (b *Batcher) flush() { + b.mu.Lock() + if len(b.messages) == 0 { + b.mu.Unlock() + return + } + + // Collect all messages. + msgs := make([]json.RawMessage, len(b.messages)) + for i, m := range b.messages { + msgs[i] = m.data + } + b.messages = b.messages[:0] + b.totalSize = 0 + + // Reset timer state since we're flushing everything. + b.standardActive = false + b.lowPriorityActive = false + b.mu.Unlock() + + batchID := generateBatchID() + if err := b.sendFn(batchID, msgs); err != nil { + log.Printf("batcher: flush failed (batch %s, %d messages): %v", batchID, len(msgs), err) + } +} + +// generateBatchID returns a new UUID v4 string for batch idempotency. +func generateBatchID() string { + var uuid [16]byte + if _, err := rand.Read(uuid[:]); err != nil { + // Fallback to timestamp-based ID if crypto/rand fails (extremely unlikely). + return fmt.Sprintf("batch-%d", time.Now().UnixNano()) + } + // Set version (4) and variant (RFC 4122). + uuid[6] = (uuid[6] & 0x0f) | 0x40 + uuid[8] = (uuid[8] & 0x3f) | 0x80 + return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x", + uuid[0:4], uuid[4:6], uuid[6:8], uuid[8:10], uuid[10:16]) +} diff --git a/internal/uplink/batcher_test.go b/internal/uplink/batcher_test.go new file mode 100644 index 0000000..5d53f06 --- /dev/null +++ b/internal/uplink/batcher_test.go @@ -0,0 +1,601 @@ +package uplink + +import ( + "context" + "encoding/json" + "strings" + "sync" + "sync/atomic" + "testing" + "time" +) + +// flushRecord captures a single flush call. +type flushRecord struct { + batchID string + messages []json.RawMessage + time time.Time +} + +// newRecordingSendFn returns a SendFunc that records all flush calls +// and a function to retrieve the records. +func newRecordingSendFn() (SendFunc, func() []flushRecord) { + var mu sync.Mutex + var records []flushRecord + + fn := func(batchID string, messages []json.RawMessage) error { + mu.Lock() + defer mu.Unlock() + // Copy messages to avoid data races. + copied := make([]json.RawMessage, len(messages)) + copy(copied, messages) + records = append(records, flushRecord{ + batchID: batchID, + messages: copied, + time: time.Now(), + }) + return nil + } + + get := func() []flushRecord { + mu.Lock() + defer mu.Unlock() + result := make([]flushRecord, len(records)) + copy(result, records) + return result + } + + return fn, get +} + +func TestBatcher_ImmediateFlush(t *testing.T) { + sendFn, getRecords := newRecordingSendFn() + b := NewBatcher(sendFn) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go b.Run(ctx) + + // Enqueue an immediate-tier message. + b.Enqueue(json.RawMessage(`{"type":"run_complete"}`), "run_complete") + + // Wait for flush. + time.Sleep(50 * time.Millisecond) + + records := getRecords() + if len(records) != 1 { + t.Fatalf("flush count = %d, want 1", len(records)) + } + if len(records[0].messages) != 1 { + t.Errorf("message count = %d, want 1", len(records[0].messages)) + } + if records[0].batchID == "" { + t.Error("batchID should not be empty") + } +} + +func TestBatcher_ImmediateFlushDrainsAllTiers(t *testing.T) { + sendFn, getRecords := newRecordingSendFn() + b := NewBatcher(sendFn) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go b.Run(ctx) + + // Enqueue messages from different tiers. + b.Enqueue(json.RawMessage(`{"type":"project_state"}`), "project_state") // low + b.Enqueue(json.RawMessage(`{"type":"claude_output"}`), "claude_output") // standard + b.Enqueue(json.RawMessage(`{"type":"error"}`), "error") // immediate + + // Wait for flush. + time.Sleep(50 * time.Millisecond) + + records := getRecords() + if len(records) != 1 { + t.Fatalf("flush count = %d, want 1 (all tiers drain together)", len(records)) + } + if len(records[0].messages) != 3 { + t.Errorf("message count = %d, want 3", len(records[0].messages)) + } +} + +func TestBatcher_AllImmediateTypes(t *testing.T) { + immediateTypes := []string{ + "run_complete", "run_paused", "error", + "clone_complete", "session_expired", "quota_exhausted", + } + + for _, msgType := range immediateTypes { + t.Run(msgType, func(t *testing.T) { + sendFn, getRecords := newRecordingSendFn() + b := NewBatcher(sendFn) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go b.Run(ctx) + + b.Enqueue(json.RawMessage(`{}`), msgType) + time.Sleep(50 * time.Millisecond) + + records := getRecords() + if len(records) != 1 { + t.Errorf("expected immediate flush for %s, got %d flushes", msgType, len(records)) + } + }) + } +} + +func TestBatcher_StandardTimerFlush(t *testing.T) { + sendFn, getRecords := newRecordingSendFn() + b := NewBatcher(sendFn) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go b.Run(ctx) + + start := time.Now() + b.Enqueue(json.RawMessage(`{"type":"claude_output"}`), "claude_output") + + // Should not flush immediately. + time.Sleep(50 * time.Millisecond) + if len(getRecords()) != 0 { + t.Fatal("standard tier should not flush immediately") + } + + // Wait for the 200ms timer. + time.Sleep(300 * time.Millisecond) + records := getRecords() + if len(records) != 1 { + t.Fatalf("flush count = %d, want 1", len(records)) + } + + elapsed := records[0].time.Sub(start) + if elapsed < 150*time.Millisecond { + t.Errorf("flushed too early: %v (expected ~200ms)", elapsed) + } +} + +func TestBatcher_LowPriorityTimerFlush(t *testing.T) { + sendFn, getRecords := newRecordingSendFn() + b := NewBatcher(sendFn) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go b.Run(ctx) + + start := time.Now() + b.Enqueue(json.RawMessage(`{"type":"project_state"}`), "project_state") + + // Should not flush at 200ms. + time.Sleep(300 * time.Millisecond) + if len(getRecords()) != 0 { + t.Fatal("low priority tier should not flush at 200ms") + } + + // Wait for the 1s timer. + time.Sleep(900 * time.Millisecond) + records := getRecords() + if len(records) != 1 { + t.Fatalf("flush count = %d, want 1", len(records)) + } + + elapsed := records[0].time.Sub(start) + if elapsed < 800*time.Millisecond { + t.Errorf("flushed too early: %v (expected ~1s)", elapsed) + } +} + +func TestBatcher_SizeBasedFlush(t *testing.T) { + sendFn, getRecords := newRecordingSendFn() + b := NewBatcher(sendFn) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go b.Run(ctx) + + // Enqueue 20 low-priority messages — should trigger size-based flush + // even though the 1s timer hasn't expired. + for i := 0; i < maxBatchMessages; i++ { + b.Enqueue(json.RawMessage(`{"type":"log_lines"}`), "log_lines") + } + + time.Sleep(50 * time.Millisecond) + + records := getRecords() + if len(records) != 1 { + t.Fatalf("flush count = %d, want 1", len(records)) + } + if len(records[0].messages) != maxBatchMessages { + t.Errorf("message count = %d, want %d", len(records[0].messages), maxBatchMessages) + } +} + +func TestBatcher_StopFlushesRemaining(t *testing.T) { + sendFn, getRecords := newRecordingSendFn() + b := NewBatcher(sendFn) + + // Don't start Run — just enqueue and stop. + b.Enqueue(json.RawMessage(`{"type":"claude_output"}`), "claude_output") + b.Enqueue(json.RawMessage(`{"type":"log_lines"}`), "log_lines") + + b.Stop() + + records := getRecords() + if len(records) != 1 { + t.Fatalf("flush count = %d, want 1 (final flush on stop)", len(records)) + } + if len(records[0].messages) != 2 { + t.Errorf("message count = %d, want 2", len(records[0].messages)) + } +} + +func TestBatcher_StopPreventsNewEnqueues(t *testing.T) { + sendFn, getRecords := newRecordingSendFn() + b := NewBatcher(sendFn) + + b.Enqueue(json.RawMessage(`{"type":"claude_output"}`), "claude_output") + b.Stop() + + // Enqueue after stop should be silently dropped. + b.Enqueue(json.RawMessage(`{"type":"error"}`), "error") + + records := getRecords() + if len(records) != 1 { + t.Fatalf("flush count = %d, want 1", len(records)) + } + if len(records[0].messages) != 1 { + t.Errorf("message count = %d, want 1 (post-stop enqueue should be dropped)", len(records[0].messages)) + } +} + +func TestBatcher_BufferOverflowDropsLowPriority(t *testing.T) { + sendFn, _ := newRecordingSendFn() + b := NewBatcher(sendFn) + + // Fill with low-priority messages up to the limit. + for i := 0; i < maxBufferMessages-1; i++ { + b.Enqueue(json.RawMessage(`{"type":"log_lines"}`), "log_lines") + } + + // Buffer is almost full. Add one standard message. + b.Enqueue(json.RawMessage(`{"type":"claude_output"}`), "claude_output") + + // Buffer is now at limit. Adding an immediate message should drop a low-priority one. + b.Enqueue(json.RawMessage(`{"type":"error"}`), "error") + + b.mu.Lock() + count := len(b.messages) + // Count messages by tier. + var immediate, standard, low int + for _, m := range b.messages { + switch m.tier { + case tierIDImmediate: + immediate++ + case tierIDStandard: + standard++ + case tierIDLowPriority: + low++ + } + } + b.mu.Unlock() + + if count != maxBufferMessages { + t.Errorf("buffer count = %d, want %d", count, maxBufferMessages) + } + if immediate != 1 { + t.Errorf("immediate count = %d, want 1", immediate) + } + if standard != 1 { + t.Errorf("standard count = %d, want 1", standard) + } + if low != maxBufferMessages-2 { + t.Errorf("low priority count = %d, want %d", low, maxBufferMessages-2) + } +} + +func TestBatcher_BufferOverflowBySize(t *testing.T) { + sendFn, _ := newRecordingSendFn() + b := NewBatcher(sendFn) + + // Create a large message (~1MB). + bigPayload := strings.Repeat("x", 1024*1024) + bigMsg := json.RawMessage(`{"type":"log_lines","data":"` + bigPayload + `"}`) + + // Fill buffer with 4 big messages (~4MB). + for i := 0; i < 4; i++ { + b.Enqueue(bigMsg, "log_lines") + } + + b.mu.Lock() + countBefore := len(b.messages) + sizeBefore := b.totalSize + b.mu.Unlock() + + if countBefore != 4 { + t.Fatalf("buffer count = %d, want 4", countBefore) + } + + // Adding another big message should trigger overflow — drops a low-priority message. + bigStandard := json.RawMessage(`{"type":"claude_output","data":"` + bigPayload + `"}`) + b.Enqueue(bigStandard, "claude_output") + + b.mu.Lock() + countAfter := len(b.messages) + sizeAfter := b.totalSize + b.mu.Unlock() + + // Should have dropped one log_lines message to make room. + if countAfter != 4 { + t.Errorf("buffer count after overflow = %d, want 4", countAfter) + } + if sizeAfter >= sizeBefore+len(bigStandard) { + t.Errorf("buffer size should not exceed limit: before=%d, after=%d", sizeBefore, sizeAfter) + } +} + +func TestBatcher_FlushesNeverOverlap(t *testing.T) { + var concurrent atomic.Int32 + var maxConcurrent atomic.Int32 + + sendFn := func(batchID string, messages []json.RawMessage) error { + n := concurrent.Add(1) + // Track max concurrency. + for { + old := maxConcurrent.Load() + if n <= old || maxConcurrent.CompareAndSwap(old, n) { + break + } + } + time.Sleep(50 * time.Millisecond) // Simulate slow send. + concurrent.Add(-1) + return nil + } + + b := NewBatcher(sendFn) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go b.Run(ctx) + + // Rapidly enqueue immediate messages to trigger many flushes. + for i := 0; i < 10; i++ { + b.Enqueue(json.RawMessage(`{"type":"error"}`), "error") + time.Sleep(10 * time.Millisecond) + } + + // Wait for all flushes to complete. + time.Sleep(600 * time.Millisecond) + + if maxConcurrent.Load() > 1 { + t.Errorf("max concurrent flushes = %d, want 1 (sequential flushes only)", maxConcurrent.Load()) + } +} + +func TestBatcher_UniqueBatchIDs(t *testing.T) { + var mu sync.Mutex + var batchIDs []string + + sendFn := func(batchID string, messages []json.RawMessage) error { + mu.Lock() + batchIDs = append(batchIDs, batchID) + mu.Unlock() + return nil + } + + b := NewBatcher(sendFn) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go b.Run(ctx) + + // Trigger multiple flushes. + for i := 0; i < 5; i++ { + b.Enqueue(json.RawMessage(`{"type":"error"}`), "error") + time.Sleep(50 * time.Millisecond) + } + + time.Sleep(100 * time.Millisecond) + + mu.Lock() + defer mu.Unlock() + + seen := make(map[string]bool) + for _, id := range batchIDs { + if seen[id] { + t.Errorf("duplicate batch ID: %s", id) + } + seen[id] = true + } +} + +func TestBatcher_EmptyFlushIsNoop(t *testing.T) { + var flushCount atomic.Int32 + + sendFn := func(batchID string, messages []json.RawMessage) error { + flushCount.Add(1) + return nil + } + + b := NewBatcher(sendFn) + + // Flush with nothing in the buffer should not call sendFn. + b.flush() + + if flushCount.Load() != 0 { + t.Errorf("flush count = %d, want 0 (empty flush should be noop)", flushCount.Load()) + } +} + +func TestBatcher_ContextCancellationStopsRun(t *testing.T) { + sendFn, _ := newRecordingSendFn() + b := NewBatcher(sendFn) + + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan struct{}) + go func() { + b.Run(ctx) + close(done) + }() + + cancel() + + select { + case <-done: + // Run exited. + case <-time.After(2 * time.Second): + t.Fatal("Run() did not exit after context cancellation") + } +} + +func TestBatcher_MessageOrderPreserved(t *testing.T) { + sendFn, getRecords := newRecordingSendFn() + b := NewBatcher(sendFn) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go b.Run(ctx) + + // Enqueue messages in a specific order, then trigger flush with an immediate message. + b.Enqueue(json.RawMessage(`{"id":"1","type":"project_state"}`), "project_state") + b.Enqueue(json.RawMessage(`{"id":"2","type":"claude_output"}`), "claude_output") + b.Enqueue(json.RawMessage(`{"id":"3","type":"error"}`), "error") // triggers flush + + time.Sleep(50 * time.Millisecond) + + records := getRecords() + if len(records) != 1 { + t.Fatalf("flush count = %d, want 1", len(records)) + } + + msgs := records[0].messages + if len(msgs) != 3 { + t.Fatalf("message count = %d, want 3", len(msgs)) + } + + // Verify order is preserved. + expected := []string{`{"id":"1","type":"project_state"}`, `{"id":"2","type":"claude_output"}`, `{"id":"3","type":"error"}`} + for i, msg := range msgs { + if string(msg) != expected[i] { + t.Errorf("message[%d] = %s, want %s", i, msg, expected[i]) + } + } +} + +func TestBatcher_SendErrorDoesNotLoseMessages(t *testing.T) { + // When send fails, messages are already removed from the buffer. + // This is by design — the caller (SendMessagesWithRetry) handles retries. + var flushCount atomic.Int32 + + sendFn := func(batchID string, messages []json.RawMessage) error { + flushCount.Add(1) + return nil + } + + b := NewBatcher(sendFn) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go b.Run(ctx) + + b.Enqueue(json.RawMessage(`{"type":"error"}`), "error") + time.Sleep(50 * time.Millisecond) + + if flushCount.Load() != 1 { + t.Errorf("flush count = %d, want 1", flushCount.Load()) + } + + // Buffer should be empty after flush. + b.mu.Lock() + remaining := len(b.messages) + b.mu.Unlock() + + if remaining != 0 { + t.Errorf("remaining messages = %d, want 0", remaining) + } +} + +func TestTierFor(t *testing.T) { + tests := []struct { + msgType string + want tier + }{ + // Immediate tier. + {"run_complete", tierIDImmediate}, + {"run_paused", tierIDImmediate}, + {"error", tierIDImmediate}, + {"clone_complete", tierIDImmediate}, + {"session_expired", tierIDImmediate}, + {"quota_exhausted", tierIDImmediate}, + {"prd_response_complete", tierIDImmediate}, + // Standard tier. + {"claude_output", tierIDStandard}, + {"prd_output", tierIDStandard}, + {"run_progress", tierIDStandard}, + {"clone_progress", tierIDStandard}, + // Low priority tier. + {"state_snapshot", tierIDLowPriority}, + {"project_state", tierIDLowPriority}, + {"project_list", tierIDLowPriority}, + {"settings", tierIDLowPriority}, + {"log_lines", tierIDLowPriority}, + // Unknown defaults to standard. + {"unknown_type", tierIDStandard}, + } + + for _, tt := range tests { + t.Run(tt.msgType, func(t *testing.T) { + got := tierFor(tt.msgType) + if got != tt.want { + t.Errorf("tierFor(%q) = %d, want %d", tt.msgType, got, tt.want) + } + }) + } +} + +func TestGenerateBatchID(t *testing.T) { + id1 := generateBatchID() + id2 := generateBatchID() + + if id1 == "" { + t.Error("batch ID should not be empty") + } + if id1 == id2 { + t.Errorf("batch IDs should be unique: %q == %q", id1, id2) + } + + // Verify UUID v4 format (8-4-4-4-12 hex chars). + if len(id1) != 36 { + t.Errorf("batch ID length = %d, want 36", len(id1)) + } +} + +func TestBatcher_ConcurrentEnqueue(t *testing.T) { + sendFn, getRecords := newRecordingSendFn() + b := NewBatcher(sendFn) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go b.Run(ctx) + + // Enqueue from multiple goroutines concurrently. + var wg sync.WaitGroup + for i := 0; i < 50; i++ { + wg.Add(1) + go func() { + defer wg.Done() + b.Enqueue(json.RawMessage(`{"type":"claude_output"}`), "claude_output") + }() + } + wg.Wait() + + // Wait for all timer-based flushes. + time.Sleep(500 * time.Millisecond) + + records := getRecords() + total := 0 + for _, r := range records { + total += len(r.messages) + } + + if total != 50 { + t.Errorf("total flushed messages = %d, want 50", total) + } +} diff --git a/internal/uplink/client.go b/internal/uplink/client.go new file mode 100644 index 0000000..3433f91 --- /dev/null +++ b/internal/uplink/client.go @@ -0,0 +1,356 @@ +package uplink + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "math" + "math/rand/v2" + "net/http" + "net/url" + "runtime" + "sync" + "time" + + "github.com/minicodemonkey/chief/internal/ws" +) + +const ( + // maxBackoff is the maximum reconnection delay. + maxBackoff = 60 * time.Second + + // initialBackoff is the starting reconnection delay. + initialBackoff = 1 * time.Second + + // httpTimeout is the default HTTP request timeout. + httpTimeout = 10 * time.Second +) + +// WelcomeResponse is the response from POST /api/device/connect. +type WelcomeResponse struct { + Type string `json:"type"` + ProtocolVersion int `json:"protocol_version"` + DeviceID int `json:"device_id"` + SessionID string `json:"session_id"` + Reverb ReverbConfig `json:"reverb"` +} + +// ReverbConfig contains Pusher/Reverb connection details from the connect response. +type ReverbConfig struct { + Key string `json:"key"` + Host string `json:"host"` + Port int `json:"port"` + Scheme string `json:"scheme"` +} + +// connectRequest is the JSON body sent to POST /api/device/connect. +type connectRequest struct { + ChiefVersion string `json:"chief_version"` + DeviceName string `json:"device_name"` + OS string `json:"os"` + Arch string `json:"arch"` + ProtocolVersion int `json:"protocol_version"` +} + +// errorResponse is a JSON error returned by the server. +type errorResponse struct { + Error string `json:"error"` + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` +} + +// ErrAuthFailed is returned when the server rejects authentication (401). +var ErrAuthFailed = fmt.Errorf("device deauthorized — run 'chief login' to re-authenticate") + +// ErrDeviceRevoked is returned when the device is revoked (403). +var ErrDeviceRevoked = fmt.Errorf("device revoked — run 'chief login' to re-authenticate") + +// Client is an HTTP client for the uplink device API. +type Client struct { + baseURL string + accessToken string + mu sync.RWMutex + httpClient *http.Client + + // Device metadata sent on connect. + chiefVersion string + deviceName string +} + +// Option configures a Client. +type Option func(*Client) + +// WithChiefVersion sets the chief CLI version string. +func WithChiefVersion(v string) Option { + return func(c *Client) { + c.chiefVersion = v + } +} + +// WithDeviceName sets the device name. +func WithDeviceName(name string) Option { + return func(c *Client) { + c.deviceName = name + } +} + +// WithHTTPClient sets a custom http.Client (useful for testing). +func WithHTTPClient(hc *http.Client) Option { + return func(c *Client) { + c.httpClient = hc + } +} + +// New creates a new uplink HTTP client. +// The baseURL must use HTTPS unless the host is localhost or 127.0.0.1. +func New(baseURL, accessToken string, opts ...Option) (*Client, error) { + if err := validateBaseURL(baseURL); err != nil { + return nil, err + } + + c := &Client{ + baseURL: baseURL, + accessToken: accessToken, + httpClient: &http.Client{Timeout: httpTimeout}, + } + for _, o := range opts { + o(c) + } + return c, nil +} + +// validateBaseURL ensures the URL uses HTTPS unless the host is localhost/127.0.0.1. +func validateBaseURL(rawURL string) error { + u, err := url.Parse(rawURL) + if err != nil { + return fmt.Errorf("invalid base URL: %w", err) + } + + host := u.Hostname() + if u.Scheme == "http" && host != "localhost" && host != "127.0.0.1" { + return fmt.Errorf("base URL must use HTTPS (got %s); HTTP is only allowed for localhost", rawURL) + } + if u.Scheme != "http" && u.Scheme != "https" { + return fmt.Errorf("base URL must use http or https scheme (got %s)", u.Scheme) + } + + return nil +} + +// SetAccessToken updates the access token in a thread-safe manner. +// This is called after a token refresh. +func (c *Client) SetAccessToken(token string) { + c.mu.Lock() + defer c.mu.Unlock() + c.accessToken = token +} + +// Connect calls POST /api/device/connect to register the device with the server. +// Returns the welcome response containing session ID and Reverb configuration. +func (c *Client) Connect(ctx context.Context) (*WelcomeResponse, error) { + version := c.chiefVersion + if version == "" { + version = "dev" + } + + body := connectRequest{ + ChiefVersion: version, + DeviceName: c.deviceName, + OS: runtime.GOOS, + Arch: runtime.GOARCH, + ProtocolVersion: ws.ProtocolVersion, + } + + var welcome WelcomeResponse + if err := c.doJSON(ctx, "POST", "/api/device/connect", body, &welcome); err != nil { + return nil, fmt.Errorf("connect: %w", err) + } + + return &welcome, nil +} + +// IngestResponse is the response from POST /api/device/messages. +type IngestResponse struct { + Accepted int `json:"accepted"` + BatchID string `json:"batch_id"` + SessionID string `json:"session_id"` +} + +// ingestRequest is the JSON body sent to POST /api/device/messages. +type ingestRequest struct { + BatchID string `json:"batch_id"` + Messages []json.RawMessage `json:"messages"` +} + +// SendMessages sends a batch of messages via POST /api/device/messages. +// It does NOT retry on failure — use SendMessagesWithRetry for retry behavior. +func (c *Client) SendMessages(ctx context.Context, batchID string, messages []json.RawMessage) (*IngestResponse, error) { + body := ingestRequest{ + BatchID: batchID, + Messages: messages, + } + + var resp IngestResponse + if err := c.doJSON(ctx, "POST", "/api/device/messages", body, &resp); err != nil { + return nil, fmt.Errorf("send messages: %w", err) + } + + return &resp, nil +} + +// SendMessagesWithRetry sends a message batch with exponential backoff retry on transient failures. +// It does not retry on 401/403 auth errors. Retries use the same batchID for server-side deduplication. +func (c *Client) SendMessagesWithRetry(ctx context.Context, batchID string, messages []json.RawMessage) (*IngestResponse, error) { + attempt := 0 + for { + resp, err := c.SendMessages(ctx, batchID, messages) + if err == nil { + return resp, nil + } + + // Don't retry auth errors. + if isAuthError(err) { + return nil, err + } + + attempt++ + delay := backoff(attempt) + log.Printf("SendMessages failed (attempt %d, batch %s): %v — retrying in %s", attempt, batchID, err, delay.Round(time.Millisecond)) + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(delay): + } + } +} + +// Heartbeat calls POST /api/device/heartbeat to tell the server the device is alive. +func (c *Client) Heartbeat(ctx context.Context) error { + var resp json.RawMessage + if err := c.doJSON(ctx, "POST", "/api/device/heartbeat", nil, &resp); err != nil { + return fmt.Errorf("heartbeat: %w", err) + } + return nil +} + +// Disconnect calls POST /api/device/disconnect to notify the server the device is going offline. +func (c *Client) Disconnect(ctx context.Context) error { + var resp json.RawMessage + if err := c.doJSON(ctx, "POST", "/api/device/disconnect", nil, &resp); err != nil { + return fmt.Errorf("disconnect: %w", err) + } + return nil +} + +// doJSON performs an HTTP request with JSON body and parses the JSON response. +// It handles auth headers and classifies HTTP error responses. +func (c *Client) doJSON(ctx context.Context, method, path string, body interface{}, result interface{}) error { + var bodyReader io.Reader + if body != nil { + data, err := json.Marshal(body) + if err != nil { + return fmt.Errorf("marshaling request: %w", err) + } + bodyReader = bytes.NewReader(data) + } + + req, err := http.NewRequestWithContext(ctx, method, c.baseURL+path, bodyReader) + if err != nil { + return fmt.Errorf("creating request: %w", err) + } + + c.mu.RLock() + token := c.accessToken + c.mu.RUnlock() + + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return fmt.Errorf("sending request: %w", err) + } + defer resp.Body.Close() + + // Read the full response body (capped to prevent OOM on rogue responses). + respBody, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) // 1MB max + if err != nil { + return fmt.Errorf("reading response: %w", err) + } + + // Classify HTTP errors. + if resp.StatusCode == http.StatusUnauthorized { + return ErrAuthFailed + } + if resp.StatusCode == http.StatusForbidden { + return ErrDeviceRevoked + } + if resp.StatusCode >= 400 { + var errResp errorResponse + if json.Unmarshal(respBody, &errResp) == nil && errResp.Message != "" { + return fmt.Errorf("server error %d: %s", resp.StatusCode, errResp.Message) + } + return fmt.Errorf("server error %d: %s", resp.StatusCode, string(respBody)) + } + + if result != nil && len(respBody) > 0 { + if err := json.Unmarshal(respBody, result); err != nil { + return fmt.Errorf("parsing response: %w", err) + } + } + + return nil +} + +// backoff returns a duration for the given attempt using exponential backoff + jitter. +func backoff(attempt int) time.Duration { + base := float64(initialBackoff) * math.Pow(2, float64(attempt-1)) + if base > float64(maxBackoff) { + base = float64(maxBackoff) + } + // Add jitter: 0.5x to 1.5x + jitter := 0.5 + rand.Float64() + return time.Duration(base * jitter) +} + +// ConnectWithRetry calls Connect with exponential backoff retry on transient failures. +// It does not retry on 401/403 auth errors. +func (c *Client) ConnectWithRetry(ctx context.Context) (*WelcomeResponse, error) { + attempt := 0 + for { + welcome, err := c.Connect(ctx) + if err == nil { + return welcome, nil + } + + // Don't retry auth errors. + if isAuthError(err) { + return nil, err + } + + attempt++ + delay := backoff(attempt) + log.Printf("Connect failed (attempt %d): %v — retrying in %s", attempt, err, delay.Round(time.Millisecond)) + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(delay): + } + } +} + +// isAuthError returns true if the error is a 401 or 403 that should not be retried. +// Uses errors.Is to handle wrapped errors. +func isAuthError(err error) bool { + if err == nil { + return false + } + return errors.Is(err, ErrAuthFailed) || errors.Is(err, ErrDeviceRevoked) +} diff --git a/internal/uplink/client_test.go b/internal/uplink/client_test.go new file mode 100644 index 0000000..5f11822 --- /dev/null +++ b/internal/uplink/client_test.go @@ -0,0 +1,743 @@ +package uplink + +import ( + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "runtime" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/minicodemonkey/chief/internal/ws" +) + +// testContext returns a context with a 15-second timeout for tests. +func testContext(t *testing.T) context.Context { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + t.Cleanup(cancel) + return ctx +} + +// newTestClient creates a Client pointing at a test server with the given token. +func newTestClient(t *testing.T, serverURL, token string, opts ...Option) *Client { + t.Helper() + c, err := New(serverURL, token, opts...) + if err != nil { + t.Fatalf("New() failed: %v", err) + } + return c +} + +func TestNew_ValidHTTPS(t *testing.T) { + c, err := New("https://example.com", "token123") + if err != nil { + t.Fatalf("New() with HTTPS failed: %v", err) + } + if c.baseURL != "https://example.com" { + t.Errorf("baseURL = %q, want %q", c.baseURL, "https://example.com") + } +} + +func TestNew_LocalhostHTTP(t *testing.T) { + _, err := New("http://localhost:8080", "token123") + if err != nil { + t.Fatalf("New() with localhost HTTP failed: %v", err) + } +} + +func TestNew_Loopback127HTTP(t *testing.T) { + _, err := New("http://127.0.0.1:8080", "token123") + if err != nil { + t.Fatalf("New() with 127.0.0.1 HTTP failed: %v", err) + } +} + +func TestNew_RejectsNonLocalhostHTTP(t *testing.T) { + _, err := New("http://example.com", "token123") + if err == nil { + t.Fatal("expected error for non-localhost HTTP, got nil") + } +} + +func TestNew_RejectsInvalidScheme(t *testing.T) { + _, err := New("ftp://example.com", "token123") + if err == nil { + t.Fatal("expected error for ftp scheme, got nil") + } +} + +func TestNew_WithOptions(t *testing.T) { + c, err := New("https://example.com", "token123", + WithChiefVersion("1.2.3"), + WithDeviceName("my-device"), + ) + if err != nil { + t.Fatalf("New() failed: %v", err) + } + if c.chiefVersion != "1.2.3" { + t.Errorf("chiefVersion = %q, want %q", c.chiefVersion, "1.2.3") + } + if c.deviceName != "my-device" { + t.Errorf("deviceName = %q, want %q", c.deviceName, "my-device") + } +} + +func TestConnect_Success(t *testing.T) { + var receivedBody connectRequest + var receivedAuth string + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/device/connect" { + http.NotFound(w, r) + return + } + if r.Method != "POST" { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + receivedAuth = r.Header.Get("Authorization") + + body, _ := io.ReadAll(r.Body) + json.Unmarshal(body, &receivedBody) + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(WelcomeResponse{ + Type: "welcome", + ProtocolVersion: 1, + DeviceID: 42, + SessionID: "sess-abc-123", + Reverb: ReverbConfig{ + Key: "app-key", + Host: "reverb.example.com", + Port: 443, + Scheme: "https", + }, + }) + })) + defer srv.Close() + + client := newTestClient(t, srv.URL, "test-token-abc", + WithChiefVersion("2.0.0"), + WithDeviceName("test-device"), + ) + + ctx := testContext(t) + welcome, err := client.Connect(ctx) + if err != nil { + t.Fatalf("Connect() failed: %v", err) + } + + // Verify request. + if receivedAuth != "Bearer test-token-abc" { + t.Errorf("Authorization = %q, want %q", receivedAuth, "Bearer test-token-abc") + } + if receivedBody.ChiefVersion != "2.0.0" { + t.Errorf("chief_version = %q, want %q", receivedBody.ChiefVersion, "2.0.0") + } + if receivedBody.DeviceName != "test-device" { + t.Errorf("device_name = %q, want %q", receivedBody.DeviceName, "test-device") + } + if receivedBody.OS != runtime.GOOS { + t.Errorf("os = %q, want %q", receivedBody.OS, runtime.GOOS) + } + if receivedBody.Arch != runtime.GOARCH { + t.Errorf("arch = %q, want %q", receivedBody.Arch, runtime.GOARCH) + } + if receivedBody.ProtocolVersion != ws.ProtocolVersion { + t.Errorf("protocol_version = %d, want %d", receivedBody.ProtocolVersion, ws.ProtocolVersion) + } + + // Verify response. + if welcome.Type != "welcome" { + t.Errorf("Type = %q, want %q", welcome.Type, "welcome") + } + if welcome.DeviceID != 42 { + t.Errorf("DeviceID = %d, want %d", welcome.DeviceID, 42) + } + if welcome.SessionID != "sess-abc-123" { + t.Errorf("SessionID = %q, want %q", welcome.SessionID, "sess-abc-123") + } + if welcome.Reverb.Key != "app-key" { + t.Errorf("Reverb.Key = %q, want %q", welcome.Reverb.Key, "app-key") + } + if welcome.Reverb.Host != "reverb.example.com" { + t.Errorf("Reverb.Host = %q, want %q", welcome.Reverb.Host, "reverb.example.com") + } + if welcome.Reverb.Port != 443 { + t.Errorf("Reverb.Port = %d, want %d", welcome.Reverb.Port, 443) + } + if welcome.Reverb.Scheme != "https" { + t.Errorf("Reverb.Scheme = %q, want %q", welcome.Reverb.Scheme, "https") + } +} + +func TestConnect_DefaultVersion(t *testing.T) { + var receivedBody connectRequest + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + json.Unmarshal(body, &receivedBody) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(WelcomeResponse{Type: "welcome", DeviceID: 1, SessionID: "s"}) + })) + defer srv.Close() + + client := newTestClient(t, srv.URL, "token") + ctx := testContext(t) + _, err := client.Connect(ctx) + if err != nil { + t.Fatalf("Connect() failed: %v", err) + } + + if receivedBody.ChiefVersion != "dev" { + t.Errorf("chief_version = %q, want %q (default)", receivedBody.ChiefVersion, "dev") + } +} + +func TestConnect_AuthFailed401(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusUnauthorized) + json.NewEncoder(w).Encode(errorResponse{Error: "invalid_token", Message: "Invalid token"}) + })) + defer srv.Close() + + client := newTestClient(t, srv.URL, "bad-token") + ctx := testContext(t) + _, err := client.Connect(ctx) + if err == nil { + t.Fatal("expected error for 401, got nil") + } + if !errors.Is(err, ErrAuthFailed) { + t.Errorf("error = %v, want ErrAuthFailed", err) + } +} + +func TestConnect_DeviceRevoked403(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusForbidden) + json.NewEncoder(w).Encode(errorResponse{Error: "device_revoked", Message: "Device revoked"}) + })) + defer srv.Close() + + client := newTestClient(t, srv.URL, "revoked-token") + ctx := testContext(t) + _, err := client.Connect(ctx) + if err == nil { + t.Fatal("expected error for 403, got nil") + } + if !errors.Is(err, ErrDeviceRevoked) { + t.Errorf("error = %v, want ErrDeviceRevoked", err) + } +} + +func TestConnect_ServerError5xx(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusInternalServerError) + json.NewEncoder(w).Encode(errorResponse{Message: "something went wrong"}) + })) + defer srv.Close() + + client := newTestClient(t, srv.URL, "token") + ctx := testContext(t) + _, err := client.Connect(ctx) + if err == nil { + t.Fatal("expected error for 500, got nil") + } + if isAuthError(err) { + t.Error("5xx error should not be classified as auth error") + } +} + +func TestDisconnect_Success(t *testing.T) { + var receivedMethod string + var receivedPath string + var receivedAuth string + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedMethod = r.Method + receivedPath = r.URL.Path + receivedAuth = r.Header.Get("Authorization") + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]string{"status": "disconnected"}) + })) + defer srv.Close() + + client := newTestClient(t, srv.URL, "test-token") + ctx := testContext(t) + err := client.Disconnect(ctx) + if err != nil { + t.Fatalf("Disconnect() failed: %v", err) + } + + if receivedMethod != "POST" { + t.Errorf("method = %q, want POST", receivedMethod) + } + if receivedPath != "/api/device/disconnect" { + t.Errorf("path = %q, want /api/device/disconnect", receivedPath) + } + if receivedAuth != "Bearer test-token" { + t.Errorf("Authorization = %q, want %q", receivedAuth, "Bearer test-token") + } +} + +func TestDisconnect_AuthFailed(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + })) + defer srv.Close() + + client := newTestClient(t, srv.URL, "bad-token") + ctx := testContext(t) + err := client.Disconnect(ctx) + if err == nil { + t.Fatal("expected error for 401, got nil") + } +} + +func TestSetAccessToken_ThreadSafe(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Return the received token in the response body for verification. + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]string{ + "token": r.Header.Get("Authorization"), + }) + })) + defer srv.Close() + + client := newTestClient(t, srv.URL, "token-v1") + + // Spawn goroutines that concurrently update the token and make requests. + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + client.SetAccessToken("token-v2") + }() + } + wg.Wait() + + // After all updates, the token should be v2. + client.mu.RLock() + token := client.accessToken + client.mu.RUnlock() + + if token != "token-v2" { + t.Errorf("accessToken = %q, want %q", token, "token-v2") + } +} + +func TestConnect_RequestFormat(t *testing.T) { + var receivedContentType string + var receivedAccept string + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedContentType = r.Header.Get("Content-Type") + receivedAccept = r.Header.Get("Accept") + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(WelcomeResponse{Type: "welcome", DeviceID: 1, SessionID: "s"}) + })) + defer srv.Close() + + client := newTestClient(t, srv.URL, "token") + ctx := testContext(t) + client.Connect(ctx) + + if receivedContentType != "application/json" { + t.Errorf("Content-Type = %q, want application/json", receivedContentType) + } + if receivedAccept != "application/json" { + t.Errorf("Accept = %q, want application/json", receivedAccept) + } +} + +func TestConnect_ContextCancellation(t *testing.T) { + blocked := make(chan struct{}) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Block until the test is done (context cancelled will abort the request). + <-blocked + })) + defer srv.Close() + defer close(blocked) // unblock the handler so the server can shut down cleanly + + client := newTestClient(t, srv.URL, "token") + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + + _, err := client.Connect(ctx) + if err == nil { + t.Fatal("expected error from cancelled context, got nil") + } +} + +func TestConnectWithRetry_SuccessOnSecondAttempt(t *testing.T) { + var attempt atomic.Int32 + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := attempt.Add(1) + if n == 1 { + w.WriteHeader(http.StatusServiceUnavailable) + return + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(WelcomeResponse{ + Type: "welcome", + DeviceID: 42, + SessionID: "sess-123", + }) + })) + defer srv.Close() + + client := newTestClient(t, srv.URL, "token") + ctx := testContext(t) + welcome, err := client.ConnectWithRetry(ctx) + if err != nil { + t.Fatalf("ConnectWithRetry() failed: %v", err) + } + if welcome.DeviceID != 42 { + t.Errorf("DeviceID = %d, want 42", welcome.DeviceID) + } + if attempt.Load() != 2 { + t.Errorf("attempts = %d, want 2", attempt.Load()) + } +} + +func TestConnectWithRetry_NoRetryOnAuthError(t *testing.T) { + var attempt atomic.Int32 + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempt.Add(1) + w.WriteHeader(http.StatusUnauthorized) + })) + defer srv.Close() + + client := newTestClient(t, srv.URL, "bad-token") + ctx := testContext(t) + _, err := client.ConnectWithRetry(ctx) + if err == nil { + t.Fatal("expected error, got nil") + } + if attempt.Load() != 1 { + t.Errorf("attempts = %d, want 1 (no retry on auth error)", attempt.Load()) + } +} + +func TestBackoff(t *testing.T) { + tests := []struct { + attempt int + minMs int64 + maxMs int64 + }{ + {1, 500, 1500}, // 1s * (0.5 to 1.5) + {2, 1000, 3000}, // 2s * (0.5 to 1.5) + {3, 2000, 6000}, // 4s * (0.5 to 1.5) + {4, 4000, 12000}, // 8s * (0.5 to 1.5) + {10, 30000, 90000}, // capped at 60s * (0.5 to 1.5) + } + + for _, tt := range tests { + d := backoff(tt.attempt) + ms := d.Milliseconds() + if ms < tt.minMs || ms > tt.maxMs { + t.Errorf("backoff(%d) = %dms, want [%d, %d]ms", tt.attempt, ms, tt.minMs, tt.maxMs) + } + } +} + +func TestSendMessages_Success(t *testing.T) { + var receivedBody ingestRequest + var receivedAuth string + var receivedMethod string + var receivedPath string + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedMethod = r.Method + receivedPath = r.URL.Path + receivedAuth = r.Header.Get("Authorization") + + body, _ := io.ReadAll(r.Body) + json.Unmarshal(body, &receivedBody) + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(IngestResponse{ + Accepted: 2, + BatchID: "batch-abc-123", + SessionID: "sess-xyz-789", + }) + })) + defer srv.Close() + + client := newTestClient(t, srv.URL, "test-token") + ctx := testContext(t) + + messages := []json.RawMessage{ + json.RawMessage(`{"type":"project_state","id":"m1"}`), + json.RawMessage(`{"type":"claude_output","id":"m2"}`), + } + + resp, err := client.SendMessages(ctx, "batch-abc-123", messages) + if err != nil { + t.Fatalf("SendMessages() failed: %v", err) + } + + // Verify request format. + if receivedMethod != "POST" { + t.Errorf("method = %q, want POST", receivedMethod) + } + if receivedPath != "/api/device/messages" { + t.Errorf("path = %q, want /api/device/messages", receivedPath) + } + if receivedAuth != "Bearer test-token" { + t.Errorf("Authorization = %q, want %q", receivedAuth, "Bearer test-token") + } + if receivedBody.BatchID != "batch-abc-123" { + t.Errorf("batch_id = %q, want %q", receivedBody.BatchID, "batch-abc-123") + } + if len(receivedBody.Messages) != 2 { + t.Errorf("messages count = %d, want 2", len(receivedBody.Messages)) + } + + // Verify response parsing. + if resp.Accepted != 2 { + t.Errorf("Accepted = %d, want 2", resp.Accepted) + } + if resp.BatchID != "batch-abc-123" { + t.Errorf("BatchID = %q, want %q", resp.BatchID, "batch-abc-123") + } + if resp.SessionID != "sess-xyz-789" { + t.Errorf("SessionID = %q, want %q", resp.SessionID, "sess-xyz-789") + } +} + +func TestSendMessages_AuthFailed401(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + })) + defer srv.Close() + + client := newTestClient(t, srv.URL, "bad-token") + ctx := testContext(t) + + _, err := client.SendMessages(ctx, "batch-1", []json.RawMessage{json.RawMessage(`{}`)}) + if err == nil { + t.Fatal("expected error for 401, got nil") + } + if !errors.Is(err, ErrAuthFailed) { + t.Errorf("error = %v, want ErrAuthFailed", err) + } +} + +func TestSendMessages_DeviceRevoked403(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + })) + defer srv.Close() + + client := newTestClient(t, srv.URL, "revoked-token") + ctx := testContext(t) + + _, err := client.SendMessages(ctx, "batch-1", []json.RawMessage{json.RawMessage(`{}`)}) + if err == nil { + t.Fatal("expected error for 403, got nil") + } + if !errors.Is(err, ErrDeviceRevoked) { + t.Errorf("error = %v, want ErrDeviceRevoked", err) + } +} + +func TestSendMessages_ServerError5xx(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer srv.Close() + + client := newTestClient(t, srv.URL, "token") + ctx := testContext(t) + + _, err := client.SendMessages(ctx, "batch-1", []json.RawMessage{json.RawMessage(`{}`)}) + if err == nil { + t.Fatal("expected error for 500, got nil") + } + if isAuthError(err) { + t.Error("5xx error should not be classified as auth error") + } +} + +func TestSendMessagesWithRetry_SuccessAfterRetry(t *testing.T) { + var attempt atomic.Int32 + var receivedBatchIDs []string + var mu sync.Mutex + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var body ingestRequest + data, _ := io.ReadAll(r.Body) + json.Unmarshal(data, &body) + + mu.Lock() + receivedBatchIDs = append(receivedBatchIDs, body.BatchID) + mu.Unlock() + + n := attempt.Add(1) + if n <= 2 { + w.WriteHeader(http.StatusServiceUnavailable) + return + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(IngestResponse{ + Accepted: 1, + BatchID: body.BatchID, + SessionID: "sess-1", + }) + })) + defer srv.Close() + + client := newTestClient(t, srv.URL, "token") + ctx := testContext(t) + + resp, err := client.SendMessagesWithRetry(ctx, "batch-retry-123", []json.RawMessage{json.RawMessage(`{"type":"log_lines"}`)}) + if err != nil { + t.Fatalf("SendMessagesWithRetry() failed: %v", err) + } + + if resp.Accepted != 1 { + t.Errorf("Accepted = %d, want 1", resp.Accepted) + } + if attempt.Load() != 3 { + t.Errorf("attempts = %d, want 3", attempt.Load()) + } + + // Verify same batch_id was used on all retries (for server-side deduplication). + mu.Lock() + defer mu.Unlock() + for i, id := range receivedBatchIDs { + if id != "batch-retry-123" { + t.Errorf("attempt %d batch_id = %q, want %q", i+1, id, "batch-retry-123") + } + } +} + +func TestSendMessagesWithRetry_NoRetryOnAuthError(t *testing.T) { + var attempt atomic.Int32 + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempt.Add(1) + w.WriteHeader(http.StatusUnauthorized) + })) + defer srv.Close() + + client := newTestClient(t, srv.URL, "bad-token") + ctx := testContext(t) + + _, err := client.SendMessagesWithRetry(ctx, "batch-1", []json.RawMessage{json.RawMessage(`{}`)}) + if err == nil { + t.Fatal("expected error, got nil") + } + if !errors.Is(err, ErrAuthFailed) { + t.Errorf("error = %v, want ErrAuthFailed", err) + } + if attempt.Load() != 1 { + t.Errorf("attempts = %d, want 1 (no retry on auth error)", attempt.Load()) + } +} + +func TestSendMessagesWithRetry_NoRetryOnRevoked(t *testing.T) { + var attempt atomic.Int32 + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempt.Add(1) + w.WriteHeader(http.StatusForbidden) + })) + defer srv.Close() + + client := newTestClient(t, srv.URL, "revoked-token") + ctx := testContext(t) + + _, err := client.SendMessagesWithRetry(ctx, "batch-1", []json.RawMessage{json.RawMessage(`{}`)}) + if err == nil { + t.Fatal("expected error, got nil") + } + if !errors.Is(err, ErrDeviceRevoked) { + t.Errorf("error = %v, want ErrDeviceRevoked", err) + } + if attempt.Load() != 1 { + t.Errorf("attempts = %d, want 1 (no retry on revoked)", attempt.Load()) + } +} + +func TestSendMessagesWithRetry_ContextCancellation(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer srv.Close() + + client := newTestClient(t, srv.URL, "token") + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + _, err := client.SendMessagesWithRetry(ctx, "batch-1", []json.RawMessage{json.RawMessage(`{}`)}) + if err == nil { + t.Fatal("expected error from cancelled context, got nil") + } +} + +func TestSendMessages_RequestBodyFormat(t *testing.T) { + var receivedRaw json.RawMessage + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + data, _ := io.ReadAll(r.Body) + receivedRaw = data + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(IngestResponse{Accepted: 1, BatchID: "b1", SessionID: "s1"}) + })) + defer srv.Close() + + client := newTestClient(t, srv.URL, "token") + ctx := testContext(t) + + messages := []json.RawMessage{ + json.RawMessage(`{"type":"project_state","id":"msg-1","timestamp":"2026-02-16T00:00:00Z"}`), + } + _, err := client.SendMessages(ctx, "batch-format-test", messages) + if err != nil { + t.Fatalf("SendMessages() failed: %v", err) + } + + // Verify the raw JSON structure matches what the server expects. + var parsed map[string]json.RawMessage + if err := json.Unmarshal(receivedRaw, &parsed); err != nil { + t.Fatalf("failed to parse request body: %v", err) + } + if _, ok := parsed["batch_id"]; !ok { + t.Error("request body missing batch_id field") + } + if _, ok := parsed["messages"]; !ok { + t.Error("request body missing messages field") + } +} + +func TestIsAuthError(t *testing.T) { + if isAuthError(nil) { + t.Error("nil should not be auth error") + } + if !isAuthError(ErrAuthFailed) { + t.Error("ErrAuthFailed should be auth error") + } + if !isAuthError(ErrDeviceRevoked) { + t.Error("ErrDeviceRevoked should be auth error") + } + if isAuthError(context.Canceled) { + t.Error("context.Canceled should not be auth error") + } +} diff --git a/internal/uplink/pusher.go b/internal/uplink/pusher.go new file mode 100644 index 0000000..ec33d51 --- /dev/null +++ b/internal/uplink/pusher.go @@ -0,0 +1,472 @@ +package uplink + +import ( + "context" + "crypto/hmac" + "crypto/sha256" + "encoding/json" + "fmt" + "log" + "net/http" + "net/url" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +const ( + // pusherProtocolVersion is the Pusher protocol version to use. + pusherProtocolVersion = 7 + + // receiveBufSize is the buffer size for the receive channel. + receiveBufSize = 256 + + // pusherPingTimeout is how long to wait for a pong after sending a ping. + pusherPingTimeout = 30 * time.Second + + // pusherWriteTimeout is the timeout for WebSocket write operations. + pusherWriteTimeout = 10 * time.Second +) + +// pusherMessage is a Pusher protocol message (both sent and received). +type pusherMessage struct { + Event string `json:"event"` + Channel string `json:"channel,omitempty"` + Data json.RawMessage `json:"data"` +} + +// pusherConnectionData is the data field of pusher:connection_established. +type pusherConnectionData struct { + SocketID string `json:"socket_id"` + ActivityTimeout int `json:"activity_timeout"` +} + +// pusherAuthResponse is the response from the broadcast auth endpoint. +type pusherAuthResponse struct { + Auth string `json:"auth"` +} + +// AuthFunc is a function that authenticates a Pusher channel subscription. +// It takes a socketID and channelName and returns the auth signature string. +type AuthFunc func(ctx context.Context, socketID, channelName string) (string, error) + +// PusherClient connects to a Reverb/Pusher WebSocket and subscribes to a private channel. +type PusherClient struct { + appKey string + host string + port int + scheme string + channel string + authFn AuthFunc + dialer *websocket.Dialer + + mu sync.Mutex + conn *websocket.Conn + socketID string + recvCh chan json.RawMessage + done chan struct{} + stopped bool +} + +// NewPusherClient creates a PusherClient configured to connect to Reverb. +// +// Parameters: +// - cfg: Reverb connection config (key, host, port, scheme) from the connect response +// - channel: the private channel to subscribe to (e.g., "private-chief-server.42") +// - authFn: function to authenticate the channel subscription +func NewPusherClient(cfg ReverbConfig, channel string, authFn AuthFunc) *PusherClient { + return &PusherClient{ + appKey: cfg.Key, + host: cfg.Host, + port: cfg.Port, + scheme: cfg.Scheme, + channel: channel, + authFn: authFn, + dialer: websocket.DefaultDialer, + recvCh: make(chan json.RawMessage, receiveBufSize), + done: make(chan struct{}), + } +} + +// Connect dials the Pusher WebSocket, waits for connection_established, +// subscribes to the private channel, and starts the read loop. +func (p *PusherClient) Connect(ctx context.Context) error { + wsURL := p.buildURL() + + headers := http.Header{} + headers.Set("Origin", fmt.Sprintf("%s://%s", p.scheme, p.host)) + + conn, _, err := p.dialer.DialContext(ctx, wsURL, headers) + if err != nil { + return fmt.Errorf("pusher dial: %w", err) + } + + p.mu.Lock() + p.conn = conn + p.mu.Unlock() + + // Wait for pusher:connection_established. + socketID, activityTimeout, err := p.waitForConnectionEstablished(ctx, conn) + if err != nil { + conn.Close() + return err + } + + p.mu.Lock() + p.socketID = socketID + p.mu.Unlock() + + // Subscribe to the private channel. + if err := p.subscribe(ctx, conn, socketID); err != nil { + conn.Close() + return err + } + + // Start read loop. + go p.readLoop(ctx, conn, activityTimeout) + + return nil +} + +// Receive returns a channel that delivers incoming command payloads. +// The channel is closed when the client shuts down. +func (p *PusherClient) Receive() <-chan json.RawMessage { + return p.recvCh +} + +// Close gracefully shuts down the Pusher client. +func (p *PusherClient) Close() error { + p.mu.Lock() + if p.stopped { + p.mu.Unlock() + return nil + } + p.stopped = true + conn := p.conn + p.conn = nil + p.mu.Unlock() + + var err error + if conn != nil { + deadline := time.Now().Add(5 * time.Second) + closeMsg := websocket.FormatCloseMessage(websocket.CloseNormalClosure, "") + _ = conn.WriteControl(websocket.CloseMessage, closeMsg, deadline) + err = conn.Close() + } + + // Wait for readLoop to finish. + <-p.done + + return err +} + +// buildURL constructs the Pusher WebSocket URL. +func (p *PusherClient) buildURL() string { + wsScheme := "wss" + if p.scheme == "http" { + wsScheme = "ws" + } + + u := url.URL{ + Scheme: wsScheme, + Host: fmt.Sprintf("%s:%d", p.host, p.port), + Path: fmt.Sprintf("/app/%s", p.appKey), + RawQuery: fmt.Sprintf("protocol=%d", pusherProtocolVersion), + } + return u.String() +} + +// waitForConnectionEstablished reads messages until it receives +// pusher:connection_established. Returns the socket ID and activity timeout. +func (p *PusherClient) waitForConnectionEstablished(ctx context.Context, conn *websocket.Conn) (string, int, error) { + // Set a read deadline for the connection established message. + conn.SetReadDeadline(time.Now().Add(10 * time.Second)) + defer conn.SetReadDeadline(time.Time{}) // Clear deadline. + + for { + select { + case <-ctx.Done(): + return "", 0, ctx.Err() + default: + } + + _, data, err := conn.ReadMessage() + if err != nil { + return "", 0, fmt.Errorf("pusher: waiting for connection_established: %w", err) + } + + var msg pusherMessage + if err := json.Unmarshal(data, &msg); err != nil { + continue // Skip unparseable messages. + } + + if msg.Event == "pusher:connection_established" { + // The data field is a JSON-encoded string inside the outer JSON, + // so we unmarshal twice: first to get the string, then to parse it. + var dataStr string + if err := json.Unmarshal(msg.Data, &dataStr); err != nil { + return "", 0, fmt.Errorf("pusher: parsing connection data wrapper: %w", err) + } + var connData pusherConnectionData + if err := json.Unmarshal([]byte(dataStr), &connData); err != nil { + return "", 0, fmt.Errorf("pusher: parsing connection data: %w", err) + } + if connData.SocketID == "" { + return "", 0, fmt.Errorf("pusher: empty socket_id in connection_established") + } + return connData.SocketID, connData.ActivityTimeout, nil + } + + if msg.Event == "pusher:error" { + return "", 0, fmt.Errorf("pusher: server error during connect: %s", string(msg.Data)) + } + } +} + +// subscribe authenticates and subscribes to the private channel. +func (p *PusherClient) subscribe(ctx context.Context, conn *websocket.Conn, socketID string) error { + // Get auth signature from the auth endpoint. + authSig, err := p.authFn(ctx, socketID, p.channel) + if err != nil { + return fmt.Errorf("pusher: channel auth failed: %w", err) + } + + // Send subscribe message. + subData, _ := json.Marshal(map[string]string{ + "auth": authSig, + "channel": p.channel, + }) + subMsg := pusherMessage{ + Event: "pusher:subscribe", + Data: subData, + } + + conn.SetWriteDeadline(time.Now().Add(pusherWriteTimeout)) + if err := conn.WriteJSON(subMsg); err != nil { + return fmt.Errorf("pusher: sending subscribe: %w", err) + } + conn.SetWriteDeadline(time.Time{}) + + // Wait for subscription_succeeded or error. + conn.SetReadDeadline(time.Now().Add(10 * time.Second)) + defer conn.SetReadDeadline(time.Time{}) + + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + _, data, err := conn.ReadMessage() + if err != nil { + return fmt.Errorf("pusher: waiting for subscription response: %w", err) + } + + var msg pusherMessage + if err := json.Unmarshal(data, &msg); err != nil { + continue + } + + if msg.Event == "pusher_internal:subscription_succeeded" && msg.Channel == p.channel { + return nil + } + + if msg.Event == "pusher:error" { + return fmt.Errorf("pusher: subscription error: %s", string(msg.Data)) + } + } +} + +// readResult is a message or error from the reader goroutine. +type readResult struct { + data []byte + err error +} + +// readLoop reads messages from the WebSocket and dispatches command events. +// +// A separate goroutine performs the blocking ReadMessage calls and feeds +// results into a channel, allowing the main loop to select on both incoming +// messages and the ping timer. Per the Pusher protocol, the client sends a +// pusher:ping after activityTimeout seconds of inactivity; if no pusher:pong +// arrives within pusherPingTimeout, the read deadline expires and the +// connection is considered dead. +func (p *PusherClient) readLoop(ctx context.Context, conn *websocket.Conn, activityTimeout int) { + defer close(p.done) + defer close(p.recvCh) + + // Ping interval is the server's advertised activity timeout — the client + // should send a ping after this many seconds of silence. + pingInterval := time.Duration(activityTimeout) * time.Second + if pingInterval <= 0 { + pingInterval = 120 * time.Second // Default Pusher activity timeout. + } + pingTimer := time.NewTimer(pingInterval) + defer pingTimer.Stop() + + // Reader goroutine: performs blocking ReadMessage calls and feeds results + // to readCh. The read deadline allows pingInterval for normal activity + // plus pusherPingTimeout for a pong response after we send a ping. + readCh := make(chan readResult, 1) + go func() { + for { + conn.SetReadDeadline(time.Now().Add(pingInterval + pusherPingTimeout)) + _, data, err := conn.ReadMessage() + select { + case readCh <- readResult{data, err}: + case <-ctx.Done(): + return + } + if err != nil { + return + } + } + }() + + for { + select { + case <-ctx.Done(): + return + + case result := <-readCh: + if result.err != nil { + select { + case <-ctx.Done(): + return + default: + } + p.mu.Lock() + stopped := p.stopped + p.mu.Unlock() + if stopped { + return + } + log.Printf("Pusher read error: %v", result.err) + return + } + + // Reset ping timer on any received data. + if !pingTimer.Stop() { + select { + case <-pingTimer.C: + default: + } + } + pingTimer.Reset(pingInterval) + + p.handleMessage(result.data) + + case <-pingTimer.C: + // No activity for pingInterval — send a ping to keep alive. + if !p.sendPusherMessage(conn, pusherMessage{ + Event: "pusher:ping", + Data: json.RawMessage("{}"), + }) { + return + } + // The read deadline (pingInterval + pusherPingTimeout) gives the + // server pusherPingTimeout to respond with pusher:pong. If no + // response arrives, ReadMessage returns a timeout error. + } + } +} + +// handleMessage processes a single Pusher protocol message. +func (p *PusherClient) handleMessage(data []byte) { + var msg pusherMessage + if err := json.Unmarshal(data, &msg); err != nil { + log.Printf("Pusher: ignoring unparseable message: %v", err) + return + } + + switch msg.Event { + case "pusher:ping": + // Respond with pong. + p.sendPusherMessage(p.getConn(), pusherMessage{ + Event: "pusher:pong", + Data: json.RawMessage("{}"), + }) + + case "pusher:pong": + // Server responded to our ping — connection confirmed alive. + + case "pusher:error": + log.Printf("Pusher server error: %s", string(msg.Data)) + + case "chief.command": + if msg.Channel == p.channel { + // Pusher wraps event data as a JSON-encoded string, so we + // must unwrap it before forwarding to the command handler. + payload := msg.Data + var dataStr string + if err := json.Unmarshal(msg.Data, &dataStr); err == nil { + payload = json.RawMessage(dataStr) + } + select { + case p.recvCh <- payload: + default: + log.Printf("Pusher: receive buffer full, dropping command") + } + } + + default: + // Ignore other event types (subscription_succeeded during reconnect, etc.). + } +} + +// sendPusherMessage writes a Pusher protocol JSON message to the connection. +// Returns false if the write failed (connection should be considered dead). +func (p *PusherClient) sendPusherMessage(conn *websocket.Conn, msg pusherMessage) bool { + if conn == nil { + return false + } + conn.SetWriteDeadline(time.Now().Add(pusherWriteTimeout)) + err := conn.WriteJSON(msg) + conn.SetWriteDeadline(time.Time{}) + if err != nil { + log.Printf("Pusher: write error: %v", err) + return false + } + return true +} + +// getConn returns the current WebSocket connection, or nil if closed. +func (p *PusherClient) getConn() *websocket.Conn { + p.mu.Lock() + defer p.mu.Unlock() + return p.conn +} + +// BroadcastAuth authenticates a Pusher channel subscription via the uplink HTTP client. +// This creates an AuthFunc that calls POST /api/device/broadcasting/auth. +func (c *Client) BroadcastAuth(ctx context.Context, socketID, channelName string) (string, error) { + body := broadcastAuthRequest{ + SocketID: socketID, + ChannelName: channelName, + } + + var resp pusherAuthResponse + if err := c.doJSON(ctx, "POST", "/api/device/broadcasting/auth", body, &resp); err != nil { + return "", fmt.Errorf("broadcast auth: %w", err) + } + + return resp.Auth, nil +} + +// broadcastAuthRequest is the JSON body sent to POST /api/device/broadcasting/auth. +type broadcastAuthRequest struct { + SocketID string `json:"socket_id"` + ChannelName string `json:"channel_name"` +} + +// GenerateAuthSignature generates a Pusher private channel auth signature locally. +// This is used in tests to verify auth signatures without hitting the server. +func GenerateAuthSignature(appKey, appSecret, socketID, channelName string) string { + data := socketID + ":" + channelName + mac := hmac.New(sha256.New, []byte(appSecret)) + mac.Write([]byte(data)) + sig := fmt.Sprintf("%x", mac.Sum(nil)) + return appKey + ":" + sig +} diff --git a/internal/uplink/pusher_test.go b/internal/uplink/pusher_test.go new file mode 100644 index 0000000..a496e42 --- /dev/null +++ b/internal/uplink/pusher_test.go @@ -0,0 +1,861 @@ +package uplink + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/gorilla/websocket" +) + +// testPusherServer is a mock Pusher/Reverb WebSocket server for testing. +type testPusherServer struct { + srv *httptest.Server + upgrader websocket.Upgrader + + mu sync.Mutex + conn *websocket.Conn + + // Configuration. + appKey string + appSecret string + socketID string + activityTimeout int + + // Control channels. + onSubscribe chan string // receives channel name when client subscribes + onMessage chan []byte // receives raw messages from client + + // Behavior flags. + rejectAuth bool + rejectSubscribe bool + skipEstablished bool +} + +func newTestPusherServer(t *testing.T) *testPusherServer { + t.Helper() + + ps := &testPusherServer{ + appKey: "test-app-key", + appSecret: "test-app-secret", + socketID: "123456.7890", + activityTimeout: 120, + onSubscribe: make(chan string, 10), + onMessage: make(chan []byte, 10), + upgrader: websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { return true }, + }, + } + + ps.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ps.handleWS(t, w, r) + })) + + t.Cleanup(func() { + ps.mu.Lock() + if ps.conn != nil { + ps.conn.Close() + } + ps.mu.Unlock() + ps.srv.Close() + }) + + return ps +} + +func (ps *testPusherServer) handleWS(t *testing.T, w http.ResponseWriter, r *http.Request) { + t.Helper() + + // Verify the URL path matches Pusher format. + expectedPath := fmt.Sprintf("/app/%s", ps.appKey) + if !strings.HasPrefix(r.URL.Path, expectedPath) { + http.Error(w, "invalid path", http.StatusNotFound) + return + } + + conn, err := ps.upgrader.Upgrade(w, r, nil) + if err != nil { + t.Logf("upgrade error: %v", err) + return + } + + ps.mu.Lock() + ps.conn = conn + ps.mu.Unlock() + + // Send connection_established unless configured not to. + if !ps.skipEstablished { + connDataJSON, _ := json.Marshal(pusherConnectionData{ + SocketID: ps.socketID, + ActivityTimeout: ps.activityTimeout, + }) + // Real Pusher/Reverb double-encodes: the data field is a JSON string. + connDataStr, _ := json.Marshal(string(connDataJSON)) + established := pusherMessage{ + Event: "pusher:connection_established", + Data: connDataStr, + } + if err := conn.WriteJSON(established); err != nil { + t.Logf("write connection_established: %v", err) + return + } + } + + // Read loop — handle subscribe messages and pass others to onMessage. + for { + _, data, err := conn.ReadMessage() + if err != nil { + return + } + + var msg pusherMessage + if json.Unmarshal(data, &msg) != nil { + continue + } + + switch msg.Event { + case "pusher:subscribe": + var subData map[string]string + json.Unmarshal(msg.Data, &subData) + channel := subData["channel"] + + select { + case ps.onSubscribe <- channel: + default: + } + + if ps.rejectSubscribe { + errData, _ := json.Marshal(map[string]interface{}{ + "message": "subscription rejected", + "code": 4009, + }) + conn.WriteJSON(pusherMessage{ + Event: "pusher:error", + Data: errData, + }) + continue + } + + // Send subscription_succeeded. + conn.WriteJSON(pusherMessage{ + Event: "pusher_internal:subscription_succeeded", + Channel: channel, + Data: json.RawMessage("{}"), + }) + + case "pusher:pong": + select { + case ps.onMessage <- data: + default: + } + + default: + select { + case ps.onMessage <- data: + default: + } + } + } +} + +// sendCommand sends a chief.command event to the connected client. +func (ps *testPusherServer) sendCommand(channel string, command json.RawMessage) error { + ps.mu.Lock() + conn := ps.conn + ps.mu.Unlock() + + if conn == nil { + return fmt.Errorf("no client connected") + } + + msg := pusherMessage{ + Event: "chief.command", + Channel: channel, + Data: command, + } + return conn.WriteJSON(msg) +} + +// sendCommandStringEncoded sends a chief.command event where the data field +// is a JSON-encoded string, matching real Reverb/Pusher wire format. +func (ps *testPusherServer) sendCommandStringEncoded(channel string, command json.RawMessage) error { + ps.mu.Lock() + conn := ps.conn + ps.mu.Unlock() + + if conn == nil { + return fmt.Errorf("no client connected") + } + + // Double-encode: wrap the JSON object as a JSON string. + encoded, err := json.Marshal(string(command)) + if err != nil { + return fmt.Errorf("encoding command: %w", err) + } + + msg := pusherMessage{ + Event: "chief.command", + Channel: channel, + Data: encoded, + } + return conn.WriteJSON(msg) +} + +// closeConnection closes the WebSocket connection from the server side, +// simulating a Pusher disconnection. +func (ps *testPusherServer) closeConnection() error { + ps.mu.Lock() + conn := ps.conn + ps.mu.Unlock() + + if conn == nil { + return fmt.Errorf("no client connected") + } + + return conn.Close() +} + +// sendPing sends a pusher:ping to the connected client. +func (ps *testPusherServer) sendPing() error { + ps.mu.Lock() + conn := ps.conn + ps.mu.Unlock() + + if conn == nil { + return fmt.Errorf("no client connected") + } + + return conn.WriteJSON(pusherMessage{ + Event: "pusher:ping", + Data: json.RawMessage("{}"), + }) +} + +// reverbConfig returns a ReverbConfig pointing at the test server. +func (ps *testPusherServer) reverbConfig() ReverbConfig { + // Extract host and port from the test server URL. + addr := ps.srv.Listener.Addr().String() + parts := strings.Split(addr, ":") + host := parts[0] + port := 0 + fmt.Sscanf(parts[1], "%d", &port) + + return ReverbConfig{ + Key: ps.appKey, + Host: host, + Port: port, + Scheme: "http", + } +} + +// testAuthFn returns an AuthFunc that uses the test server's app key/secret. +func (ps *testPusherServer) testAuthFn() AuthFunc { + return func(ctx context.Context, socketID, channelName string) (string, error) { + return GenerateAuthSignature(ps.appKey, ps.appSecret, socketID, channelName), nil + } +} + +// failingAuthFn returns an AuthFunc that always fails. +func failingAuthFn() AuthFunc { + return func(ctx context.Context, socketID, channelName string) (string, error) { + return "", fmt.Errorf("auth endpoint unavailable") + } +} + +// --- Tests --- + +func TestPusherClient_ConnectAndReceive(t *testing.T) { + ps := newTestPusherServer(t) + channel := "private-chief-server.42" + + client := NewPusherClient(ps.reverbConfig(), channel, ps.testAuthFn()) + + ctx := testContext(t) + if err := client.Connect(ctx); err != nil { + t.Fatalf("Connect() failed: %v", err) + } + defer client.Close() + + // Wait for subscription. + select { + case ch := <-ps.onSubscribe: + if ch != channel { + t.Errorf("subscribed to %q, want %q", ch, channel) + } + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for subscription") + } + + // Send a command and verify receipt. + cmd := json.RawMessage(`{"type":"start_run","project":"test"}`) + if err := ps.sendCommand(channel, cmd); err != nil { + t.Fatalf("sendCommand failed: %v", err) + } + + select { + case received := <-client.Receive(): + var parsed map[string]interface{} + json.Unmarshal(received, &parsed) + if parsed["type"] != "start_run" { + t.Errorf("received type = %v, want start_run", parsed["type"]) + } + if parsed["project"] != "test" { + t.Errorf("received project = %v, want test", parsed["project"]) + } + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for command") + } +} + +func TestPusherClient_MultipleCommands(t *testing.T) { + ps := newTestPusherServer(t) + channel := "private-chief-server.99" + + client := NewPusherClient(ps.reverbConfig(), channel, ps.testAuthFn()) + + ctx := testContext(t) + if err := client.Connect(ctx); err != nil { + t.Fatalf("Connect() failed: %v", err) + } + defer client.Close() + + // Wait for subscription. + select { + case <-ps.onSubscribe: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for subscription") + } + + // Send multiple commands. + commands := []string{ + `{"type":"start_run","id":"1"}`, + `{"type":"pause_run","id":"2"}`, + `{"type":"stop_run","id":"3"}`, + } + + for _, cmd := range commands { + if err := ps.sendCommand(channel, json.RawMessage(cmd)); err != nil { + t.Fatalf("sendCommand failed: %v", err) + } + } + + // Receive all commands in order. + for i, expected := range commands { + select { + case received := <-client.Receive(): + var expectedMap, receivedMap map[string]interface{} + json.Unmarshal([]byte(expected), &expectedMap) + json.Unmarshal(received, &receivedMap) + if receivedMap["id"] != expectedMap["id"] { + t.Errorf("command %d: id = %v, want %v", i, receivedMap["id"], expectedMap["id"]) + } + case <-time.After(5 * time.Second): + t.Fatalf("timeout waiting for command %d", i) + } + } +} + +func TestPusherClient_IgnoresOtherChannels(t *testing.T) { + ps := newTestPusherServer(t) + myChannel := "private-chief-server.42" + + client := NewPusherClient(ps.reverbConfig(), myChannel, ps.testAuthFn()) + + ctx := testContext(t) + if err := client.Connect(ctx); err != nil { + t.Fatalf("Connect() failed: %v", err) + } + defer client.Close() + + // Wait for subscription. + select { + case <-ps.onSubscribe: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for subscription") + } + + // Send a command on a different channel — should be ignored. + if err := ps.sendCommand("private-chief-server.99", json.RawMessage(`{"type":"other"}`)); err != nil { + t.Fatalf("sendCommand failed: %v", err) + } + + // Send a command on our channel — should be received. + if err := ps.sendCommand(myChannel, json.RawMessage(`{"type":"mine"}`)); err != nil { + t.Fatalf("sendCommand failed: %v", err) + } + + select { + case received := <-client.Receive(): + var parsed map[string]interface{} + json.Unmarshal(received, &parsed) + if parsed["type"] != "mine" { + t.Errorf("received type = %v, want mine (wrong channel message leaked through)", parsed["type"]) + } + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for command") + } +} + +func TestPusherClient_PingPong(t *testing.T) { + ps := newTestPusherServer(t) + channel := "private-chief-server.42" + + client := NewPusherClient(ps.reverbConfig(), channel, ps.testAuthFn()) + + ctx := testContext(t) + if err := client.Connect(ctx); err != nil { + t.Fatalf("Connect() failed: %v", err) + } + defer client.Close() + + // Wait for subscription. + select { + case <-ps.onSubscribe: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for subscription") + } + + // Send a ping and verify pong response. + if err := ps.sendPing(); err != nil { + t.Fatalf("sendPing failed: %v", err) + } + + // The client should send back a pusher:pong. + select { + case data := <-ps.onMessage: + var msg pusherMessage + if err := json.Unmarshal(data, &msg); err != nil { + t.Fatalf("failed to parse pong message: %v", err) + } + if msg.Event != "pusher:pong" { + t.Errorf("response event = %q, want pusher:pong", msg.Event) + } + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for pong") + } +} + +func TestPusherClient_Close(t *testing.T) { + ps := newTestPusherServer(t) + channel := "private-chief-server.42" + + client := NewPusherClient(ps.reverbConfig(), channel, ps.testAuthFn()) + + ctx := testContext(t) + if err := client.Connect(ctx); err != nil { + t.Fatalf("Connect() failed: %v", err) + } + + // Wait for subscription. + select { + case <-ps.onSubscribe: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for subscription") + } + + // Close and verify the receive channel closes. + if err := client.Close(); err != nil { + t.Fatalf("Close() failed: %v", err) + } + + // Receive channel should be closed. + select { + case _, ok := <-client.Receive(): + if ok { + t.Error("expected receive channel to be closed after Close()") + } + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for receive channel to close") + } + + // Double-close should be safe. + if err := client.Close(); err != nil { + t.Fatalf("double Close() failed: %v", err) + } +} + +func TestPusherClient_ContextCancellation(t *testing.T) { + ps := newTestPusherServer(t) + channel := "private-chief-server.42" + + client := NewPusherClient(ps.reverbConfig(), channel, ps.testAuthFn()) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + if err := client.Connect(ctx); err != nil { + t.Fatalf("Connect() failed: %v", err) + } + + // Wait for subscription. + select { + case <-ps.onSubscribe: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for subscription") + } + + // Cancel context — the read loop should stop. + cancel() + + // Give the readLoop time to notice the cancellation and close. + select { + case <-client.Receive(): + // Channel closed or drained — good. + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for shutdown after context cancellation") + } +} + +func TestPusherClient_AuthFailure(t *testing.T) { + ps := newTestPusherServer(t) + channel := "private-chief-server.42" + + client := NewPusherClient(ps.reverbConfig(), channel, failingAuthFn()) + + ctx := testContext(t) + err := client.Connect(ctx) + if err == nil { + t.Fatal("expected error when auth fails, got nil") + client.Close() + } + if !strings.Contains(err.Error(), "auth endpoint unavailable") { + t.Errorf("error = %v, want containing 'auth endpoint unavailable'", err) + } +} + +func TestPusherClient_SubscriptionRejected(t *testing.T) { + ps := newTestPusherServer(t) + ps.rejectSubscribe = true + channel := "private-chief-server.42" + + client := NewPusherClient(ps.reverbConfig(), channel, ps.testAuthFn()) + + ctx := testContext(t) + err := client.Connect(ctx) + if err == nil { + t.Fatal("expected error when subscription is rejected, got nil") + client.Close() + } + if !strings.Contains(err.Error(), "subscription error") { + t.Errorf("error = %v, want containing 'subscription error'", err) + } +} + +func TestPusherClient_BuildURL(t *testing.T) { + tests := []struct { + name string + cfg ReverbConfig + expect string + }{ + { + name: "HTTPS scheme", + cfg: ReverbConfig{ + Key: "my-key", + Host: "reverb.example.com", + Port: 443, + Scheme: "https", + }, + expect: "wss://reverb.example.com:443/app/my-key?protocol=7", + }, + { + name: "HTTP scheme", + cfg: ReverbConfig{ + Key: "local-key", + Host: "localhost", + Port: 8080, + Scheme: "http", + }, + expect: "ws://localhost:8080/app/local-key?protocol=7", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := NewPusherClient(tt.cfg, "private-test", nil) + got := p.buildURL() + if got != tt.expect { + t.Errorf("buildURL() = %q, want %q", got, tt.expect) + } + }) + } +} + +func TestPusherClient_ReceiveChannelBuffered(t *testing.T) { + ps := newTestPusherServer(t) + channel := "private-chief-server.42" + + client := NewPusherClient(ps.reverbConfig(), channel, ps.testAuthFn()) + + ctx := testContext(t) + if err := client.Connect(ctx); err != nil { + t.Fatalf("Connect() failed: %v", err) + } + defer client.Close() + + // Wait for subscription. + select { + case <-ps.onSubscribe: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for subscription") + } + + // The receive channel should be buffered. + if cap(client.recvCh) != receiveBufSize { + t.Errorf("receive channel capacity = %d, want %d", cap(client.recvCh), receiveBufSize) + } +} + +func TestPusherClient_ServerError(t *testing.T) { + ps := newTestPusherServer(t) + channel := "private-chief-server.42" + + client := NewPusherClient(ps.reverbConfig(), channel, ps.testAuthFn()) + + ctx := testContext(t) + if err := client.Connect(ctx); err != nil { + t.Fatalf("Connect() failed: %v", err) + } + defer client.Close() + + // Wait for subscription. + select { + case <-ps.onSubscribe: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for subscription") + } + + // Send a pusher:error — client should log it but not crash. + ps.mu.Lock() + conn := ps.conn + ps.mu.Unlock() + + errData, _ := json.Marshal(map[string]interface{}{"message": "test error", "code": 4100}) + conn.WriteJSON(pusherMessage{ + Event: "pusher:error", + Data: errData, + }) + + // Send a command after the error — client should still be functioning. + if err := ps.sendCommand(channel, json.RawMessage(`{"type":"after_error"}`)); err != nil { + t.Fatalf("sendCommand failed: %v", err) + } + + select { + case received := <-client.Receive(): + var parsed map[string]interface{} + json.Unmarshal(received, &parsed) + if parsed["type"] != "after_error" { + t.Errorf("received type = %v, want after_error", parsed["type"]) + } + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for command after error") + } +} + +func TestBroadcastAuth_Success(t *testing.T) { + var receivedSocketID, receivedChannel string + var receivedAuth string + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/device/broadcasting/auth" { + http.NotFound(w, r) + return + } + + receivedAuth = r.Header.Get("Authorization") + + var body broadcastAuthRequest + json.NewDecoder(r.Body).Decode(&body) + receivedSocketID = body.SocketID + receivedChannel = body.ChannelName + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(pusherAuthResponse{ + Auth: "app-key:test-signature", + }) + })) + defer srv.Close() + + client := newTestClient(t, srv.URL, "test-token") + ctx := testContext(t) + + auth, err := client.BroadcastAuth(ctx, "12345.67890", "private-chief-server.42") + if err != nil { + t.Fatalf("BroadcastAuth() failed: %v", err) + } + + if receivedAuth != "Bearer test-token" { + t.Errorf("Authorization = %q, want %q", receivedAuth, "Bearer test-token") + } + if receivedSocketID != "12345.67890" { + t.Errorf("socket_id = %q, want %q", receivedSocketID, "12345.67890") + } + if receivedChannel != "private-chief-server.42" { + t.Errorf("channel_name = %q, want %q", receivedChannel, "private-chief-server.42") + } + if auth != "app-key:test-signature" { + t.Errorf("auth = %q, want %q", auth, "app-key:test-signature") + } +} + +func TestBroadcastAuth_AuthFailed(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + })) + defer srv.Close() + + client := newTestClient(t, srv.URL, "bad-token") + ctx := testContext(t) + + _, err := client.BroadcastAuth(ctx, "12345.67890", "private-chief-server.42") + if err == nil { + t.Fatal("expected error for 401, got nil") + } +} + +func TestGenerateAuthSignature(t *testing.T) { + // Known test vectors for Pusher private channel auth. + sig := GenerateAuthSignature("278d425bdf160313ff76", "7ad3773142a6692b25b8", "1234.1234", "private-foobar") + + // The format should be "key:hex_signature". + if !strings.HasPrefix(sig, "278d425bdf160313ff76:") { + t.Errorf("signature should start with app key, got %q", sig) + } + + parts := strings.SplitN(sig, ":", 2) + if len(parts) != 2 { + t.Fatalf("signature should have format key:sig, got %q", sig) + } + if len(parts[1]) != 64 { // SHA256 hex = 64 chars + t.Errorf("signature hex length = %d, want 64", len(parts[1])) + } +} + +func TestPusherClient_ConnectionEstablishedTimeout(t *testing.T) { + ps := newTestPusherServer(t) + ps.skipEstablished = true // Server won't send connection_established. + channel := "private-chief-server.42" + + client := NewPusherClient(ps.reverbConfig(), channel, ps.testAuthFn()) + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + err := client.Connect(ctx) + if err == nil { + t.Fatal("expected error when connection_established is not received, got nil") + client.Close() + } +} + +func TestPusherClient_ConcurrentClose(t *testing.T) { + ps := newTestPusherServer(t) + channel := "private-chief-server.42" + + client := NewPusherClient(ps.reverbConfig(), channel, ps.testAuthFn()) + + ctx := testContext(t) + if err := client.Connect(ctx); err != nil { + t.Fatalf("Connect() failed: %v", err) + } + + // Wait for subscription. + select { + case <-ps.onSubscribe: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for subscription") + } + + // Close from multiple goroutines concurrently. + var wg sync.WaitGroup + var closeErrors atomic.Int32 + for i := 0; i < 5; i++ { + wg.Add(1) + go func() { + defer wg.Done() + if err := client.Close(); err != nil { + closeErrors.Add(1) + } + }() + } + wg.Wait() + + // At most one goroutine should get an error (the connection close), rest should be nil. + // This test mainly verifies no panics from concurrent access. +} + +func TestPusherClient_DialFailure(t *testing.T) { + // Connect to a non-existent server. + cfg := ReverbConfig{ + Key: "test-key", + Host: "127.0.0.1", + Port: 1, // Port 1 should be unreachable. + Scheme: "http", + } + channel := "private-chief-server.42" + client := NewPusherClient(cfg, channel, failingAuthFn()) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + err := client.Connect(ctx) + if err == nil { + t.Fatal("expected error connecting to unreachable server, got nil") + client.Close() + } +} + +// TestPusherClient_DoubleEncodedData verifies that commands with Pusher's +// real wire format (data as JSON string) are correctly unwrapped. +func TestPusherClient_DoubleEncodedData(t *testing.T) { + ps := newTestPusherServer(t) + channel := "private-chief-server.42" + + client := NewPusherClient(ps.reverbConfig(), channel, ps.testAuthFn()) + + ctx := testContext(t) + if err := client.Connect(ctx); err != nil { + t.Fatalf("Connect() failed: %v", err) + } + defer client.Close() + + // Wait for subscription. + select { + case <-ps.onSubscribe: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for subscription") + } + + // Send command using Reverb's real format (data as JSON string). + cmd := json.RawMessage(`{"type":"start_run","payload":{"project_slug":"my-project"}}`) + if err := ps.sendCommandStringEncoded(channel, cmd); err != nil { + t.Fatalf("sendCommandStringEncoded failed: %v", err) + } + + select { + case received := <-client.Receive(): + var parsed map[string]interface{} + if err := json.Unmarshal(received, &parsed); err != nil { + t.Fatalf("failed to parse received command: %v (raw: %s)", err, string(received)) + } + if parsed["type"] != "start_run" { + t.Errorf("received type = %v, want start_run", parsed["type"]) + } + payload, ok := parsed["payload"].(map[string]interface{}) + if !ok { + t.Fatalf("payload is not an object: %v", parsed["payload"]) + } + if payload["project_slug"] != "my-project" { + t.Errorf("received project_slug = %v, want my-project", payload["project_slug"]) + } + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for command") + } +} diff --git a/internal/uplink/uplink.go b/internal/uplink/uplink.go new file mode 100644 index 0000000..58a972b --- /dev/null +++ b/internal/uplink/uplink.go @@ -0,0 +1,553 @@ +package uplink + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log" + "sync" + "time" +) + +const ( + // heartbeatInterval is how often heartbeats are sent. + heartbeatInterval = 30 * time.Second + + // heartbeatRetryDelay is the delay before retrying a failed heartbeat. + heartbeatRetryDelay = 5 * time.Second + + // heartbeatSkipWindow is the duration after a message send within which + // we skip the explicit heartbeat (server treats message receipt as implicit heartbeat). + heartbeatSkipWindow = 25 * time.Second + + // heartbeatMaxFailures is the number of consecutive failures before triggering reconnection. + heartbeatMaxFailures = 3 +) + +// Uplink composes the HTTP client, message batcher, and Pusher client +// into a unified Send/Receive interface. +type Uplink struct { + client *Client + batcher *Batcher + pusher *PusherClient + + mu sync.RWMutex + sessionID string + deviceID int + connected bool + + // lastSendTime records when the batcher last successfully sent a batch. + // Used by the heartbeat goroutine to skip heartbeats when messages + // were recently sent (implicit heartbeat optimization). + lastSendTime time.Time + + // Heartbeat timing (overridable for tests, default to package constants). + hbInterval time.Duration + hbRetryDelay time.Duration + hbSkipWindow time.Duration + hbMaxFails int + + // recvCh is a stable receive channel that outlives individual Pusher clients. + // Commands from each Pusher client are forwarded into this channel so callers + // of Receive() don't need to re-subscribe after reconnection. + recvCh chan json.RawMessage + + // onReconnect is called after each successful reconnection. + onReconnect func() + + // onAuthFailure is called when a 401 auth error occurs during reconnection. + // The callback should perform a token refresh and call SetAccessToken() before returning. + // If the callback returns nil, reconnection retries with the new token. + // If it returns an error, reconnection aborts. + onAuthFailure func() error + + // onHeartbeatMaxFailures is called when consecutive heartbeat failures + // reach hbMaxFails. If nil, triggerReconnect() is called directly. + // Tests can set this to override the default reconnection behavior. + onHeartbeatMaxFailures func() + + // reconnecting tracks whether a reconnection is in progress to prevent concurrent reconnects. + reconnecting bool + + // parentCtx is the context passed to Connect() — used as the parent for reconnection contexts. + parentCtx context.Context + + // cancel stops the batcher run loop and heartbeat goroutine. + cancel context.CancelFunc +} + +// UplinkOption configures an Uplink. +type UplinkOption func(*Uplink) + +// WithOnReconnect sets a callback invoked after each successful reconnection. +// This matches the ws.WithOnReconnect pattern — serve.go uses it to re-send +// a full state snapshot after reconnecting. +func WithOnReconnect(fn func()) UplinkOption { + return func(u *Uplink) { + u.onReconnect = fn + } +} + +// WithOnAuthFailure sets a callback invoked when a 401 auth error occurs during +// reconnection. The callback should perform a token refresh and call +// SetAccessToken() before returning. Return nil to retry, or an error to abort. +func WithOnAuthFailure(fn func() error) UplinkOption { + return func(u *Uplink) { + u.onAuthFailure = fn + } +} + +// NewUplink creates a new Uplink that uses the given HTTP client. +// The Uplink does not connect until Connect is called. +func NewUplink(client *Client, opts ...UplinkOption) *Uplink { + u := &Uplink{ + client: client, + hbInterval: heartbeatInterval, + hbRetryDelay: heartbeatRetryDelay, + hbSkipWindow: heartbeatSkipWindow, + hbMaxFails: heartbeatMaxFailures, + recvCh: make(chan json.RawMessage, receiveBufSize), + } + for _, o := range opts { + o(u) + } + return u +} + +// Connect establishes the full uplink connection: +// 1. HTTP connect (registers device, gets session ID + Reverb config) +// 2. Pusher connect (subscribes to private command channel) +// 3. Batcher start (begins background flush loop) +// 4. Heartbeat start (sends periodic heartbeats to server) +// 5. Pusher monitor (detects disconnection and triggers reconnection) +func (u *Uplink) Connect(ctx context.Context) error { + // Step 1: HTTP connect to register the device. + welcome, err := u.client.Connect(ctx) + if err != nil { + return fmt.Errorf("uplink connect: %w", err) + } + + u.mu.Lock() + u.sessionID = welcome.SessionID + u.deviceID = welcome.DeviceID + u.connected = true + u.parentCtx = ctx + u.mu.Unlock() + + // Step 2: Start the Pusher client for receiving commands. + channel := fmt.Sprintf("private-chief-server.%d", welcome.DeviceID) + u.pusher = NewPusherClient(welcome.Reverb, channel, u.client.BroadcastAuth) + + if err := u.pusher.Connect(ctx); err != nil { + // Clean up: disconnect from HTTP since Pusher failed. + disconnectCtx, cancel := context.WithTimeout(context.Background(), httpTimeout) + defer cancel() + if dErr := u.client.Disconnect(disconnectCtx); dErr != nil { + log.Printf("uplink: failed to disconnect after Pusher error: %v", dErr) + } + return fmt.Errorf("uplink pusher connect: %w", err) + } + + // Step 3: Start the batcher for outgoing messages. + batchCtx, batchCancel := context.WithCancel(ctx) + u.cancel = batchCancel + + u.batcher = NewBatcher(func(batchID string, messages []json.RawMessage) error { + _, err := u.client.SendMessagesWithRetry(batchCtx, batchID, messages) + if err == nil { + u.mu.Lock() + u.lastSendTime = time.Now() + u.mu.Unlock() + } + return err + }) + go u.batcher.Run(batchCtx) + + // Step 4: Start the heartbeat goroutine. + go u.runHeartbeat(batchCtx) + + // Step 5: Monitor Pusher for disconnection. + go u.monitorPusher(batchCtx) + + log.Printf("Uplink connected (device=%d, session=%s)", welcome.DeviceID, welcome.SessionID) + return nil +} + +// Send enqueues a message into the batcher for batched delivery. +// The batcher handles flush timing. +// During reconnection, messages are buffered locally in the batcher. +func (u *Uplink) Send(msg json.RawMessage, msgType string) { + u.mu.RLock() + connected := u.connected + u.mu.RUnlock() + + if !connected { + log.Printf("uplink: dropping message (type=%s) — not connected", msgType) + return + } + + u.batcher.Enqueue(msg, msgType) +} + +// Receive returns a channel that delivers incoming command payloads. +// This channel is stable across reconnections — new Pusher clients +// forward commands into the same channel. +func (u *Uplink) Receive() <-chan json.RawMessage { + return u.recvCh +} + +// Close performs graceful shutdown: +// 1. Stop the batcher (flushes remaining messages) +// 2. Close the Pusher client +// 3. HTTP disconnect +func (u *Uplink) Close() error { + return u.doClose() +} + +// CloseWithTimeout performs graceful shutdown with a deadline. +// If the timeout expires before the batcher flush completes, the flush is +// abandoned and shutdown continues with Pusher close and HTTP disconnect. +// This prevents shutdown from hanging when the server is unreachable. +func (u *Uplink) CloseWithTimeout(timeout time.Duration) error { + done := make(chan error, 1) + go func() { + done <- u.doClose() + }() + + select { + case err := <-done: + return err + case <-time.After(timeout): + log.Printf("uplink: graceful close timed out after %s — forcing shutdown", timeout) + // Force-cancel the batcher/heartbeat/monitor contexts to unblock doClose. + u.mu.Lock() + u.connected = false + u.mu.Unlock() + if u.cancel != nil { + u.cancel() + } + // Wait briefly for doClose to finish after cancellation. + select { + case err := <-done: + return err + case <-time.After(2 * time.Second): + log.Printf("uplink: forced shutdown complete") + return nil + } + } +} + +// doClose is the internal close implementation shared by Close and CloseWithTimeout. +func (u *Uplink) doClose() error { + u.mu.Lock() + if !u.connected { + u.mu.Unlock() + return nil + } + u.connected = false + u.mu.Unlock() + + // Step 1: Stop the batcher — this flushes remaining messages. + if u.batcher != nil { + u.batcher.Stop() + } + + // Cancel the batcher context to stop the Run loop, heartbeat, and Pusher monitor. + if u.cancel != nil { + u.cancel() + } + + // Step 2: Close the Pusher client. + var pusherErr error + if u.pusher != nil { + pusherErr = u.pusher.Close() + } + + // Step 3: HTTP disconnect. + disconnectCtx, cancel := context.WithTimeout(context.Background(), httpTimeout) + defer cancel() + if err := u.client.Disconnect(disconnectCtx); err != nil { + log.Printf("uplink: disconnect failed: %v", err) + } + + log.Printf("Uplink disconnected") + return pusherErr +} + +// monitorPusher watches the Pusher client's receive channel. When it closes +// (Pusher readLoop exited due to an error), it triggers a full reconnection. +func (u *Uplink) monitorPusher(ctx context.Context) { + if u.pusher == nil { + return + } + pusherRecv := u.pusher.Receive() + + for { + select { + case <-ctx.Done(): + return + case msg, ok := <-pusherRecv: + if !ok { + // Pusher channel closed — readLoop exited. + // Check if we're shutting down. + select { + case <-ctx.Done(): + return + default: + } + + u.mu.RLock() + connected := u.connected + u.mu.RUnlock() + if !connected { + return + } + + u.triggerReconnect("Pusher disconnected") + return + } + // Forward the command to the stable recvCh. + select { + case u.recvCh <- msg: + default: + log.Printf("uplink: receive buffer full, dropping command") + } + } + } +} + +// triggerReconnect initiates a reconnection attempt in the background. +// It is safe to call from multiple goroutines — only one reconnection runs at a time. +func (u *Uplink) triggerReconnect(reason string) { + u.mu.Lock() + if u.reconnecting || !u.connected { + u.mu.Unlock() + return + } + u.reconnecting = true + parentCtx := u.parentCtx + u.mu.Unlock() + + log.Printf("uplink: triggering reconnection (%s)", reason) + go u.reconnect(parentCtx) +} + +// reconnect tears down the existing connection and re-establishes it with backoff. +// On success, it fires the onReconnect callback so the caller can re-send state. +func (u *Uplink) reconnect(ctx context.Context) { + defer func() { + u.mu.Lock() + u.reconnecting = false + u.mu.Unlock() + }() + + // Step 1: Tear down old batcher and Pusher. + // Stop the batcher — this flushes remaining messages. + if u.batcher != nil { + u.batcher.Stop() + } + + // Cancel old batcher context to stop the old Run loop, heartbeat, and monitor. + if u.cancel != nil { + u.cancel() + } + + // Close the old Pusher client. + if u.pusher != nil { + if err := u.pusher.Close(); err != nil { + log.Printf("uplink: error closing Pusher during reconnection: %v", err) + } + } + + // Step 2: Reconnect with exponential backoff. + attempt := 0 + for { + select { + case <-ctx.Done(): + log.Printf("uplink: reconnection cancelled") + return + default: + } + + u.mu.RLock() + connected := u.connected + u.mu.RUnlock() + if !connected { + // Close() was called — stop reconnecting. + return + } + + attempt++ + delay := backoff(attempt) + log.Printf("uplink: reconnection attempt %d — retrying in %s", attempt, delay.Round(time.Millisecond)) + + select { + case <-ctx.Done(): + log.Printf("uplink: reconnection cancelled") + return + case <-time.After(delay): + } + + // Try HTTP connect. + welcome, err := u.client.Connect(ctx) + if err != nil { + if errors.Is(err, ErrAuthFailed) { + // Auth failure — try token refresh if callback is set. + if u.onAuthFailure != nil { + log.Printf("uplink: auth failed during reconnection — requesting token refresh") + if refreshErr := u.onAuthFailure(); refreshErr != nil { + log.Printf("uplink: token refresh failed: %v — aborting reconnection", refreshErr) + return + } + // Token refreshed — retry without incrementing attempt. + attempt-- + continue + } + log.Printf("uplink: auth failed during reconnection (no refresh callback) — aborting") + return + } + log.Printf("uplink: reconnection attempt %d HTTP connect failed: %v", attempt, err) + continue + } + + // Update session/device. + u.mu.Lock() + u.sessionID = welcome.SessionID + u.deviceID = welcome.DeviceID + u.mu.Unlock() + + // Try Pusher connect. + channel := fmt.Sprintf("private-chief-server.%d", welcome.DeviceID) + pusher := NewPusherClient(welcome.Reverb, channel, u.client.BroadcastAuth) + + if err := pusher.Connect(ctx); err != nil { + log.Printf("uplink: reconnection attempt %d Pusher connect failed: %v — disconnecting HTTP", attempt, err) + disconnectCtx, cancel := context.WithTimeout(context.Background(), httpTimeout) + if dErr := u.client.Disconnect(disconnectCtx); dErr != nil { + log.Printf("uplink: failed to disconnect after Pusher reconnect error: %v", dErr) + } + cancel() + continue + } + + // Step 3: Start new batcher and heartbeat. + batchCtx, batchCancel := context.WithCancel(ctx) + + u.mu.Lock() + u.pusher = pusher + u.cancel = batchCancel + u.lastSendTime = time.Time{} // Reset — force next heartbeat to fire. + u.mu.Unlock() + + u.batcher = NewBatcher(func(batchID string, messages []json.RawMessage) error { + _, err := u.client.SendMessagesWithRetry(batchCtx, batchID, messages) + if err == nil { + u.mu.Lock() + u.lastSendTime = time.Now() + u.mu.Unlock() + } + return err + }) + go u.batcher.Run(batchCtx) + + // Restart heartbeat. + go u.runHeartbeat(batchCtx) + + // Restart Pusher monitor. + go u.monitorPusher(batchCtx) + + log.Printf("Uplink reconnected (attempt %d, device=%d, session=%s)", attempt, welcome.DeviceID, welcome.SessionID) + + // Fire the OnReconnect callback so serve.go can re-send state. + if u.onReconnect != nil { + u.onReconnect() + } + + return + } +} + +// runHeartbeat sends periodic heartbeats to the server every heartbeatInterval. +// It skips the heartbeat if a message batch was sent within heartbeatSkipWindow. +// On transient failure, it retries once after heartbeatRetryDelay. +// After heartbeatMaxFailures consecutive failures, it triggers reconnection. +func (u *Uplink) runHeartbeat(ctx context.Context) { + ticker := time.NewTicker(u.hbInterval) + defer ticker.Stop() + + consecutiveFailures := 0 + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + // Skip heartbeat if a message batch was sent recently. + u.mu.RLock() + lastSend := u.lastSendTime + u.mu.RUnlock() + + if !lastSend.IsZero() && time.Since(lastSend) < u.hbSkipWindow { + consecutiveFailures = 0 + continue + } + + // Send heartbeat. + err := u.client.Heartbeat(ctx) + if err == nil { + consecutiveFailures = 0 + continue + } + + // First failure — retry once after a short delay. + log.Printf("uplink: heartbeat failed: %v — retrying in %s", err, u.hbRetryDelay) + select { + case <-ctx.Done(): + return + case <-time.After(u.hbRetryDelay): + } + + err = u.client.Heartbeat(ctx) + if err == nil { + consecutiveFailures = 0 + continue + } + + // Retry also failed — count as a failure. + consecutiveFailures++ + log.Printf("uplink: heartbeat retry failed (%d/%d consecutive): %v", consecutiveFailures, u.hbMaxFails, err) + + if consecutiveFailures >= u.hbMaxFails { + log.Printf("uplink: %d consecutive heartbeat failures — triggering reconnection", consecutiveFailures) + if u.onHeartbeatMaxFailures != nil { + u.onHeartbeatMaxFailures() + } else { + u.triggerReconnect("heartbeat failures") + } + consecutiveFailures = 0 + } + } + } +} + +// SessionID returns the current session ID from the connect response. +func (u *Uplink) SessionID() string { + u.mu.RLock() + defer u.mu.RUnlock() + return u.sessionID +} + +// DeviceID returns the device ID from the connect response. +func (u *Uplink) DeviceID() int { + u.mu.RLock() + defer u.mu.RUnlock() + return u.deviceID +} + +// SetAccessToken updates the access token on the HTTP client. +// This is called after a token refresh — the new token will be used +// for subsequent HTTP requests and Pusher auth calls. +func (u *Uplink) SetAccessToken(token string) { + u.client.SetAccessToken(token) +} diff --git a/internal/uplink/uplink_test.go b/internal/uplink/uplink_test.go new file mode 100644 index 0000000..012207e --- /dev/null +++ b/internal/uplink/uplink_test.go @@ -0,0 +1,1658 @@ +package uplink + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "sync" + "sync/atomic" + "testing" + "time" +) + +// testUplinkServer combines a mock HTTP API server and a mock Pusher WebSocket server +// for end-to-end Uplink testing. +type testUplinkServer struct { + httpSrv *httptest.Server + pusherSrv *testPusherServer + + mu sync.Mutex + connectCalls atomic.Int32 + disconnectCalls atomic.Int32 + heartbeatCalls atomic.Int32 + messageBatches []messageBatch + + // Last received connect metadata. + lastConnectBody map[string]interface{} + + // heartbeatStatus controls the HTTP status code returned by heartbeat. + // 0 or 200 means success. + heartbeatStatus atomic.Int32 + + // connectStatus controls the HTTP status code returned by connect. + // 0 or 200 means success. + connectStatus atomic.Int32 + + // sessionCounter increments on each connect — used for unique session IDs. + sessionCounter atomic.Int32 +} + +type messageBatch struct { + BatchID string + Messages []json.RawMessage +} + +func newTestUplinkServer(t *testing.T) *testUplinkServer { + t.Helper() + + ps := newTestPusherServer(t) + + us := &testUplinkServer{ + pusherSrv: ps, + } + + // Build the Reverb config from the Pusher server. + reverbCfg := ps.reverbConfig() + + us.httpSrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + us.handleHTTP(t, w, r, reverbCfg) + })) + t.Cleanup(func() { us.httpSrv.Close() }) + + return us +} + +func (us *testUplinkServer) handleHTTP(t *testing.T, w http.ResponseWriter, r *http.Request, reverbCfg ReverbConfig) { + t.Helper() + + // Check auth header. + auth := r.Header.Get("Authorization") + if !strings.HasPrefix(auth, "Bearer ") { + w.WriteHeader(http.StatusUnauthorized) + json.NewEncoder(w).Encode(map[string]string{"error": "missing token"}) + return + } + + w.Header().Set("Content-Type", "application/json") + + switch r.URL.Path { + case "/api/device/connect": + us.connectCalls.Add(1) + + status := int(us.connectStatus.Load()) + if status >= 400 { + w.WriteHeader(status) + json.NewEncoder(w).Encode(map[string]string{"error": "connect failed"}) + return + } + + var body map[string]interface{} + json.NewDecoder(r.Body).Decode(&body) + us.mu.Lock() + us.lastConnectBody = body + us.mu.Unlock() + + n := us.sessionCounter.Add(1) + sessionID := fmt.Sprintf("test-session-%d", n) + + json.NewEncoder(w).Encode(WelcomeResponse{ + Type: "welcome", + ProtocolVersion: 1, + DeviceID: 42, + SessionID: sessionID, + Reverb: reverbCfg, + }) + + case "/api/device/disconnect": + us.disconnectCalls.Add(1) + json.NewEncoder(w).Encode(map[string]string{"status": "disconnected"}) + + case "/api/device/heartbeat": + us.heartbeatCalls.Add(1) + status := int(us.heartbeatStatus.Load()) + if status >= 400 { + w.WriteHeader(status) + json.NewEncoder(w).Encode(map[string]string{"error": "heartbeat failed"}) + return + } + json.NewEncoder(w).Encode(map[string]string{"status": "ok"}) + + case "/api/device/messages": + var req ingestRequest + json.NewDecoder(r.Body).Decode(&req) + + us.mu.Lock() + us.messageBatches = append(us.messageBatches, messageBatch{ + BatchID: req.BatchID, + Messages: req.Messages, + }) + us.mu.Unlock() + + currentSession := fmt.Sprintf("test-session-%d", us.sessionCounter.Load()) + json.NewEncoder(w).Encode(IngestResponse{ + Accepted: len(req.Messages), + BatchID: req.BatchID, + SessionID: currentSession, + }) + + case "/api/device/broadcasting/auth": + var body broadcastAuthRequest + json.NewDecoder(r.Body).Decode(&body) + + sig := GenerateAuthSignature( + us.pusherSrv.appKey, + us.pusherSrv.appSecret, + body.SocketID, + body.ChannelName, + ) + json.NewEncoder(w).Encode(pusherAuthResponse{Auth: sig}) + + default: + http.NotFound(w, r) + } +} + +func (us *testUplinkServer) getMessageBatches() []messageBatch { + us.mu.Lock() + defer us.mu.Unlock() + result := make([]messageBatch, len(us.messageBatches)) + copy(result, us.messageBatches) + return result +} + +// newTestUplink creates an Uplink connected to the test servers. +func newTestUplink(t *testing.T, us *testUplinkServer, opts ...UplinkOption) *Uplink { + t.Helper() + + client := newTestClient(t, us.httpSrv.URL, "test-token") + u := NewUplink(client, opts...) + return u +} + +// --- Tests --- + +func TestUplink_FullLifecycle(t *testing.T) { + us := newTestUplinkServer(t) + u := newTestUplink(t, us) + + ctx := testContext(t) + + // Connect. + if err := u.Connect(ctx); err != nil { + t.Fatalf("Connect() failed: %v", err) + } + + // Wait for Pusher subscription. + select { + case <-us.pusherSrv.onSubscribe: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for Pusher subscription") + } + + // Verify connect was called. + if got := us.connectCalls.Load(); got != 1 { + t.Errorf("connect calls = %d, want 1", got) + } + + // Verify session/device IDs. + if got := u.SessionID(); !strings.HasPrefix(got, "test-session-") { + t.Errorf("SessionID() = %q, want prefix %q", got, "test-session-") + } + if got := u.DeviceID(); got != 42 { + t.Errorf("DeviceID() = %d, want 42", got) + } + + // Send a message (immediate tier — should flush right away). + msg := json.RawMessage(`{"type":"run_complete","project":"test"}`) + u.Send(msg, "run_complete") + + // Wait for the batcher to flush. + deadline := time.After(5 * time.Second) + for { + batches := us.getMessageBatches() + if len(batches) > 0 { + if len(batches[0].Messages) != 1 { + t.Errorf("batch has %d messages, want 1", len(batches[0].Messages)) + } + var parsed map[string]interface{} + json.Unmarshal(batches[0].Messages[0], &parsed) + if parsed["type"] != "run_complete" { + t.Errorf("message type = %v, want run_complete", parsed["type"]) + } + break + } + select { + case <-deadline: + t.Fatal("timeout waiting for message batch to be sent") + case <-time.After(10 * time.Millisecond): + } + } + + // Receive a command from the server via Pusher. + channel := fmt.Sprintf("private-chief-server.%d", u.DeviceID()) + cmd := json.RawMessage(`{"type":"start_run","project":"myapp"}`) + if err := us.pusherSrv.sendCommand(channel, cmd); err != nil { + t.Fatalf("sendCommand failed: %v", err) + } + + select { + case received := <-u.Receive(): + var parsed map[string]interface{} + json.Unmarshal(received, &parsed) + if parsed["type"] != "start_run" { + t.Errorf("received type = %v, want start_run", parsed["type"]) + } + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for command") + } + + // Close. + if err := u.Close(); err != nil { + t.Fatalf("Close() failed: %v", err) + } + + // Verify disconnect was called. + if got := us.disconnectCalls.Load(); got != 1 { + t.Errorf("disconnect calls = %d, want 1", got) + } +} + +func TestUplink_SessionIDAndDeviceID(t *testing.T) { + us := newTestUplinkServer(t) + u := newTestUplink(t, us) + + // Before connect, values should be zero/empty. + if got := u.SessionID(); got != "" { + t.Errorf("SessionID() before connect = %q, want empty", got) + } + if got := u.DeviceID(); got != 0 { + t.Errorf("DeviceID() before connect = %d, want 0", got) + } + + ctx := testContext(t) + if err := u.Connect(ctx); err != nil { + t.Fatalf("Connect() failed: %v", err) + } + + // Wait for Pusher subscription. + select { + case <-us.pusherSrv.onSubscribe: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for subscription") + } + + if got := u.SessionID(); !strings.HasPrefix(got, "test-session-") { + t.Errorf("SessionID() = %q, want prefix %q", got, "test-session-") + } + if got := u.DeviceID(); got != 42 { + t.Errorf("DeviceID() = %d, want 42", got) + } + + u.Close() +} + +func TestUplink_SendEnqueuesToBatcher(t *testing.T) { + us := newTestUplinkServer(t) + u := newTestUplink(t, us) + + ctx := testContext(t) + if err := u.Connect(ctx); err != nil { + t.Fatalf("Connect() failed: %v", err) + } + defer u.Close() + + // Wait for Pusher subscription. + select { + case <-us.pusherSrv.onSubscribe: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for subscription") + } + + // Send multiple messages of different tiers. + u.Send(json.RawMessage(`{"type":"error","msg":"oops"}`), "error") // immediate + u.Send(json.RawMessage(`{"type":"claude_output","data":"hello"}`), "claude_output") // standard + u.Send(json.RawMessage(`{"type":"project_state","data":"state"}`), "project_state") // low priority + + // The immediate message triggers a flush that drains all tiers. + deadline := time.After(5 * time.Second) + for { + batches := us.getMessageBatches() + if len(batches) > 0 { + // All three messages should be in the first batch (immediate triggers drain of all). + if len(batches[0].Messages) != 3 { + t.Errorf("batch has %d messages, want 3", len(batches[0].Messages)) + } + break + } + select { + case <-deadline: + t.Fatal("timeout waiting for batched messages") + case <-time.After(10 * time.Millisecond): + } + } +} + +func TestUplink_SendBeforeConnect(t *testing.T) { + us := newTestUplinkServer(t) + u := newTestUplink(t, us) + + // Send before connect — should be silently dropped. + u.Send(json.RawMessage(`{"type":"error"}`), "error") + + // No crash, no messages sent. + time.Sleep(100 * time.Millisecond) + batches := us.getMessageBatches() + if len(batches) != 0 { + t.Errorf("expected 0 batches before connect, got %d", len(batches)) + } +} + +func TestUplink_ReceiveFromPusher(t *testing.T) { + us := newTestUplinkServer(t) + u := newTestUplink(t, us) + + ctx := testContext(t) + if err := u.Connect(ctx); err != nil { + t.Fatalf("Connect() failed: %v", err) + } + defer u.Close() + + // Wait for Pusher subscription. + select { + case <-us.pusherSrv.onSubscribe: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for subscription") + } + + channel := fmt.Sprintf("private-chief-server.%d", u.DeviceID()) + + // Send 3 commands. + for i := 0; i < 3; i++ { + cmd := json.RawMessage(fmt.Sprintf(`{"type":"cmd","id":"%d"}`, i)) + if err := us.pusherSrv.sendCommand(channel, cmd); err != nil { + t.Fatalf("sendCommand(%d) failed: %v", i, err) + } + } + + // Receive all 3 in order. + for i := 0; i < 3; i++ { + select { + case received := <-u.Receive(): + var parsed map[string]interface{} + json.Unmarshal(received, &parsed) + want := fmt.Sprintf("%d", i) + if parsed["id"] != want { + t.Errorf("command %d: id = %v, want %v", i, parsed["id"], want) + } + case <-time.After(5 * time.Second): + t.Fatalf("timeout waiting for command %d", i) + } + } +} + +func TestUplink_Close_FlushesAndDisconnects(t *testing.T) { + us := newTestUplinkServer(t) + u := newTestUplink(t, us) + + ctx := testContext(t) + if err := u.Connect(ctx); err != nil { + t.Fatalf("Connect() failed: %v", err) + } + + // Wait for Pusher subscription. + select { + case <-us.pusherSrv.onSubscribe: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for subscription") + } + + // Enqueue a low-priority message (wouldn't normally flush for 1s). + u.Send(json.RawMessage(`{"type":"settings","data":"config"}`), "settings") + + // Close should flush the remaining message before disconnecting. + if err := u.Close(); err != nil { + t.Fatalf("Close() failed: %v", err) + } + + // Verify the message was flushed. + batches := us.getMessageBatches() + if len(batches) == 0 { + t.Error("expected at least 1 batch after Close(), got 0") + } else { + found := false + for _, batch := range batches { + for _, msg := range batch.Messages { + var parsed map[string]interface{} + json.Unmarshal(msg, &parsed) + if parsed["type"] == "settings" { + found = true + } + } + } + if !found { + t.Error("settings message was not flushed on Close()") + } + } + + // Verify disconnect was called. + if got := us.disconnectCalls.Load(); got != 1 { + t.Errorf("disconnect calls = %d, want 1", got) + } +} + +func TestUplink_Close_DoubleCloseIsSafe(t *testing.T) { + us := newTestUplinkServer(t) + u := newTestUplink(t, us) + + ctx := testContext(t) + if err := u.Connect(ctx); err != nil { + t.Fatalf("Connect() failed: %v", err) + } + + // Wait for Pusher subscription. + select { + case <-us.pusherSrv.onSubscribe: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for subscription") + } + + // First close. + if err := u.Close(); err != nil { + t.Fatalf("first Close() failed: %v", err) + } + + // Second close should be a no-op. + if err := u.Close(); err != nil { + t.Fatalf("second Close() failed: %v", err) + } + + // Only one disconnect call. + if got := us.disconnectCalls.Load(); got != 1 { + t.Errorf("disconnect calls = %d, want 1", got) + } +} + +func TestUplink_SetAccessToken(t *testing.T) { + us := newTestUplinkServer(t) + u := newTestUplink(t, us) + + ctx := testContext(t) + if err := u.Connect(ctx); err != nil { + t.Fatalf("Connect() failed: %v", err) + } + defer u.Close() + + // Wait for Pusher subscription. + select { + case <-us.pusherSrv.onSubscribe: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for subscription") + } + + // Update the token. + u.SetAccessToken("new-token-xyz") + + // The internal client should use the new token. + // We can verify this by checking the client's token directly. + u.client.mu.RLock() + token := u.client.accessToken + u.client.mu.RUnlock() + + if token != "new-token-xyz" { + t.Errorf("accessToken = %q, want %q", token, "new-token-xyz") + } +} + +func TestUplink_OnReconnectCallback(t *testing.T) { + us := newTestUplinkServer(t) + + var callCount atomic.Int32 + u := newTestUplink(t, us, WithOnReconnect(func() { + callCount.Add(1) + })) + + // Verify the callback is stored. + if u.onReconnect == nil { + t.Fatal("onReconnect should be set") + } + + // The callback itself is used by the reconnection logic (US-020). + // For now just verify it can be invoked. + u.onReconnect() + if got := callCount.Load(); got != 1 { + t.Errorf("callback count = %d, want 1", got) + } +} + +func TestUplink_ConnectFailure_HTTPError(t *testing.T) { + // HTTP server that rejects connect. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + })) + defer srv.Close() + + client := newTestClient(t, srv.URL, "bad-token") + u := NewUplink(client) + + ctx := testContext(t) + err := u.Connect(ctx) + if err == nil { + t.Fatal("expected error when connect fails, got nil") + } + if !strings.Contains(err.Error(), "uplink connect") { + t.Errorf("error = %v, want containing 'uplink connect'", err) + } + + // Should not be connected. + if u.SessionID() != "" { + t.Error("SessionID should be empty after failed connect") + } +} + +func TestUplink_ConnectFailure_PusherError(t *testing.T) { + // HTTP server that succeeds for connect but Pusher server that rejects auth. + ps := newTestPusherServer(t) + ps.rejectSubscribe = true + reverbCfg := ps.reverbConfig() + + var disconnectCalled atomic.Int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + switch r.URL.Path { + case "/api/device/connect": + json.NewEncoder(w).Encode(WelcomeResponse{ + Type: "welcome", + ProtocolVersion: 1, + DeviceID: 42, + SessionID: "sess-123", + Reverb: reverbCfg, + }) + case "/api/device/disconnect": + disconnectCalled.Add(1) + json.NewEncoder(w).Encode(map[string]string{"status": "disconnected"}) + case "/api/device/broadcasting/auth": + sig := GenerateAuthSignature(ps.appKey, ps.appSecret, "unused", "unused") + json.NewEncoder(w).Encode(pusherAuthResponse{Auth: sig}) + default: + http.NotFound(w, r) + } + })) + defer srv.Close() + + client := newTestClient(t, srv.URL, "test-token") + u := NewUplink(client) + + ctx := testContext(t) + err := u.Connect(ctx) + if err == nil { + t.Fatal("expected error when Pusher subscription fails, got nil") + } + if !strings.Contains(err.Error(), "pusher") { + t.Errorf("error = %v, want containing 'pusher'", err) + } + + // HTTP disconnect should have been called as cleanup. + time.Sleep(100 * time.Millisecond) + if got := disconnectCalled.Load(); got != 1 { + t.Errorf("disconnect calls after Pusher failure = %d, want 1", got) + } +} + +func TestUplink_ChannelName(t *testing.T) { + us := newTestUplinkServer(t) + u := newTestUplink(t, us) + + ctx := testContext(t) + if err := u.Connect(ctx); err != nil { + t.Fatalf("Connect() failed: %v", err) + } + defer u.Close() + + // Verify the Pusher client subscribes to the correct channel. + select { + case channel := <-us.pusherSrv.onSubscribe: + expected := "private-chief-server.42" + if channel != expected { + t.Errorf("subscribed to %q, want %q", channel, expected) + } + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for subscription") + } +} + +// --- Heartbeat Tests --- + +// newTestUplinkWithHeartbeat creates a connected Uplink with fast heartbeat timing for tests. +func newTestUplinkWithHeartbeat(t *testing.T, us *testUplinkServer, interval, retryDelay, skipWindow time.Duration, maxFails int, opts ...UplinkOption) *Uplink { + t.Helper() + + client := newTestClient(t, us.httpSrv.URL, "test-token") + u := NewUplink(client, opts...) + + // Override heartbeat timing for fast tests. + u.hbInterval = interval + u.hbRetryDelay = retryDelay + u.hbSkipWindow = skipWindow + u.hbMaxFails = maxFails + + return u +} + +func TestUplink_Heartbeat_SendsPeriodically(t *testing.T) { + us := newTestUplinkServer(t) + u := newTestUplinkWithHeartbeat(t, us, 50*time.Millisecond, 10*time.Millisecond, 0, 3) + + ctx := testContext(t) + if err := u.Connect(ctx); err != nil { + t.Fatalf("Connect() failed: %v", err) + } + + // Wait for Pusher subscription. + select { + case <-us.pusherSrv.onSubscribe: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for subscription") + } + + // Wait for at least 3 heartbeats. + deadline := time.After(2 * time.Second) + for { + if us.heartbeatCalls.Load() >= 3 { + break + } + select { + case <-deadline: + t.Fatalf("expected at least 3 heartbeats, got %d", us.heartbeatCalls.Load()) + case <-time.After(10 * time.Millisecond): + } + } + + u.Close() +} + +func TestUplink_Heartbeat_StopsOnClose(t *testing.T) { + us := newTestUplinkServer(t) + u := newTestUplinkWithHeartbeat(t, us, 50*time.Millisecond, 10*time.Millisecond, 0, 3) + + ctx := testContext(t) + if err := u.Connect(ctx); err != nil { + t.Fatalf("Connect() failed: %v", err) + } + + // Wait for Pusher subscription. + select { + case <-us.pusherSrv.onSubscribe: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for subscription") + } + + // Wait for at least 1 heartbeat. + deadline := time.After(2 * time.Second) + for { + if us.heartbeatCalls.Load() >= 1 { + break + } + select { + case <-deadline: + t.Fatal("timeout waiting for first heartbeat") + case <-time.After(10 * time.Millisecond): + } + } + + // Close the uplink. + u.Close() + + // Record count and wait to confirm no more heartbeats are sent. + countAfterClose := us.heartbeatCalls.Load() + time.Sleep(200 * time.Millisecond) + + if got := us.heartbeatCalls.Load(); got != countAfterClose { + t.Errorf("heartbeat calls after close: got %d more (total %d), want 0 more", got-countAfterClose, got) + } +} + +func TestUplink_Heartbeat_SkipsWhenMessagesSentRecently(t *testing.T) { + us := newTestUplinkServer(t) + // skipWindow of 5s — any message sent within 5s skips heartbeat. + u := newTestUplinkWithHeartbeat(t, us, 50*time.Millisecond, 10*time.Millisecond, 5*time.Second, 3) + + ctx := testContext(t) + if err := u.Connect(ctx); err != nil { + t.Fatalf("Connect() failed: %v", err) + } + + // Wait for Pusher subscription. + select { + case <-us.pusherSrv.onSubscribe: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for subscription") + } + + // Send a message to trigger the lastSendTime update. + msg := json.RawMessage(`{"type":"run_complete","data":"done"}`) + u.Send(msg, "run_complete") + + // Wait for the message batch to be sent (sets lastSendTime). + deadline := time.After(2 * time.Second) + for { + batches := us.getMessageBatches() + if len(batches) > 0 { + break + } + select { + case <-deadline: + t.Fatal("timeout waiting for message batch") + case <-time.After(10 * time.Millisecond): + } + } + + // Record heartbeat count now. + countBeforeSkip := us.heartbeatCalls.Load() + + // Wait 300ms — multiple heartbeat intervals would have passed (50ms each). + time.Sleep(300 * time.Millisecond) + + // Heartbeats should have been skipped because lastSendTime is recent. + countAfterWait := us.heartbeatCalls.Load() + if countAfterWait != countBeforeSkip { + t.Errorf("expected heartbeats to be skipped, but %d extra heartbeats were sent", countAfterWait-countBeforeSkip) + } + + u.Close() +} + +func TestUplink_Heartbeat_RetryOnTransientFailure(t *testing.T) { + us := newTestUplinkServer(t) + u := newTestUplinkWithHeartbeat(t, us, 50*time.Millisecond, 10*time.Millisecond, 0, 3) + + ctx := testContext(t) + if err := u.Connect(ctx); err != nil { + t.Fatalf("Connect() failed: %v", err) + } + + // Wait for Pusher subscription. + select { + case <-us.pusherSrv.onSubscribe: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for subscription") + } + + // Make heartbeat fail with 500 (transient). + us.heartbeatStatus.Store(500) + + // Wait for heartbeat calls to accumulate (initial call + retry). + deadline := time.After(2 * time.Second) + for { + // Each heartbeat tick produces 2 calls (initial + retry). + if us.heartbeatCalls.Load() >= 4 { + break + } + select { + case <-deadline: + t.Fatalf("expected at least 4 heartbeat calls (2 ticks × 2 attempts), got %d", us.heartbeatCalls.Load()) + case <-time.After(10 * time.Millisecond): + } + } + + u.Close() +} + +func TestUplink_Heartbeat_RetrySucceedsResetsFailureCount(t *testing.T) { + us := newTestUplinkServer(t) + + var maxFailuresCalled atomic.Int32 + u := newTestUplinkWithHeartbeat(t, us, 50*time.Millisecond, 10*time.Millisecond, 0, 3) + u.onHeartbeatMaxFailures = func() { + maxFailuresCalled.Add(1) + } + + ctx := testContext(t) + if err := u.Connect(ctx); err != nil { + t.Fatalf("Connect() failed: %v", err) + } + + // Wait for Pusher subscription. + select { + case <-us.pusherSrv.onSubscribe: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for subscription") + } + + // Let heartbeats succeed — no max failures callback should fire. + time.Sleep(200 * time.Millisecond) + + if got := maxFailuresCalled.Load(); got != 0 { + t.Errorf("maxFailures callback called %d times, want 0 (all heartbeats succeeded)", got) + } + + u.Close() +} + +func TestUplink_Heartbeat_MaxFailuresTriggersCallback(t *testing.T) { + us := newTestUplinkServer(t) + + maxFailuresCh := make(chan struct{}, 1) + u := newTestUplinkWithHeartbeat(t, us, 50*time.Millisecond, 10*time.Millisecond, 0, 3) + u.onHeartbeatMaxFailures = func() { + select { + case maxFailuresCh <- struct{}{}: + default: + } + } + + ctx := testContext(t) + if err := u.Connect(ctx); err != nil { + t.Fatalf("Connect() failed: %v", err) + } + + // Wait for Pusher subscription. + select { + case <-us.pusherSrv.onSubscribe: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for subscription") + } + + // Make all heartbeats fail. + us.heartbeatStatus.Store(500) + + // Wait for the max-failures callback. With 50ms interval and 10ms retry delay, + // each tick is ~60ms. We need 3 consecutive failures → ~180ms. + select { + case <-maxFailuresCh: + // Success — the callback was triggered. + case <-time.After(3 * time.Second): + t.Fatal("timeout waiting for heartbeat max failures callback") + } + + u.Close() +} + +func TestUplink_Heartbeat_ContextCancellationStops(t *testing.T) { + us := newTestUplinkServer(t) + u := newTestUplinkWithHeartbeat(t, us, 50*time.Millisecond, 10*time.Millisecond, 0, 3) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + if err := u.Connect(ctx); err != nil { + t.Fatalf("Connect() failed: %v", err) + } + + // Wait for at least 1 heartbeat. + deadline := time.After(2 * time.Second) + for { + if us.heartbeatCalls.Load() >= 1 { + break + } + select { + case <-deadline: + t.Fatal("timeout waiting for first heartbeat") + case <-time.After(10 * time.Millisecond): + } + } + + // Cancel the context. + cancel() + + countAfterCancel := us.heartbeatCalls.Load() + time.Sleep(200 * time.Millisecond) + + if got := us.heartbeatCalls.Load(); got != countAfterCancel { + t.Errorf("heartbeat calls after cancel: got %d more, want 0 more", got-countAfterCancel) + } + + // Clean up: close still works even though context is cancelled. + u.Close() +} + +// --- Reconnection Tests --- + +// waitForSubscription drains the onSubscribe channel and returns the channel name. +func waitForSubscription(t *testing.T, ps *testPusherServer) string { + t.Helper() + select { + case ch := <-ps.onSubscribe: + return ch + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for Pusher subscription") + return "" + } +} + +func TestUplink_Reconnect_PusherDisconnection(t *testing.T) { + us := newTestUplinkServer(t) + reconnectCh := make(chan struct{}, 1) + + u := newTestUplink(t, us, WithOnReconnect(func() { + select { + case reconnectCh <- struct{}{}: + default: + } + })) + + ctx := testContext(t) + if err := u.Connect(ctx); err != nil { + t.Fatalf("Connect() failed: %v", err) + } + defer u.Close() + + // Wait for initial subscription. + waitForSubscription(t, us.pusherSrv) + + initialSession := u.SessionID() + connectsBefore := us.connectCalls.Load() + + // Close the Pusher WebSocket from the server side to simulate disconnection. + if err := us.pusherSrv.closeConnection(); err != nil { + t.Fatalf("closeConnection() failed: %v", err) + } + + // Wait for the OnReconnect callback — this means full reconnection succeeded. + select { + case <-reconnectCh: + case <-time.After(10 * time.Second): + t.Fatal("timeout waiting for reconnection after Pusher disconnect") + } + + // Wait for re-subscription. + waitForSubscription(t, us.pusherSrv) + + // Verify a new connect call was made. + if got := us.connectCalls.Load(); got <= connectsBefore { + t.Errorf("connect calls after reconnect = %d, want > %d", got, connectsBefore) + } + + // Verify session ID was refreshed. + newSession := u.SessionID() + if newSession == initialSession { + t.Errorf("session ID should change after reconnection, got same: %q", newSession) + } + + // Verify commands can still be received after reconnection. + channel := fmt.Sprintf("private-chief-server.%d", u.DeviceID()) + cmd := json.RawMessage(`{"type":"post_reconnect_cmd"}`) + // Give the Pusher client a moment to be ready. + time.Sleep(100 * time.Millisecond) + if err := us.pusherSrv.sendCommand(channel, cmd); err != nil { + t.Fatalf("sendCommand after reconnect failed: %v", err) + } + + select { + case received := <-u.Receive(): + var parsed map[string]interface{} + json.Unmarshal(received, &parsed) + if parsed["type"] != "post_reconnect_cmd" { + t.Errorf("received type = %v, want post_reconnect_cmd", parsed["type"]) + } + case <-time.After(5 * time.Second): + t.Fatal("timeout receiving command after reconnection") + } +} + +func TestUplink_Reconnect_HeartbeatFailuresTriggersReconnect(t *testing.T) { + us := newTestUplinkServer(t) + reconnectCh := make(chan struct{}, 1) + + // Use fast heartbeat timing to trigger reconnection quickly. + client := newTestClient(t, us.httpSrv.URL, "test-token") + u := NewUplink(client, WithOnReconnect(func() { + select { + case reconnectCh <- struct{}{}: + default: + } + })) + u.hbInterval = 50 * time.Millisecond + u.hbRetryDelay = 10 * time.Millisecond + u.hbSkipWindow = 0 + u.hbMaxFails = 2 + + ctx := testContext(t) + if err := u.Connect(ctx); err != nil { + t.Fatalf("Connect() failed: %v", err) + } + defer u.Close() + + // Wait for initial subscription. + waitForSubscription(t, us.pusherSrv) + + initialSession := u.SessionID() + + // Make heartbeats fail. + us.heartbeatStatus.Store(500) + + // Wait for reconnection (heartbeat failures → reconnect → OnReconnect). + select { + case <-reconnectCh: + case <-time.After(10 * time.Second): + t.Fatal("timeout waiting for reconnection triggered by heartbeat failures") + } + + // Verify session was refreshed. + newSession := u.SessionID() + if newSession == initialSession { + t.Errorf("session should change after reconnection, got same: %q", newSession) + } +} + +func TestUplink_Reconnect_HTTPConnectFailureThenRecover(t *testing.T) { + us := newTestUplinkServer(t) + reconnectCh := make(chan struct{}, 1) + + u := newTestUplink(t, us, WithOnReconnect(func() { + select { + case reconnectCh <- struct{}{}: + default: + } + })) + + ctx := testContext(t) + if err := u.Connect(ctx); err != nil { + t.Fatalf("Connect() failed: %v", err) + } + defer u.Close() + + // Wait for initial subscription. + waitForSubscription(t, us.pusherSrv) + + // Make HTTP connect fail to simulate server outage during reconnection. + us.connectStatus.Store(500) + + // Trigger Pusher disconnection to start reconnection. + if err := us.pusherSrv.closeConnection(); err != nil { + t.Fatalf("closeConnection() failed: %v", err) + } + + // Wait a bit for the first reconnection attempt to fail. + time.Sleep(2 * time.Second) + + // Now restore HTTP connect. + us.connectStatus.Store(0) + + // Wait for successful reconnection. + select { + case <-reconnectCh: + case <-time.After(10 * time.Second): + t.Fatal("timeout waiting for reconnection after server recovery") + } + + // Multiple connect attempts should have been made (at least the failed one + the successful one). + if got := us.connectCalls.Load(); got < 3 { + t.Errorf("connect calls = %d, want >= 3 (initial + failed + success)", got) + } +} + +func TestUplink_Reconnect_AuthFailureTriggersTokenRefresh(t *testing.T) { + us := newTestUplinkServer(t) + reconnectCh := make(chan struct{}, 1) + refreshCh := make(chan struct{}, 1) + + client := newTestClient(t, us.httpSrv.URL, "test-token") + u := NewUplink(client, + WithOnReconnect(func() { + select { + case reconnectCh <- struct{}{}: + default: + } + }), + WithOnAuthFailure(func() error { + // Simulate token refresh: restore connect and update token. + us.connectStatus.Store(0) + client.SetAccessToken("refreshed-token") + select { + case refreshCh <- struct{}{}: + default: + } + return nil + }), + ) + + ctx := testContext(t) + if err := u.Connect(ctx); err != nil { + t.Fatalf("Connect() failed: %v", err) + } + defer u.Close() + + // Wait for initial subscription. + waitForSubscription(t, us.pusherSrv) + + // Make connect return 401 to trigger auth failure during reconnection. + us.connectStatus.Store(401) + + // Trigger reconnection via Pusher disconnect. + if err := us.pusherSrv.closeConnection(); err != nil { + t.Fatalf("closeConnection() failed: %v", err) + } + + // Wait for the token refresh callback. + select { + case <-refreshCh: + case <-time.After(10 * time.Second): + t.Fatal("timeout waiting for token refresh callback") + } + + // Wait for successful reconnection with new token. + select { + case <-reconnectCh: + case <-time.After(10 * time.Second): + t.Fatal("timeout waiting for reconnection after token refresh") + } +} + +func TestUplink_Reconnect_OnReconnectCallbackFires(t *testing.T) { + us := newTestUplinkServer(t) + + var callbackCount atomic.Int32 + u := newTestUplink(t, us, WithOnReconnect(func() { + callbackCount.Add(1) + })) + + ctx := testContext(t) + if err := u.Connect(ctx); err != nil { + t.Fatalf("Connect() failed: %v", err) + } + defer u.Close() + + // Wait for initial subscription. + waitForSubscription(t, us.pusherSrv) + + // No reconnect callback yet. + if got := callbackCount.Load(); got != 0 { + t.Errorf("callback count before reconnect = %d, want 0", got) + } + + // Trigger reconnection. + if err := us.pusherSrv.closeConnection(); err != nil { + t.Fatalf("closeConnection() failed: %v", err) + } + + // Wait for reconnection. + deadline := time.After(10 * time.Second) + for { + if callbackCount.Load() >= 1 { + break + } + select { + case <-deadline: + t.Fatal("timeout waiting for OnReconnect callback") + case <-time.After(50 * time.Millisecond): + } + } + + // Drain subscription channel. + waitForSubscription(t, us.pusherSrv) + + if got := callbackCount.Load(); got != 1 { + t.Errorf("callback count = %d, want 1", got) + } +} + +func TestUplink_Reconnect_SendBuffersDuringOutage(t *testing.T) { + us := newTestUplinkServer(t) + reconnectCh := make(chan struct{}, 1) + + u := newTestUplink(t, us, WithOnReconnect(func() { + select { + case reconnectCh <- struct{}{}: + default: + } + })) + + ctx := testContext(t) + if err := u.Connect(ctx); err != nil { + t.Fatalf("Connect() failed: %v", err) + } + defer u.Close() + + // Wait for initial subscription. + waitForSubscription(t, us.pusherSrv) + + // Send a message and verify it arrives. + u.Send(json.RawMessage(`{"type":"run_complete","data":"before"}`), "run_complete") + deadline := time.After(5 * time.Second) + for { + batches := us.getMessageBatches() + if len(batches) > 0 { + break + } + select { + case <-deadline: + t.Fatal("timeout waiting for initial message") + case <-time.After(10 * time.Millisecond): + } + } + + batchCountBefore := len(us.getMessageBatches()) + + // Trigger reconnection. + if err := us.pusherSrv.closeConnection(); err != nil { + t.Fatalf("closeConnection() failed: %v", err) + } + + // Wait for reconnection to complete. + select { + case <-reconnectCh: + case <-time.After(10 * time.Second): + t.Fatal("timeout waiting for reconnection") + } + + // Wait for re-subscription. + waitForSubscription(t, us.pusherSrv) + + // Send a message after reconnection — it should be delivered. + u.Send(json.RawMessage(`{"type":"run_complete","data":"after"}`), "run_complete") + + deadline = time.After(5 * time.Second) + for { + batches := us.getMessageBatches() + if len(batches) > batchCountBefore { + // Found a new batch after reconnection. + lastBatch := batches[len(batches)-1] + found := false + for _, msg := range lastBatch.Messages { + var parsed map[string]interface{} + json.Unmarshal(msg, &parsed) + if parsed["data"] == "after" { + found = true + } + } + if !found { + t.Error("expected 'after' message in batch after reconnection") + } + break + } + select { + case <-deadline: + t.Fatal("timeout waiting for message after reconnection") + case <-time.After(10 * time.Millisecond): + } + } +} + +func TestUplink_Reconnect_ConcurrentTriggersPrevented(t *testing.T) { + us := newTestUplinkServer(t) + + var reconnectCount atomic.Int32 + u := newTestUplink(t, us, WithOnReconnect(func() { + reconnectCount.Add(1) + })) + + ctx := testContext(t) + if err := u.Connect(ctx); err != nil { + t.Fatalf("Connect() failed: %v", err) + } + defer u.Close() + + // Wait for initial subscription. + waitForSubscription(t, us.pusherSrv) + + // Trigger reconnection multiple times concurrently — only one should run. + for i := 0; i < 5; i++ { + u.triggerReconnect("concurrent test") + } + + // Wait for exactly 1 reconnection. + deadline := time.After(10 * time.Second) + for { + if reconnectCount.Load() >= 1 { + break + } + select { + case <-deadline: + t.Fatal("timeout waiting for reconnection") + case <-time.After(50 * time.Millisecond): + } + } + + // Wait a bit more to confirm no additional reconnections happen. + waitForSubscription(t, us.pusherSrv) + time.Sleep(500 * time.Millisecond) + + if got := reconnectCount.Load(); got != 1 { + t.Errorf("reconnect count = %d, want 1 (concurrent triggers should be prevented)", got) + } +} + +func TestUplink_Reconnect_CloseDuringReconnectStops(t *testing.T) { + us := newTestUplinkServer(t) + + u := newTestUplink(t, us) + + ctx := testContext(t) + if err := u.Connect(ctx); err != nil { + t.Fatalf("Connect() failed: %v", err) + } + + // Wait for initial subscription. + waitForSubscription(t, us.pusherSrv) + + // Make connect fail so reconnection keeps retrying. + us.connectStatus.Store(500) + + // Trigger Pusher disconnection. + if err := us.pusherSrv.closeConnection(); err != nil { + t.Fatalf("closeConnection() failed: %v", err) + } + + // Wait a moment for reconnection to start. + time.Sleep(500 * time.Millisecond) + + // Close the uplink — should stop reconnection. + err := u.Close() + if err != nil { + // Close may return an error from the already-closed Pusher — that's OK. + t.Logf("Close() returned: %v (expected for already-closed Pusher)", err) + } + + // Verify it doesn't hang or panic. Record connect calls and wait. + connectsAfterClose := us.connectCalls.Load() + time.Sleep(2 * time.Second) + + // There should be no more connect attempts after Close(). + if got := us.connectCalls.Load(); got > connectsAfterClose+1 { + t.Errorf("connect calls after Close: %d more than expected (got %d, started at %d)", got-connectsAfterClose, got, connectsAfterClose) + } +} + +func TestUplink_Reconnect_LogsAttemptCountAndDelay(t *testing.T) { + // This test verifies the reconnection logic makes multiple attempts with backoff. + // We can't easily capture log output, so we verify the behavior indirectly + // by checking the number of connect attempts and timing. + us := newTestUplinkServer(t) + reconnectCh := make(chan struct{}, 1) + + u := newTestUplink(t, us, WithOnReconnect(func() { + select { + case reconnectCh <- struct{}{}: + default: + } + })) + + ctx := testContext(t) + if err := u.Connect(ctx); err != nil { + t.Fatalf("Connect() failed: %v", err) + } + defer u.Close() + + waitForSubscription(t, us.pusherSrv) + + // Make connect fail twice, then succeed. + failCount := atomic.Int32{} + originalStatus := us.connectStatus.Load() + us.connectStatus.Store(500) + + go func() { + for { + current := us.connectCalls.Load() + if current >= 3 { // initial + 2 failed + if failCount.Add(1) == 1 { + us.connectStatus.Store(int32(originalStatus)) + } + return + } + time.Sleep(50 * time.Millisecond) + } + }() + + // Trigger reconnection. + if err := us.pusherSrv.closeConnection(); err != nil { + t.Fatalf("closeConnection() failed: %v", err) + } + + select { + case <-reconnectCh: + case <-time.After(15 * time.Second): + t.Fatal("timeout waiting for reconnection after transient failures") + } + + // Multiple connect calls (initial + retries + success). + if got := us.connectCalls.Load(); got < 3 { + t.Errorf("connect calls = %d, want >= 3", got) + } +} + +func TestUplink_Reconnect_WithOnAuthFailureOption(t *testing.T) { + us := newTestUplinkServer(t) + + var authFailureCalled atomic.Int32 + u := newTestUplink(t, us, WithOnAuthFailure(func() error { + authFailureCalled.Add(1) + return nil + })) + + // Verify the option was set. + if u.onAuthFailure == nil { + t.Fatal("onAuthFailure should be set by WithOnAuthFailure option") + } + + // Invoke and verify. + if err := u.onAuthFailure(); err != nil { + t.Errorf("onAuthFailure() = %v, want nil", err) + } + if got := authFailureCalled.Load(); got != 1 { + t.Errorf("authFailureCalled = %d, want 1", got) + } +} + +func TestUplink_Reconnect_StableReceiveChannel(t *testing.T) { + // Verify that Receive() returns the same channel before and after reconnection. + us := newTestUplinkServer(t) + reconnectCh := make(chan struct{}, 1) + + u := newTestUplink(t, us, WithOnReconnect(func() { + select { + case reconnectCh <- struct{}{}: + default: + } + })) + + // Receive channel is created at construction time. + recvBefore := u.Receive() + + ctx := testContext(t) + if err := u.Connect(ctx); err != nil { + t.Fatalf("Connect() failed: %v", err) + } + defer u.Close() + + waitForSubscription(t, us.pusherSrv) + + // Verify it's the same channel after connect. + recvAfterConnect := u.Receive() + if recvBefore != recvAfterConnect { + t.Error("Receive() channel changed after Connect — should be stable") + } + + // Trigger reconnection. + if err := us.pusherSrv.closeConnection(); err != nil { + t.Fatalf("closeConnection() failed: %v", err) + } + + select { + case <-reconnectCh: + case <-time.After(10 * time.Second): + t.Fatal("timeout waiting for reconnection") + } + + waitForSubscription(t, us.pusherSrv) + + // Verify it's still the same channel after reconnection. + recvAfterReconnect := u.Receive() + if recvBefore != recvAfterReconnect { + t.Error("Receive() channel changed after reconnection — should be stable") + } +} + +// --- CloseWithTimeout Tests --- + +func TestUplink_CloseWithTimeout_NormalShutdown(t *testing.T) { + us := newTestUplinkServer(t) + u := newTestUplink(t, us) + + ctx := testContext(t) + if err := u.Connect(ctx); err != nil { + t.Fatalf("Connect() failed: %v", err) + } + + // Wait for Pusher subscription. + waitForSubscription(t, us.pusherSrv) + + // Enqueue a low-priority message — flush should happen during close. + u.Send(json.RawMessage(`{"type":"settings","data":"config"}`), "settings") + + // Close with a generous timeout — should complete well within it. + start := time.Now() + if err := u.CloseWithTimeout(5 * time.Second); err != nil { + t.Fatalf("CloseWithTimeout() failed: %v", err) + } + elapsed := time.Since(start) + + // Should have completed quickly (under 2 seconds). + if elapsed > 2*time.Second { + t.Errorf("CloseWithTimeout took %s, expected < 2s", elapsed) + } + + // Verify the message was flushed. + batches := us.getMessageBatches() + found := false + for _, batch := range batches { + for _, msg := range batch.Messages { + var parsed map[string]interface{} + json.Unmarshal(msg, &parsed) + if parsed["type"] == "settings" { + found = true + } + } + } + if !found { + t.Error("settings message was not flushed during CloseWithTimeout") + } + + // Verify disconnect was called. + if got := us.disconnectCalls.Load(); got != 1 { + t.Errorf("disconnect calls = %d, want 1", got) + } +} + +func TestUplink_CloseWithTimeout_TimesOut(t *testing.T) { + // Create a server that hangs on message sending to simulate an unreachable server. + ps := newTestPusherServer(t) + reverbCfg := ps.reverbConfig() + + // hangDone is closed before the server closes — allows the hanging handler to exit + // so the httptest.Server can close cleanly. Registered AFTER srv.Close() in cleanup + // (LIFO order means hangDone closes first, then srv.Close proceeds). + hangDone := make(chan struct{}) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + auth := r.Header.Get("Authorization") + if !strings.HasPrefix(auth, "Bearer ") { + w.WriteHeader(http.StatusUnauthorized) + return + } + + w.Header().Set("Content-Type", "application/json") + + switch r.URL.Path { + case "/api/device/connect": + json.NewEncoder(w).Encode(WelcomeResponse{ + Type: "welcome", + ProtocolVersion: 1, + DeviceID: 42, + SessionID: "sess-timeout-test", + Reverb: reverbCfg, + }) + + case "/api/device/disconnect": + json.NewEncoder(w).Encode(map[string]string{"status": "disconnected"}) + + case "/api/device/messages": + // Hang until test cleanup to simulate unreachable server during batcher flush. + select { + case <-hangDone: + case <-r.Context().Done(): + } + + case "/api/device/broadcasting/auth": + var body broadcastAuthRequest + json.NewDecoder(r.Body).Decode(&body) + sig := GenerateAuthSignature(ps.appKey, ps.appSecret, body.SocketID, body.ChannelName) + json.NewEncoder(w).Encode(pusherAuthResponse{Auth: sig}) + + default: + http.NotFound(w, r) + } + })) + // Register srv.Close first (runs second in LIFO), then hangDone (runs first). + t.Cleanup(func() { srv.Close() }) + t.Cleanup(func() { close(hangDone) }) + + client := newTestClient(t, srv.URL, "test-token") + u := NewUplink(client) + + ctx := testContext(t) + if err := u.Connect(ctx); err != nil { + t.Fatalf("Connect() failed: %v", err) + } + + // Wait for Pusher subscription. + select { + case <-ps.onSubscribe: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for subscription") + } + + // Enqueue an immediate message — batcher will try to flush on Stop() + // but the server hangs, so the flush blocks. + u.Send(json.RawMessage(`{"type":"run_complete","data":"test"}`), "run_complete") + + // Give the batcher time to attempt sending (and block on the hanging server). + time.Sleep(200 * time.Millisecond) + + // CloseWithTimeout should return within the timeout even though flush is stuck. + start := time.Now() + err := u.CloseWithTimeout(1 * time.Second) + elapsed := time.Since(start) + + // Should complete near the timeout (1-4 seconds, accounting for the force-close 2s grace). + if elapsed > 5*time.Second { + t.Errorf("CloseWithTimeout took %s, expected < 5s", elapsed) + } + + // No error is expected — timeout is handled internally. + if err != nil { + t.Logf("CloseWithTimeout returned: %v (acceptable)", err) + } + + t.Logf("CloseWithTimeout completed in %s", elapsed.Round(time.Millisecond)) +} + +func TestUplink_CloseWithTimeout_DoubleCloseIsSafe(t *testing.T) { + us := newTestUplinkServer(t) + u := newTestUplink(t, us) + + ctx := testContext(t) + if err := u.Connect(ctx); err != nil { + t.Fatalf("Connect() failed: %v", err) + } + + // Wait for Pusher subscription. + waitForSubscription(t, us.pusherSrv) + + // First close. + if err := u.CloseWithTimeout(5 * time.Second); err != nil { + t.Fatalf("first CloseWithTimeout() failed: %v", err) + } + + // Second close should be a no-op. + if err := u.CloseWithTimeout(5 * time.Second); err != nil { + t.Fatalf("second CloseWithTimeout() failed: %v", err) + } + + // Only one disconnect call. + if got := us.disconnectCalls.Load(); got != 1 { + t.Errorf("disconnect calls = %d, want 1", got) + } +} diff --git a/internal/workspace/scanner.go b/internal/workspace/scanner.go new file mode 100644 index 0000000..7a323e1 --- /dev/null +++ b/internal/workspace/scanner.go @@ -0,0 +1,339 @@ +// Package workspace provides workspace directory scanning for discovering +// git repositories and tracking their state. +package workspace + +import ( + "context" + "encoding/json" + "fmt" + "log" + "os" + "os/exec" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/minicodemonkey/chief/internal/ws" +) + +// ScanInterval is how often the scanner re-scans the workspace. +const ScanInterval = 60 * time.Second + +// MessageSender is an interface for sending messages to the server. +type MessageSender interface { + Send(msg interface{}) error +} + +// Scanner discovers and tracks git repositories in a workspace directory. +type Scanner struct { + workspace string + sender MessageSender + interval time.Duration + + mu sync.RWMutex + projects []ws.ProjectSummary +} + +// New creates a new Scanner for the given workspace directory. +func New(workspace string, sender MessageSender) *Scanner { + return &Scanner{ + workspace: workspace, + sender: sender, + interval: ScanInterval, + } +} + +// WorkspacePath returns the workspace directory path. +func (s *Scanner) WorkspacePath() string { + return s.workspace +} + +// SetSender sets the message sender on the scanner. +// This allows creating the scanner before the sender is fully set up. +func (s *Scanner) SetSender(sender MessageSender) { + s.mu.Lock() + defer s.mu.Unlock() + s.sender = sender +} + +// Projects returns the current list of discovered projects. +func (s *Scanner) Projects() []ws.ProjectSummary { + s.mu.RLock() + defer s.mu.RUnlock() + result := make([]ws.ProjectSummary, len(s.projects)) + copy(result, s.projects) + return result +} + +// FindProject looks up a single project by name. +// Returns the project and true if found, or a zero value and false if not found. +func (s *Scanner) FindProject(name string) (ws.ProjectSummary, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + for _, p := range s.projects { + if p.Name == name { + return p, true + } + } + return ws.ProjectSummary{}, false +} + +// Scan performs a single scan of the workspace directory and returns discovered projects. +func (s *Scanner) Scan() []ws.ProjectSummary { + entries, err := os.ReadDir(s.workspace) + if err != nil { + log.Printf("Warning: failed to read workspace directory: %v", err) + return nil + } + + var projects []ws.ProjectSummary + for _, entry := range entries { + if !entry.IsDir() { + continue + } + + dirPath := filepath.Join(s.workspace, entry.Name()) + + // Check for .git/ directory + gitDir := filepath.Join(dirPath, ".git") + info, err := os.Stat(gitDir) + if err != nil { + if os.IsPermission(err) { + log.Printf("Warning: permission denied accessing %s, skipping", dirPath) + } + continue + } + // .git can be a directory (normal repo) or a file (worktree) + if !info.IsDir() { + // .git file means it's a worktree link, still a valid git repo + _ = info + } + + project := scanProject(dirPath, entry.Name()) + projects = append(projects, project) + } + + return projects +} + +// ScanAndUpdate performs a scan and updates the stored project list. +// Returns true if the project list changed. +func (s *Scanner) ScanAndUpdate() bool { + newProjects := s.Scan() + + s.mu.Lock() + defer s.mu.Unlock() + + if projectsEqual(s.projects, newProjects) { + return false + } + + s.projects = newProjects + return true +} + +// Run starts the periodic scanning loop. It performs an initial scan immediately, +// then re-scans at the configured interval. It sends project_list updates over +// WebSocket when projects change. +func (s *Scanner) Run(ctx context.Context) { + // Initial scan + if s.ScanAndUpdate() { + s.sendProjectList() + } + + ticker := time.NewTicker(s.interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if s.ScanAndUpdate() { + log.Println("Workspace projects changed, sending update") + s.sendProjectList() + } + } + } +} + +// sendProjectList sends a project_list message. +func (s *Scanner) sendProjectList() { + if s.sender == nil { + return + } + + s.mu.RLock() + projects := make([]ws.ProjectSummary, len(s.projects)) + copy(projects, s.projects) + s.mu.RUnlock() + + msg := ws.NewMessage(ws.TypeProjectList) + plMsg := ws.ProjectListMessage{ + Type: msg.Type, + ID: msg.ID, + Timestamp: msg.Timestamp, + Projects: projects, + } + + if err := s.sender.Send(plMsg); err != nil { + log.Printf("Error sending project_list: %v", err) + } +} + +// scanProject gathers information about a single project directory. +func scanProject(dirPath, name string) ws.ProjectSummary { + project := ws.ProjectSummary{ + Name: name, + Path: dirPath, + } + + // Check for .chief/ directory + chiefDir := filepath.Join(dirPath, ".chief") + if info, err := os.Stat(chiefDir); err == nil && info.IsDir() { + project.HasChief = true + } + + // Get git branch + branch, err := gitCurrentBranch(dirPath) + if err == nil { + project.Branch = branch + } + + // Get last commit info + commit, err := gitLastCommit(dirPath) + if err == nil { + project.Commit = commit + } + + // Get PRD list if .chief/ exists + if project.HasChief { + project.PRDs = scanPRDs(dirPath) + } + + return project +} + +// gitCurrentBranch returns the current branch for a git repo. +func gitCurrentBranch(dir string) (string, error) { + cmd := exec.Command("git", "rev-parse", "--abbrev-ref", "HEAD") + cmd.Dir = dir + output, err := cmd.Output() + if err != nil { + return "", err + } + return strings.TrimSpace(string(output)), nil +} + +// gitLastCommit returns the last commit info for a git repo. +func gitLastCommit(dir string) (ws.CommitInfo, error) { + // Use git log with a specific format to get hash, message, author, timestamp + cmd := exec.Command("git", "log", "-1", "--format=%H%n%s%n%an%n%aI") + cmd.Dir = dir + output, err := cmd.Output() + if err != nil { + return ws.CommitInfo{}, err + } + + lines := strings.SplitN(strings.TrimSpace(string(output)), "\n", 4) + if len(lines) < 4 { + return ws.CommitInfo{}, fmt.Errorf("unexpected git log output") + } + + return ws.CommitInfo{ + Hash: lines[0], + Message: lines[1], + Author: lines[2], + Timestamp: lines[3], + }, nil +} + +// scanPRDs discovers PRDs in a project's .chief/prds/ directory. +func scanPRDs(dirPath string) []ws.PRDInfo { + prdsDir := filepath.Join(dirPath, ".chief", "prds") + entries, err := os.ReadDir(prdsDir) + if err != nil { + return nil + } + + var prds []ws.PRDInfo + for _, entry := range entries { + if !entry.IsDir() { + continue + } + + prdJSON := filepath.Join(prdsDir, entry.Name(), "prd.json") + data, err := os.ReadFile(prdJSON) + if err != nil { + continue + } + + var prdData struct { + Project string `json:"project"` + UserStories []struct { + ID string `json:"id"` + Passes bool `json:"passes"` + } `json:"userStories"` + } + if err := json.Unmarshal(data, &prdData); err != nil { + continue + } + + total := len(prdData.UserStories) + passed := 0 + for _, s := range prdData.UserStories { + if s.Passes { + passed++ + } + } + + status := fmt.Sprintf("%d/%d", passed, total) + + prds = append(prds, ws.PRDInfo{ + ID: entry.Name(), + Name: prdData.Project, + StoryCount: total, + CompletionStatus: status, + }) + } + + return prds +} + +// projectsEqual compares two project lists for equality. +func projectsEqual(a, b []ws.ProjectSummary) bool { + if len(a) != len(b) { + return false + } + + // Build maps for comparison + aMap := make(map[string]ws.ProjectSummary, len(a)) + for _, p := range a { + aMap[p.Name] = p + } + + for _, pb := range b { + pa, ok := aMap[pb.Name] + if !ok { + return false + } + if pa.Path != pb.Path || + pa.HasChief != pb.HasChief || + pa.Branch != pb.Branch || + pa.Commit.Hash != pb.Commit.Hash || + len(pa.PRDs) != len(pb.PRDs) { + return false + } + // Compare PRDs + for i := range pa.PRDs { + if pa.PRDs[i].ID != pb.PRDs[i].ID || + pa.PRDs[i].StoryCount != pb.PRDs[i].StoryCount || + pa.PRDs[i].CompletionStatus != pb.PRDs[i].CompletionStatus { + return false + } + } + } + + return true +} diff --git a/internal/workspace/scanner_test.go b/internal/workspace/scanner_test.go new file mode 100644 index 0000000..f6c33c0 --- /dev/null +++ b/internal/workspace/scanner_test.go @@ -0,0 +1,504 @@ +package workspace + +import ( + "context" + "encoding/json" + "os" + "os/exec" + "path/filepath" + "sync" + "testing" + "time" + + "github.com/minicodemonkey/chief/internal/ws" +) + +// testSender is a mock MessageSender that captures sent messages. +type testSender struct { + mu sync.Mutex + messages []json.RawMessage +} + +func (s *testSender) Send(msg interface{}) error { + data, err := json.Marshal(msg) + if err != nil { + return err + } + s.mu.Lock() + s.messages = append(s.messages, data) + s.mu.Unlock() + return nil +} + +func (s *testSender) getMessages() []json.RawMessage { + s.mu.Lock() + defer s.mu.Unlock() + cp := make([]json.RawMessage, len(s.messages)) + copy(cp, s.messages) + return cp +} + +func (s *testSender) waitForType(msgType string, timeout time.Duration) (json.RawMessage, bool) { + deadline := time.After(timeout) + for { + msgs := s.getMessages() + for _, raw := range msgs { + var m struct{ Type string `json:"type"` } + if json.Unmarshal(raw, &m) == nil && m.Type == msgType { + return raw, true + } + } + select { + case <-deadline: + return nil, false + case <-time.After(50 * time.Millisecond): + } + } +} + +// initGitRepo initializes a git repo with an initial commit. +func initGitRepo(t *testing.T, dir string) { + t.Helper() + cmds := [][]string{ + {"git", "init"}, + {"git", "config", "user.email", "test@example.com"}, + {"git", "config", "user.name", "Test User"}, + {"git", "commit", "--allow-empty", "-m", "initial commit"}, + } + for _, args := range cmds { + cmd := exec.Command(args[0], args[1:]...) + cmd.Dir = dir + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("git command %v failed: %v\n%s", args, err, out) + } + } +} + +func TestScan_DiscoversGitRepos(t *testing.T) { + workspace := t.TempDir() + + // Create a git repo + repoDir := filepath.Join(workspace, "my-project") + if err := os.MkdirAll(repoDir, 0o755); err != nil { + t.Fatal(err) + } + initGitRepo(t, repoDir) + + // Create a non-git directory (should be ignored) + nonGitDir := filepath.Join(workspace, "not-a-repo") + if err := os.MkdirAll(nonGitDir, 0o755); err != nil { + t.Fatal(err) + } + + // Create a file (should be ignored) + if err := os.WriteFile(filepath.Join(workspace, "some-file.txt"), []byte("hello"), 0o644); err != nil { + t.Fatal(err) + } + + scanner := New(workspace, nil) + projects := scanner.Scan() + + if len(projects) != 1 { + t.Fatalf("expected 1 project, got %d", len(projects)) + } + + p := projects[0] + if p.Name != "my-project" { + t.Errorf("expected name 'my-project', got %q", p.Name) + } + if p.Path != repoDir { + t.Errorf("expected path %q, got %q", repoDir, p.Path) + } + if p.HasChief { + t.Error("expected has_chief to be false") + } + if p.Branch == "" { + t.Error("expected branch to be set") + } + if p.Commit.Hash == "" { + t.Error("expected commit hash to be set") + } + if p.Commit.Message != "initial commit" { + t.Errorf("expected commit message 'initial commit', got %q", p.Commit.Message) + } + if p.Commit.Author != "Test User" { + t.Errorf("expected commit author 'Test User', got %q", p.Commit.Author) + } + if p.Commit.Timestamp == "" { + t.Error("expected commit timestamp to be set") + } +} + +func TestScan_DetectsChiefDirectory(t *testing.T) { + workspace := t.TempDir() + + repoDir := filepath.Join(workspace, "chief-project") + if err := os.MkdirAll(repoDir, 0o755); err != nil { + t.Fatal(err) + } + initGitRepo(t, repoDir) + + // Create .chief/ directory + if err := os.MkdirAll(filepath.Join(repoDir, ".chief", "prds"), 0o755); err != nil { + t.Fatal(err) + } + + scanner := New(workspace, nil) + projects := scanner.Scan() + + if len(projects) != 1 { + t.Fatalf("expected 1 project, got %d", len(projects)) + } + if !projects[0].HasChief { + t.Error("expected has_chief to be true") + } +} + +func TestScan_DiscoversPRDs(t *testing.T) { + workspace := t.TempDir() + + repoDir := filepath.Join(workspace, "prd-project") + if err := os.MkdirAll(repoDir, 0o755); err != nil { + t.Fatal(err) + } + initGitRepo(t, repoDir) + + // Create .chief/prds/my-feature/prd.json + prdDir := filepath.Join(repoDir, ".chief", "prds", "my-feature") + if err := os.MkdirAll(prdDir, 0o755); err != nil { + t.Fatal(err) + } + + prdData := map[string]interface{}{ + "project": "My Feature", + "userStories": []map[string]interface{}{ + {"id": "US-001", "passes": true}, + {"id": "US-002", "passes": false}, + {"id": "US-003", "passes": true}, + }, + } + data, _ := json.Marshal(prdData) + if err := os.WriteFile(filepath.Join(prdDir, "prd.json"), data, 0o644); err != nil { + t.Fatal(err) + } + + scanner := New(workspace, nil) + projects := scanner.Scan() + + if len(projects) != 1 { + t.Fatalf("expected 1 project, got %d", len(projects)) + } + + p := projects[0] + if len(p.PRDs) != 1 { + t.Fatalf("expected 1 PRD, got %d", len(p.PRDs)) + } + + prd := p.PRDs[0] + if prd.ID != "my-feature" { + t.Errorf("expected PRD ID 'my-feature', got %q", prd.ID) + } + if prd.Name != "My Feature" { + t.Errorf("expected PRD name 'My Feature', got %q", prd.Name) + } + if prd.StoryCount != 3 { + t.Errorf("expected 3 stories, got %d", prd.StoryCount) + } + if prd.CompletionStatus != "2/3" { + t.Errorf("expected completion '2/3', got %q", prd.CompletionStatus) + } +} + +func TestScan_MultipleProjects(t *testing.T) { + workspace := t.TempDir() + + for _, name := range []string{"alpha", "beta", "gamma"} { + dir := filepath.Join(workspace, name) + if err := os.MkdirAll(dir, 0o755); err != nil { + t.Fatal(err) + } + initGitRepo(t, dir) + } + + scanner := New(workspace, nil) + projects := scanner.Scan() + + if len(projects) != 3 { + t.Fatalf("expected 3 projects, got %d", len(projects)) + } + + names := make(map[string]bool) + for _, p := range projects { + names[p.Name] = true + } + for _, name := range []string{"alpha", "beta", "gamma"} { + if !names[name] { + t.Errorf("expected project %q to be discovered", name) + } + } +} + +func TestScan_EmptyWorkspace(t *testing.T) { + workspace := t.TempDir() + + scanner := New(workspace, nil) + projects := scanner.Scan() + + if len(projects) != 0 { + t.Errorf("expected 0 projects, got %d", len(projects)) + } +} + +func TestScan_PermissionError(t *testing.T) { + // Skip if running as root (permissions are not enforced) + if os.Getuid() == 0 { + t.Skip("skipping permission test when running as root") + } + + workspace := t.TempDir() + + // Create a directory with .git inside, then remove traverse permission on parent + // so os.Stat on .git fails with permission denied + restrictedDir := filepath.Join(workspace, "restricted") + if err := os.MkdirAll(filepath.Join(restrictedDir, ".git"), 0o755); err != nil { + t.Fatal(err) + } + if err := os.Chmod(restrictedDir, 0o000); err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + os.Chmod(restrictedDir, 0o755) + }) + + // Create a normal git repo too + goodDir := filepath.Join(workspace, "good-project") + if err := os.MkdirAll(goodDir, 0o755); err != nil { + t.Fatal(err) + } + initGitRepo(t, goodDir) + + scanner := New(workspace, nil) + projects := scanner.Scan() + + // Should still discover the good project even if restricted one has issues + if len(projects) != 1 { + t.Fatalf("expected 1 project, got %d", len(projects)) + } + if projects[0].Name != "good-project" { + t.Errorf("expected 'good-project', got %q", projects[0].Name) + } +} + +func TestScanAndUpdate_DetectsChanges(t *testing.T) { + workspace := t.TempDir() + + scanner := New(workspace, nil) + + // First scan: empty + changed := scanner.ScanAndUpdate() + if changed { + t.Error("expected no change on first scan of empty workspace") + } + + // Add a project + repoDir := filepath.Join(workspace, "new-project") + if err := os.MkdirAll(repoDir, 0o755); err != nil { + t.Fatal(err) + } + initGitRepo(t, repoDir) + + // Second scan: should detect the new project + changed = scanner.ScanAndUpdate() + if !changed { + t.Error("expected change after adding a project") + } + + projects := scanner.Projects() + if len(projects) != 1 { + t.Fatalf("expected 1 project, got %d", len(projects)) + } + + // Third scan: no changes + changed = scanner.ScanAndUpdate() + if changed { + t.Error("expected no change on repeat scan") + } +} + +func TestScanAndUpdate_DetectsRemoval(t *testing.T) { + workspace := t.TempDir() + + repoDir := filepath.Join(workspace, "removable") + if err := os.MkdirAll(repoDir, 0o755); err != nil { + t.Fatal(err) + } + initGitRepo(t, repoDir) + + scanner := New(workspace, nil) + scanner.ScanAndUpdate() + + if len(scanner.Projects()) != 1 { + t.Fatal("expected 1 project initially") + } + + // Remove the project + if err := os.RemoveAll(repoDir); err != nil { + t.Fatal(err) + } + + changed := scanner.ScanAndUpdate() + if !changed { + t.Error("expected change after removing project") + } + if len(scanner.Projects()) != 0 { + t.Error("expected 0 projects after removal") + } +} + +func TestRun_SendsProjectListOnChange(t *testing.T) { + workspace := t.TempDir() + + sender := &testSender{} + + // Create a project before starting the scanner + repoDir := filepath.Join(workspace, "starter") + if err := os.MkdirAll(repoDir, 0o755); err != nil { + t.Fatal(err) + } + initGitRepo(t, repoDir) + + scanner := New(workspace, sender) + scanner.interval = 100 * time.Millisecond // Speed up for testing + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Run scanner in background + go scanner.Run(ctx) + + // Wait for initial project_list message + raw, ok := sender.waitForType(ws.TypeProjectList, 5*time.Second) + if !ok { + t.Fatal("timed out waiting for initial project_list message") + } + + var first ws.ProjectListMessage + if err := json.Unmarshal(raw, &first); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if len(first.Projects) != 1 { + t.Errorf("expected 1 project in initial scan, got %d", len(first.Projects)) + } else if first.Projects[0].Name != "starter" { + t.Errorf("expected project name 'starter', got %q", first.Projects[0].Name) + } + + // Add another project + newDir := filepath.Join(workspace, "newcomer") + if err := os.MkdirAll(newDir, 0o755); err != nil { + t.Fatal(err) + } + initGitRepo(t, newDir) + + // Wait for periodic scan to detect the new project + deadline := time.After(5 * time.Second) + for { + msgs := sender.getMessages() + for _, raw := range msgs { + var msg ws.ProjectListMessage + if json.Unmarshal(raw, &msg) == nil && msg.Type == ws.TypeProjectList && len(msg.Projects) == 2 { + return // Success + } + } + select { + case <-deadline: + t.Fatal("timed out waiting for updated project_list message") + case <-time.After(50 * time.Millisecond): + } + } +} + +func TestRun_StopsOnContextCancel(t *testing.T) { + workspace := t.TempDir() + + scanner := New(workspace, nil) + scanner.interval = 50 * time.Millisecond + + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan struct{}) + go func() { + scanner.Run(ctx) + close(done) + }() + + // Let it run briefly + time.Sleep(100 * time.Millisecond) + cancel() + + select { + case <-done: + // Good, it stopped + case <-time.After(2 * time.Second): + t.Fatal("scanner did not stop after context cancel") + } +} + +func TestProjectsEqual(t *testing.T) { + a := []ws.ProjectSummary{ + {Name: "proj1", Path: "/a/proj1", Branch: "main", Commit: ws.CommitInfo{Hash: "abc"}}, + {Name: "proj2", Path: "/a/proj2", Branch: "dev", Commit: ws.CommitInfo{Hash: "def"}}, + } + b := []ws.ProjectSummary{ + {Name: "proj1", Path: "/a/proj1", Branch: "main", Commit: ws.CommitInfo{Hash: "abc"}}, + {Name: "proj2", Path: "/a/proj2", Branch: "dev", Commit: ws.CommitInfo{Hash: "def"}}, + } + + if !projectsEqual(a, b) { + t.Error("expected equal project lists to be equal") + } + + // Change a commit hash + b[1].Commit.Hash = "changed" + if projectsEqual(a, b) { + t.Error("expected project lists with different commit hashes to be unequal") + } + + // Different lengths + if projectsEqual(a, a[:1]) { + t.Error("expected project lists of different lengths to be unequal") + } + + // Both nil/empty + if !projectsEqual(nil, nil) { + t.Error("expected two nil lists to be equal") + } + if !projectsEqual(nil, []ws.ProjectSummary{}) { + t.Error("expected nil and empty to be equal") + } +} + +func TestScan_GitBranch(t *testing.T) { + workspace := t.TempDir() + + repoDir := filepath.Join(workspace, "branched") + if err := os.MkdirAll(repoDir, 0o755); err != nil { + t.Fatal(err) + } + initGitRepo(t, repoDir) + + // Create and switch to a feature branch + cmd := exec.Command("git", "checkout", "-b", "feature/test") + cmd.Dir = repoDir + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("git checkout failed: %v\n%s", err, out) + } + + scanner := New(workspace, nil) + projects := scanner.Scan() + + if len(projects) != 1 { + t.Fatalf("expected 1 project, got %d", len(projects)) + } + if projects[0].Branch != "feature/test" { + t.Errorf("expected branch 'feature/test', got %q", projects[0].Branch) + } +} diff --git a/internal/workspace/watcher.go b/internal/workspace/watcher.go new file mode 100644 index 0000000..ffddbe7 --- /dev/null +++ b/internal/workspace/watcher.go @@ -0,0 +1,311 @@ +package workspace + +import ( + "context" + "log" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/fsnotify/fsnotify" + "github.com/minicodemonkey/chief/internal/ws" +) + +// InactivityTimeout is how long a project remains active without interaction. +const InactivityTimeout = 10 * time.Minute + +// Watcher watches filesystem changes in the workspace using fsnotify. +// It watches the workspace root for new/removed projects and sets up deep +// watchers (.chief/, .git/HEAD) only for active projects. +type Watcher struct { + workspace string + scanner *Scanner + sender MessageSender + watcher *fsnotify.Watcher + + mu sync.Mutex + activeProjects map[string]*activeProject // project name → state + inactiveTimeout time.Duration +} + +// activeProject tracks an actively-watched project. +type activeProject struct { + name string + path string + lastActive time.Time + watching bool // whether deep watchers are set up +} + +// NewWatcher creates a new Watcher for the given workspace directory. +func NewWatcher(workspace string, scanner *Scanner, sender MessageSender) (*Watcher, error) { + fsw, err := fsnotify.NewWatcher() + if err != nil { + return nil, err + } + + return &Watcher{ + workspace: workspace, + scanner: scanner, + sender: sender, + watcher: fsw, + activeProjects: make(map[string]*activeProject), + inactiveTimeout: InactivityTimeout, + }, nil +} + +// Activate marks a project as active, setting up deep watchers if not already watching. +// Call this when a run is started, a Claude session is opened, or get_project is requested. +func (w *Watcher) Activate(projectName string) { + w.mu.Lock() + defer w.mu.Unlock() + + ap, exists := w.activeProjects[projectName] + if exists { + ap.lastActive = time.Now() + log.Printf("[debug] Project %q activity refreshed", projectName) + return + } + + // Find the project path from the scanner + projectPath := "" + for _, p := range w.scanner.Projects() { + if p.Name == projectName { + projectPath = p.Path + break + } + } + if projectPath == "" { + log.Printf("[debug] Project %q not found in scanner, cannot activate watcher", projectName) + return + } + + ap = &activeProject{ + name: projectName, + path: projectPath, + lastActive: time.Now(), + } + w.activeProjects[projectName] = ap + + w.setupDeepWatchers(ap) +} + +// setupDeepWatchers adds fsnotify watches for .chief/ and .git/HEAD for a project. +func (w *Watcher) setupDeepWatchers(ap *activeProject) { + if ap.watching { + return + } + + chiefDir := filepath.Join(ap.path, ".chief") + prdsDir := filepath.Join(ap.path, ".chief", "prds") + gitDir := filepath.Join(ap.path, ".git") + + // Watch .chief/ directory + if err := w.watcher.Add(chiefDir); err != nil { + log.Printf("[debug] Could not watch %s: %v", chiefDir, err) + } else { + log.Printf("[debug] Watching %s for project %q", chiefDir, ap.name) + } + + // Watch .chief/prds/ directory and each PRD subdirectory + // fsnotify does not recurse, so we must add each subdirectory explicitly + if err := w.watcher.Add(prdsDir); err != nil { + log.Printf("[debug] Could not watch %s: %v", prdsDir, err) + } else { + log.Printf("[debug] Watching %s for project %q", prdsDir, ap.name) + // Also watch each PRD subdirectory (e.g., .chief/prds/feature/) + entries, err := os.ReadDir(prdsDir) + if err == nil { + for _, entry := range entries { + if entry.IsDir() { + subDir := filepath.Join(prdsDir, entry.Name()) + if err := w.watcher.Add(subDir); err != nil { + log.Printf("[debug] Could not watch %s: %v", subDir, err) + } else { + log.Printf("[debug] Watching %s for project %q", subDir, ap.name) + } + } + } + } + } + + // Watch .git/ directory (for HEAD changes = branch switches) + if err := w.watcher.Add(gitDir); err != nil { + log.Printf("[debug] Could not watch %s: %v", gitDir, err) + } else { + log.Printf("[debug] Watching .git/ for project %q", ap.name) + } + + ap.watching = true +} + +// removeDeepWatchers removes fsnotify watches for a project. +func (w *Watcher) removeDeepWatchers(ap *activeProject) { + if !ap.watching { + return + } + + chiefDir := filepath.Join(ap.path, ".chief") + prdsDir := filepath.Join(ap.path, ".chief", "prds") + gitDir := filepath.Join(ap.path, ".git") + + _ = w.watcher.Remove(chiefDir) + _ = w.watcher.Remove(prdsDir) + _ = w.watcher.Remove(gitDir) + + ap.watching = false + log.Printf("[debug] Removed watchers for project %q", ap.name) +} + +// Run starts the watcher event loop. It watches the workspace root for project +// additions/removals and handles deep watcher events for active projects. +func (w *Watcher) Run(ctx context.Context) error { + // Watch workspace root for new/removed project directories + if err := w.watcher.Add(w.workspace); err != nil { + return err + } + log.Printf("[debug] Watching workspace root: %s", w.workspace) + + // Start inactivity checker + inactivityTicker := time.NewTicker(1 * time.Minute) + defer inactivityTicker.Stop() + + for { + select { + case <-ctx.Done(): + return w.watcher.Close() + + case event, ok := <-w.watcher.Events: + if !ok { + return nil + } + w.handleEvent(event) + + case err, ok := <-w.watcher.Errors: + if !ok { + return nil + } + log.Printf("Watcher error: %v", err) + + case <-inactivityTicker.C: + w.cleanupInactive() + } + } +} + +// handleEvent processes a single fsnotify event. +func (w *Watcher) handleEvent(event fsnotify.Event) { + path := event.Name + + // Check if this is a workspace-root-level event (new/removed project) + if filepath.Dir(path) == w.workspace { + if event.Has(fsnotify.Create) || event.Has(fsnotify.Remove) || event.Has(fsnotify.Rename) { + log.Printf("[debug] Workspace root change detected: %s (%s)", filepath.Base(path), event.Op) + // Trigger a re-scan to detect new/removed projects + if w.scanner.ScanAndUpdate() { + w.scanner.sendProjectList() + } + } + return + } + + // For deep watcher events, find which project this belongs to + projectName := w.projectForPath(path) + if projectName == "" { + return + } + + // Determine what changed + rel, err := filepath.Rel(w.workspace, path) + if err != nil { + return + } + parts := strings.SplitN(rel, string(filepath.Separator), 3) + if len(parts) < 2 { + return + } + + subPath := strings.Join(parts[1:], string(filepath.Separator)) + + switch { + case strings.HasPrefix(subPath, filepath.Join(".chief", "prds")): + log.Printf("[debug] PRD change detected in project %q: %s", projectName, subPath) + w.sendProjectState(projectName) + + case subPath == filepath.Join(".git", "HEAD"): + log.Printf("[debug] Git HEAD change detected in project %q", projectName) + w.sendProjectState(projectName) + + case strings.HasPrefix(subPath, ".chief"): + log.Printf("[debug] Chief config change in project %q: %s", projectName, subPath) + w.sendProjectState(projectName) + + case strings.HasPrefix(subPath, ".git"): + // Other .git changes (like refs) — check if HEAD changed + if strings.Contains(subPath, "HEAD") { + log.Printf("[debug] Git ref change in project %q: %s", projectName, subPath) + w.sendProjectState(projectName) + } + } +} + +// projectForPath finds which active project a file path belongs to. +func (w *Watcher) projectForPath(path string) string { + w.mu.Lock() + defer w.mu.Unlock() + + for _, ap := range w.activeProjects { + if strings.HasPrefix(path, ap.path+string(filepath.Separator)) || path == ap.path { + return ap.name + } + } + return "" +} + +// sendProjectState re-scans a single project and sends a project_state update. +func (w *Watcher) sendProjectState(projectName string) { + if w.sender == nil { + return + } + + // Re-scan the project to get updated state + w.scanner.ScanAndUpdate() + + // Find the project in the scanner's list + for _, p := range w.scanner.Projects() { + if p.Name == projectName { + msg := ws.NewMessage(ws.TypeProjectState) + psMsg := ws.ProjectStateMessage{ + Type: msg.Type, + ID: msg.ID, + Timestamp: msg.Timestamp, + Project: p, + } + if err := w.sender.Send(psMsg); err != nil { + log.Printf("Error sending project_state for %q: %v", projectName, err) + } + return + } + } +} + +// cleanupInactive removes watchers for projects that have been inactive. +func (w *Watcher) cleanupInactive() { + w.mu.Lock() + defer w.mu.Unlock() + + now := time.Now() + for name, ap := range w.activeProjects { + if now.Sub(ap.lastActive) > w.inactiveTimeout { + log.Printf("[debug] Project %q inactive for %s, removing watchers", name, w.inactiveTimeout) + w.removeDeepWatchers(ap) + delete(w.activeProjects, name) + } + } +} + +// Close closes the underlying fsnotify watcher. +func (w *Watcher) Close() error { + return w.watcher.Close() +} diff --git a/internal/workspace/watcher_test.go b/internal/workspace/watcher_test.go new file mode 100644 index 0000000..df00d51 --- /dev/null +++ b/internal/workspace/watcher_test.go @@ -0,0 +1,431 @@ +package workspace + +import ( + "context" + "encoding/json" + "os" + "os/exec" + "path/filepath" + "testing" + "time" + + "github.com/minicodemonkey/chief/internal/ws" +) + +func TestWatcher_WorkspaceRootChanges(t *testing.T) { + workspace := t.TempDir() + + // Create initial project + repoDir := filepath.Join(workspace, "existing") + if err := os.MkdirAll(repoDir, 0o755); err != nil { + t.Fatal(err) + } + initGitRepo(t, repoDir) + + sender := &testSender{} + + scanner := New(workspace, sender) + scanner.ScanAndUpdate() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + watcher, err := NewWatcher(workspace, scanner, sender) + if err != nil { + t.Fatalf("NewWatcher failed: %v", err) + } + + go watcher.Run(ctx) + + // Allow watcher to start + time.Sleep(100 * time.Millisecond) + + // Create a new project directory with git + newDir := filepath.Join(workspace, "new-project") + if err := os.MkdirAll(newDir, 0o755); err != nil { + t.Fatal(err) + } + initGitRepo(t, newDir) + + // Wait for the project_list message with 2 projects + deadline := time.After(5 * time.Second) + for { + msgs := sender.getMessages() + for _, raw := range msgs { + var msg struct { + Type string `json:"type"` + } + if json.Unmarshal(raw, &msg) == nil && msg.Type == ws.TypeProjectList { + var plMsg ws.ProjectListMessage + if err := json.Unmarshal(raw, &plMsg); err != nil { + t.Fatalf("unmarshal project_list: %v", err) + } + if len(plMsg.Projects) == 2 { + return // Success + } + } + } + select { + case <-deadline: + t.Fatal("timed out waiting for project_list with new project") + case <-time.After(50 * time.Millisecond): + } + } +} + +func TestWatcher_ActivateProject(t *testing.T) { + workspace := t.TempDir() + + // Create a project with .chief and .git + repoDir := filepath.Join(workspace, "my-project") + if err := os.MkdirAll(repoDir, 0o755); err != nil { + t.Fatal(err) + } + initGitRepo(t, repoDir) + if err := os.MkdirAll(filepath.Join(repoDir, ".chief", "prds"), 0o755); err != nil { + t.Fatal(err) + } + + scanner := New(workspace, nil) + scanner.ScanAndUpdate() + + watcher, err := NewWatcher(workspace, scanner, nil) + if err != nil { + t.Fatalf("NewWatcher failed: %v", err) + } + defer watcher.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go watcher.Run(ctx) + + // Initially no active projects + watcher.mu.Lock() + if len(watcher.activeProjects) != 0 { + t.Error("expected 0 active projects initially") + } + watcher.mu.Unlock() + + // Activate a project + watcher.Activate("my-project") + + watcher.mu.Lock() + ap, exists := watcher.activeProjects["my-project"] + watcher.mu.Unlock() + + if !exists { + t.Fatal("expected my-project to be active") + } + if !ap.watching { + t.Error("expected deep watchers to be set up") + } +} + +func TestWatcher_ActivateUnknownProject(t *testing.T) { + workspace := t.TempDir() + + scanner := New(workspace, nil) + scanner.ScanAndUpdate() + + watcher, err := NewWatcher(workspace, scanner, nil) + if err != nil { + t.Fatalf("NewWatcher failed: %v", err) + } + defer watcher.Close() + + // Activating unknown project should not panic or add to active list + watcher.Activate("nonexistent") + + watcher.mu.Lock() + defer watcher.mu.Unlock() + if len(watcher.activeProjects) != 0 { + t.Error("expected 0 active projects for unknown project") + } +} + +func TestWatcher_ActivateRefreshesActivity(t *testing.T) { + workspace := t.TempDir() + + repoDir := filepath.Join(workspace, "proj") + if err := os.MkdirAll(repoDir, 0o755); err != nil { + t.Fatal(err) + } + initGitRepo(t, repoDir) + + scanner := New(workspace, nil) + scanner.ScanAndUpdate() + + watcher, err := NewWatcher(workspace, scanner, nil) + if err != nil { + t.Fatalf("NewWatcher failed: %v", err) + } + defer watcher.Close() + + watcher.Activate("proj") + + watcher.mu.Lock() + firstActive := watcher.activeProjects["proj"].lastActive + watcher.mu.Unlock() + + time.Sleep(10 * time.Millisecond) + + watcher.Activate("proj") + + watcher.mu.Lock() + secondActive := watcher.activeProjects["proj"].lastActive + watcher.mu.Unlock() + + if !secondActive.After(firstActive) { + t.Error("expected lastActive to be refreshed on re-activation") + } +} + +func TestWatcher_InactivityCleanup(t *testing.T) { + workspace := t.TempDir() + + repoDir := filepath.Join(workspace, "proj") + if err := os.MkdirAll(repoDir, 0o755); err != nil { + t.Fatal(err) + } + initGitRepo(t, repoDir) + + scanner := New(workspace, nil) + scanner.ScanAndUpdate() + + watcher, err := NewWatcher(workspace, scanner, nil) + if err != nil { + t.Fatalf("NewWatcher failed: %v", err) + } + defer watcher.Close() + + // Use a very short timeout for testing + watcher.inactiveTimeout = 50 * time.Millisecond + + watcher.Activate("proj") + + watcher.mu.Lock() + if len(watcher.activeProjects) != 1 { + t.Fatal("expected 1 active project") + } + watcher.mu.Unlock() + + // Wait for the project to become inactive + time.Sleep(100 * time.Millisecond) + + watcher.cleanupInactive() + + watcher.mu.Lock() + defer watcher.mu.Unlock() + if len(watcher.activeProjects) != 0 { + t.Error("expected project to be cleaned up after inactivity timeout") + } +} + +func TestWatcher_ChiefPRDChangeSendsProjectState(t *testing.T) { + workspace := t.TempDir() + + // Create project with .chief/prds + repoDir := filepath.Join(workspace, "proj") + if err := os.MkdirAll(repoDir, 0o755); err != nil { + t.Fatal(err) + } + initGitRepo(t, repoDir) + + prdDir := filepath.Join(repoDir, ".chief", "prds", "feature") + if err := os.MkdirAll(prdDir, 0o755); err != nil { + t.Fatal(err) + } + + prdData := map[string]interface{}{ + "project": "Feature", + "userStories": []map[string]interface{}{ + {"id": "US-001", "passes": false}, + }, + } + data, _ := json.Marshal(prdData) + if err := os.WriteFile(filepath.Join(prdDir, "prd.json"), data, 0o644); err != nil { + t.Fatal(err) + } + + sender := &testSender{} + + scanner := New(workspace, sender) + scanner.ScanAndUpdate() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + watcher, err := NewWatcher(workspace, scanner, sender) + if err != nil { + t.Fatalf("NewWatcher failed: %v", err) + } + + go watcher.Run(ctx) + time.Sleep(100 * time.Millisecond) + + // Activate the project to set up deep watchers + watcher.Activate("proj") + time.Sleep(100 * time.Millisecond) + + // Modify the PRD file + prdData["userStories"] = []map[string]interface{}{ + {"id": "US-001", "passes": true}, + } + data, _ = json.Marshal(prdData) + if err := os.WriteFile(filepath.Join(prdDir, "prd.json"), data, 0o644); err != nil { + t.Fatal(err) + } + + // Wait for project_state message + deadline := time.After(5 * time.Second) + for { + msgs := sender.getMessages() + for _, raw := range msgs { + var msg struct { + Type string `json:"type"` + } + if json.Unmarshal(raw, &msg) == nil && msg.Type == ws.TypeProjectState { + var psMsg ws.ProjectStateMessage + if err := json.Unmarshal(raw, &psMsg); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if psMsg.Project.Name == "proj" { + return // Success + } + } + } + select { + case <-deadline: + t.Fatal("timed out waiting for project_state message after PRD change") + case <-time.After(50 * time.Millisecond): + } + } +} + +func TestWatcher_GitHEADChangeSendsProjectState(t *testing.T) { + workspace := t.TempDir() + + repoDir := filepath.Join(workspace, "proj") + if err := os.MkdirAll(repoDir, 0o755); err != nil { + t.Fatal(err) + } + initGitRepo(t, repoDir) + + sender := &testSender{} + + scanner := New(workspace, sender) + scanner.ScanAndUpdate() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + watcher, err := NewWatcher(workspace, scanner, sender) + if err != nil { + t.Fatalf("NewWatcher failed: %v", err) + } + + go watcher.Run(ctx) + time.Sleep(100 * time.Millisecond) + + // Activate the project + watcher.Activate("proj") + time.Sleep(100 * time.Millisecond) + + // Switch branch (changes .git/HEAD) + cmd := exec.Command("git", "checkout", "-b", "feature/new-branch") + cmd.Dir = repoDir + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("git checkout failed: %v\n%s", err, out) + } + + // Wait for project_state message + deadline := time.After(5 * time.Second) + for { + msgs := sender.getMessages() + for _, raw := range msgs { + var msg struct { + Type string `json:"type"` + } + if json.Unmarshal(raw, &msg) == nil && msg.Type == ws.TypeProjectState { + var psMsg ws.ProjectStateMessage + if err := json.Unmarshal(raw, &psMsg); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if psMsg.Project.Name == "proj" { + return // Success + } + } + } + select { + case <-deadline: + t.Fatal("timed out waiting for project_state message after branch switch") + case <-time.After(50 * time.Millisecond): + } + } +} + +func TestWatcher_ContextCancellation(t *testing.T) { + workspace := t.TempDir() + + scanner := New(workspace, nil) + + watcher, err := NewWatcher(workspace, scanner, nil) + if err != nil { + t.Fatalf("NewWatcher failed: %v", err) + } + + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan error, 1) + go func() { + done <- watcher.Run(ctx) + }() + + // Let it start + time.Sleep(50 * time.Millisecond) + cancel() + + select { + case <-done: + // Good, it stopped + case <-time.After(2 * time.Second): + t.Fatal("watcher did not stop after context cancel") + } +} + +func TestWatcher_NoDeepWatchersForInactiveProjects(t *testing.T) { + workspace := t.TempDir() + + // Create two projects + for _, name := range []string{"active-proj", "inactive-proj"} { + dir := filepath.Join(workspace, name) + if err := os.MkdirAll(dir, 0o755); err != nil { + t.Fatal(err) + } + initGitRepo(t, dir) + if err := os.MkdirAll(filepath.Join(dir, ".chief", "prds"), 0o755); err != nil { + t.Fatal(err) + } + } + + scanner := New(workspace, nil) + scanner.ScanAndUpdate() + + watcher, err := NewWatcher(workspace, scanner, nil) + if err != nil { + t.Fatalf("NewWatcher failed: %v", err) + } + defer watcher.Close() + + // Only activate one project + watcher.Activate("active-proj") + + watcher.mu.Lock() + defer watcher.mu.Unlock() + + if _, exists := watcher.activeProjects["active-proj"]; !exists { + t.Error("expected active-proj to be in active projects") + } + if _, exists := watcher.activeProjects["inactive-proj"]; exists { + t.Error("expected inactive-proj to NOT be in active projects") + } +} diff --git a/internal/ws/messages.go b/internal/ws/messages.go new file mode 100644 index 0000000..54c46bd --- /dev/null +++ b/internal/ws/messages.go @@ -0,0 +1,631 @@ +package ws + +import ( + "crypto/rand" + "encoding/json" + "fmt" + "time" +) + +// ProtocolVersion is the current protocol version. +const ProtocolVersion = 1 + +// Message represents a protocol message envelope. +type Message struct { + Type string `json:"type"` + ID string `json:"id,omitempty"` + Timestamp string `json:"timestamp,omitempty"` + Raw json.RawMessage `json:"-"` +} + +// NewMessage creates a new message envelope with type, UUID, and ISO8601 timestamp. +func NewMessage(msgType string) Message { + return Message{ + Type: msgType, + ID: newUUID(), + Timestamp: time.Now().UTC().Format(time.RFC3339), + } +} + +// newUUID generates a random UUID v4 string. +func newUUID() string { + var uuid [16]byte + _, _ = rand.Read(uuid[:]) + // Set version 4 bits. + uuid[6] = (uuid[6] & 0x0f) | 0x40 + // Set variant bits. + uuid[8] = (uuid[8] & 0x3f) | 0x80 + return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x", + uuid[0:4], uuid[4:6], uuid[6:8], uuid[8:10], uuid[10:16]) +} + +// Error codes for protocol error messages. +const ( + ErrCodeAuthFailed = "AUTH_FAILED" + ErrCodeProjectNotFound = "PROJECT_NOT_FOUND" + ErrCodePRDNotFound = "PRD_NOT_FOUND" + ErrCodeRunAlreadyActive = "RUN_ALREADY_ACTIVE" + ErrCodeRunNotActive = "RUN_NOT_ACTIVE" + ErrCodeSessionNotFound = "SESSION_NOT_FOUND" + ErrCodeCloneFailed = "CLONE_FAILED" + ErrCodeQuotaExhausted = "QUOTA_EXHAUSTED" + ErrCodeFilesystemError = "FILESYSTEM_ERROR" + ErrCodeClaudeError = "CLAUDE_ERROR" + ErrCodeUpdateFailed = "UPDATE_FAILED" + ErrCodeIncompatibleVersion = "INCOMPATIBLE_VERSION" + ErrCodeRateLimited = "RATE_LIMITED" +) + +// Message type constants for the protocol catalog. +const ( + // Server → Web App message types. + TypeHello = "hello" + TypeStateSnapshot = "state_snapshot" + TypeProjectList = "project_list" + TypeProjectState = "project_state" + TypePRDContent = "prd_content" + TypeClaudeOutput = "claude_output" + TypeRunProgress = "run_progress" + TypeRunComplete = "run_complete" + TypeRunPaused = "run_paused" + TypeDiff = "diff" + TypeDiffsResponse = "diffs_response" + TypePRDsResponse = "prds_response" + TypeCloneProgress = "clone_progress" + TypeCloneComplete = "clone_complete" + TypeError = "error" + TypeQuotaExhausted = "quota_exhausted" + TypeLogLines = "log_lines" + TypeSessionTimeoutWarning = "session_timeout_warning" + TypeSessionExpired = "session_expired" + TypeSettings = "settings" + TypeSettingsResponse = "settings_response" + TypeSettingsUpdated = "settings_updated" + TypeUpdateAvailable = "update_available" + TypePRDOutput = "prd_output" + TypePRDResponseComplete = "prd_response_complete" + + // Web App → Server message types. + TypeWelcome = "welcome" + TypeIncompatible = "incompatible" + TypeListProjects = "list_projects" + TypeGetProject = "get_project" + TypeGetPRD = "get_prd" + TypeGetPRDs = "get_prds" + TypeNewPRD = "new_prd" + TypeRefinePRD = "refine_prd" + TypePRDMessage = "prd_message" + TypeClosePRDSession = "close_prd_session" + TypeStartRun = "start_run" + TypePauseRun = "pause_run" + TypeResumeRun = "resume_run" + TypeStopRun = "stop_run" + TypeCloneRepo = "clone_repo" + TypeCreateProject = "create_project" + TypeGetDiff = "get_diff" + TypeGetDiffs = "get_diffs" + TypeGetLogs = "get_logs" + TypeGetSettings = "get_settings" + TypeUpdateSettings = "update_settings" + TypeTriggerUpdate = "trigger_update" + TypePing = "ping" + + // Bidirectional. + TypePong = "pong" +) + +// --- Server → Web App messages --- + +// StateSnapshotMessage is sent on connect/reconnect with full state. +type StateSnapshotMessage struct { + Type string `json:"type"` + ID string `json:"id"` + Timestamp string `json:"timestamp"` + Projects []ProjectSummary `json:"projects"` + Runs []RunState `json:"runs"` + Sessions []SessionState `json:"sessions"` +} + +// ProjectSummary describes a project in the workspace. +type ProjectSummary struct { + Name string `json:"name"` + Path string `json:"path"` + HasChief bool `json:"has_chief"` + Branch string `json:"branch"` + Commit CommitInfo `json:"commit"` + PRDs []PRDInfo `json:"prds"` +} + +// CommitInfo describes a git commit. +type CommitInfo struct { + Hash string `json:"hash"` + Message string `json:"message"` + Author string `json:"author"` + Timestamp string `json:"timestamp"` +} + +// PRDInfo describes a PRD in a project. +type PRDInfo struct { + ID string `json:"id"` + Name string `json:"name"` + StoryCount int `json:"story_count"` + CompletionStatus string `json:"completion_status"` +} + +// RunState describes an active run. +type RunState struct { + Project string `json:"project"` + PRDID string `json:"prd_id"` + StoryID string `json:"story_id"` + Status string `json:"status"` + Iteration int `json:"iteration"` +} + +// SessionState describes an active Claude session. +type SessionState struct { + SessionID string `json:"session_id"` + Project string `json:"project"` + PRDID string `json:"prd_id"` +} + +// ProjectListMessage lists all discovered projects. +type ProjectListMessage struct { + Type string `json:"type"` + ID string `json:"id"` + Timestamp string `json:"timestamp"` + Projects []ProjectSummary `json:"projects"` +} + +// ProjectStateMessage returns state for a single project. +type ProjectStateMessage struct { + Type string `json:"type"` + ID string `json:"id"` + Timestamp string `json:"timestamp"` + Project ProjectSummary `json:"project"` +} + +// PRDContentMessage returns PRD content and state. +type PRDContentMessage struct { + Type string `json:"type"` + ID string `json:"id"` + Timestamp string `json:"timestamp"` + Project string `json:"project"` + PRDID string `json:"prd_id"` + Content string `json:"content"` + State interface{} `json:"state"` +} + +// ClaudeOutputMessage streams Claude output. +type ClaudeOutputMessage struct { + Type string `json:"type"` + ID string `json:"id"` + Timestamp string `json:"timestamp"` + SessionID string `json:"session_id,omitempty"` + Project string `json:"project,omitempty"` + PRDID string `json:"prd_id,omitempty"` + StoryID string `json:"story_id,omitempty"` + Data string `json:"data"` + Done bool `json:"done"` +} + +// PRDOutputPayload is the payload of a PRD output message. +type PRDOutputPayload struct { + Content string `json:"content"` + SessionID string `json:"session_id"` + Project string `json:"project"` +} + +// PRDOutputMessage streams PRD session output (text chunks from Claude). +type PRDOutputMessage struct { + Type string `json:"type"` + Payload PRDOutputPayload `json:"payload"` +} + +// PRDResponseCompletePayload is the payload of a PRD response complete message. +type PRDResponseCompletePayload struct { + SessionID string `json:"session_id"` + Project string `json:"project"` +} + +// PRDResponseCompleteMessage signals that a PRD session's Claude process has finished. +type PRDResponseCompleteMessage struct { + Type string `json:"type"` + Payload PRDResponseCompletePayload `json:"payload"` +} + +// RunProgressMessage reports run state changes. +type RunProgressMessage struct { + Type string `json:"type"` + ID string `json:"id"` + Timestamp string `json:"timestamp"` + Project string `json:"project"` + PRDID string `json:"prd_id"` + StoryID string `json:"story_id"` + Status string `json:"status"` + Iteration int `json:"iteration"` + Attempt int `json:"attempt"` +} + +// RunCompleteMessage reports run completion. +type RunCompleteMessage struct { + Type string `json:"type"` + ID string `json:"id"` + Timestamp string `json:"timestamp"` + Project string `json:"project"` + PRDID string `json:"prd_id"` + StoriesCompleted int `json:"stories_completed"` + Duration string `json:"duration"` + PassCount int `json:"pass_count"` + FailCount int `json:"fail_count"` +} + +// RunPausedMessage reports a paused run. +type RunPausedMessage struct { + Type string `json:"type"` + ID string `json:"id"` + Timestamp string `json:"timestamp"` + Project string `json:"project"` + PRDID string `json:"prd_id"` + StoryID string `json:"story_id"` + Reason string `json:"reason"` +} + +// DiffMessage contains a story's diff. +type DiffMessage struct { + Type string `json:"type"` + ID string `json:"id"` + Timestamp string `json:"timestamp"` + Project string `json:"project"` + PRDID string `json:"prd_id"` + StoryID string `json:"story_id"` + Files []string `json:"files"` + DiffText string `json:"diff_text"` +} + +// CloneProgressMessage reports git clone progress. +type CloneProgressMessage struct { + Type string `json:"type"` + ID string `json:"id"` + Timestamp string `json:"timestamp"` + URL string `json:"url"` + ProgressText string `json:"progress_text"` + Percent int `json:"percent"` +} + +// CloneCompleteMessage reports clone completion. +type CloneCompleteMessage struct { + Type string `json:"type"` + ID string `json:"id"` + Timestamp string `json:"timestamp"` + URL string `json:"url"` + Success bool `json:"success"` + Error string `json:"error,omitempty"` + Project string `json:"project,omitempty"` +} + +// ErrorMessage reports an error. +type ErrorMessage struct { + Type string `json:"type"` + ID string `json:"id"` + Timestamp string `json:"timestamp"` + Code string `json:"code"` + Message string `json:"message"` + RequestID string `json:"request_id,omitempty"` +} + +// QuotaExhaustedMessage reports quota exhaustion. +type QuotaExhaustedMessage struct { + Type string `json:"type"` + ID string `json:"id"` + Timestamp string `json:"timestamp"` + Runs []string `json:"runs"` + Sessions []string `json:"sessions"` +} + +// LogLinesMessage returns log content. +type LogLinesMessage struct { + Type string `json:"type"` + ID string `json:"id"` + Timestamp string `json:"timestamp"` + Project string `json:"project"` + PRDID string `json:"prd_id"` + StoryID string `json:"story_id"` + Lines []string `json:"lines"` + Level string `json:"level"` +} + +// SessionTimeoutWarningMessage warns of impending session timeout. +type SessionTimeoutWarningMessage struct { + Type string `json:"type"` + ID string `json:"id"` + Timestamp string `json:"timestamp"` + SessionID string `json:"session_id"` + MinutesRemaining int `json:"minutes_remaining"` +} + +// SessionExpiredMessage reports that a session has timed out. +type SessionExpiredMessage struct { + Type string `json:"type"` + ID string `json:"id"` + Timestamp string `json:"timestamp"` + SessionID string `json:"session_id"` + SavedState string `json:"saved_state,omitempty"` +} + +// SettingsMessage returns project settings. +type SettingsMessage struct { + Type string `json:"type"` + ID string `json:"id"` + Timestamp string `json:"timestamp"` + Project string `json:"project"` + MaxIterations int `json:"max_iterations"` + AutoCommit bool `json:"auto_commit"` + CommitPrefix string `json:"commit_prefix"` + ClaudeModel string `json:"claude_model"` + TestCommand string `json:"test_command"` +} + +// UpdateAvailableMessage reports an available update. +type UpdateAvailableMessage struct { + Type string `json:"type"` + ID string `json:"id"` + Timestamp string `json:"timestamp"` + CurrentVersion string `json:"current_version"` + LatestVersion string `json:"latest_version"` +} + +// --- Web App → Server messages --- + +// ListProjectsMessage requests the project list. +type ListProjectsMessage struct { + Type string `json:"type"` + ID string `json:"id"` + Timestamp string `json:"timestamp"` +} + +// GetProjectMessage requests a single project's state. +type GetProjectMessage struct { + Type string `json:"type"` + ID string `json:"id"` + Timestamp string `json:"timestamp"` + Project string `json:"project"` +} + +// GetPRDsMessage requests a list of all PRDs for a project. +type GetPRDsMessage struct { + Project string `json:"project"` +} + +// PRDsResponseMessage returns a list of PRDs for a project. +type PRDsResponseMessage struct { + Type string `json:"type"` + Payload PRDsResponsePayload `json:"payload"` +} + +// PRDsResponsePayload is the payload of a PRDs response. +type PRDsResponsePayload struct { + Project string `json:"project"` + PRDs []PRDItem `json:"prds"` +} + +// PRDItem describes a PRD in the response list. +type PRDItem struct { + ID string `json:"id"` + Name string `json:"name"` + StoryCount int `json:"story_count"` + Status string `json:"status"` +} + +// SettingsResponseMessage wraps settings for browser delivery. +type SettingsResponseMessage struct { + Type string `json:"type"` + Payload SettingsResponsePayload `json:"payload"` +} + +// SettingsResponsePayload is the payload of a settings response. +type SettingsResponsePayload struct { + Project string `json:"project"` + Settings SettingsData `json:"settings"` +} + +// SettingsData contains project settings fields. +type SettingsData struct { + MaxIterations int `json:"max_iterations"` + AutoCommit bool `json:"auto_commit"` + CommitPrefix string `json:"commit_prefix"` + ClaudeModel string `json:"claude_model"` + TestCommand string `json:"test_command"` +} + +// DiffsResponseMessage wraps diff data for browser delivery. +type DiffsResponseMessage struct { + Type string `json:"type"` + Payload DiffsResponsePayload `json:"payload"` +} + +// DiffsResponsePayload is the payload of a diffs response. +type DiffsResponsePayload struct { + Project string `json:"project"` + StoryID string `json:"story_id"` + Files []DiffFileDetail `json:"files"` +} + +// DiffFileDetail represents a single file's diff information. +type DiffFileDetail struct { + Filename string `json:"filename"` + Additions int `json:"additions"` + Deletions int `json:"deletions"` + Patch string `json:"patch"` +} + +// GetDiffsMessage requests diffs for a story (without requiring prd_id). +type GetDiffsMessage struct { + Project string `json:"project"` + StoryID string `json:"story_id"` +} + +// GetPRDMessage requests a PRD's content. +type GetPRDMessage struct { + Type string `json:"type"` + ID string `json:"id"` + Timestamp string `json:"timestamp"` + Project string `json:"project"` + PRDID string `json:"prd_id"` +} + +// NewPRDMessage requests creation of a new PRD via Claude. +type NewPRDMessage struct { + Type string `json:"type"` + ID string `json:"id"` + Timestamp string `json:"timestamp"` + Project string `json:"project"` + SessionID string `json:"session_id"` + Message string `json:"message"` +} + +// RefinePRDMessage requests editing an existing PRD via Claude. +type RefinePRDMessage struct { + Type string `json:"type"` + ID string `json:"id"` + Timestamp string `json:"timestamp"` + Project string `json:"project"` + SessionID string `json:"session_id"` + PRDID string `json:"prd_id"` + Message string `json:"message"` +} + +// PRDMessageMessage sends a user message to an active PRD session. +type PRDMessageMessage struct { + Type string `json:"type"` + ID string `json:"id"` + Timestamp string `json:"timestamp"` + Project string `json:"project"` + SessionID string `json:"session_id"` + Message string `json:"message"` +} + +// ClosePRDSessionMessage closes a PRD session. +type ClosePRDSessionMessage struct { + Type string `json:"type"` + ID string `json:"id"` + Timestamp string `json:"timestamp"` + SessionID string `json:"session_id"` + Save bool `json:"save"` +} + +// StartRunMessage starts a Ralph loop. +type StartRunMessage struct { + Type string `json:"type"` + ID string `json:"id"` + Timestamp string `json:"timestamp"` + Project string `json:"project"` + PRDID string `json:"prd_id"` +} + +// PauseRunMessage pauses a running loop. +type PauseRunMessage struct { + Type string `json:"type"` + ID string `json:"id"` + Timestamp string `json:"timestamp"` + Project string `json:"project"` + PRDID string `json:"prd_id"` +} + +// ResumeRunMessage resumes a paused loop. +type ResumeRunMessage struct { + Type string `json:"type"` + ID string `json:"id"` + Timestamp string `json:"timestamp"` + Project string `json:"project"` + PRDID string `json:"prd_id"` +} + +// StopRunMessage stops a running loop. +type StopRunMessage struct { + Type string `json:"type"` + ID string `json:"id"` + Timestamp string `json:"timestamp"` + Project string `json:"project"` + PRDID string `json:"prd_id"` +} + +// CloneRepoMessage requests a git clone. +type CloneRepoMessage struct { + Type string `json:"type"` + ID string `json:"id"` + Timestamp string `json:"timestamp"` + URL string `json:"url"` + DirectoryName string `json:"directory_name,omitempty"` +} + +// CreateProjectMessage creates a new project. +type CreateProjectMessage struct { + Type string `json:"type"` + ID string `json:"id"` + Timestamp string `json:"timestamp"` + Name string `json:"name"` + GitInit bool `json:"git_init"` +} + +// GetDiffMessage requests a story's diff. +type GetDiffMessage struct { + Type string `json:"type"` + ID string `json:"id"` + Timestamp string `json:"timestamp"` + Project string `json:"project"` + PRDID string `json:"prd_id"` + StoryID string `json:"story_id"` +} + +// GetLogsMessage requests log lines. +type GetLogsMessage struct { + Type string `json:"type"` + ID string `json:"id"` + Timestamp string `json:"timestamp"` + Project string `json:"project"` + PRDID string `json:"prd_id"` + StoryID string `json:"story_id,omitempty"` + Lines int `json:"lines,omitempty"` +} + +// GetSettingsMessage requests project settings. +type GetSettingsMessage struct { + Type string `json:"type"` + ID string `json:"id"` + Timestamp string `json:"timestamp"` + Project string `json:"project"` +} + +// UpdateSettingsMessage updates project settings. +type UpdateSettingsMessage struct { + Type string `json:"type"` + ID string `json:"id"` + Timestamp string `json:"timestamp"` + Project string `json:"project"` + MaxIterations *int `json:"max_iterations,omitempty"` + AutoCommit *bool `json:"auto_commit,omitempty"` + CommitPrefix *string `json:"commit_prefix,omitempty"` + ClaudeModel *string `json:"claude_model,omitempty"` + TestCommand *string `json:"test_command,omitempty"` +} + +// TriggerUpdateMessage requests a self-update. +type TriggerUpdateMessage struct { + Type string `json:"type"` + ID string `json:"id"` + Timestamp string `json:"timestamp"` +} + +// PingMessage is a keepalive ping. +type PingMessage struct { + Type string `json:"type"` + ID string `json:"id"` + Timestamp string `json:"timestamp"` +} + +// PongMessage is a keepalive pong response. +type PongMessage struct { + Type string `json:"type"` + ID string `json:"id"` + Timestamp string `json:"timestamp"` +} diff --git a/internal/ws/messages_test.go b/internal/ws/messages_test.go new file mode 100644 index 0000000..efb8ed8 --- /dev/null +++ b/internal/ws/messages_test.go @@ -0,0 +1,1038 @@ +package ws + +import ( + "encoding/json" + "testing" +) + +func TestStateSnapshotRoundTrip(t *testing.T) { + msg := StateSnapshotMessage{ + Type: TypeStateSnapshot, + ID: newUUID(), + Timestamp: "2026-02-15T10:00:00Z", + Projects: []ProjectSummary{ + { + Name: "my-project", + Path: "/home/user/projects/my-project", + HasChief: true, + Branch: "main", + Commit: CommitInfo{ + Hash: "abc123", + Message: "initial commit", + Author: "dev", + Timestamp: "2026-02-15T09:00:00Z", + }, + PRDs: []PRDInfo{ + {ID: "auth", Name: "Authentication", StoryCount: 5, CompletionStatus: "3/5"}, + }, + }, + }, + Runs: []RunState{{Project: "my-project", PRDID: "auth", StoryID: "US-003", Status: "running", Iteration: 2}}, + Sessions: []SessionState{{SessionID: "sess-1", Project: "my-project", PRDID: "auth"}}, + } + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + var got StateSnapshotMessage + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if got.Type != TypeStateSnapshot { + t.Errorf("type = %q, want %q", got.Type, TypeStateSnapshot) + } + if len(got.Projects) != 1 { + t.Fatalf("projects count = %d, want 1", len(got.Projects)) + } + if got.Projects[0].Name != "my-project" { + t.Errorf("project name = %q, want %q", got.Projects[0].Name, "my-project") + } + if got.Projects[0].Commit.Hash != "abc123" { + t.Errorf("commit hash = %q, want %q", got.Projects[0].Commit.Hash, "abc123") + } + if len(got.Projects[0].PRDs) != 1 || got.Projects[0].PRDs[0].StoryCount != 5 { + t.Errorf("unexpected PRD info: %+v", got.Projects[0].PRDs) + } + if len(got.Runs) != 1 || got.Runs[0].Iteration != 2 { + t.Errorf("unexpected run state: %+v", got.Runs) + } + if len(got.Sessions) != 1 || got.Sessions[0].SessionID != "sess-1" { + t.Errorf("unexpected session state: %+v", got.Sessions) + } +} + +func TestClaudeOutputRoundTrip(t *testing.T) { + msg := ClaudeOutputMessage{ + Type: TypeClaudeOutput, + ID: newUUID(), + Timestamp: "2026-02-15T10:00:00Z", + SessionID: "session-123", + Project: "my-project", + Data: "Hello from Claude!\nLine 2.", + Done: false, + } + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + var got ClaudeOutputMessage + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if got.Type != TypeClaudeOutput { + t.Errorf("type = %q, want %q", got.Type, TypeClaudeOutput) + } + if got.SessionID != "session-123" { + t.Errorf("session_id = %q, want %q", got.SessionID, "session-123") + } + if got.Data != "Hello from Claude!\nLine 2." { + t.Errorf("data = %q, want %q", got.Data, "Hello from Claude!\nLine 2.") + } + if got.Done { + t.Error("done should be false") + } +} + +func TestRunProgressRoundTrip(t *testing.T) { + msg := RunProgressMessage{ + Type: TypeRunProgress, + ID: newUUID(), + Timestamp: "2026-02-15T10:00:00Z", + Project: "my-project", + PRDID: "auth", + StoryID: "US-003", + Status: "running", + Iteration: 3, + Attempt: 1, + } + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + var got RunProgressMessage + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if got.Project != "my-project" { + t.Errorf("project = %q, want %q", got.Project, "my-project") + } + if got.StoryID != "US-003" { + t.Errorf("story_id = %q, want %q", got.StoryID, "US-003") + } + if got.Iteration != 3 { + t.Errorf("iteration = %d, want 3", got.Iteration) + } + if got.Attempt != 1 { + t.Errorf("attempt = %d, want 1", got.Attempt) + } +} + +func TestRunCompleteRoundTrip(t *testing.T) { + msg := RunCompleteMessage{ + Type: TypeRunComplete, + ID: newUUID(), + Timestamp: "2026-02-15T10:00:00Z", + Project: "my-project", + PRDID: "auth", + StoriesCompleted: 5, + Duration: "12m34s", + PassCount: 4, + FailCount: 1, + } + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + var got RunCompleteMessage + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if got.StoriesCompleted != 5 { + t.Errorf("stories_completed = %d, want 5", got.StoriesCompleted) + } + if got.Duration != "12m34s" { + t.Errorf("duration = %q, want %q", got.Duration, "12m34s") + } + if got.PassCount != 4 || got.FailCount != 1 { + t.Errorf("pass/fail = %d/%d, want 4/1", got.PassCount, got.FailCount) + } +} + +func TestErrorMessageRoundTrip(t *testing.T) { + msg := ErrorMessage{ + Type: TypeError, + ID: newUUID(), + Timestamp: "2026-02-15T10:00:00Z", + Code: ErrCodeProjectNotFound, + Message: "Project 'foobar' not found", + RequestID: "req-456", + } + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + var got ErrorMessage + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if got.Code != ErrCodeProjectNotFound { + t.Errorf("code = %q, want %q", got.Code, ErrCodeProjectNotFound) + } + if got.Message != "Project 'foobar' not found" { + t.Errorf("message = %q", got.Message) + } + if got.RequestID != "req-456" { + t.Errorf("request_id = %q, want %q", got.RequestID, "req-456") + } +} + +func TestErrorMessageWithoutRequestID(t *testing.T) { + msg := ErrorMessage{ + Type: TypeError, + ID: newUUID(), + Timestamp: "2026-02-15T10:00:00Z", + Code: ErrCodeClaudeError, + Message: "Claude process crashed", + } + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + // Verify request_id is omitted. + var raw map[string]interface{} + json.Unmarshal(data, &raw) + if _, ok := raw["request_id"]; ok { + t.Error("request_id should be omitted when empty") + } +} + +func TestDiffMessageRoundTrip(t *testing.T) { + msg := DiffMessage{ + Type: TypeDiff, + ID: newUUID(), + Timestamp: "2026-02-15T10:00:00Z", + Project: "my-project", + PRDID: "auth", + StoryID: "US-003", + Files: []string{"internal/auth/auth.go", "internal/auth/auth_test.go"}, + DiffText: "--- a/internal/auth/auth.go\n+++ b/internal/auth/auth.go\n@@ -1,3 +1,5 @@\n+// new code\n", + } + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + var got DiffMessage + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if len(got.Files) != 2 { + t.Fatalf("files count = %d, want 2", len(got.Files)) + } + if got.Files[0] != "internal/auth/auth.go" { + t.Errorf("files[0] = %q", got.Files[0]) + } + if got.DiffText == "" { + t.Error("diff_text should not be empty") + } +} + +func TestStartRunRoundTrip(t *testing.T) { + msg := StartRunMessage{ + Type: TypeStartRun, + ID: newUUID(), + Timestamp: "2026-02-15T10:00:00Z", + Project: "my-project", + PRDID: "auth", + } + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + var got StartRunMessage + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if got.Type != TypeStartRun { + t.Errorf("type = %q, want %q", got.Type, TypeStartRun) + } + if got.Project != "my-project" { + t.Errorf("project = %q", got.Project) + } + if got.PRDID != "auth" { + t.Errorf("prd_id = %q", got.PRDID) + } +} + +func TestNewPRDRoundTrip(t *testing.T) { + msg := NewPRDMessage{ + Type: TypeNewPRD, + ID: newUUID(), + Timestamp: "2026-02-15T10:00:00Z", + Project: "my-project", + SessionID: "session-abc", + Message: "Build an authentication system", + } + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + var got NewPRDMessage + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if got.SessionID != "session-abc" { + t.Errorf("session_id = %q", got.SessionID) + } + if got.Message != "Build an authentication system" { + t.Errorf("message = %q", got.Message) + } +} + +func TestRefinePRDRoundTrip(t *testing.T) { + msg := RefinePRDMessage{ + Type: TypeRefinePRD, + ID: newUUID(), + Timestamp: "2026-02-15T10:00:00Z", + Project: "my-project", + SessionID: "session-abc", + PRDID: "feature-auth", + Message: "Add OAuth support", + } + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + var got RefinePRDMessage + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if got.Type != TypeRefinePRD { + t.Errorf("type = %q, want %q", got.Type, TypeRefinePRD) + } + if got.SessionID != "session-abc" { + t.Errorf("session_id = %q", got.SessionID) + } + if got.PRDID != "feature-auth" { + t.Errorf("prd_id = %q", got.PRDID) + } + if got.Message != "Add OAuth support" { + t.Errorf("message = %q", got.Message) + } +} + +func TestCloneRepoRoundTrip(t *testing.T) { + msg := CloneRepoMessage{ + Type: TypeCloneRepo, + ID: newUUID(), + Timestamp: "2026-02-15T10:00:00Z", + URL: "git@github.com:user/repo.git", + DirectoryName: "my-repo", + } + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + var got CloneRepoMessage + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if got.URL != "git@github.com:user/repo.git" { + t.Errorf("url = %q", got.URL) + } + if got.DirectoryName != "my-repo" { + t.Errorf("directory_name = %q", got.DirectoryName) + } +} + +func TestCloneRepoOmitsEmptyDirectoryName(t *testing.T) { + msg := CloneRepoMessage{ + Type: TypeCloneRepo, + ID: newUUID(), + Timestamp: "2026-02-15T10:00:00Z", + URL: "git@github.com:user/repo.git", + } + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + var raw map[string]interface{} + json.Unmarshal(data, &raw) + if _, ok := raw["directory_name"]; ok { + t.Error("directory_name should be omitted when empty") + } +} + +func TestUpdateSettingsPartialFields(t *testing.T) { + // Only updating max_iterations and auto_commit. + maxIter := 10 + autoCommit := false + msg := UpdateSettingsMessage{ + Type: TypeUpdateSettings, + ID: newUUID(), + Timestamp: "2026-02-15T10:00:00Z", + Project: "my-project", + MaxIterations: &maxIter, + AutoCommit: &autoCommit, + } + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + var got UpdateSettingsMessage + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if got.MaxIterations == nil || *got.MaxIterations != 10 { + t.Errorf("max_iterations = %v, want 10", got.MaxIterations) + } + if got.AutoCommit == nil || *got.AutoCommit != false { + t.Errorf("auto_commit = %v, want false", got.AutoCommit) + } + if got.CommitPrefix != nil { + t.Errorf("commit_prefix should be nil, got %v", got.CommitPrefix) + } + if got.ClaudeModel != nil { + t.Errorf("claude_model should be nil, got %v", got.ClaudeModel) + } + if got.TestCommand != nil { + t.Errorf("test_command should be nil, got %v", got.TestCommand) + } +} + +func TestSessionTimeoutWarningRoundTrip(t *testing.T) { + msg := SessionTimeoutWarningMessage{ + Type: TypeSessionTimeoutWarning, + ID: newUUID(), + Timestamp: "2026-02-15T10:00:00Z", + SessionID: "session-xyz", + MinutesRemaining: 5, + } + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + var got SessionTimeoutWarningMessage + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if got.MinutesRemaining != 5 { + t.Errorf("minutes_remaining = %d, want 5", got.MinutesRemaining) + } + if got.SessionID != "session-xyz" { + t.Errorf("session_id = %q", got.SessionID) + } +} + +func TestQuotaExhaustedRoundTrip(t *testing.T) { + msg := QuotaExhaustedMessage{ + Type: TypeQuotaExhausted, + ID: newUUID(), + Timestamp: "2026-02-15T10:00:00Z", + Runs: []string{"run-1", "run-2"}, + Sessions: []string{"session-1"}, + } + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + var got QuotaExhaustedMessage + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if len(got.Runs) != 2 { + t.Errorf("runs count = %d, want 2", len(got.Runs)) + } + if len(got.Sessions) != 1 { + t.Errorf("sessions count = %d, want 1", len(got.Sessions)) + } +} + +func TestSettingsRoundTrip(t *testing.T) { + msg := SettingsMessage{ + Type: TypeSettings, + ID: newUUID(), + Timestamp: "2026-02-15T10:00:00Z", + Project: "my-project", + MaxIterations: 5, + AutoCommit: true, + CommitPrefix: "feat:", + ClaudeModel: "claude-opus-4-6", + TestCommand: "go test ./...", + } + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + var got SettingsMessage + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if got.MaxIterations != 5 { + t.Errorf("max_iterations = %d, want 5", got.MaxIterations) + } + if !got.AutoCommit { + t.Error("auto_commit should be true") + } + if got.TestCommand != "go test ./..." { + t.Errorf("test_command = %q", got.TestCommand) + } +} + +func TestLogLinesRoundTrip(t *testing.T) { + msg := LogLinesMessage{ + Type: TypeLogLines, + ID: newUUID(), + Timestamp: "2026-02-15T10:00:00Z", + Project: "my-project", + PRDID: "auth", + StoryID: "US-003", + Lines: []string{"Starting iteration 1...", "Running tests...", "All tests passed."}, + Level: "info", + } + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + var got LogLinesMessage + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if len(got.Lines) != 3 { + t.Fatalf("lines count = %d, want 3", len(got.Lines)) + } + if got.Level != "info" { + t.Errorf("level = %q, want %q", got.Level, "info") + } +} + +func TestRunPausedRoundTrip(t *testing.T) { + msg := RunPausedMessage{ + Type: TypeRunPaused, + ID: newUUID(), + Timestamp: "2026-02-15T10:00:00Z", + Project: "my-project", + PRDID: "auth", + StoryID: "US-003", + Reason: "quota_exhausted", + } + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + var got RunPausedMessage + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if got.Reason != "quota_exhausted" { + t.Errorf("reason = %q, want %q", got.Reason, "quota_exhausted") + } +} + +func TestCloneProgressRoundTrip(t *testing.T) { + msg := CloneProgressMessage{ + Type: TypeCloneProgress, + ID: newUUID(), + Timestamp: "2026-02-15T10:00:00Z", + URL: "git@github.com:user/repo.git", + ProgressText: "Receiving objects: 45%", + Percent: 45, + } + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + var got CloneProgressMessage + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if got.Percent != 45 { + t.Errorf("percent = %d, want 45", got.Percent) + } + if got.ProgressText != "Receiving objects: 45%" { + t.Errorf("progress_text = %q", got.ProgressText) + } +} + +func TestUpdateAvailableRoundTrip(t *testing.T) { + msg := UpdateAvailableMessage{ + Type: TypeUpdateAvailable, + ID: newUUID(), + Timestamp: "2026-02-15T10:00:00Z", + CurrentVersion: "0.5.0", + LatestVersion: "0.5.1", + } + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + var got UpdateAvailableMessage + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if got.CurrentVersion != "0.5.0" { + t.Errorf("current_version = %q", got.CurrentVersion) + } + if got.LatestVersion != "0.5.1" { + t.Errorf("latest_version = %q", got.LatestVersion) + } +} + +func TestPingPongRoundTrip(t *testing.T) { + ping := PingMessage{ + Type: TypePing, + ID: newUUID(), + Timestamp: "2026-02-15T10:00:00Z", + } + + data, err := json.Marshal(ping) + if err != nil { + t.Fatalf("marshal ping: %v", err) + } + + var gotPing PingMessage + if err := json.Unmarshal(data, &gotPing); err != nil { + t.Fatalf("unmarshal ping: %v", err) + } + + if gotPing.Type != TypePing { + t.Errorf("ping type = %q, want %q", gotPing.Type, TypePing) + } + + pong := PongMessage{ + Type: TypePong, + ID: newUUID(), + Timestamp: "2026-02-15T10:00:00Z", + } + + data, err = json.Marshal(pong) + if err != nil { + t.Fatalf("marshal pong: %v", err) + } + + var gotPong PongMessage + if err := json.Unmarshal(data, &gotPong); err != nil { + t.Fatalf("unmarshal pong: %v", err) + } + + if gotPong.Type != TypePong { + t.Errorf("pong type = %q, want %q", gotPong.Type, TypePong) + } +} + +func TestGenericMessageEnvelopeParsing(t *testing.T) { + // Verify that any message can be parsed as the generic Message type for routing. + msg := RunProgressMessage{ + Type: TypeRunProgress, + ID: newUUID(), + Timestamp: "2026-02-15T10:00:00Z", + Project: "my-project", + PRDID: "auth", + StoryID: "US-003", + Status: "running", + Iteration: 2, + Attempt: 1, + } + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + var envelope Message + if err := json.Unmarshal(data, &envelope); err != nil { + t.Fatalf("unmarshal envelope: %v", err) + } + + if envelope.Type != TypeRunProgress { + t.Errorf("type = %q, want %q", envelope.Type, TypeRunProgress) + } + if envelope.ID == "" { + t.Error("id should be set") + } + if envelope.Timestamp == "" { + t.Error("timestamp should be set") + } +} + +func TestClosePRDSessionRoundTrip(t *testing.T) { + msg := ClosePRDSessionMessage{ + Type: TypeClosePRDSession, + ID: newUUID(), + Timestamp: "2026-02-15T10:00:00Z", + SessionID: "session-abc", + Save: true, + } + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + var got ClosePRDSessionMessage + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if !got.Save { + t.Error("save should be true") + } + if got.SessionID != "session-abc" { + t.Errorf("session_id = %q", got.SessionID) + } +} + +func TestSessionExpiredRoundTrip(t *testing.T) { + msg := SessionExpiredMessage{ + Type: TypeSessionExpired, + ID: newUUID(), + Timestamp: "2026-02-15T10:00:00Z", + SessionID: "session-abc", + SavedState: "partial PRD content here", + } + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + var got SessionExpiredMessage + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if got.SavedState != "partial PRD content here" { + t.Errorf("saved_state = %q", got.SavedState) + } +} + +func TestCreateProjectRoundTrip(t *testing.T) { + msg := CreateProjectMessage{ + Type: TypeCreateProject, + ID: newUUID(), + Timestamp: "2026-02-15T10:00:00Z", + Name: "new-project", + GitInit: true, + } + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + var got CreateProjectMessage + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if got.Name != "new-project" { + t.Errorf("name = %q", got.Name) + } + if !got.GitInit { + t.Error("git_init should be true") + } +} + +func TestGetLogsWithOptionalFields(t *testing.T) { + // With all fields. + msg := GetLogsMessage{ + Type: TypeGetLogs, + ID: newUUID(), + Timestamp: "2026-02-15T10:00:00Z", + Project: "my-project", + PRDID: "auth", + StoryID: "US-003", + Lines: 100, + } + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + var got GetLogsMessage + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if got.StoryID != "US-003" { + t.Errorf("story_id = %q", got.StoryID) + } + if got.Lines != 100 { + t.Errorf("lines = %d, want 100", got.Lines) + } + + // Without optional fields. + msg2 := GetLogsMessage{ + Type: TypeGetLogs, + ID: newUUID(), + Timestamp: "2026-02-15T10:00:00Z", + Project: "my-project", + PRDID: "auth", + } + + data2, _ := json.Marshal(msg2) + var raw map[string]interface{} + json.Unmarshal(data2, &raw) + + if _, ok := raw["story_id"]; ok { + t.Error("story_id should be omitted when empty") + } + if v, ok := raw["lines"]; ok && v != float64(0) { + t.Error("lines should be omitted when zero") + } +} + +func TestCloneCompleteRoundTrip(t *testing.T) { + // Success case. + msg := CloneCompleteMessage{ + Type: TypeCloneComplete, + ID: newUUID(), + Timestamp: "2026-02-15T10:00:00Z", + URL: "git@github.com:user/repo.git", + Success: true, + Project: "repo", + } + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + var got CloneCompleteMessage + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if !got.Success { + t.Error("success should be true") + } + if got.Project != "repo" { + t.Errorf("project = %q", got.Project) + } + + // Failure case. + msg2 := CloneCompleteMessage{ + Type: TypeCloneComplete, + ID: newUUID(), + Timestamp: "2026-02-15T10:00:00Z", + URL: "git@github.com:user/repo.git", + Success: false, + Error: "repository not found", + } + + data2, _ := json.Marshal(msg2) + var got2 CloneCompleteMessage + json.Unmarshal(data2, &got2) + + if got2.Success { + t.Error("success should be false") + } + if got2.Error != "repository not found" { + t.Errorf("error = %q", got2.Error) + } +} + +func TestPRDContentRoundTrip(t *testing.T) { + msg := PRDContentMessage{ + Type: TypePRDContent, + ID: newUUID(), + Timestamp: "2026-02-15T10:00:00Z", + Project: "my-project", + PRDID: "auth", + Content: "# Authentication PRD\n\nBuild a login system.", + State: map[string]interface{}{"stories": 5, "completed": 3}, + } + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + var got PRDContentMessage + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if got.Content != "# Authentication PRD\n\nBuild a login system." { + t.Errorf("content = %q", got.Content) + } + if got.PRDID != "auth" { + t.Errorf("prd_id = %q", got.PRDID) + } +} + +func TestProjectListRoundTrip(t *testing.T) { + msg := ProjectListMessage{ + Type: TypeProjectList, + ID: newUUID(), + Timestamp: "2026-02-15T10:00:00Z", + Projects: []ProjectSummary{ + {Name: "project-a", Path: "/home/user/projects/project-a", HasChief: true, Branch: "main"}, + {Name: "project-b", Path: "/home/user/projects/project-b", HasChief: false, Branch: "develop"}, + }, + } + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + var got ProjectListMessage + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if len(got.Projects) != 2 { + t.Fatalf("projects count = %d, want 2", len(got.Projects)) + } + if got.Projects[1].Branch != "develop" { + t.Errorf("projects[1].branch = %q", got.Projects[1].Branch) + } +} + +func TestPRDOutputRoundTrip(t *testing.T) { + msg := PRDOutputMessage{ + Type: TypePRDOutput, + Payload: PRDOutputPayload{ + Content: "Here is the PRD content\n", + SessionID: "session-123", + Project: "my-project", + }, + } + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + var got PRDOutputMessage + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if got.Type != TypePRDOutput { + t.Errorf("type = %q, want %q", got.Type, TypePRDOutput) + } + if got.Payload.SessionID != "session-123" { + t.Errorf("payload.session_id = %q, want %q", got.Payload.SessionID, "session-123") + } + if got.Payload.Project != "my-project" { + t.Errorf("payload.project = %q, want %q", got.Payload.Project, "my-project") + } + if got.Payload.Content != "Here is the PRD content\n" { + t.Errorf("payload.content = %q, want %q", got.Payload.Content, "Here is the PRD content\n") + } +} + +func TestPRDResponseCompleteRoundTrip(t *testing.T) { + msg := PRDResponseCompleteMessage{ + Type: TypePRDResponseComplete, + Payload: PRDResponseCompletePayload{ + SessionID: "session-123", + Project: "my-project", + }, + } + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + var got PRDResponseCompleteMessage + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if got.Type != TypePRDResponseComplete { + t.Errorf("type = %q, want %q", got.Type, TypePRDResponseComplete) + } + if got.Payload.SessionID != "session-123" { + t.Errorf("payload.session_id = %q, want %q", got.Payload.SessionID, "session-123") + } + if got.Payload.Project != "my-project" { + t.Errorf("payload.project = %q, want %q", got.Payload.Project, "my-project") + } +} + +func TestPRDMessageRoundTrip(t *testing.T) { + msg := PRDMessageMessage{ + Type: TypePRDMessage, + ID: newUUID(), + Timestamp: "2026-02-15T10:00:00Z", + Project: "my-project", + SessionID: "session-abc", + Message: "Add OAuth support to the PRD", + } + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + var got PRDMessageMessage + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if got.Message != "Add OAuth support to the PRD" { + t.Errorf("message = %q", got.Message) + } +} diff --git a/internal/ws/ratelimit.go b/internal/ws/ratelimit.go new file mode 100644 index 0000000..5ca70e2 --- /dev/null +++ b/internal/ws/ratelimit.go @@ -0,0 +1,207 @@ +package ws + +import ( + "fmt" + "sync" + "time" +) + +// Rate limiter constants (hardcoded for V1). +const ( + // Global token bucket: 30 burst, 10/second sustained. + globalBurst = 30 + globalRate = 10.0 // tokens per second + globalInterval = time.Second / 10 + + // Expensive operations: 2 per minute. + expensiveLimit = 2 + expensiveWindow = time.Minute + expensiveInterval = expensiveWindow / 2 // 30 seconds between allowed ops +) + +// expensiveTypes are message types with stricter per-type rate limits. +var expensiveTypes = map[string]bool{ + TypeCloneRepo: true, + TypeStartRun: true, + TypeNewPRD: true, +} + +// exemptTypes are message types exempt from rate limiting. +var exemptTypes = map[string]bool{ + TypePing: true, +} + +// IsExpensiveType returns true if the message type has a stricter per-type rate limit. +func IsExpensiveType(msgType string) bool { + return expensiveTypes[msgType] +} + +// IsExemptType returns true if the message type is exempt from rate limiting. +func IsExemptType(msgType string) bool { + return exemptTypes[msgType] +} + +// tokenBucket implements a simple token bucket rate limiter. +type tokenBucket struct { + tokens float64 + capacity float64 + rate float64 // tokens per second + lastTime time.Time +} + +func newTokenBucket(capacity float64, rate float64) *tokenBucket { + return &tokenBucket{ + tokens: capacity, + capacity: capacity, + rate: rate, + lastTime: time.Now(), + } +} + +// allow checks if a token is available and consumes one if so. +// Returns true if allowed, false if rate limited. +func (tb *tokenBucket) allow(now time.Time) bool { + elapsed := now.Sub(tb.lastTime).Seconds() + tb.tokens += elapsed * tb.rate + if tb.tokens > tb.capacity { + tb.tokens = tb.capacity + } + tb.lastTime = now + + if tb.tokens >= 1 { + tb.tokens-- + return true + } + return false +} + +// retryAfter returns the duration until the next token is available. +func (tb *tokenBucket) retryAfter() time.Duration { + if tb.tokens >= 1 { + return 0 + } + needed := 1.0 - tb.tokens + return time.Duration(needed / tb.rate * float64(time.Second)) +} + +// expensiveTracker tracks per-type rate limiting for expensive operations. +type expensiveTracker struct { + timestamps []time.Time + limit int + window time.Duration +} + +func newExpensiveTracker(limit int, window time.Duration) *expensiveTracker { + return &expensiveTracker{ + limit: limit, + window: window, + } +} + +// allow checks if the operation is allowed within the rate limit window. +func (et *expensiveTracker) allow(now time.Time) bool { + // Remove expired timestamps + cutoff := now.Add(-et.window) + valid := et.timestamps[:0] + for _, ts := range et.timestamps { + if ts.After(cutoff) { + valid = append(valid, ts) + } + } + et.timestamps = valid + + if len(et.timestamps) >= et.limit { + return false + } + et.timestamps = append(et.timestamps, now) + return true +} + +// retryAfter returns the duration until the next operation would be allowed. +func (et *expensiveTracker) retryAfter(now time.Time) time.Duration { + if len(et.timestamps) < et.limit { + return 0 + } + oldest := et.timestamps[0] + return oldest.Add(et.window).Sub(now) +} + +// RateLimiter provides rate limiting for incoming WebSocket messages. +type RateLimiter struct { + mu sync.Mutex + global *tokenBucket + expensive map[string]*expensiveTracker +} + +// NewRateLimiter creates a new rate limiter with default settings. +func NewRateLimiter() *RateLimiter { + return &RateLimiter{ + global: newTokenBucket(globalBurst, globalRate), + expensive: make(map[string]*expensiveTracker), + } +} + +// RateLimitResult contains the result of a rate limit check. +type RateLimitResult struct { + Allowed bool + RetryAfter time.Duration +} + +// Allow checks if a message of the given type should be allowed. +// Returns RateLimitResult indicating whether the message is allowed and retry-after hint. +func (rl *RateLimiter) Allow(msgType string) RateLimitResult { + // Exempt types bypass all rate limiting + if IsExemptType(msgType) { + return RateLimitResult{Allowed: true} + } + + rl.mu.Lock() + defer rl.mu.Unlock() + + now := time.Now() + + // Check expensive operation limit first + if IsExpensiveType(msgType) { + tracker, ok := rl.expensive[msgType] + if !ok { + tracker = newExpensiveTracker(expensiveLimit, expensiveWindow) + rl.expensive[msgType] = tracker + } + if !tracker.allow(now) { + retryAfter := tracker.retryAfter(now) + return RateLimitResult{ + Allowed: false, + RetryAfter: retryAfter, + } + } + } + + // Check global rate limit + if !rl.global.allow(now) { + retryAfter := rl.global.retryAfter() + return RateLimitResult{ + Allowed: false, + RetryAfter: retryAfter, + } + } + + return RateLimitResult{Allowed: true} +} + +// Reset clears all rate limiter state (called on reconnection). +func (rl *RateLimiter) Reset() { + rl.mu.Lock() + defer rl.mu.Unlock() + + rl.global = newTokenBucket(globalBurst, globalRate) + rl.expensive = make(map[string]*expensiveTracker) +} + +// FormatRetryAfter returns a human-readable retry-after string. +func FormatRetryAfter(d time.Duration) string { + secs := int(d.Seconds()) + 1 // round up + if secs <= 0 { + secs = 1 + } + return fmt.Sprintf("%ds", secs) +} diff --git a/internal/ws/ratelimit_test.go b/internal/ws/ratelimit_test.go new file mode 100644 index 0000000..a6015da --- /dev/null +++ b/internal/ws/ratelimit_test.go @@ -0,0 +1,293 @@ +package ws + +import ( + "testing" + "time" +) + +func TestRateLimiter_AllowsNormalMessages(t *testing.T) { + rl := NewRateLimiter() + + // Should allow a burst of messages up to the burst limit + for i := 0; i < globalBurst; i++ { + result := rl.Allow(TypeGetProject) + if !result.Allowed { + t.Fatalf("expected message %d to be allowed within burst", i+1) + } + } +} + +func TestRateLimiter_BlocksAfterBurstExhausted(t *testing.T) { + rl := NewRateLimiter() + + // Exhaust the burst + for i := 0; i < globalBurst; i++ { + rl.Allow(TypeGetProject) + } + + // Next message should be blocked + result := rl.Allow(TypeGetProject) + if result.Allowed { + t.Fatal("expected message to be blocked after burst exhausted") + } + if result.RetryAfter <= 0 { + t.Error("expected positive retry-after duration") + } +} + +func TestRateLimiter_RefillsTokensOverTime(t *testing.T) { + rl := NewRateLimiter() + + // Exhaust the burst + for i := 0; i < globalBurst; i++ { + rl.Allow(TypeGetProject) + } + + // Manually advance the token bucket's last time to simulate time passing + rl.mu.Lock() + rl.global.lastTime = time.Now().Add(-200 * time.Millisecond) // should refill ~2 tokens + rl.mu.Unlock() + + result := rl.Allow(TypeGetProject) + if !result.Allowed { + t.Fatal("expected message to be allowed after token refill") + } +} + +func TestRateLimiter_PingExempt(t *testing.T) { + rl := NewRateLimiter() + + // Exhaust the burst + for i := 0; i < globalBurst+10; i++ { + rl.Allow(TypeGetProject) + } + + // Ping should still be allowed + result := rl.Allow(TypePing) + if !result.Allowed { + t.Fatal("expected ping to be exempt from rate limiting") + } +} + +func TestRateLimiter_ExpensiveOperationsLimited(t *testing.T) { + rl := NewRateLimiter() + + // First two clone_repo should be allowed + result1 := rl.Allow(TypeCloneRepo) + result2 := rl.Allow(TypeCloneRepo) + if !result1.Allowed || !result2.Allowed { + t.Fatal("expected first two expensive operations to be allowed") + } + + // Third should be blocked + result3 := rl.Allow(TypeCloneRepo) + if result3.Allowed { + t.Fatal("expected third expensive operation to be blocked") + } + if result3.RetryAfter <= 0 { + t.Error("expected positive retry-after for expensive operation") + } +} + +func TestRateLimiter_ExpensiveTypesIndependent(t *testing.T) { + rl := NewRateLimiter() + + // Use up clone_repo limit + rl.Allow(TypeCloneRepo) + rl.Allow(TypeCloneRepo) + + // start_run should still be allowed (independent tracker) + result := rl.Allow(TypeStartRun) + if !result.Allowed { + t.Fatal("expected start_run to be allowed independently of clone_repo") + } + + // new_prd should also be allowed + result = rl.Allow(TypeNewPRD) + if !result.Allowed { + t.Fatal("expected new_prd to be allowed independently") + } +} + +func TestRateLimiter_ExpensiveWindowExpiry(t *testing.T) { + rl := NewRateLimiter() + + // Use up the limit + rl.Allow(TypeStartRun) + rl.Allow(TypeStartRun) + + // Should be blocked + result := rl.Allow(TypeStartRun) + if result.Allowed { + t.Fatal("expected to be blocked") + } + + // Simulate time passing beyond the window + rl.mu.Lock() + tracker := rl.expensive[TypeStartRun] + for i := range tracker.timestamps { + tracker.timestamps[i] = time.Now().Add(-expensiveWindow - time.Second) + } + rl.mu.Unlock() + + // Should be allowed again + result = rl.Allow(TypeStartRun) + if !result.Allowed { + t.Fatal("expected to be allowed after window expires") + } +} + +func TestRateLimiter_Reset(t *testing.T) { + rl := NewRateLimiter() + + // Exhaust burst and expensive limits + for i := 0; i < globalBurst+5; i++ { + rl.Allow(TypeGetProject) + } + rl.Allow(TypeCloneRepo) + rl.Allow(TypeCloneRepo) + + // Verify blocked + result := rl.Allow(TypeGetProject) + if result.Allowed { + t.Fatal("expected blocked before reset") + } + result = rl.Allow(TypeCloneRepo) + if result.Allowed { + t.Fatal("expected expensive blocked before reset") + } + + // Reset + rl.Reset() + + // Should be allowed again + result = rl.Allow(TypeGetProject) + if !result.Allowed { + t.Fatal("expected allowed after reset") + } + result = rl.Allow(TypeCloneRepo) + if !result.Allowed { + t.Fatal("expected expensive allowed after reset") + } +} + +func TestRateLimiter_ExpensiveAlsoConsumesGlobal(t *testing.T) { + rl := NewRateLimiter() + + // Exhaust global bucket + for i := 0; i < globalBurst; i++ { + rl.Allow(TypeGetProject) + } + + // Expensive operation should be blocked by global limit even though + // the expensive tracker would allow it + result := rl.Allow(TypeCloneRepo) + if result.Allowed { + t.Fatal("expected expensive operation to be blocked by global limit") + } +} + +func TestFormatRetryAfter(t *testing.T) { + tests := []struct { + d time.Duration + want string + }{ + {100 * time.Millisecond, "1s"}, + {500 * time.Millisecond, "1s"}, + {1500 * time.Millisecond, "2s"}, + {30 * time.Second, "31s"}, + {0, "1s"}, + } + + for _, tt := range tests { + got := FormatRetryAfter(tt.d) + if got != tt.want { + t.Errorf("FormatRetryAfter(%v) = %q, want %q", tt.d, got, tt.want) + } + } +} + +func TestIsExpensiveType(t *testing.T) { + if !IsExpensiveType(TypeCloneRepo) { + t.Error("expected clone_repo to be expensive") + } + if !IsExpensiveType(TypeStartRun) { + t.Error("expected start_run to be expensive") + } + if !IsExpensiveType(TypeNewPRD) { + t.Error("expected new_prd to be expensive") + } + if IsExpensiveType(TypeGetProject) { + t.Error("expected get_project to NOT be expensive") + } + if IsExpensiveType(TypePing) { + t.Error("expected ping to NOT be expensive") + } +} + +func TestIsExemptType(t *testing.T) { + if !IsExemptType(TypePing) { + t.Error("expected ping to be exempt") + } + if IsExemptType(TypeGetProject) { + t.Error("expected get_project to NOT be exempt") + } +} + +func TestTokenBucket_RetryAfter(t *testing.T) { + tb := newTokenBucket(1, 10) // 1 burst, 10/sec + + // Use the token + now := time.Now() + tb.allow(now) + + // Should need ~100ms for next token + retryAfter := tb.retryAfter() + if retryAfter <= 0 { + t.Error("expected positive retry-after") + } + if retryAfter > 200*time.Millisecond { + t.Errorf("retry-after too large: %v", retryAfter) + } +} + +func TestExpensiveTracker_RetryAfter(t *testing.T) { + et := newExpensiveTracker(2, time.Minute) + + now := time.Now() + et.allow(now) + et.allow(now.Add(time.Second)) + + // Should be blocked with retry-after pointing to when oldest expires + retryAfter := et.retryAfter(now.Add(2 * time.Second)) + if retryAfter <= 0 { + t.Error("expected positive retry-after") + } + // Should be about 58 seconds (60 - 2 seconds elapsed) + if retryAfter < 55*time.Second || retryAfter > 61*time.Second { + t.Errorf("unexpected retry-after: %v", retryAfter) + } +} + +func TestRateLimiter_ConcurrentAccess(t *testing.T) { + rl := NewRateLimiter() + + done := make(chan struct{}) + for i := 0; i < 10; i++ { + go func() { + defer func() { done <- struct{}{} }() + for j := 0; j < 100; j++ { + rl.Allow(TypeGetProject) + rl.Allow(TypeCloneRepo) + rl.Allow(TypePing) + } + }() + } + + for i := 0; i < 10; i++ { + <-done + } + + // Just ensure no panics or races + rl.Reset() +}