diff --git a/packages/appkit/src/index.ts b/packages/appkit/src/index.ts index a5e6f50c..3e603e47 100644 --- a/packages/appkit/src/index.ts +++ b/packages/appkit/src/index.ts @@ -47,7 +47,7 @@ export { } from "./errors"; // Plugin authoring export { Plugin, type ToPlugin, toPlugin } from "./plugin"; -export { analytics, lakebase, server } from "./plugins"; +export { analytics, genie, lakebase, server } from "./plugins"; // Registry types and utilities for plugin manifests export type { ConfigSchema, diff --git a/packages/appkit/src/plugins/genie/defaults.ts b/packages/appkit/src/plugins/genie/defaults.ts new file mode 100644 index 00000000..7d2655bc --- /dev/null +++ b/packages/appkit/src/plugins/genie/defaults.ts @@ -0,0 +1,19 @@ +import type { StreamExecutionSettings } from "shared"; + +export const genieStreamDefaults: StreamExecutionSettings = { + default: { + // Cache disabled: chat messages are conversational and stateful, not repeatable queries. + cache: { + enabled: false, + }, + // Retry disabled: Genie calls are not idempotent (retries could create duplicate + // conversations/messages), and the SDK Waiter already handles transient polling failures. + retry: { + enabled: false, + }, + timeout: 120_000, + }, + stream: { + bufferSize: 100, + }, +}; diff --git a/packages/appkit/src/plugins/genie/genie.ts b/packages/appkit/src/plugins/genie/genie.ts new file mode 100644 index 00000000..8c9e9ee5 --- /dev/null +++ b/packages/appkit/src/plugins/genie/genie.ts @@ -0,0 +1,491 @@ +import { randomUUID } from "node:crypto"; +import { Time, TimeUnits } from "@databricks/sdk-experimental"; +import type { + GenieMessage, + GenieStartConversationResponse, +} from "@databricks/sdk-experimental/dist/apis/dashboards"; +import type { Waiter } from "@databricks/sdk-experimental/dist/wait"; +import type express from "express"; +import type { IAppRouter, StreamExecutionSettings } from "shared"; +import { getWorkspaceClient } from "../../context"; +import { createLogger } from "../../logging"; +import { Plugin, toPlugin } from "../../plugin"; +import { genieStreamDefaults } from "./defaults"; +import { genieManifest } from "./manifest"; +import { pollWaiter } from "./poll-waiter"; +import type { + GenieAttachmentResponse, + GenieConversationHistoryResponse, + GenieMessageResponse, + GenieSendMessageRequest, + GenieStreamEvent, + IGenieConfig, +} from "./types"; + +const logger = createLogger("genie"); + +type StartConversationWaiter = Waiter< + GenieStartConversationResponse, + GenieMessage +>; +type CreateMessageWaiter = Waiter; + +/** Extract our cleaned attachment response from a raw SDK GenieMessage */ +function mapAttachments(message: GenieMessage): GenieAttachmentResponse[] { + return ( + message.attachments?.map((att) => ({ + attachmentId: att.attachment_id, + query: att.query + ? { + title: att.query.title, + description: att.query.description, + query: att.query.query, + statementId: att.query.statement_id, + } + : undefined, + text: att.text ? { content: att.text.content } : undefined, + suggestedQuestions: att.suggested_questions?.questions, + })) ?? [] + ); +} + +/** Build a GenieMessageResponse from a raw SDK GenieMessage */ +function toMessageResponse(message: GenieMessage): GenieMessageResponse { + return { + messageId: message.message_id, + conversationId: message.conversation_id, + spaceId: message.space_id, + status: message.status ?? "COMPLETED", + content: message.content, + attachments: mapAttachments(message), + error: message.error?.error, + }; +} + +export class GeniePlugin extends Plugin { + name = "genie"; + + static manifest = genieManifest; + + protected static description = + "AI/BI Genie space integration for natural language data queries"; + protected declare config: IGenieConfig; + + constructor(config: IGenieConfig) { + super(config); + this.config = { + ...config, + spaces: config.spaces ?? this.defaultSpaces(), + }; + } + + private defaultSpaces(): Record { + const spaceId = process.env.DATABRICKS_GENIE_SPACE_ID; + return spaceId ? { default: spaceId } : {}; + } + + private resolveSpaceId(alias: string): string | null { + return this.config.spaces?.[alias] ?? null; + } + + injectRoutes(router: IAppRouter) { + this.route(router, { + name: "sendMessage", + method: "post", + path: "/:alias/messages", + handler: async (req: express.Request, res: express.Response) => { + await this.asUser(req)._handleSendMessage(req, res); + }, + }); + + this.route(router, { + name: "getConversation", + method: "get", + path: "/:alias/conversations/:conversationId", + handler: async (req: express.Request, res: express.Response) => { + await this.asUser(req)._handleGetConversation(req, res); + }, + }); + } + + async _handleSendMessage( + req: express.Request, + res: express.Response, + ): Promise { + const { alias } = req.params; + const spaceId = this.resolveSpaceId(alias); + + if (!spaceId) { + res.status(404).json({ error: `Unknown space alias: ${alias}` }); + return; + } + + const { content, conversationId } = req.body as GenieSendMessageRequest; + + if (!content) { + res.status(400).json({ error: "content is required" }); + return; + } + + logger.debug( + "Sending message to space %s (alias=%s, conversationId=%s)", + spaceId, + alias, + conversationId ?? "new", + ); + + const timeout = this.config.timeout ?? 120_000; + const requestId = (req.query.requestId as string) || randomUUID(); + + const streamSettings: StreamExecutionSettings = { + ...genieStreamDefaults, + default: { + ...genieStreamDefaults.default, + // timeout: 0 means indefinite (no TimeoutInterceptor) + timeout, + }, + stream: { + ...genieStreamDefaults.stream, + streamId: requestId, + }, + }; + + await this.executeStream( + res, + async function* () { + const workspaceClient = getWorkspaceClient(); + + try { + // Step 1: API call → get waiter + IDs + let messageWaiter: CreateMessageWaiter; + let resultConversationId: string; + let resultMessageId: string; + + if (conversationId) { + messageWaiter = await workspaceClient.genie.createMessage({ + space_id: spaceId, + conversation_id: conversationId, + content, + }); + resultConversationId = conversationId; + resultMessageId = messageWaiter.message_id ?? ""; + } else { + const startWaiter: StartConversationWaiter = + await workspaceClient.genie.startConversation({ + space_id: spaceId, + content, + }); + resultConversationId = startWaiter.conversation_id; + resultMessageId = startWaiter.message_id; + messageWaiter = startWaiter as unknown as CreateMessageWaiter; + } + + // Step 2: Yield message_start immediately — IDs are available from API response + yield { + type: "message_start" as const, + conversationId: resultConversationId, + messageId: resultMessageId, + spaceId, + }; + + // Step 3: Poll for status updates and completion + let completedMessage!: GenieMessage; + for await (const event of pollWaiter(messageWaiter)) { + if (event.type === "progress") { + if (event.value.status) { + yield { type: "status" as const, status: event.value.status }; + } + } else { + completedMessage = event.value; + resultMessageId = event.value.message_id; + } + } + + // Step 4: Build cleaned message response + const messageResponse = toMessageResponse(completedMessage); + + yield { + type: "message_result" as const, + message: messageResponse, + }; + + // Step 5: Fetch query results for each query attachment + const attachments = messageResponse.attachments ?? []; + for (const att of attachments) { + if (att.query?.statementId && att.attachmentId) { + try { + const queryResult = + await workspaceClient.genie.getMessageAttachmentQueryResult({ + space_id: spaceId, + conversation_id: resultConversationId, + message_id: resultMessageId, + attachment_id: att.attachmentId, + }); + + yield { + type: "query_result" as const, + attachmentId: att.attachmentId, + statementId: att.query.statementId, + data: queryResult.statement_response, + }; + } catch (error) { + logger.error( + "Failed to fetch query result for attachment %s: %O", + att.attachmentId, + error, + ); + yield { + type: "error" as const, + error: `Failed to fetch query result for attachment ${att.attachmentId}`, + }; + } + } + } + } catch (error) { + logger.error("Genie message error: %O", error); + yield { + type: "error" as const, + error: + error instanceof Error ? error.message : "Genie request failed", + }; + } + }, + streamSettings, + ); + } + + private async _fetchAllMessages( + spaceId: string, + conversationId: string, + ): Promise { + const workspaceClient = getWorkspaceClient(); + const allMessages: GenieMessage[] = []; + let pageToken: string | undefined; + const maxMessages = 200; + + do { + const response = await workspaceClient.genie.listConversationMessages({ + space_id: spaceId, + conversation_id: conversationId, + page_size: 100, + ...(pageToken ? { page_token: pageToken } : {}), + }); + + if (response.messages) { + allMessages.push(...response.messages); + } + + pageToken = response.next_page_token; + } while (pageToken && allMessages.length < maxMessages); + + return allMessages.slice(0, maxMessages); + } + + async _handleGetConversation( + req: express.Request, + res: express.Response, + ): Promise { + const { alias, conversationId } = req.params; + const spaceId = this.resolveSpaceId(alias); + + if (!spaceId) { + res.status(404).json({ error: `Unknown space alias: ${alias}` }); + return; + } + + const includeQueryResults = req.query.includeQueryResults !== "false"; + const requestId = (req.query.requestId as string) || randomUUID(); + + logger.debug( + "Fetching conversation %s from space %s (alias=%s, includeQueryResults=%s)", + conversationId, + spaceId, + alias, + includeQueryResults, + ); + + const self = this; + + const streamSettings: StreamExecutionSettings = { + ...genieStreamDefaults, + stream: { + ...genieStreamDefaults.stream, + streamId: requestId, + }, + }; + + await this.executeStream( + res, + async function* () { + try { + const messages = await self._fetchAllMessages( + spaceId, + conversationId, + ); + + const messageResponses: GenieMessageResponse[] = []; + + for (const message of messages) { + const messageResponse = toMessageResponse(message); + messageResponses.push(messageResponse); + + yield { + type: "message_result" as const, + message: messageResponse, + }; + } + + if (includeQueryResults) { + // Collect all query attachments across all messages + const queryAttachments: Array<{ + messageId: string; + attachmentId: string; + statementId: string; + }> = []; + + for (const msg of messageResponses) { + for (const att of msg.attachments ?? []) { + if (att.query?.statementId && att.attachmentId) { + queryAttachments.push({ + messageId: msg.messageId, + attachmentId: att.attachmentId, + statementId: att.query.statementId, + }); + } + } + } + + // Fetch all query results in parallel + const workspaceClient = getWorkspaceClient(); + const results = await Promise.allSettled( + queryAttachments.map(async (att) => { + const queryResult = + await workspaceClient.genie.getMessageAttachmentQueryResult({ + space_id: spaceId, + conversation_id: conversationId, + message_id: att.messageId, + attachment_id: att.attachmentId, + }); + return { + attachmentId: att.attachmentId, + statementId: att.statementId, + data: queryResult.statement_response, + }; + }), + ); + + for (const result of results) { + if (result.status === "fulfilled") { + yield { + type: "query_result" as const, + attachmentId: result.value.attachmentId, + statementId: result.value.statementId, + data: result.value.data, + }; + } else { + logger.error("Failed to fetch query result: %O", result.reason); + yield { + type: "error" as const, + error: + result.reason instanceof Error + ? result.reason.message + : "Failed to fetch query result", + }; + } + } + } + } catch (error) { + logger.error("Genie getConversation error: %O", error); + yield { + type: "error" as const, + error: + error instanceof Error + ? error.message + : "Failed to fetch conversation", + }; + } + }, + streamSettings, + ); + } + + async getConversation( + alias: string, + conversationId: string, + ): Promise { + const spaceId = this.resolveSpaceId(alias); + if (!spaceId) { + throw new Error(`Unknown space alias: ${alias}`); + } + + const messages = await this._fetchAllMessages(spaceId, conversationId); + + return { + conversationId, + spaceId, + messages: messages.map(toMessageResponse), + }; + } + + async sendMessage( + alias: string, + content: string, + conversationId?: string, + ): Promise { + const spaceId = this.resolveSpaceId(alias); + if (!spaceId) { + throw new Error(`Unknown space alias: ${alias}`); + } + + const workspaceClient = getWorkspaceClient(); + const timeout = this.config.timeout ?? 120_000; + + let messageWaiter: CreateMessageWaiter; + let resultConversationId: string; + + if (conversationId) { + messageWaiter = await workspaceClient.genie.createMessage({ + space_id: spaceId, + conversation_id: conversationId, + content, + }); + resultConversationId = conversationId; + } else { + const startWaiter: StartConversationWaiter = + await workspaceClient.genie.startConversation({ + space_id: spaceId, + content, + }); + resultConversationId = startWaiter.conversation_id; + messageWaiter = startWaiter as unknown as CreateMessageWaiter; + } + + const waitOptions = + timeout > 0 ? { timeout: new Time(timeout, TimeUnits.milliseconds) } : {}; + const completedMessage = await messageWaiter.wait(waitOptions); + + return { + ...toMessageResponse(completedMessage), + conversationId: resultConversationId, + }; + } + + async shutdown(): Promise { + this.streamManager.abortAll(); + } + + exports() { + return { + sendMessage: this.sendMessage, + getConversation: this.getConversation, + }; + } +} + +/** + * @internal + */ +export const genie = toPlugin( + GeniePlugin, + "genie", +); diff --git a/packages/appkit/src/plugins/genie/index.ts b/packages/appkit/src/plugins/genie/index.ts new file mode 100644 index 00000000..6726262f --- /dev/null +++ b/packages/appkit/src/plugins/genie/index.ts @@ -0,0 +1,3 @@ +export * from "./genie"; +export * from "./manifest"; +export * from "./types"; diff --git a/packages/appkit/src/plugins/genie/manifest.json b/packages/appkit/src/plugins/genie/manifest.json new file mode 100644 index 00000000..a269795d --- /dev/null +++ b/packages/appkit/src/plugins/genie/manifest.json @@ -0,0 +1,43 @@ +{ + "name": "genie", + "displayName": "Genie Plugin", + "description": "AI/BI Genie space integration for natural language data queries", + "resources": { + "required": [ + { + "type": "genie_space", + "alias": "Genie Space", + "resourceKey": "genie-space", + "description": "Genie Space for AI-powered data queries. Space IDs configured via plugin config.", + "permission": "CAN_RUN", + "fields": { + "id": { + "env": "DATABRICKS_GENIE_SPACE_ID", + "description": "Default Genie Space ID" + } + } + } + ], + "optional": [] + }, + "config": { + "schema": { + "type": "object", + "properties": { + "spaces": { + "type": "object", + "description": "Map of alias names to Genie Space IDs", + "additionalProperties": { + "type": "string" + } + }, + "timeout": { + "type": "number", + "default": 120000, + "description": "Genie polling timeout in ms. Set to 0 for indefinite." + } + }, + "required": ["spaces"] + } + } +} diff --git a/packages/appkit/src/plugins/genie/manifest.ts b/packages/appkit/src/plugins/genie/manifest.ts new file mode 100644 index 00000000..cf3d98fb --- /dev/null +++ b/packages/appkit/src/plugins/genie/manifest.ts @@ -0,0 +1,10 @@ +import { readFileSync } from "node:fs"; +import { dirname, join } from "node:path"; +import { fileURLToPath } from "node:url"; +import type { PluginManifest } from "../../registry"; + +const __dirname = dirname(fileURLToPath(import.meta.url)); + +export const genieManifest: PluginManifest = JSON.parse( + readFileSync(join(__dirname, "manifest.json"), "utf-8"), +) as PluginManifest; diff --git a/packages/appkit/src/plugins/genie/poll-waiter.ts b/packages/appkit/src/plugins/genie/poll-waiter.ts new file mode 100644 index 00000000..ab29bbdd --- /dev/null +++ b/packages/appkit/src/plugins/genie/poll-waiter.ts @@ -0,0 +1,92 @@ +/** + * Structural interface matching the SDK's `Waiter.wait()` shape + * without importing the SDK directly. + */ +export interface Pollable

{ + wait(options?: { + onProgress?: (p: P) => Promise; + timeout?: unknown; + }): Promise

; +} + +export type PollEvent

= + | { type: "progress"; value: P } + | { type: "completed"; value: P }; + +/** + * Bridges a callback-based waiter into an async generator. + * + * The SDK's `waiter.wait({ onProgress })` API uses a callback to report + * progress and returns a promise that resolves with the final result. + * This function converts that push-based model into a pull-based async + * generator so callers can simply `for await (const event of pollWaiter(w))`. + * + * Yields `{ type: "progress", value }` for each `onProgress` callback, + * then `{ type: "completed", value }` for the final result. + * Throws if the waiter rejects. + */ +export async function* pollWaiter

( + waiter: Pollable

, + options?: { timeout?: unknown }, +): AsyncGenerator> { + // --- shared state between the onProgress callback and the generator loop --- + const queue: P[] = []; // progress values waiting to be yielded + let notify: () => void = () => {}; // resolves the generator's "sleep" promise + let done = false; // true once waiter.wait() settles (success or error) + let result!: P; + let error: unknown = null; + + // Start the waiter in the background (not awaited — runs concurrently + // with the generator loop below). The onProgress callback pushes values + // into the queue and wakes the generator via notify(). + waiter + .wait({ + onProgress: async (p: P) => { + queue.push(p); + notify(); + }, + ...(options?.timeout != null ? { timeout: options.timeout } : {}), + }) + .then((r) => { + result = r; + done = true; + notify(); + }) + .catch((err) => { + error = err; + done = true; + notify(); + }); + + // Drain progress events as they arrive. The loop exits once the waiter + // has settled AND the queue is empty. + while (!done || queue.length > 0) { + // Yield all queued progress values before sleeping. + while (queue.length > 0) { + const value = queue.shift() as P; + yield { type: "progress", value }; + } + + // Nothing in the queue yet and the waiter hasn't settled — sleep until + // the next onProgress call or waiter settlement wakes us via notify(). + // + // Race-condition guard: after setting `notify = resolve`, we re-check + // `done` and `queue.length`. If either changed between the outer while + // check and this point (possible via microtask), we resolve immediately + // so the loop doesn't hang. + if (!done) { + await new Promise((resolve) => { + notify = resolve; + if (done || queue.length > 0) resolve(); + }); + } + } + + // The waiter settled. If it rejected, propagate the error. + if (error !== null) { + throw error; + } + + // Final event: the completed result from waiter.wait(). + yield { type: "completed", value: result }; +} diff --git a/packages/appkit/src/plugins/genie/tests/genie.test.ts b/packages/appkit/src/plugins/genie/tests/genie.test.ts new file mode 100644 index 00000000..58acfcfe --- /dev/null +++ b/packages/appkit/src/plugins/genie/tests/genie.test.ts @@ -0,0 +1,1099 @@ +import { + createMockRequest, + createMockResponse, + createMockRouter, + mockServiceContext, + setupDatabricksEnv, +} from "@tools/test-helpers"; +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { ServiceContext } from "../../../context/service-context"; +import { Plugin } from "../../../plugin"; +import { GeniePlugin, genie } from "../genie"; +import type { IGenieConfig } from "../types"; + +// Mock CacheManager singleton +const { mockCacheInstance } = vi.hoisted(() => { + const instance = { + get: vi.fn(), + set: vi.fn(), + delete: vi.fn(), + getOrExecute: vi + .fn() + .mockImplementation( + async (_key: unknown[], fn: () => Promise) => { + return await fn(); + }, + ), + generateKey: vi.fn((...args: unknown[]) => JSON.stringify(args)), + }; + + return { mockCacheInstance: instance }; +}); + +vi.mock("../../../cache", () => ({ + CacheManager: { + getInstanceSync: vi.fn(() => mockCacheInstance), + }, +})); + +function createMockGenieService() { + const getMessageAttachmentQueryResult = vi.fn(); + + const createWaiter = ( + conversationId: string, + messageId: string, + attachments: any[] = [], + status = "COMPLETED", + ) => ({ + wait: vi.fn().mockImplementation(async ({ onProgress }: any) => { + if (onProgress) { + await onProgress({ status: "ASKING_AI" }); + await onProgress({ status: "EXECUTING_QUERY" }); + } + return { + message_id: messageId, + conversation_id: conversationId, + space_id: "test-space-id", + content: "Here are your results", + status, + attachments, + error: undefined, + }; + }), + }); + + const startConversation = vi.fn().mockImplementation(async () => ({ + conversation_id: "new-conv-id", + message_id: "new-msg-id", + ...createWaiter("new-conv-id", "new-msg-id", [ + { + attachment_id: "att-1", + query: { + title: "Top Customers", + description: "Query for top customers", + query: "SELECT * FROM customers", + statement_id: "stmt-1", + }, + }, + ]), + })); + + const createMessage = vi.fn().mockImplementation(async () => + createWaiter("existing-conv-id", "followup-msg-id", [ + { + attachment_id: "att-2", + query: { + title: "Follow-up Query", + query: "SELECT * FROM orders", + statement_id: "stmt-2", + }, + }, + ]), + ); + + const listConversationMessages = vi.fn(); + + return { + startConversation, + createMessage, + getMessageAttachmentQueryResult, + listConversationMessages, + createWaiter, + }; +} + +describe("Genie Plugin", () => { + let config: IGenieConfig; + let serviceContextMock: Awaited>; + let mockGenieService: ReturnType; + + beforeEach(async () => { + config = { + spaces: { + myspace: "test-space-id", + salesbot: "sales-space-id", + }, + timeout: 5000, + }; + setupDatabricksEnv(); + ServiceContext.reset(); + + mockGenieService = createMockGenieService(); + + mockGenieService.getMessageAttachmentQueryResult.mockResolvedValue({ + statement_response: { + status: { state: "SUCCEEDED" }, + result: { + data_array: [ + ["Acme Corp", "1000000"], + ["Globex", "500000"], + ], + }, + manifest: { + schema: { + columns: [ + { name: "customer", type_name: "STRING" }, + { name: "revenue", type_name: "DECIMAL" }, + ], + }, + }, + }, + }); + + serviceContextMock = await mockServiceContext({ + userDatabricksClient: { + genie: mockGenieService, + }, + }); + }); + + afterEach(() => { + serviceContextMock?.restore(); + }); + + test("genie factory should have correct name", () => { + const pluginData = genie({ spaces: { test: "id" } }); + expect(pluginData.name).toBe("genie"); + }); + + test("plugin instance should be created with correct name", () => { + const plugin = new GeniePlugin(config); + expect(plugin.name).toBe("genie"); + }); + + describe("injectRoutes", () => { + test("should register POST and GET routes", () => { + const plugin = new GeniePlugin(config); + const { router } = createMockRouter(); + + plugin.injectRoutes(router); + + expect(router.post).toHaveBeenCalledTimes(1); + expect(router.post).toHaveBeenCalledWith( + "/:alias/messages", + expect.any(Function), + ); + + expect(router.get).toHaveBeenCalledTimes(1); + expect(router.get).toHaveBeenCalledWith( + "/:alias/conversations/:conversationId", + expect.any(Function), + ); + }); + }); + + describe("space alias resolution", () => { + test("should return 404 for unknown alias", async () => { + const plugin = new GeniePlugin(config); + const { router, getHandler } = createMockRouter(); + + plugin.injectRoutes(router); + + const handler = getHandler("POST", "/:alias/messages"); + const mockReq = createMockRequest({ + params: { alias: "unknown" }, + body: { content: "test question" }, + headers: { + "x-forwarded-access-token": "user-token", + "x-forwarded-user": "user-1", + }, + }); + const mockRes = createMockResponse(); + + await handler(mockReq, mockRes); + + expect(mockRes.status).toHaveBeenCalledWith(404); + expect(mockRes.json).toHaveBeenCalledWith({ + error: "Unknown space alias: unknown", + }); + }); + + test("should resolve valid alias", async () => { + const plugin = new GeniePlugin(config); + const { router, getHandler } = createMockRouter(); + + plugin.injectRoutes(router); + + const handler = getHandler("POST", "/:alias/messages"); + const mockReq = createMockRequest({ + params: { alias: "myspace" }, + body: { content: "What are my top customers?" }, + headers: { + "x-forwarded-access-token": "user-token", + "x-forwarded-user": "user-1", + }, + }); + const mockRes = createMockResponse(); + + await handler(mockReq, mockRes); + + expect(mockRes.status).not.toHaveBeenCalledWith(404); + expect(mockGenieService.startConversation).toHaveBeenCalledWith( + expect.objectContaining({ + space_id: "test-space-id", + content: "What are my top customers?", + }), + ); + }); + }); + + describe("validation", () => { + test("should return 400 when content is missing", async () => { + const plugin = new GeniePlugin(config); + const { router, getHandler } = createMockRouter(); + + plugin.injectRoutes(router); + + const handler = getHandler("POST", "/:alias/messages"); + const mockReq = createMockRequest({ + params: { alias: "myspace" }, + body: {}, + headers: { + "x-forwarded-access-token": "user-token", + "x-forwarded-user": "user-1", + }, + }); + const mockRes = createMockResponse(); + + await handler(mockReq, mockRes); + + expect(mockRes.status).toHaveBeenCalledWith(400); + expect(mockRes.json).toHaveBeenCalledWith({ + error: "content is required", + }); + }); + }); + + describe("send message - new conversation", () => { + test("should call startConversation and stream SSE events", async () => { + const plugin = new GeniePlugin(config); + const { router, getHandler } = createMockRouter(); + + plugin.injectRoutes(router); + + const handler = getHandler("POST", "/:alias/messages"); + const mockReq = createMockRequest({ + params: { alias: "myspace" }, + body: { content: "What are my top customers?" }, + headers: { + "x-forwarded-access-token": "user-token", + "x-forwarded-user": "user-1", + }, + }); + const mockRes = createMockResponse(); + + await handler(mockReq, mockRes); + + expect(mockGenieService.startConversation).toHaveBeenCalledWith( + expect.objectContaining({ + space_id: "test-space-id", + content: "What are my top customers?", + }), + ); + + // Verify SSE headers + expect(mockRes.setHeader).toHaveBeenCalledWith( + "Content-Type", + "text/event-stream", + ); + expect(mockRes.setHeader).toHaveBeenCalledWith( + "Cache-Control", + "no-cache", + ); + + // Verify SSE events are written + const writeCalls = mockRes.write.mock.calls.map((call: any[]) => call[0]); + const allWritten = writeCalls.join(""); + + // Should have message_start event + expect(allWritten).toContain("message_start"); + expect(allWritten).toContain("new-conv-id"); + + // Should have status events + expect(allWritten).toContain("status"); + expect(allWritten).toContain("ASKING_AI"); + + // Should have message_result event + expect(allWritten).toContain("message_result"); + + // Should have query_result event + expect(allWritten).toContain("query_result"); + + expect(mockRes.end).toHaveBeenCalled(); + }); + }); + + describe("send message - follow-up", () => { + test("should call createMessage with conversationId", async () => { + const plugin = new GeniePlugin(config); + const { router, getHandler } = createMockRouter(); + + plugin.injectRoutes(router); + + const handler = getHandler("POST", "/:alias/messages"); + const mockReq = createMockRequest({ + params: { alias: "myspace" }, + body: { + content: "Show me more details", + conversationId: "existing-conv-id", + }, + headers: { + "x-forwarded-access-token": "user-token", + "x-forwarded-user": "user-1", + }, + }); + const mockRes = createMockResponse(); + + await handler(mockReq, mockRes); + + expect(mockGenieService.createMessage).toHaveBeenCalledWith( + expect.objectContaining({ + space_id: "test-space-id", + conversation_id: "existing-conv-id", + content: "Show me more details", + }), + ); + + expect(mockGenieService.startConversation).not.toHaveBeenCalled(); + + const writeCalls = mockRes.write.mock.calls.map((call: any[]) => call[0]); + const allWritten = writeCalls.join(""); + + expect(allWritten).toContain("message_start"); + expect(allWritten).toContain("existing-conv-id"); + expect(mockRes.end).toHaveBeenCalled(); + }); + }); + + describe("multiple attachments", () => { + test("should yield query_result for each query attachment", async () => { + // Override startConversation to return multiple query attachments + mockGenieService.startConversation.mockImplementation(async () => ({ + conversation_id: "multi-conv-id", + message_id: "multi-msg-id", + wait: vi.fn().mockImplementation(async ({ onProgress }: any) => { + if (onProgress) { + await onProgress({ status: "ASKING_AI" }); + } + return { + message_id: "multi-msg-id", + conversation_id: "multi-conv-id", + space_id: "test-space-id", + content: "Here are two queries", + status: "COMPLETED", + attachments: [ + { + attachment_id: "att-q1", + query: { + title: "Query 1", + query: "SELECT 1", + statement_id: "stmt-q1", + }, + }, + { + attachment_id: "att-q2", + query: { + title: "Query 2", + query: "SELECT 2", + statement_id: "stmt-q2", + }, + }, + { + attachment_id: "att-text", + text: { content: "Some explanation" }, + }, + ], + }; + }), + })); + + mockGenieService.getMessageAttachmentQueryResult + .mockResolvedValueOnce({ + statement_response: { result: { data: [["row1"]] } }, + }) + .mockResolvedValueOnce({ + statement_response: { result: { data: [["row2"]] } }, + }); + + const plugin = new GeniePlugin(config); + const { router, getHandler } = createMockRouter(); + + plugin.injectRoutes(router); + + const handler = getHandler("POST", "/:alias/messages"); + const mockReq = createMockRequest({ + params: { alias: "myspace" }, + body: { content: "Run two queries" }, + headers: { + "x-forwarded-access-token": "user-token", + "x-forwarded-user": "user-1", + }, + }); + const mockRes = createMockResponse(); + + await handler(mockReq, mockRes); + + // getMessageAttachmentQueryResult should be called twice (once per query attachment) + expect( + mockGenieService.getMessageAttachmentQueryResult, + ).toHaveBeenCalledTimes(2); + + expect( + mockGenieService.getMessageAttachmentQueryResult, + ).toHaveBeenCalledWith( + expect.objectContaining({ attachment_id: "att-q1" }), + ); + expect( + mockGenieService.getMessageAttachmentQueryResult, + ).toHaveBeenCalledWith( + expect.objectContaining({ attachment_id: "att-q2" }), + ); + + const writeCalls = mockRes.write.mock.calls.map((call: any[]) => call[0]); + const allWritten = writeCalls.join(""); + + // Should have two query_result events + const queryResultCount = (allWritten.match(/query_result/g) || []).length; + expect(queryResultCount).toBeGreaterThanOrEqual(2); + + expect(mockRes.end).toHaveBeenCalled(); + }); + }); + + describe("error handling", () => { + test("should yield error event on SDK failure", async () => { + mockGenieService.startConversation.mockRejectedValue( + new Error("Genie service unavailable"), + ); + + const plugin = new GeniePlugin(config); + const { router, getHandler } = createMockRouter(); + + plugin.injectRoutes(router); + + const handler = getHandler("POST", "/:alias/messages"); + const mockReq = createMockRequest({ + params: { alias: "myspace" }, + body: { content: "test question" }, + headers: { + "x-forwarded-access-token": "user-token", + "x-forwarded-user": "user-1", + }, + }); + const mockRes = createMockResponse(); + + await handler(mockReq, mockRes); + + const writeCalls = mockRes.write.mock.calls.map((call: any[]) => call[0]); + const allWritten = writeCalls.join(""); + + expect(allWritten).toContain("error"); + expect(allWritten).toContain("Genie service unavailable"); + + expect(mockRes.end).toHaveBeenCalled(); + }); + }); + + describe("getConversation", () => { + function createConversationRequest(overrides: Record = {}) { + return createMockRequest({ + params: { alias: "myspace", conversationId: "conv-123" }, + query: {}, + headers: { + "x-forwarded-access-token": "user-token", + "x-forwarded-user": "user-1", + }, + ...overrides, + }); + } + + function mockMessages(messages: any[]) { + mockGenieService.listConversationMessages.mockResolvedValue({ + messages, + next_page_token: undefined, + }); + } + + test("should return 404 for unknown alias", async () => { + const plugin = new GeniePlugin(config); + const { router, getHandler } = createMockRouter(); + + plugin.injectRoutes(router); + + const handler = getHandler( + "GET", + "/:alias/conversations/:conversationId", + ); + const mockReq = createConversationRequest({ + params: { alias: "unknown", conversationId: "conv-123" }, + }); + const mockRes = createMockResponse(); + + await handler(mockReq, mockRes); + + expect(mockRes.status).toHaveBeenCalledWith(404); + expect(mockRes.json).toHaveBeenCalledWith({ + error: "Unknown space alias: unknown", + }); + }); + + test("should stream message_result events for each message", async () => { + mockMessages([ + { + message_id: "msg-1", + conversation_id: "conv-123", + space_id: "test-space-id", + content: "What are the top customers?", + status: "COMPLETED", + attachments: [], + }, + { + message_id: "msg-2", + conversation_id: "conv-123", + space_id: "test-space-id", + content: "Here are the results", + status: "COMPLETED", + attachments: [ + { + attachment_id: "att-1", + query: { + title: "Top Customers", + query: "SELECT * FROM customers", + statement_id: "stmt-1", + }, + }, + ], + }, + ]); + + const plugin = new GeniePlugin(config); + const { router, getHandler } = createMockRouter(); + + plugin.injectRoutes(router); + + const handler = getHandler( + "GET", + "/:alias/conversations/:conversationId", + ); + const mockReq = createConversationRequest(); + const mockRes = createMockResponse(); + + await handler(mockReq, mockRes); + + expect(mockGenieService.listConversationMessages).toHaveBeenCalledWith( + expect.objectContaining({ + space_id: "test-space-id", + conversation_id: "conv-123", + page_size: 100, + }), + ); + + const writeCalls = mockRes.write.mock.calls.map((call: any[]) => call[0]); + const allWritten = writeCalls.join(""); + + // Should have two message_result events + const messageResultCount = ( + allWritten.match(/"type":"message_result"/g) || [] + ).length; + expect(messageResultCount).toBe(2); + + // Should contain message content + expect(allWritten).toContain("What are the top customers?"); + expect(allWritten).toContain("Here are the results"); + + expect(mockRes.end).toHaveBeenCalled(); + }); + + test("should stream query_result events when includeQueryResults is true (default)", async () => { + mockMessages([ + { + message_id: "msg-1", + conversation_id: "conv-123", + space_id: "test-space-id", + content: "Results", + status: "COMPLETED", + attachments: [ + { + attachment_id: "att-1", + query: { + title: "Query 1", + query: "SELECT 1", + statement_id: "stmt-1", + }, + }, + ], + }, + ]); + + const plugin = new GeniePlugin(config); + const { router, getHandler } = createMockRouter(); + + plugin.injectRoutes(router); + + const handler = getHandler( + "GET", + "/:alias/conversations/:conversationId", + ); + const mockReq = createConversationRequest(); + const mockRes = createMockResponse(); + + await handler(mockReq, mockRes); + + expect( + mockGenieService.getMessageAttachmentQueryResult, + ).toHaveBeenCalledWith( + expect.objectContaining({ + space_id: "test-space-id", + conversation_id: "conv-123", + message_id: "msg-1", + attachment_id: "att-1", + }), + ); + + const writeCalls = mockRes.write.mock.calls.map((call: any[]) => call[0]); + const allWritten = writeCalls.join(""); + + expect(allWritten).toContain("message_result"); + expect(allWritten).toContain("query_result"); + expect(mockRes.end).toHaveBeenCalled(); + }); + + test("should NOT stream query_result events when includeQueryResults is false", async () => { + mockMessages([ + { + message_id: "msg-1", + conversation_id: "conv-123", + space_id: "test-space-id", + content: "Results", + status: "COMPLETED", + attachments: [ + { + attachment_id: "att-1", + query: { + title: "Query 1", + query: "SELECT 1", + statement_id: "stmt-1", + }, + }, + ], + }, + ]); + + const plugin = new GeniePlugin(config); + const { router, getHandler } = createMockRouter(); + + plugin.injectRoutes(router); + + const handler = getHandler( + "GET", + "/:alias/conversations/:conversationId", + ); + const mockReq = createConversationRequest({ + query: { includeQueryResults: "false" }, + }); + const mockRes = createMockResponse(); + + await handler(mockReq, mockRes); + + expect( + mockGenieService.getMessageAttachmentQueryResult, + ).not.toHaveBeenCalled(); + + const writeCalls = mockRes.write.mock.calls.map((call: any[]) => call[0]); + const allWritten = writeCalls.join(""); + + expect(allWritten).toContain("message_result"); + expect(allWritten).not.toContain("query_result"); + expect(mockRes.end).toHaveBeenCalled(); + }); + + test("should paginate through all messages", async () => { + mockGenieService.listConversationMessages + .mockResolvedValueOnce({ + messages: [ + { + message_id: "msg-1", + conversation_id: "conv-123", + space_id: "test-space-id", + content: "Page 1 message", + status: "COMPLETED", + attachments: [], + }, + ], + next_page_token: "page-2-token", + }) + .mockResolvedValueOnce({ + messages: [ + { + message_id: "msg-2", + conversation_id: "conv-123", + space_id: "test-space-id", + content: "Page 2 message", + status: "COMPLETED", + attachments: [], + }, + ], + next_page_token: undefined, + }); + + const plugin = new GeniePlugin(config); + const { router, getHandler } = createMockRouter(); + + plugin.injectRoutes(router); + + const handler = getHandler( + "GET", + "/:alias/conversations/:conversationId", + ); + const mockReq = createConversationRequest({ + query: { includeQueryResults: "false" }, + }); + const mockRes = createMockResponse(); + + await handler(mockReq, mockRes); + + expect(mockGenieService.listConversationMessages).toHaveBeenCalledTimes( + 2, + ); + + // First call without page_token + expect(mockGenieService.listConversationMessages).toHaveBeenNthCalledWith( + 1, + expect.objectContaining({ + space_id: "test-space-id", + conversation_id: "conv-123", + page_size: 100, + }), + ); + + // Second call with page_token + expect(mockGenieService.listConversationMessages).toHaveBeenNthCalledWith( + 2, + expect.objectContaining({ + space_id: "test-space-id", + conversation_id: "conv-123", + page_size: 100, + page_token: "page-2-token", + }), + ); + + const writeCalls = mockRes.write.mock.calls.map((call: any[]) => call[0]); + const allWritten = writeCalls.join(""); + + expect(allWritten).toContain("Page 1 message"); + expect(allWritten).toContain("Page 2 message"); + expect(mockRes.end).toHaveBeenCalled(); + }); + + test("should handle empty conversation", async () => { + mockMessages([]); + + const plugin = new GeniePlugin(config); + const { router, getHandler } = createMockRouter(); + + plugin.injectRoutes(router); + + const handler = getHandler( + "GET", + "/:alias/conversations/:conversationId", + ); + const mockReq = createConversationRequest(); + const mockRes = createMockResponse(); + + await handler(mockReq, mockRes); + + const writeCalls = mockRes.write.mock.calls.map((call: any[]) => call[0]); + const allWritten = writeCalls.join(""); + + expect(allWritten).not.toContain("message_result"); + expect(allWritten).not.toContain("query_result"); + expect(mockRes.end).toHaveBeenCalled(); + }); + + test("should yield error event on SDK failure", async () => { + mockGenieService.listConversationMessages.mockRejectedValue( + new Error("Conversation not found"), + ); + + const plugin = new GeniePlugin(config); + const { router, getHandler } = createMockRouter(); + + plugin.injectRoutes(router); + + const handler = getHandler( + "GET", + "/:alias/conversations/:conversationId", + ); + const mockReq = createConversationRequest(); + const mockRes = createMockResponse(); + + await handler(mockReq, mockRes); + + const writeCalls = mockRes.write.mock.calls.map((call: any[]) => call[0]); + const allWritten = writeCalls.join(""); + + expect(allWritten).toContain("error"); + expect(allWritten).toContain("Conversation not found"); + expect(mockRes.end).toHaveBeenCalled(); + }); + + test("should fetch query results in parallel for multiple attachments across messages", async () => { + mockMessages([ + { + message_id: "msg-1", + conversation_id: "conv-123", + space_id: "test-space-id", + content: "First query", + status: "COMPLETED", + attachments: [ + { + attachment_id: "att-1", + query: { + title: "Query 1", + query: "SELECT 1", + statement_id: "stmt-1", + }, + }, + ], + }, + { + message_id: "msg-2", + conversation_id: "conv-123", + space_id: "test-space-id", + content: "Second query", + status: "COMPLETED", + attachments: [ + { + attachment_id: "att-2", + query: { + title: "Query 2", + query: "SELECT 2", + statement_id: "stmt-2", + }, + }, + ], + }, + ]); + + const plugin = new GeniePlugin(config); + const { router, getHandler } = createMockRouter(); + + plugin.injectRoutes(router); + + const handler = getHandler( + "GET", + "/:alias/conversations/:conversationId", + ); + const mockReq = createConversationRequest(); + const mockRes = createMockResponse(); + + await handler(mockReq, mockRes); + + expect( + mockGenieService.getMessageAttachmentQueryResult, + ).toHaveBeenCalledTimes(2); + + expect( + mockGenieService.getMessageAttachmentQueryResult, + ).toHaveBeenCalledWith( + expect.objectContaining({ + message_id: "msg-1", + attachment_id: "att-1", + }), + ); + expect( + mockGenieService.getMessageAttachmentQueryResult, + ).toHaveBeenCalledWith( + expect.objectContaining({ + message_id: "msg-2", + attachment_id: "att-2", + }), + ); + + const writeCalls = mockRes.write.mock.calls.map((call: any[]) => call[0]); + const allWritten = writeCalls.join(""); + + const queryResultCount = ( + allWritten.match(/"type":"query_result"/g) || [] + ).length; + expect(queryResultCount).toBe(2); + expect(mockRes.end).toHaveBeenCalled(); + }); + }); + + describe("default spaces from DATABRICKS_GENIE_SPACE_ID", () => { + test("should use env var as default space when spaces is omitted", async () => { + process.env.DATABRICKS_GENIE_SPACE_ID = "env-space-id"; + + const plugin = new GeniePlugin({ timeout: 5000 }); + const { router, getHandler } = createMockRouter(); + + plugin.injectRoutes(router); + + const handler = getHandler("POST", "/:alias/messages"); + const mockReq = createMockRequest({ + params: { alias: "default" }, + body: { content: "test question" }, + headers: { + "x-forwarded-access-token": "user-token", + "x-forwarded-user": "user-1", + }, + }); + const mockRes = createMockResponse(); + + await handler(mockReq, mockRes); + + expect(mockRes.status).not.toHaveBeenCalledWith(404); + expect(mockGenieService.startConversation).toHaveBeenCalledWith( + expect.objectContaining({ + space_id: "env-space-id", + content: "test question", + }), + ); + + delete process.env.DATABRICKS_GENIE_SPACE_ID; + }); + + test("should 404 for any alias when spaces is omitted and env var is unset", async () => { + delete process.env.DATABRICKS_GENIE_SPACE_ID; + + const plugin = new GeniePlugin({ timeout: 5000 }); + const { router, getHandler } = createMockRouter(); + + plugin.injectRoutes(router); + + const handler = getHandler("POST", "/:alias/messages"); + const mockReq = createMockRequest({ + params: { alias: "default" }, + body: { content: "test question" }, + headers: { + "x-forwarded-access-token": "user-token", + "x-forwarded-user": "user-1", + }, + }); + const mockRes = createMockResponse(); + + await handler(mockReq, mockRes); + + expect(mockRes.status).toHaveBeenCalledWith(404); + expect(mockRes.json).toHaveBeenCalledWith({ + error: "Unknown space alias: default", + }); + }); + }); + + describe("SSE reconnection streamId", () => { + let executeStreamSpy: ReturnType; + + beforeEach(() => { + executeStreamSpy = vi.spyOn(Plugin.prototype as any, "executeStream"); + executeStreamSpy.mockResolvedValue(undefined); + }); + + afterEach(() => { + executeStreamSpy.mockRestore(); + }); + + test("sendMessage should use requestId query param as streamId", async () => { + const plugin = new GeniePlugin(config); + const { router, getHandler } = createMockRouter(); + + plugin.injectRoutes(router); + + const handler = getHandler("POST", "/:alias/messages"); + const mockReq = createMockRequest({ + params: { alias: "myspace" }, + query: { requestId: "req-uuid-123" }, + body: { + content: "follow-up question", + conversationId: "conv-42", + }, + headers: { + "x-forwarded-access-token": "user-token", + "x-forwarded-user": "user-1", + }, + }); + const mockRes = createMockResponse(); + + await handler(mockReq, mockRes); + + expect(executeStreamSpy).toHaveBeenCalledWith( + mockRes, + expect.any(Function), + expect.objectContaining({ + stream: expect.objectContaining({ + streamId: "req-uuid-123", + bufferSize: 100, + }), + }), + ); + }); + + test("sendMessage without requestId should generate a random streamId", async () => { + const plugin = new GeniePlugin(config); + const { router, getHandler } = createMockRouter(); + + plugin.injectRoutes(router); + + const handler = getHandler("POST", "/:alias/messages"); + const mockReq = createMockRequest({ + params: { alias: "myspace" }, + body: { content: "new question" }, + headers: { + "x-forwarded-access-token": "user-token", + "x-forwarded-user": "user-1", + }, + }); + const mockRes = createMockResponse(); + + await handler(mockReq, mockRes); + + expect(executeStreamSpy).toHaveBeenCalledWith( + mockRes, + expect.any(Function), + expect.objectContaining({ + stream: expect.objectContaining({ + streamId: expect.stringMatching( + /^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$/, + ), + bufferSize: 100, + }), + }), + ); + }); + + test("getConversation should use requestId query param as streamId", async () => { + const plugin = new GeniePlugin(config); + const { router, getHandler } = createMockRouter(); + + plugin.injectRoutes(router); + + const handler = getHandler( + "GET", + "/:alias/conversations/:conversationId", + ); + const mockReq = createMockRequest({ + params: { alias: "myspace", conversationId: "conv-99" }, + query: { requestId: "req-uuid-456" }, + headers: { + "x-forwarded-access-token": "user-token", + "x-forwarded-user": "user-1", + }, + }); + const mockRes = createMockResponse(); + + await handler(mockReq, mockRes); + + expect(executeStreamSpy).toHaveBeenCalledWith( + mockRes, + expect.any(Function), + expect.objectContaining({ + stream: expect.objectContaining({ + streamId: "req-uuid-456", + bufferSize: 100, + }), + }), + ); + }); + }); +}); diff --git a/packages/appkit/src/plugins/genie/tests/poll-waiter.test.ts b/packages/appkit/src/plugins/genie/tests/poll-waiter.test.ts new file mode 100644 index 00000000..751f2e8a --- /dev/null +++ b/packages/appkit/src/plugins/genie/tests/poll-waiter.test.ts @@ -0,0 +1,125 @@ +import { describe, expect, test, vi } from "vitest"; +import { type Pollable, type PollEvent, pollWaiter } from "../poll-waiter"; + +function createMockWaiter

(opts: { + progressValues?: P[]; + result: P; + error?: Error; + delay?: number; +}): Pollable

{ + return { + wait: vi.fn().mockImplementation(async (options: any = {}) => { + if (opts.progressValues) { + for (const value of opts.progressValues) { + if (opts.delay) { + await new Promise((r) => setTimeout(r, opts.delay)); + } + if (options.onProgress) { + await options.onProgress(value); + } + } + } + if (opts.error) throw opts.error; + return opts.result; + }), + }; +} + +async function collect

( + gen: AsyncGenerator>, +): Promise[]> { + const events: PollEvent

[] = []; + for await (const event of gen) { + events.push(event); + } + return events; +} + +describe("pollWaiter", () => { + test("yields progress events then completed", async () => { + const waiter = createMockWaiter({ + progressValues: [{ status: "A" }, { status: "B" }], + result: { status: "DONE" }, + }); + + const events = await collect(pollWaiter(waiter)); + + expect(events).toEqual([ + { type: "progress", value: { status: "A" } }, + { type: "progress", value: { status: "B" } }, + { type: "completed", value: { status: "DONE" } }, + ]); + }); + + test("yields only completed when no progress events", async () => { + const waiter = createMockWaiter({ + result: { value: 42 }, + }); + + const events = await collect(pollWaiter(waiter)); + + expect(events).toEqual([{ type: "completed", value: { value: 42 } }]); + }); + + test("throws when waiter rejects", async () => { + const waiter = createMockWaiter({ + result: null as any, + error: new Error("boom"), + }); + + const events: PollEvent[] = []; + await expect(async () => { + for await (const event of pollWaiter(waiter)) { + events.push(event); + } + }).rejects.toThrow("boom"); + + expect(events).toEqual([]); + }); + + test("throws after yielding progress if waiter fails mid-poll", async () => { + const waiter = createMockWaiter({ + progressValues: [{ status: "A" }], + result: null as any, + error: new Error("mid-poll failure"), + }); + + const events: PollEvent[] = []; + await expect(async () => { + for await (const event of pollWaiter(waiter)) { + events.push(event); + } + }).rejects.toThrow("mid-poll failure"); + + expect(events).toEqual([{ type: "progress", value: { status: "A" } }]); + }); + + test("handles async delays between progress callbacks", async () => { + const waiter = createMockWaiter({ + progressValues: [{ n: 1 }, { n: 2 }, { n: 3 }], + result: { n: 99 }, + delay: 10, + }); + + const events = await collect(pollWaiter(waiter)); + + expect(events).toHaveLength(4); + expect(events[0]).toEqual({ type: "progress", value: { n: 1 } }); + expect(events[1]).toEqual({ type: "progress", value: { n: 2 } }); + expect(events[2]).toEqual({ type: "progress", value: { n: 3 } }); + expect(events[3]).toEqual({ type: "completed", value: { n: 99 } }); + }); + + test("passes timeout option through to waiter.wait()", async () => { + const waiter = createMockWaiter({ + result: { done: true }, + }); + + const timeoutValue = { ms: 5000 }; + await collect(pollWaiter(waiter, { timeout: timeoutValue })); + + expect(waiter.wait).toHaveBeenCalledWith( + expect.objectContaining({ timeout: timeoutValue }), + ); + }); +}); diff --git a/packages/appkit/src/plugins/genie/types.ts b/packages/appkit/src/plugins/genie/types.ts new file mode 100644 index 00000000..ebb0debd --- /dev/null +++ b/packages/appkit/src/plugins/genie/types.ts @@ -0,0 +1,60 @@ +import type { BasePluginConfig } from "shared"; + +export interface IGenieConfig extends BasePluginConfig { + /** Map of alias → Genie Space ID. Defaults to { default: DATABRICKS_GENIE_SPACE_ID } if omitted. */ + spaces?: Record; + /** Genie polling timeout in ms. Set to 0 for indefinite. Default: 120000 (2 min) */ + timeout?: number; +} + +export interface GenieSendMessageRequest { + content: string; + conversationId?: string; +} + +/** SSE event discriminated union */ +export type GenieStreamEvent = + | { + type: "message_start"; + conversationId: string; + messageId: string; + spaceId: string; + } + | { type: "status"; status: string } + | { type: "message_result"; message: GenieMessageResponse } + | { + type: "query_result"; + attachmentId: string; + statementId: string; + data: unknown; + } + | { type: "error"; error: string }; + +/** Cleaned response — subset of SDK's GenieMessage */ +export interface GenieMessageResponse { + messageId: string; + conversationId: string; + spaceId: string; + status: string; + content: string; + attachments?: GenieAttachmentResponse[]; + error?: string; +} + +export interface GenieConversationHistoryResponse { + conversationId: string; + spaceId: string; + messages: GenieMessageResponse[]; +} + +export interface GenieAttachmentResponse { + attachmentId?: string; + query?: { + title?: string; + description?: string; + query?: string; + statementId?: string; + }; + text?: { content?: string }; + suggestedQuestions?: string[]; +} diff --git a/packages/appkit/src/plugins/index.ts b/packages/appkit/src/plugins/index.ts index f6a9e2c5..fafd11eb 100644 --- a/packages/appkit/src/plugins/index.ts +++ b/packages/appkit/src/plugins/index.ts @@ -1,3 +1,4 @@ export * from "./analytics"; +export * from "./genie"; export * from "./lakebase"; export * from "./server"; diff --git a/packages/appkit/tsdown.config.ts b/packages/appkit/tsdown.config.ts index 32600ee7..5fa71e97 100644 --- a/packages/appkit/tsdown.config.ts +++ b/packages/appkit/tsdown.config.ts @@ -42,6 +42,10 @@ export default defineConfig([ from: "src/plugins/analytics/manifest.json", to: "dist/plugins/analytics/manifest.json", }, + { + from: "src/plugins/genie/manifest.json", + to: "dist/plugins/genie/manifest.json", + }, { from: "src/plugins/lakebase/manifest.json", to: "dist/plugins/lakebase/manifest.json",