diff --git a/.github/actions/spelling/allow.txt b/.github/actions/spelling/allow.txt index 525dae910..d59515930 100644 --- a/.github/actions/spelling/allow.txt +++ b/.github/actions/spelling/allow.txt @@ -92,6 +92,7 @@ openapiv2 opensource otherurl pb2 +poolclass postgres POSTGRES postgresql diff --git a/alembic.ini b/alembic.ini new file mode 100644 index 000000000..58249b073 --- /dev/null +++ b/alembic.ini @@ -0,0 +1,45 @@ +# A generic, single database configuration. + +[alembic] + +# database URL. This is consumed by the user-maintained env.py script only. +# other means of configuring database URLs may be customized within the env.py +# file. +# IMPORTANT: This is a placeholder and an example, and should be replaced with your actual database URL. +sqlalchemy.url = sqlite+aiosqlite:///./test.db + + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARNING +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARNING +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/alembic/README b/alembic/README new file mode 100644 index 000000000..06ec9e9a8 --- /dev/null +++ b/alembic/README @@ -0,0 +1,59 @@ +# Database Migrations with Alembic + +This directory contains database migration scripts for the A2A SDK, managed by [Alembic](https://alembic.sqlalchemy.org/). + +## Configuration + +- `alembic.ini`: Global configuration for Alembic, including the database URL. +- `env.py`: Python script that runs when the Alembic environment is invoked. It configures the SQLAlchemy engine and connects it to the migration context. +- `versions/`: Directory containing individual migration scripts. + +## Common Commands + +All commands should be run from the project root using `uv run`. + +### Viewing Status +```bash +# View current migration version of the database +uv run alembic current + +# View migration history +uv run alembic history --verbose +``` + +### Running Migrations +```bash +# Upgrade to the latest version +uv run alembic upgrade head + +# Downgrade by one version +uv run alembic downgrade -1 + +# Revert all migrations +uv run alembic downgrade base +``` + +### Creating Migrations +```bash +# Create a new migration manually +uv run alembic revision -m "description of changes" + +# Create a new migration automatically (detects changes in models.py) +uv run alembic revision --autogenerate -m "description of changes" +``` + +## Troubleshooting + +### "duplicate column name" error +If you see an error like `sqlalchemy.exc.OperationalError: (sqlite3.OperationalError) duplicate column name: owner`, it usually means the column was already created (perhaps by `Base.metadata.create_all()` in tests or development) but Alembic doesn't know about it yet. + +To fix this, "stamp" the database to tell Alembic it is already at the latest version: +```bash +uv run alembic stamp head +``` + +## How to add a new migration +1. Modify the models in `src/a2a/server/models.py`. +2. Run `uv run alembic revision --autogenerate -m "Add new field to Task"`. +3. Review the generated script in `alembic/versions/`. +4. Apply the migration with `uv run alembic upgrade head`. diff --git a/alembic/__init__.py b/alembic/__init__.py new file mode 100644 index 000000000..7b55fb93e --- /dev/null +++ b/alembic/__init__.py @@ -0,0 +1 @@ +"Alembic database migration package." diff --git a/alembic/env.py b/alembic/env.py new file mode 100644 index 000000000..07864de4d --- /dev/null +++ b/alembic/env.py @@ -0,0 +1,96 @@ +import asyncio + +from logging.config import fileConfig + +from sqlalchemy import Connection, pool +from sqlalchemy.ext.asyncio import async_engine_from_config + +from a2a.server.models import Base +from alembic import context + + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +# add your model's MetaData object here for 'autogenerate' support +target_metadata = Base.metadata + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") # noqa: ERA001 +# ... etc. + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = config.get_main_option('sqlalchemy.url') + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={'paramstyle': 'named'}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def do_run_migrations(connection: Connection) -> None: + """Run migrations in 'online' mode. + + This function is called within a synchronous context (via run_sync) + to configure the migration context with the provided connection + and target metadata, then execute the migrations within a transaction. + + Args: + connection: The SQLAlchemy connection to use for the migrations. + """ + context.configure(connection=connection, target_metadata=target_metadata) + + with context.begin_transaction(): + context.run_migrations() + + +async def run_async_migrations() -> None: + """Run migrations using an Engine. + + In this scenario we need to create an Engine + and associate a connection with the context. + """ + connectable = async_engine_from_config( + config.get_section(config.config_ini_section), + prefix='sqlalchemy.', + poolclass=pool.NullPool, + ) + + async with connectable.connect() as connection: + await connection.run_sync(do_run_migrations) + + await connectable.dispose() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode.""" + asyncio.run(run_async_migrations()) + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/alembic/script.py.mako b/alembic/script.py.mako new file mode 100644 index 000000000..11016301e --- /dev/null +++ b/alembic/script.py.mako @@ -0,0 +1,28 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + """Upgrade schema.""" + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + """Downgrade schema.""" + ${downgrades if downgrades else "pass"} diff --git a/alembic/versions/6419d2d130f6_add_owner_to_task.py b/alembic/versions/6419d2d130f6_add_owner_to_task.py new file mode 100644 index 000000000..6e2ede603 --- /dev/null +++ b/alembic/versions/6419d2d130f6_add_owner_to_task.py @@ -0,0 +1,38 @@ +"""add_owner_to_task. + +Revision ID: 6419d2d130f6 +Revises: +Create Date: 2026-02-17 09:23:06.758085 + +""" + +from collections.abc import Sequence + +import sqlalchemy as sa + +from alembic import op + + +# revision identifiers, used by Alembic. +revision: str = '6419d2d130f6' +down_revision: str | Sequence[str] | None = None +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Upgrade schema.""" + op.add_column( + 'tasks', + sa.Column( + 'owner', + sa.String(255), + nullable=False, + server_default='unknown', # Set your desired default value here + ), + ) + + +def downgrade() -> None: + """Downgrade schema.""" + op.drop_column('tasks', 'owner') diff --git a/alembic/versions/__init__.py b/alembic/versions/__init__.py new file mode 100644 index 000000000..574828c67 --- /dev/null +++ b/alembic/versions/__init__.py @@ -0,0 +1 @@ +"""Alembic migrations scripts for the A2A project.""" diff --git a/pyproject.toml b/pyproject.toml index f5b02ab65..7e3b6a2f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -106,6 +106,7 @@ style = "pep440" [dependency-groups] dev = [ + "alembic>=1.14.0", "mypy>=1.15.0", "PyJWT>=2.0.0", "pytest>=8.3.5", @@ -347,3 +348,89 @@ docstring-code-format = true docstring-code-line-length = "dynamic" quote-style = "single" indent-style = "space" + + +[tool.alembic] + +# path to migration scripts. +# this is typically a path given in POSIX (e.g. forward slashes) +# format, relative to the token %(here)s which refers to the location of this +# ini file +script_location = "%(here)s/alembic" + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +# file_template = "%%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s" +# Or organize into date-based subdirectories (requires recursive_version_locations = true) +# file_template = "%%(year)d/%%(month).2d/%%(day).2d_%%(hour).2d%%(minute).2d_%%(second).2d_%%(rev)s_%%(slug)s" + +# additional paths to be prepended to sys.path. defaults to the current working directory. +prepend_sys_path = [ + "." +] + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the tzdata library which can be installed by adding +# `alembic[tz]` to the pip requirements. +# string value is passed to ZoneInfo() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to /versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# version_locations = [ +# "%(here)s/alembic/versions", +# "%(here)s/foo/bar" +# ] + + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = "utf-8" + +# This section defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples +# [[tool.alembic.post_write_hooks]] +# format using "black" - use the console_scripts runner, +# against the "black" entrypoint +# name = "black" +# type = "console_scripts" +# entrypoint = "black" +# options = "-l 79 REVISION_SCRIPT_FILENAME" +# +# [[tool.alembic.post_write_hooks]] +# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module +# name = "ruff" +# type = "module" +# module = "ruff" +# options = "check --fix REVISION_SCRIPT_FILENAME" +# +# [[tool.alembic.post_write_hooks]] +# Alternatively, use the exec runner to execute a binary found on your PATH +# name = "ruff" +# type = "exec" +# executable = "ruff" +# options = "check --fix REVISION_SCRIPT_FILENAME" + diff --git a/src/a2a/server/models.py b/src/a2a/server/models.py index b8e1904ed..a7e80d81c 100644 --- a/src/a2a/server/models.py +++ b/src/a2a/server/models.py @@ -18,12 +18,7 @@ def override(func): # noqa: ANN001, ANN201 try: - from sqlalchemy import ( - JSON, - Dialect, - LargeBinary, - String, - ) + from sqlalchemy import JSON, Dialect, Index, LargeBinary, String from sqlalchemy.orm import ( DeclarativeBase, Mapped, @@ -153,6 +148,8 @@ class TaskMixin: kind: Mapped[str] = mapped_column( String(16), nullable=False, default='task' ) + owner: Mapped[str] = mapped_column(String(255), nullable=False) + last_updated: Mapped[str] = mapped_column(String(22), nullable=True) # Properly typed Pydantic fields with automatic serialization status: Mapped[TaskStatus] = mapped_column(PydanticType(TaskStatus)) @@ -178,6 +175,17 @@ def __repr__(self) -> str: f'context_id="{self.context_id}", status="{self.status}")>' ) + @declared_attr.directive + @classmethod + def __table_args__(cls) -> tuple[Any, ...]: + """Define a composite index (owner, last_updated) for each table that uses the mixin.""" + tablename = getattr(cls, '__tablename__', 'tasks') + return ( + Index( + f'idx_{tablename}_owner_last_updated', 'owner', 'last_updated' + ), + ) + def create_task_model( table_name: str = 'tasks', base: type[DeclarativeBase] = Base @@ -238,6 +246,7 @@ class PushNotificationConfigMixin: task_id: Mapped[str] = mapped_column(String(36), primary_key=True) config_id: Mapped[str] = mapped_column(String(255), primary_key=True) config_data: Mapped[bytes] = mapped_column(LargeBinary, nullable=False) + owner: Mapped[str] = mapped_column(String(255), nullable=False, index=True) @override def __repr__(self) -> str: diff --git a/src/a2a/server/owner_resolver.py b/src/a2a/server/owner_resolver.py new file mode 100644 index 000000000..4fa310b92 --- /dev/null +++ b/src/a2a/server/owner_resolver.py @@ -0,0 +1,18 @@ +from collections.abc import Callable + +from a2a.server.context import ServerCallContext + + +# Definition +OwnerResolver = Callable[[ServerCallContext | None], str] + + +# Example Default Implementation +def resolve_user_scope(context: ServerCallContext | None) -> str: + """Resolves the owner scope based on the user in the context.""" + if not context: + return 'unknown' + if not context.user: + raise ValueError('User not found in context.') + # Example: Basic user name. Adapt as needed for your user model. + return context.user.user_name diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index 41425457f..104b256de 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -227,7 +227,7 @@ async def _run_event_stream( async def _setup_message_execution( self, params: SendMessageRequest, - context: ServerCallContext | None = None, + context: ServerCallContext | None, ) -> tuple[TaskManager, str, EventQueue, ResultAggregator, asyncio.Task]: """Common setup logic for both streaming and non-streaming message handling. @@ -283,7 +283,9 @@ async def _setup_message_execution( and params.configuration.push_notification_config ): await self._push_config_store.set_info( - task_id, params.configuration.push_notification_config + task_id, + params.configuration.push_notification_config, + context or ServerCallContext(), ) queue = await self._queue_manager.create_or_tap(task_id) @@ -495,6 +497,7 @@ async def on_create_task_push_notification_config( await self._push_config_store.set_info( task_id, params.config, + context or ServerCallContext(), ) return TaskPushNotificationConfig( @@ -521,7 +524,10 @@ async def on_get_task_push_notification_config( raise ServerError(error=TaskNotFoundError()) push_notification_configs: list[PushNotificationConfig] = ( - await self._push_config_store.get_info(task_id) or [] + await self._push_config_store.get_info( + task_id, context or ServerCallContext() + ) + or [] ) for config in push_notification_configs: @@ -597,7 +603,7 @@ async def on_list_task_push_notification_configs( raise ServerError(error=TaskNotFoundError()) push_notification_config_list = await self._push_config_store.get_info( - task_id + task_id, context or ServerCallContext() ) return ListTaskPushNotificationConfigsResponse( @@ -628,4 +634,6 @@ async def on_delete_task_push_notification_config( if not task: raise ServerError(error=TaskNotFoundError()) - await self._push_config_store.delete_info(task_id, config_id) + await self._push_config_store.delete_info( + task_id, context or ServerCallContext(), config_id + ) diff --git a/src/a2a/server/tasks/base_push_notification_sender.py b/src/a2a/server/tasks/base_push_notification_sender.py index 4e4444923..84f544f5e 100644 --- a/src/a2a/server/tasks/base_push_notification_sender.py +++ b/src/a2a/server/tasks/base_push_notification_sender.py @@ -5,6 +5,7 @@ from google.protobuf.json_format import MessageToDict +from a2a.server.context import ServerCallContext from a2a.server.tasks.push_notification_config_store import ( PushNotificationConfigStore, ) @@ -22,19 +23,24 @@ def __init__( self, httpx_client: httpx.AsyncClient, config_store: PushNotificationConfigStore, + context: ServerCallContext, ) -> None: """Initializes the BasePushNotificationSender. Args: httpx_client: An async HTTP client instance to send notifications. config_store: A PushNotificationConfigStore instance to retrieve configurations. + context: The `ServerCallContext` that this push notification is produced under. """ self._client = httpx_client self._config_store = config_store + self._call_context: ServerCallContext = context async def send_notification(self, task: Task) -> None: """Sends a push notification for a task if configuration exists.""" - push_configs = await self._config_store.get_info(task.id) + push_configs = await self._config_store.get_info( + task.id, self._call_context + ) if not push_configs: return diff --git a/src/a2a/server/tasks/database_push_notification_config_store.py b/src/a2a/server/tasks/database_push_notification_config_store.py index 14f3bb162..be8f16121 100644 --- a/src/a2a/server/tasks/database_push_notification_config_store.py +++ b/src/a2a/server/tasks/database_push_notification_config_store.py @@ -8,11 +8,7 @@ try: - from sqlalchemy import ( - Table, - delete, - select, - ) + from sqlalchemy import Table, and_, delete, select from sqlalchemy.ext.asyncio import ( AsyncEngine, AsyncSession, @@ -31,11 +27,13 @@ "or 'pip install a2a-sdk[sql]'" ) from e +from a2a.server.context import ServerCallContext from a2a.server.models import ( Base, PushNotificationConfigModel, create_push_notification_config_model, ) +from a2a.server.owner_resolver import OwnerResolver, resolve_user_scope from a2a.server.tasks.push_notification_config_store import ( PushNotificationConfigStore, ) @@ -61,6 +59,7 @@ class DatabasePushNotificationConfigStore(PushNotificationConfigStore): _initialized: bool config_model: type[PushNotificationConfigModel] _fernet: 'Fernet | None' + owner_resolver: OwnerResolver def __init__( self, @@ -68,6 +67,7 @@ def __init__( create_table: bool = True, table_name: str = 'push_notification_configs', encryption_key: str | bytes | None = None, + owner_resolver: OwnerResolver = resolve_user_scope, ) -> None: """Initializes the DatabasePushNotificationConfigStore. @@ -78,6 +78,7 @@ def __init__( encryption_key: A key for encrypting sensitive configuration data. If provided, `config_data` will be encrypted in the database. The key must be a URL-safe base64-encoded 32-byte key. + owner_resolver: Function to resolve the owner from the context. """ logger.debug( 'Initializing DatabasePushNotificationConfigStore with existing engine, table: %s', @@ -89,6 +90,7 @@ def __init__( ) self.create_table = create_table self._initialized = False + self.owner_resolver = owner_resolver self.config_model = ( PushNotificationConfigModel if table_name == 'push_notification_configs' @@ -143,7 +145,7 @@ async def _ensure_initialized(self) -> None: await self.initialize() def _to_orm( - self, task_id: str, config: PushNotificationConfig + self, task_id: str, config: PushNotificationConfig, owner: str ) -> PushNotificationConfigModel: """Maps a PushNotificationConfig proto to a SQLAlchemy model instance. @@ -159,6 +161,7 @@ def _to_orm( return self.config_model( task_id=task_id, config_id=config.id, + owner=owner, config_data=data_to_store, ) @@ -235,10 +238,14 @@ def _from_orm( ) from e async def set_info( - self, task_id: str, notification_config: PushNotificationConfig + self, + task_id: str, + notification_config: PushNotificationConfig, + context: ServerCallContext, ) -> None: """Sets or updates the push notification configuration for a task.""" await self._ensure_initialized() + owner = self.owner_resolver(context) # Create a copy of the config using proto CopyFrom config_to_save = PushNotificationConfig() @@ -246,21 +253,30 @@ async def set_info( if not config_to_save.id: config_to_save.id = task_id - db_config = self._to_orm(task_id, config_to_save) + db_config = self._to_orm(task_id, config_to_save, owner) async with self.async_session_maker.begin() as session: await session.merge(db_config) logger.debug( - 'Push notification config for task %s with config id %s saved/updated.', + 'Push notification config for task %s with config id %s for owner %s saved/updated.', task_id, config_to_save.id, + owner, ) - async def get_info(self, task_id: str) -> list[PushNotificationConfig]: - """Retrieves all push notification configurations for a task.""" + async def get_info( + self, + task_id: str, + context: ServerCallContext, + ) -> list[PushNotificationConfig]: + """Retrieves all push notification configurations for a task, for the given owner.""" await self._ensure_initialized() + owner = self.owner_resolver(context) async with self.async_session_maker() as session: stmt = select(self.config_model).where( - self.config_model.task_id == task_id + and_( + self.config_model.task_id == task_id, + self.config_model.owner == owner, + ) ) result = await session.execute(stmt) models = result.scalars().all() @@ -271,24 +287,32 @@ async def get_info(self, task_id: str) -> list[PushNotificationConfig]: configs.append(self._from_orm(model)) except ValueError: # noqa: PERF203 logger.exception( - 'Could not deserialize push notification config for task %s, config %s', + 'Could not deserialize push notification config for task %s, config %s, owner %s', model.task_id, model.config_id, + owner, ) return configs async def delete_info( - self, task_id: str, config_id: str | None = None + self, + task_id: str, + context: ServerCallContext, + config_id: str | None = None, ) -> None: """Deletes push notification configurations for a task. If config_id is provided, only that specific configuration is deleted. - If config_id is None, all configurations for the task are deleted. + If config_id is None, all configurations for the task for the owner are deleted. """ await self._ensure_initialized() + owner = self.owner_resolver(context) async with self.async_session_maker.begin() as session: stmt = delete(self.config_model).where( - self.config_model.task_id == task_id + and_( + self.config_model.task_id == task_id, + self.config_model.owner == owner, + ) ) if config_id is not None: stmt = stmt.where(self.config_model.config_id == config_id) @@ -297,13 +321,15 @@ async def delete_info( if result.rowcount > 0: # type: ignore[attr-defined] logger.info( - 'Deleted %s push notification config(s) for task %s.', + 'Deleted %s push notification config(s) for task %s, owner %s.', result.rowcount, # type: ignore[attr-defined] task_id, + owner, ) else: logger.warning( - 'Attempted to delete push notification config for task %s with config_id: %s that does not exist.', + 'Attempted to delete push notification config for task %s, owner %s with config_id: %s that does not exist.', task_id, + owner, config_id, ) diff --git a/src/a2a/server/tasks/database_task_store.py b/src/a2a/server/tasks/database_task_store.py index 0acb9c2d4..1e8330f10 100644 --- a/src/a2a/server/tasks/database_task_store.py +++ b/src/a2a/server/tasks/database_task_store.py @@ -34,6 +34,7 @@ from a2a.server.context import ServerCallContext from a2a.server.models import Base, TaskModel, create_task_model +from a2a.server.owner_resolver import OwnerResolver, resolve_user_scope from a2a.server.tasks.task_store import TaskStore from a2a.types import a2a_pb2 from a2a.types.a2a_pb2 import Task @@ -55,12 +56,14 @@ class DatabaseTaskStore(TaskStore): create_table: bool _initialized: bool task_model: type[TaskModel] + owner_resolver: OwnerResolver def __init__( self, engine: AsyncEngine, create_table: bool = True, table_name: str = 'tasks', + owner_resolver: OwnerResolver = resolve_user_scope, ) -> None: """Initializes the DatabaseTaskStore. @@ -68,6 +71,7 @@ def __init__( engine: An existing SQLAlchemy AsyncEngine to be used by Task Store create_table: If true, create tasks table on initialization. table_name: Name of the database table. Defaults to 'tasks'. + owner_resolver: Function to resolve the owner from the context. """ logger.debug( 'Initializing DatabaseTaskStore with existing engine, table: %s', @@ -79,6 +83,7 @@ def __init__( ) self.create_table = create_table self._initialized = False + self.owner_resolver = owner_resolver self.task_model = ( TaskModel @@ -109,7 +114,7 @@ async def _ensure_initialized(self) -> None: if not self._initialized: await self.initialize() - def _to_orm(self, task: Task) -> TaskModel: + def _to_orm(self, task: Task, owner: str) -> TaskModel: """Maps a Proto Task to a SQLAlchemy TaskModel instance.""" # Pass proto objects directly - PydanticType/PydanticListType # handle serialization via process_bind_param @@ -117,6 +122,12 @@ def _to_orm(self, task: Task) -> TaskModel: id=task.id, context_id=task.context_id, kind='task', # Default kind for tasks + owner=owner, + last_updated=( + task.status.timestamp.ToJsonString() + if task.HasField('status') and task.status.HasField('timestamp') + else None + ), status=task.status if task.HasField('status') else None, artifacts=list(task.artifacts) if task.artifacts else [], history=list(task.history) if task.history else [], @@ -148,28 +159,45 @@ def _from_orm(self, task_model: TaskModel) -> Task: async def save( self, task: Task, context: ServerCallContext | None = None ) -> None: - """Saves or updates a task in the database.""" + """Saves or updates a task in the database for the resolved owner.""" await self._ensure_initialized() - db_task = self._to_orm(task) + owner = self.owner_resolver(context) + db_task = self._to_orm(task, owner) async with self.async_session_maker.begin() as session: await session.merge(db_task) - logger.debug('Task %s saved/updated successfully.', task.id) + logger.debug( + 'Task %s for owner %s saved/updated successfully.', + task.id, + owner, + ) async def get( self, task_id: str, context: ServerCallContext | None = None ) -> Task | None: - """Retrieves a task from the database by ID.""" + """Retrieves a task from the database by ID, for the given owner.""" await self._ensure_initialized() + owner = self.owner_resolver(context) async with self.async_session_maker() as session: - stmt = select(self.task_model).where(self.task_model.id == task_id) + stmt = select(self.task_model).where( + and_( + self.task_model.id == task_id, + self.task_model.owner == owner, + ) + ) result = await session.execute(stmt) task_model = result.scalar_one_or_none() if task_model: task = self._from_orm(task_model) - logger.debug('Task %s retrieved successfully.', task_id) + logger.debug( + 'Task %s retrieved successfully for owner %s.', + task_id, + owner, + ) return task - logger.debug('Task %s not found in store.', task_id) + logger.debug( + 'Task %s not found in store for owner %s.', task_id, owner + ) return None async def list( @@ -177,11 +205,16 @@ async def list( params: a2a_pb2.ListTasksRequest, context: ServerCallContext | None = None, ) -> a2a_pb2.ListTasksResponse: - """Retrieves all tasks from the database.""" + """Retrieves tasks from the database based on provided parameters, for the given owner.""" await self._ensure_initialized() + owner = self.owner_resolver(context) + logger.debug('Listing tasks for owner %s with params %s', owner, params) + async with self.async_session_maker() as session: - timestamp_col = self.task_model.status['timestamp'].as_string() - base_stmt = select(self.task_model) + timestamp_col = self.task_model.last_updated + base_stmt = select(self.task_model).where( + self.task_model.owner == owner + ) # Add filters if params.context_id: @@ -218,33 +251,36 @@ async def list( start_task = ( await session.execute( select(self.task_model).where( - self.task_model.id == start_task_id + and_( + self.task_model.id == start_task_id, + self.task_model.owner == owner, + ) ) ) ).scalar_one_or_none() if not start_task: raise ValueError(f'Invalid page token: {params.page_token}') - if start_task.status.HasField('timestamp'): - start_timestamp_iso = ( - start_task.status.timestamp.ToJsonString() - ) - stmt = stmt.where( - or_( - and_( - timestamp_col == start_timestamp_iso, - self.task_model.id <= start_task.id, - ), - timestamp_col < start_timestamp_iso, - timestamp_col.is_(None), + + start_task_timestamp = start_task.last_updated + where_clauses = [] + if start_task_timestamp: + where_clauses.append( + and_( + timestamp_col == start_task_timestamp, + self.task_model.id <= start_task_id, ) ) + where_clauses.append(timestamp_col < start_task_timestamp) + where_clauses.append(timestamp_col.is_(None)) else: - stmt = stmt.where( + where_clauses.append( and_( timestamp_col.is_(None), - self.task_model.id <= start_task.id, + self.task_model.id <= start_task_id, ) ) + stmt = stmt.where(or_(*where_clauses)) + page_size = params.page_size or DEFAULT_LIST_TASKS_PAGE_SIZE stmt = stmt.limit(page_size + 1) # Add 1 for next page token @@ -268,17 +304,27 @@ async def list( async def delete( self, task_id: str, context: ServerCallContext | None = None ) -> None: - """Deletes a task from the database by ID.""" + """Deletes a task from the database by ID, for the given owner.""" await self._ensure_initialized() + owner = self.owner_resolver(context) async with self.async_session_maker.begin() as session: - stmt = delete(self.task_model).where(self.task_model.id == task_id) + stmt = delete(self.task_model).where( + and_( + self.task_model.id == task_id, + self.task_model.owner == owner, + ) + ) result = await session.execute(stmt) # Commit is automatic when using session.begin() if result.rowcount > 0: # type: ignore[attr-defined] - logger.info('Task %s deleted successfully.', task_id) + logger.info( + 'Task %s deleted successfully for owner %s.', task_id, owner + ) else: logger.warning( - 'Attempted to delete nonexistent task with id: %s', task_id + 'Attempted to delete nonexistent task with id: %s and owner %s', + task_id, + owner, ) diff --git a/src/a2a/server/tasks/inmemory_push_notification_config_store.py b/src/a2a/server/tasks/inmemory_push_notification_config_store.py index 707156593..eb336e329 100644 --- a/src/a2a/server/tasks/inmemory_push_notification_config_store.py +++ b/src/a2a/server/tasks/inmemory_push_notification_config_store.py @@ -1,6 +1,10 @@ import asyncio import logging +from collections import defaultdict + +from a2a.server.context import ServerCallContext +from a2a.server.owner_resolver import OwnerResolver, resolve_user_scope from a2a.server.tasks.push_notification_config_store import ( PushNotificationConfigStore, ) @@ -13,56 +17,115 @@ class InMemoryPushNotificationConfigStore(PushNotificationConfigStore): """In-memory implementation of PushNotificationConfigStore interface. - Stores push notification configurations in memory + Stores push notification configurations in a nested dictionary in memory, + keyed by owner then task_id. """ - def __init__(self) -> None: + def __init__( + self, + owner_resolver: OwnerResolver = resolve_user_scope, + ) -> None: """Initializes the InMemoryPushNotificationConfigStore.""" self.lock = asyncio.Lock() self._push_notification_infos: dict[ - str, list[PushNotificationConfig] - ] = {} + str, dict[str, list[PushNotificationConfig]] + ] = defaultdict(dict) + self.owner_resolver = owner_resolver async def set_info( - self, task_id: str, notification_config: PushNotificationConfig + self, + task_id: str, + notification_config: PushNotificationConfig, + context: ServerCallContext, ) -> None: """Sets or updates the push notification configuration for a task in memory.""" + owner = self.owner_resolver(context) async with self.lock: - if task_id not in self._push_notification_infos: - self._push_notification_infos[task_id] = [] + owner_infos = self._push_notification_infos[owner] + if task_id not in owner_infos: + owner_infos[task_id] = [] if not notification_config.id: notification_config.id = task_id - for config in self._push_notification_infos[task_id]: + # Remove existing config with the same ID + for config in owner_infos[task_id]: if config.id == notification_config.id: - self._push_notification_infos[task_id].remove(config) + owner_infos[task_id].remove(config) break - self._push_notification_infos[task_id].append(notification_config) - - async def get_info(self, task_id: str) -> list[PushNotificationConfig]: - """Retrieves the push notification configuration for a task from memory.""" + owner_infos[task_id].append(notification_config) + logger.debug( + 'Push notification config for task %s with config id %s for owner %s saved/updated.', + task_id, + notification_config.id, + owner, + ) + + async def get_info( + self, + task_id: str, + context: ServerCallContext, + ) -> list[PushNotificationConfig]: + """Retrieves all push notification configurations for a task from memory, for the given owner.""" + owner = self.owner_resolver(context) async with self.lock: - return self._push_notification_infos.get(task_id) or [] + owner_infos = self._push_notification_infos.get(owner, {}) + return list(owner_infos.get(task_id, [])) async def delete_info( - self, task_id: str, config_id: str | None = None + self, + task_id: str, + context: ServerCallContext, + config_id: str | None = None, ) -> None: - """Deletes the push notification configuration for a task from memory.""" - async with self.lock: - if config_id is None: - config_id = task_id + """Deletes push notification configurations for a task from memory. - if task_id in self._push_notification_infos: - configurations = self._push_notification_infos[task_id] - if not configurations: - return + If config_id is provided, only that specific configuration is deleted. + If config_id is None, all configurations for the task for the owner are deleted. + """ + owner = self.owner_resolver(context) + async with self.lock: + owner_infos = self._push_notification_infos.get(owner, {}) + if task_id not in owner_infos: + logger.warning( + 'Attempted to delete push notification config for task %s, owner %s that does not exist.', + task_id, + owner, + ) + return + if config_id is None: + del owner_infos[task_id] + logger.info( + 'Deleted all push notification configs for task %s, owner %s.', + task_id, + owner, + ) + else: + configurations = owner_infos[task_id] + found = False for config in configurations: if config.id == config_id: configurations.remove(config) + found = True break - - if len(configurations) == 0: - del self._push_notification_infos[task_id] + if found: + logger.info( + 'Deleted push notification config %s for task %s, owner %s.', + config_id, + task_id, + owner, + ) + if len(configurations) == 0: + del owner_infos[task_id] + else: + logger.warning( + 'Attempted to delete push notification config %s for task %s, owner %s that does not exist.', + config_id, + task_id, + owner, + ) + + if not owner_infos: + del self._push_notification_infos[owner] diff --git a/src/a2a/server/tasks/inmemory_task_store.py b/src/a2a/server/tasks/inmemory_task_store.py index 241d9899e..019fd773e 100644 --- a/src/a2a/server/tasks/inmemory_task_store.py +++ b/src/a2a/server/tasks/inmemory_task_store.py @@ -1,7 +1,10 @@ import asyncio import logging +from collections import defaultdict + from a2a.server.context import ServerCallContext +from a2a.server.owner_resolver import OwnerResolver, resolve_user_scope from a2a.server.tasks.task_store import TaskStore from a2a.types import a2a_pb2 from a2a.types.a2a_pb2 import Task @@ -15,45 +18,69 @@ class InMemoryTaskStore(TaskStore): """In-memory implementation of TaskStore. - Stores task objects in a dictionary in memory. Task data is lost when the - server process stops. + Stores task objects in a nested dictionary in memory, keyed by owner then task_id. + Task data is lost when the server process stops. """ - def __init__(self) -> None: + def __init__( + self, + owner_resolver: OwnerResolver = resolve_user_scope, + ) -> None: """Initializes the InMemoryTaskStore.""" logger.debug('Initializing InMemoryTaskStore') - self.tasks: dict[str, Task] = {} + self.tasks: dict[str, dict[str, Task]] = defaultdict(dict) self.lock = asyncio.Lock() + self.owner_resolver = owner_resolver async def save( self, task: Task, context: ServerCallContext | None = None ) -> None: - """Saves or updates a task in the in-memory store.""" + """Saves or updates a task in the in-memory store for the resolved owner.""" + owner = self.owner_resolver(context) + async with self.lock: - self.tasks[task.id] = task - logger.debug('Task %s saved successfully.', task.id) + self.tasks[owner][task.id] = task + logger.debug( + 'Task %s for owner %s saved successfully.', task.id, owner + ) async def get( self, task_id: str, context: ServerCallContext | None = None ) -> Task | None: - """Retrieves a task from the in-memory store by ID.""" + """Retrieves a task from the in-memory store by ID, for the given owner.""" + owner = self.owner_resolver(context) async with self.lock: - logger.debug('Attempting to get task with id: %s', task_id) - task = self.tasks.get(task_id) + logger.debug( + 'Attempting to get task with id: %s for owner: %s', + task_id, + owner, + ) + owner_tasks = self.tasks.get(owner, {}) + task = owner_tasks.get(task_id) if task: - logger.debug('Task %s retrieved successfully.', task_id) - else: - logger.debug('Task %s not found in store.', task_id) - return task + logger.debug( + 'Task %s retrieved successfully for owner %s.', + task_id, + owner, + ) + return task + logger.debug( + 'Task %s not found in store for owner %s.', task_id, owner + ) + return None async def list( self, params: a2a_pb2.ListTasksRequest, context: ServerCallContext | None = None, ) -> a2a_pb2.ListTasksResponse: - """Retrieves a list of tasks from the store.""" + """Retrieves a list of tasks from the store, for the given owner.""" + owner = self.owner_resolver(context) + logger.debug('Listing tasks for owner %s with params %s', owner, params) + async with self.lock: - tasks = list(self.tasks.values()) + owner_tasks = self.tasks.get(owner, {}) + tasks = list(owner_tasks.values()) # Filter tasks if params.context_id: @@ -125,13 +152,28 @@ async def list( async def delete( self, task_id: str, context: ServerCallContext | None = None ) -> None: - """Deletes a task from the in-memory store by ID.""" + """Deletes a task from the in-memory store by ID, for the given owner.""" + owner = self.owner_resolver(context) async with self.lock: - logger.debug('Attempting to delete task with id: %s', task_id) - if task_id in self.tasks: - del self.tasks[task_id] - logger.debug('Task %s deleted successfully.', task_id) - else: + logger.debug( + 'Attempting to delete task with id: %s for owner %s', + task_id, + owner, + ) + + owner_tasks = self.tasks.get(owner, {}) + if task_id not in owner_tasks: logger.warning( - 'Attempted to delete nonexistent task with id: %s', task_id + 'Attempted to delete nonexistent task with id: %s for owner %s', + task_id, + owner, ) + return + + del owner_tasks[task_id] + logger.debug( + 'Task %s deleted successfully for owner %s.', task_id, owner + ) + if not owner_tasks: + del self.tasks[owner] + logger.debug('Removed empty owner %s from store.', owner) diff --git a/src/a2a/server/tasks/push_notification_config_store.py b/src/a2a/server/tasks/push_notification_config_store.py index a1c049e90..f1db64664 100644 --- a/src/a2a/server/tasks/push_notification_config_store.py +++ b/src/a2a/server/tasks/push_notification_config_store.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod +from a2a.server.context import ServerCallContext from a2a.types.a2a_pb2 import PushNotificationConfig @@ -8,16 +9,26 @@ class PushNotificationConfigStore(ABC): @abstractmethod async def set_info( - self, task_id: str, notification_config: PushNotificationConfig + self, + task_id: str, + notification_config: PushNotificationConfig, + context: ServerCallContext, ) -> None: """Sets or updates the push notification configuration for a task.""" @abstractmethod - async def get_info(self, task_id: str) -> list[PushNotificationConfig]: + async def get_info( + self, + task_id: str, + context: ServerCallContext, + ) -> list[PushNotificationConfig]: """Retrieves the push notification configuration for a task.""" @abstractmethod async def delete_info( - self, task_id: str, config_id: str | None = None + self, + task_id: str, + context: ServerCallContext, + config_id: str | None = None, ) -> None: """Deletes the push notification configuration for a task.""" diff --git a/tests/e2e/push_notifications/agent_app.py b/tests/e2e/push_notifications/agent_app.py index ef8276c4e..dfe71566a 100644 --- a/tests/e2e/push_notifications/agent_app.py +++ b/tests/e2e/push_notifications/agent_app.py @@ -4,6 +4,7 @@ from a2a.server.agent_execution import AgentExecutor, RequestContext from a2a.server.apps import A2ARESTFastAPIApplication +from a2a.server.context import ServerCallContext from a2a.server.events import EventQueue from a2a.server.request_handlers import DefaultRequestHandler from a2a.server.tasks import ( @@ -148,6 +149,7 @@ def create_agent_app( push_sender=BasePushNotificationSender( httpx_client=notification_client, config_store=push_config_store, + context=ServerCallContext(), ), ), ) diff --git a/tests/server/request_handlers/test_default_request_handler.py b/tests/server/request_handlers/test_default_request_handler.py index 350d595a4..6e2d30f87 100644 --- a/tests/server/request_handlers/test_default_request_handler.py +++ b/tests/server/request_handlers/test_default_request_handler.py @@ -546,6 +546,7 @@ async def mock_current_result(): lambda self: mock_current_result() ) + context = create_server_call_context() with ( patch( 'a2a.server.request_handlers.default_request_handler.ResultAggregator', @@ -560,12 +561,10 @@ async def mock_current_result(): return_value=sample_initial_task, ), ): # Ensure task object is returned - await request_handler.on_message_send( - params, create_server_call_context() - ) + await request_handler.on_message_send(params, context) mock_push_notification_store.set_info.assert_awaited_once_with( - task_id, push_config + task_id, push_config, context ) # Other assertions for full flow if needed (e.g., agent execution) mock_agent_executor.execute.assert_awaited_once() @@ -665,6 +664,7 @@ async def mock_consume_and_break_on_interrupt( mock_consume_and_break_on_interrupt ) + context = create_server_call_context() with ( patch( 'a2a.server.request_handlers.default_request_handler.ResultAggregator', @@ -680,9 +680,7 @@ async def mock_consume_and_break_on_interrupt( ), ): # Execute the non-blocking request - result = await request_handler.on_message_send( - params, create_server_call_context() - ) + result = await request_handler.on_message_send(params, context) # Verify the result is the initial task (non-blocking behavior) assert result == initial_task @@ -700,7 +698,7 @@ async def mock_consume_and_break_on_interrupt( # Verify that the push notification config was stored mock_push_notification_store.set_info.assert_awaited_once_with( - task_id, push_config + task_id, push_config, context ) @@ -763,6 +761,7 @@ async def mock_current_result(): lambda self: mock_current_result() ) + context = create_server_call_context() with ( patch( 'a2a.server.request_handlers.default_request_handler.ResultAggregator', @@ -773,12 +772,10 @@ async def mock_current_result(): return_value=None, ), ): - await request_handler.on_message_send( - params, create_server_call_context() - ) + await request_handler.on_message_send(params, context) mock_push_notification_store.set_info.assert_awaited_once_with( - task_id, push_config + task_id, push_config, context ) # Other assertions for full flow if needed (e.g., agent execution) mock_agent_executor.execute.assert_awaited_once() @@ -938,9 +935,8 @@ async def test_on_message_send_non_blocking(): ), ) - result = await request_handler.on_message_send( - params, create_server_call_context() - ) + context = create_server_call_context() + result = await request_handler.on_message_send(params, context) assert result is not None assert isinstance(result, Task) @@ -950,7 +946,7 @@ async def test_on_message_send_non_blocking(): task: Task | None = None for _ in range(5): await asyncio.sleep(0.1) - task = await task_store.get(result.id) + task = await task_store.get(result.id, context) assert task is not None if task.status.state == TaskState.TASK_STATE_COMPLETED: break @@ -987,9 +983,8 @@ async def test_on_message_send_limit_history(): ), ) - result = await request_handler.on_message_send( - params, create_server_call_context() - ) + context = create_server_call_context() + result = await request_handler.on_message_send(params, context) # verify that history_length is honored assert result is not None @@ -998,7 +993,7 @@ async def test_on_message_send_limit_history(): assert result.status.state == TaskState.TASK_STATE_COMPLETED # verify that history is still persisted to the store - task = await task_store.get(result.id) + task = await task_store.get(result.id, context) assert task is not None assert task.history is not None and len(task.history) > 1 @@ -1384,6 +1379,7 @@ def sync_get_event_stream_gen_for_prop_test(*args, **kwargs): side_effect=[get_current_result_coro1(), get_current_result_coro2()] ) + context = create_server_call_context() with ( patch( 'a2a.server.request_handlers.default_request_handler.ResultAggregator', @@ -1399,16 +1395,16 @@ def sync_get_event_stream_gen_for_prop_test(*args, **kwargs): ), ): # Consume the stream - async for _ in request_handler.on_message_send_stream( - params, create_server_call_context() - ): + async for _ in request_handler.on_message_send_stream(params, context): pass await asyncio.wait_for(execute_called.wait(), timeout=0.1) # Assertions # 1. set_info called once at the beginning if task exists (or after task is created from message) - mock_push_config_store.set_info.assert_any_call(task_id, push_config) + mock_push_config_store.set_info.assert_any_call( + task_id, push_config, context + ) # 2. send_notification called for each task event yielded by aggregator assert mock_push_sender.send_notification.await_count == 2 @@ -2087,7 +2083,9 @@ async def test_get_task_push_notification_config_info_not_found(): exc_info.value.error, InternalError ) # Current code raises InternalError mock_task_store.get.assert_awaited_once_with('non_existent_task', context) - mock_push_store.get_info.assert_awaited_once_with('non_existent_task') + mock_push_store.get_info.assert_awaited_once_with( + 'non_existent_task', context + ) @pytest.mark.asyncio @@ -2241,7 +2239,7 @@ async def test_on_message_send_stream(): async def consume_stream(): events = [] async for event in request_handler.on_message_send_stream( - message_params + message_params, create_server_call_context() ): events.append(event) if len(events) >= 3: @@ -2345,8 +2343,9 @@ async def test_list_task_push_notification_config_info_with_config(): ) push_store = InMemoryPushNotificationConfigStore() - await push_store.set_info('task_1', push_config1) - await push_store.set_info('task_1', push_config2) + context = create_server_call_context() + await push_store.set_info('task_1', push_config1, context) + await push_store.set_info('task_1', push_config2, context) request_handler = DefaultRequestHandler( agent_executor=MockAgentExecutor(), @@ -2472,6 +2471,7 @@ async def test_delete_no_task_push_notification_config_info(): await push_store.set_info( 'task_2', PushNotificationConfig(id='config_1', url='http://example.com'), + create_server_call_context(), ) request_handler = DefaultRequestHandler( @@ -2514,9 +2514,10 @@ async def test_delete_task_push_notification_config_info_with_config(): ) push_store = InMemoryPushNotificationConfigStore() - await push_store.set_info('task_1', push_config1) - await push_store.set_info('task_1', push_config2) - await push_store.set_info('task_2', push_config1) + context = create_server_call_context() + await push_store.set_info('task_1', push_config1, context) + await push_store.set_info('task_1', push_config2, context) + await push_store.set_info('task_2', push_config1, context) request_handler = DefaultRequestHandler( agent_executor=MockAgentExecutor(), @@ -2555,8 +2556,9 @@ async def test_delete_task_push_notification_config_info_with_config_and_no_id() # insertion without id should replace the existing config push_store = InMemoryPushNotificationConfigStore() - await push_store.set_info('task_1', push_config) - await push_store.set_info('task_1', push_config) + context = create_server_call_context() + await push_store.set_info('task_1', push_config, context) + await push_store.set_info('task_1', push_config, context) request_handler = DefaultRequestHandler( agent_executor=MockAgentExecutor(), diff --git a/tests/server/request_handlers/test_jsonrpc_handler.py b/tests/server/request_handlers/test_jsonrpc_handler.py index a9e940a03..aa448f354 100644 --- a/tests/server/request_handlers/test_jsonrpc_handler.py +++ b/tests/server/request_handlers/test_jsonrpc_handler.py @@ -550,11 +550,12 @@ async def test_set_push_notification_success(self) -> None: task_id=mock_task.id, config=push_config, ) - response = await handler.set_push_notification_config(request) + context = ServerCallContext() + response = await handler.set_push_notification_config(request, context) self.assertIsInstance(response, dict) self.assertTrue(is_success_response(response)) mock_push_notification_store.set_info.assert_called_once_with( - mock_task.id, push_config + mock_task.id, push_config, context ) async def test_get_push_notification_success(self) -> None: @@ -601,7 +602,7 @@ async def test_on_message_stream_new_message_send_push_notification_success( mock_httpx_client = AsyncMock(spec=httpx.AsyncClient) push_notification_store = InMemoryPushNotificationConfigStore() push_notification_sender = BasePushNotificationSender( - mock_httpx_client, push_notification_store + mock_httpx_client, push_notification_store, ServerCallContext() ) request_handler = DefaultRequestHandler( mock_agent_executor, diff --git a/tests/server/tasks/test_database_push_notification_config_store.py b/tests/server/tasks/test_database_push_notification_config_store.py index b0445d8fd..042ff8000 100644 --- a/tests/server/tasks/test_database_push_notification_config_store.py +++ b/tests/server/tasks/test_database_push_notification_config_store.py @@ -3,6 +3,8 @@ from collections.abc import AsyncGenerator import pytest +from a2a.server.context import ServerCallContext +from a2a.auth.user import User # Skip entire test module if SQLAlchemy is not installed @@ -102,6 +104,24 @@ def _create_timestamp() -> Timestamp: ) +class SampleUser(User): + """A test implementation of the User interface.""" + + def __init__(self, user_name: str): + self._user_name = user_name + + @property + def is_authenticated(self) -> bool: + return True + + @property + def user_name(self) -> str: + return self._user_name + + +MINIMAL_CALL_CONTEXT = ServerCallContext(user=SampleUser(user_name='user')) + + @pytest_asyncio.fixture(params=DB_CONFIGS) async def db_store_parameterized( request, @@ -181,8 +201,10 @@ async def test_set_and_get_info_single_config( task_id = 'task-1' config = PushNotificationConfig(id='config-1', url='http://example.com') - await db_store_parameterized.set_info(task_id, config) - retrieved_configs = await db_store_parameterized.get_info(task_id) + await db_store_parameterized.set_info(task_id, config, MINIMAL_CALL_CONTEXT) + retrieved_configs = await db_store_parameterized.get_info( + task_id, MINIMAL_CALL_CONTEXT + ) assert len(retrieved_configs) == 1 assert retrieved_configs[0] == config @@ -198,9 +220,15 @@ async def test_set_and_get_info_multiple_configs( config1 = PushNotificationConfig(id='config-1', url='http://example.com/1') config2 = PushNotificationConfig(id='config-2', url='http://example.com/2') - await db_store_parameterized.set_info(task_id, config1) - await db_store_parameterized.set_info(task_id, config2) - retrieved_configs = await db_store_parameterized.get_info(task_id) + await db_store_parameterized.set_info( + task_id, config1, MINIMAL_CALL_CONTEXT + ) + await db_store_parameterized.set_info( + task_id, config2, MINIMAL_CALL_CONTEXT + ) + retrieved_configs = await db_store_parameterized.get_info( + task_id, MINIMAL_CALL_CONTEXT + ) assert len(retrieved_configs) == 2 assert config1 in retrieved_configs @@ -221,9 +249,15 @@ async def test_set_info_updates_existing_config( id=config_id, url='http://updated.url' ) - await db_store_parameterized.set_info(task_id, initial_config) - await db_store_parameterized.set_info(task_id, updated_config) - retrieved_configs = await db_store_parameterized.get_info(task_id) + await db_store_parameterized.set_info( + task_id, initial_config, MINIMAL_CALL_CONTEXT + ) + await db_store_parameterized.set_info( + task_id, updated_config, MINIMAL_CALL_CONTEXT + ) + retrieved_configs = await db_store_parameterized.get_info( + task_id, MINIMAL_CALL_CONTEXT + ) assert len(retrieved_configs) == 1 assert retrieved_configs[0].url == 'http://updated.url' @@ -237,8 +271,10 @@ async def test_set_info_defaults_config_id_to_task_id( task_id = 'task-1' config = PushNotificationConfig(url='http://example.com') # id is None - await db_store_parameterized.set_info(task_id, config) - retrieved_configs = await db_store_parameterized.get_info(task_id) + await db_store_parameterized.set_info(task_id, config, MINIMAL_CALL_CONTEXT) + retrieved_configs = await db_store_parameterized.get_info( + task_id, MINIMAL_CALL_CONTEXT + ) assert len(retrieved_configs) == 1 assert retrieved_configs[0].id == task_id @@ -250,7 +286,7 @@ async def test_get_info_not_found( ): """Test getting info for a task with no configs returns an empty list.""" retrieved_configs = await db_store_parameterized.get_info( - 'non-existent-task' + 'non-existent-task', MINIMAL_CALL_CONTEXT ) assert retrieved_configs == [] @@ -264,11 +300,19 @@ async def test_delete_info_specific_config( config1 = PushNotificationConfig(id='config-1', url='http://a.com') config2 = PushNotificationConfig(id='config-2', url='http://b.com') - await db_store_parameterized.set_info(task_id, config1) - await db_store_parameterized.set_info(task_id, config2) + await db_store_parameterized.set_info( + task_id, config1, MINIMAL_CALL_CONTEXT + ) + await db_store_parameterized.set_info( + task_id, config2, MINIMAL_CALL_CONTEXT + ) - await db_store_parameterized.delete_info(task_id, 'config-1') - retrieved_configs = await db_store_parameterized.get_info(task_id) + await db_store_parameterized.delete_info( + task_id, MINIMAL_CALL_CONTEXT, 'config-1' + ) + retrieved_configs = await db_store_parameterized.get_info( + task_id, MINIMAL_CALL_CONTEXT + ) assert len(retrieved_configs) == 1 assert retrieved_configs[0] == config2 @@ -284,11 +328,19 @@ async def test_delete_info_all_for_task( config1 = PushNotificationConfig(id='config-1', url='http://a.com') config2 = PushNotificationConfig(id='config-2', url='http://b.com') - await db_store_parameterized.set_info(task_id, config1) - await db_store_parameterized.set_info(task_id, config2) + await db_store_parameterized.set_info( + task_id, config1, MINIMAL_CALL_CONTEXT + ) + await db_store_parameterized.set_info( + task_id, config2, MINIMAL_CALL_CONTEXT + ) - await db_store_parameterized.delete_info(task_id, None) - retrieved_configs = await db_store_parameterized.get_info(task_id) + await db_store_parameterized.delete_info( + task_id, MINIMAL_CALL_CONTEXT, None + ) + retrieved_configs = await db_store_parameterized.get_info( + task_id, MINIMAL_CALL_CONTEXT + ) assert retrieved_configs == [] @@ -299,7 +351,9 @@ async def test_delete_info_not_found( ): """Test that deleting a non-existent config does not raise an error.""" # Should not raise - await db_store_parameterized.delete_info('task-1', 'non-existent-config') + await db_store_parameterized.delete_info( + 'task-1', MINIMAL_CALL_CONTEXT, 'non-existent-config' + ) @pytest.mark.asyncio @@ -313,7 +367,7 @@ async def test_data_is_encrypted_in_db( ) plain_json = MessageToJson(config) - await db_store_parameterized.set_info(task_id, config) + await db_store_parameterized.set_info(task_id, config, MINIMAL_CALL_CONTEXT) # Directly query the database to inspect the raw data async_session = async_sessionmaker( @@ -343,7 +397,7 @@ async def test_decryption_error_with_wrong_key( task_id = 'wrong-key-task' config = PushNotificationConfig(id='config-1', url='http://secret.url') - await db_store_parameterized.set_info(task_id, config) + await db_store_parameterized.set_info(task_id, config, MINIMAL_CALL_CONTEXT) # 2. Try to read with a different key # Directly query the database to inspect the raw data @@ -352,7 +406,7 @@ async def test_decryption_error_with_wrong_key( db_store_parameterized.engine, encryption_key=wrong_key ) - retrieved_configs = await store2.get_info(task_id) + retrieved_configs = await store2.get_info(task_id, MINIMAL_CALL_CONTEXT) assert retrieved_configs == [] # _from_orm should raise a ValueError @@ -377,13 +431,13 @@ async def test_decryption_error_with_no_key( task_id = 'wrong-key-task' config = PushNotificationConfig(id='config-1', url='http://secret.url') - await db_store_parameterized.set_info(task_id, config) + await db_store_parameterized.set_info(task_id, config, MINIMAL_CALL_CONTEXT) # 2. Try to read with no key set # Directly query the database to inspect the raw data store2 = DatabasePushNotificationConfigStore(db_store_parameterized.engine) - retrieved_configs = await store2.get_info(task_id) + retrieved_configs = await store2.get_info(task_id, MINIMAL_CALL_CONTEXT) assert retrieved_configs == [] # _from_orm should raise a ValueError @@ -420,8 +474,10 @@ async def test_custom_table_name( config = PushNotificationConfig(id='config-1', url='http://custom.url') # This will create the table on first use - await custom_store.set_info(task_id, config) - retrieved_configs = await custom_store.get_info(task_id) + await custom_store.set_info(task_id, config, MINIMAL_CALL_CONTEXT) + retrieved_configs = await custom_store.get_info( + task_id, MINIMAL_CALL_CONTEXT + ) assert len(retrieved_configs) == 1 assert retrieved_configs[0] == config @@ -465,9 +521,9 @@ async def test_set_and_get_info_multiple_configs_no_key( config1 = PushNotificationConfig(id='config-1', url='http://example.com/1') config2 = PushNotificationConfig(id='config-2', url='http://example.com/2') - await store.set_info(task_id, config1) - await store.set_info(task_id, config2) - retrieved_configs = await store.get_info(task_id) + await store.set_info(task_id, config1, MINIMAL_CALL_CONTEXT) + await store.set_info(task_id, config2, MINIMAL_CALL_CONTEXT) + retrieved_configs = await store.get_info(task_id, MINIMAL_CALL_CONTEXT) assert len(retrieved_configs) == 2 assert config1 in retrieved_configs @@ -491,7 +547,7 @@ async def test_data_is_not_encrypted_in_db_if_no_key_is_set( config = PushNotificationConfig(id='config-1', url='http://example.com/1') plain_json = MessageToJson(config) - await store.set_info(task_id, config) + await store.set_info(task_id, config, MINIMAL_CALL_CONTEXT) # Directly query the database to inspect the raw data async_session = async_sessionmaker( @@ -522,10 +578,12 @@ async def test_decryption_fallback_for_unencrypted_data( task_id = 'mixed-encryption-task' config = PushNotificationConfig(id='config-1', url='http://plain.url') - await unencrypted_store.set_info(task_id, config) + await unencrypted_store.set_info(task_id, config, MINIMAL_CALL_CONTEXT) # 2. Try to read with the encryption-enabled store from the fixture - retrieved_configs = await db_store_parameterized.get_info(task_id) + retrieved_configs = await db_store_parameterized.get_info( + task_id, MINIMAL_CALL_CONTEXT + ) # Should fall back to parsing as plain JSON and not fail assert len(retrieved_configs) == 1 @@ -555,12 +613,15 @@ async def test_parsing_error_after_successful_decryption( task_id=task_id, config_id=config_id, config_data=encrypted_data, + owner='user', ) session.add(db_model) await session.commit() # 3. get_info should log an error and return an empty list - retrieved_configs = await db_store_parameterized.get_info(task_id) + retrieved_configs = await db_store_parameterized.get_info( + task_id, MINIMAL_CALL_CONTEXT + ) assert retrieved_configs == [] # 4. _from_orm should raise a ValueError @@ -571,3 +632,78 @@ async def test_parsing_error_after_successful_decryption( with pytest.raises(ValueError): db_store_parameterized._from_orm(db_model_retrieved) # type: ignore + + +@pytest.mark.asyncio +async def test_owner_resource_scoping( + db_store_parameterized: DatabasePushNotificationConfigStore, +) -> None: + """Test that operations are scoped to the correct owner.""" + config_store = db_store_parameterized + + context_user1 = ServerCallContext(user=SampleUser(user_name='user1')) + context_user2 = ServerCallContext(user=SampleUser(user_name='user2')) + + # Create configs for different owners + task1_u1_config1 = PushNotificationConfig( + id='t1-u1-c1', url='http://u1.com/1' + ) + task1_u1_config2 = PushNotificationConfig( + id='t1-u1-c2', url='http://u1.com/2' + ) + task1_u2_config1 = PushNotificationConfig( + id='t1-u2-c1', url='http://u2.com/1' + ) + task2_u1_config1 = PushNotificationConfig( + id='t2-u1-c1', url='http://u1.com/3' + ) + + await config_store.set_info('task1', task1_u1_config1, context_user1) + await config_store.set_info('task1', task1_u1_config2, context_user1) + await config_store.set_info('task1', task1_u2_config1, context_user2) + await config_store.set_info('task2', task2_u1_config1, context_user1) + + # Test GET_INFO + # User 1 should get only their configs for task1 + u1_task1_configs = await config_store.get_info('task1', context_user1) + assert len(u1_task1_configs) == 2 + assert {c.id for c in u1_task1_configs} == {'t1-u1-c1', 't1-u1-c2'} + + # User 2 should get only their configs for task1 + u2_task1_configs = await config_store.get_info('task1', context_user2) + assert len(u2_task1_configs) == 1 + assert u2_task1_configs[0].id == 't1-u2-c1' + + # User 2 should get no configs for task2 + u2_task2_configs = await config_store.get_info('task2', context_user2) + assert len(u2_task2_configs) == 0 + + # User 1 should get their config for task2 + u1_task2_configs = await config_store.get_info('task2', context_user1) + assert len(u1_task2_configs) == 1 + assert u1_task2_configs[0].id == 't2-u1-c1' + + # Test DELETE_INFO + # User 2 deleting User 1's config should not work + await config_store.delete_info('task1', context_user2, 't1-u1-c1') + u1_task1_configs = await config_store.get_info('task1', context_user1) + assert len(u1_task1_configs) == 2 + + # User 1 deleting their own config + await config_store.delete_info( + 'task1', + context_user1, + 't1-u1-c1', + ) + u1_task1_configs = await config_store.get_info('task1', context_user1) + assert len(u1_task1_configs) == 1 + assert u1_task1_configs[0].id == 't1-u1-c2' + + # User 1 deleting all configs for task2 + await config_store.delete_info('task2', context=context_user1) + u1_task2_configs = await config_store.get_info('task2', context_user1) + assert len(u1_task2_configs) == 0 + + # Cleanup remaining + await config_store.delete_info('task1', context=context_user1) + await config_store.delete_info('task1', context=context_user2) diff --git a/tests/server/tasks/test_database_task_store.py b/tests/server/tasks/test_database_task_store.py index aa9132172..e6b67701c 100644 --- a/tests/server/tasks/test_database_task_store.py +++ b/tests/server/tasks/test_database_task_store.py @@ -7,6 +7,7 @@ import pytest_asyncio from _pytest.mark.structures import ParameterSet +from a2a.types.a2a_pb2 import ListTasksRequest # Skip entire test module if SQLAlchemy is not installed @@ -30,9 +31,26 @@ TaskState, TaskStatus, ) +from a2a.auth.user import User +from a2a.server.context import ServerCallContext from a2a.utils.constants import DEFAULT_LIST_TASKS_PAGE_SIZE +class TestUser(User): + """A test implementation of the User interface.""" + + def __init__(self, user_name: str): + self._user_name = user_name + + @property + def is_authenticated(self) -> bool: + return True + + @property + def user_name(self) -> str: + return self._user_name + + # DSNs for different databases SQLITE_TEST_DSN = ( 'sqlite+aiosqlite:///file:testdb?mode=memory&cache=shared&uri=true' @@ -605,4 +623,57 @@ async def test_metadata_field_mapping( await db_store_parameterized.delete('task-metadata-test-4') +@pytest.mark.asyncio +async def test_owner_resource_scoping( + db_store_parameterized: DatabaseTaskStore, +) -> None: + """Test that operations are scoped to the correct owner.""" + task_store = db_store_parameterized + + context_user1 = ServerCallContext(user=TestUser(user_name='user1')) + context_user2 = ServerCallContext(user=TestUser(user_name='user2')) + + # Create tasks for different owners + task1_user1, task2_user1, task1_user2 = Task(), Task(), Task() + task1_user1.CopyFrom(MINIMAL_TASK_OBJ) + task1_user1.id = 'u1-task1' + task2_user1.CopyFrom(MINIMAL_TASK_OBJ) + task2_user1.id = 'u1-task2' + task1_user2.CopyFrom(MINIMAL_TASK_OBJ) + task1_user2.id = 'u2-task1' + + await task_store.save(task1_user1, context_user1) + await task_store.save(task2_user1, context_user1) + await task_store.save(task1_user2, context_user2) + + # Test GET + assert await task_store.get('u1-task1', context_user1) is not None + assert await task_store.get('u1-task1', context_user2) is None + assert await task_store.get('u2-task1', context_user1) is None + assert await task_store.get('u2-task1', context_user2) is not None + + # Test LIST + params = ListTasksRequest() + page_user1 = await task_store.list(params, context_user1) + assert len(page_user1.tasks) == 2 + assert {t.id for t in page_user1.tasks} == {'u1-task1', 'u1-task2'} + assert page_user1.total_size == 2 + + page_user2 = await task_store.list(params, context_user2) + assert len(page_user2.tasks) == 1 + assert {t.id for t in page_user2.tasks} == {'u2-task1'} + assert page_user2.total_size == 1 + + # Test DELETE + await task_store.delete('u1-task1', context_user2) # Should not delete + assert await task_store.get('u1-task1', context_user1) is not None + + await task_store.delete('u1-task1', context_user1) # Should delete + assert await task_store.get('u1-task1', context_user1) is None + + # Cleanup remaining tasks + await task_store.delete('u1-task2', context_user1) + await task_store.delete('u2-task1', context_user2) + + # Ensure aiosqlite, asyncpg, and aiomysql are installed in the test environment (added to pyproject.toml). diff --git a/tests/server/tasks/test_inmemory_push_notifications.py b/tests/server/tasks/test_inmemory_push_notifications.py index bbb01de2c..0024a95a6 100644 --- a/tests/server/tasks/test_inmemory_push_notifications.py +++ b/tests/server/tasks/test_inmemory_push_notifications.py @@ -5,6 +5,8 @@ import httpx from google.protobuf.json_format import MessageToDict +from a2a.auth.user import User +from a2a.server.context import ServerCallContext from a2a.server.tasks.base_push_notification_sender import ( BasePushNotificationSender, ) @@ -24,7 +26,7 @@ # logging.disable(logging.CRITICAL) -def create_sample_task( +def _create_sample_task( task_id: str = 'task123', status_state: TaskState = TaskState.TASK_STATE_COMPLETED, ) -> Task: @@ -35,7 +37,7 @@ def create_sample_task( ) -def create_sample_push_config( +def _create_sample_push_config( url: str = 'http://example.com/callback', config_id: str = 'cfg1', token: str | None = None, @@ -43,12 +45,32 @@ def create_sample_push_config( return PushNotificationConfig(id=config_id, url=url, token=token) +class SampleUser(User): + """A test implementation of the User interface.""" + + def __init__(self, user_name: str): + self._user_name = user_name + + @property + def is_authenticated(self) -> bool: + return True + + @property + def user_name(self) -> str: + return self._user_name + + +MINIMAL_CALL_CONTEXT = ServerCallContext(user=SampleUser(user_name='user')) + + class TestInMemoryPushNotifier(unittest.IsolatedAsyncioTestCase): def setUp(self) -> None: self.mock_httpx_client = AsyncMock(spec=httpx.AsyncClient) self.config_store = InMemoryPushNotificationConfigStore() self.notifier = BasePushNotificationSender( - httpx_client=self.mock_httpx_client, config_store=self.config_store + httpx_client=self.mock_httpx_client, + config_store=self.config_store, + context=MINIMAL_CALL_CONTEXT, ) # Corrected argument name def test_constructor_stores_client(self) -> None: @@ -56,100 +78,121 @@ def test_constructor_stores_client(self) -> None: async def test_set_info_adds_new_config(self) -> None: task_id = 'task_new' - config = create_sample_push_config(url='http://new.url/callback') + config = _create_sample_push_config(url='http://new.url/callback') - await self.config_store.set_info(task_id, config) + await self.config_store.set_info(task_id, config, MINIMAL_CALL_CONTEXT) - self.assertIn(task_id, self.config_store._push_notification_infos) - self.assertEqual( - self.config_store._push_notification_infos[task_id], [config] + retrieved = await self.config_store.get_info( + task_id, MINIMAL_CALL_CONTEXT ) + self.assertEqual(retrieved, [config]) async def test_set_info_appends_to_existing_config(self) -> None: task_id = 'task_update' - initial_config = create_sample_push_config( + initial_config = _create_sample_push_config( url='http://initial.url/callback', config_id='cfg_initial' ) - await self.config_store.set_info(task_id, initial_config) + await self.config_store.set_info( + task_id, initial_config, MINIMAL_CALL_CONTEXT + ) - updated_config = create_sample_push_config( + updated_config = _create_sample_push_config( url='http://updated.url/callback', config_id='cfg_updated' ) - await self.config_store.set_info(task_id, updated_config) - - self.assertIn(task_id, self.config_store._push_notification_infos) - self.assertEqual( - self.config_store._push_notification_infos[task_id][0], - initial_config, + await self.config_store.set_info( + task_id, updated_config, MINIMAL_CALL_CONTEXT ) - self.assertEqual( - self.config_store._push_notification_infos[task_id][1], - updated_config, + + retrieved = await self.config_store.get_info( + task_id, MINIMAL_CALL_CONTEXT ) + self.assertEqual(len(retrieved), 2) + self.assertEqual(retrieved[0], initial_config) + self.assertEqual(retrieved[1], updated_config) async def test_set_info_without_config_id(self) -> None: task_id = 'task1' initial_config = PushNotificationConfig( url='http://initial.url/callback' ) - await self.config_store.set_info(task_id, initial_config) + await self.config_store.set_info( + task_id, initial_config, MINIMAL_CALL_CONTEXT + ) - assert ( - self.config_store._push_notification_infos[task_id][0].id == task_id + retrieved = await self.config_store.get_info( + task_id, MINIMAL_CALL_CONTEXT ) + assert retrieved[0].id == task_id updated_config = PushNotificationConfig( url='http://initial.url/callback_new' ) - await self.config_store.set_info(task_id, updated_config) + await self.config_store.set_info( + task_id, updated_config, MINIMAL_CALL_CONTEXT + ) - self.assertIn(task_id, self.config_store._push_notification_infos) - assert len(self.config_store._push_notification_infos[task_id]) == 1 - self.assertEqual( - self.config_store._push_notification_infos[task_id][0].url, - updated_config.url, + retrieved = await self.config_store.get_info( + task_id, MINIMAL_CALL_CONTEXT ) + assert len(retrieved) == 1 + self.assertEqual(retrieved[0].url, updated_config.url) async def test_get_info_existing_config(self) -> None: task_id = 'task_get_exist' - config = create_sample_push_config(url='http://get.this/callback') - await self.config_store.set_info(task_id, config) + config = _create_sample_push_config(url='http://get.this/callback') + await self.config_store.set_info(task_id, config, MINIMAL_CALL_CONTEXT) - retrieved_config = await self.config_store.get_info(task_id) + retrieved_config = await self.config_store.get_info( + task_id, MINIMAL_CALL_CONTEXT + ) self.assertEqual(retrieved_config, [config]) async def test_get_info_non_existent_config(self) -> None: task_id = 'task_get_non_exist' - retrieved_config = await self.config_store.get_info(task_id) + retrieved_config = await self.config_store.get_info( + task_id, MINIMAL_CALL_CONTEXT + ) assert retrieved_config == [] async def test_delete_info_existing_config(self) -> None: task_id = 'task_delete_exist' - config = create_sample_push_config(url='http://delete.this/callback') - await self.config_store.set_info(task_id, config) + config = _create_sample_push_config(url='http://delete.this/callback') + await self.config_store.set_info(task_id, config, MINIMAL_CALL_CONTEXT) - self.assertIn(task_id, self.config_store._push_notification_infos) - await self.config_store.delete_info(task_id, config_id=config.id) - self.assertNotIn(task_id, self.config_store._push_notification_infos) + retrieved = await self.config_store.get_info( + task_id, MINIMAL_CALL_CONTEXT + ) + self.assertEqual(len(retrieved), 1) + + await self.config_store.delete_info( + task_id, config_id=config.id, context=MINIMAL_CALL_CONTEXT + ) + retrieved = await self.config_store.get_info( + task_id, MINIMAL_CALL_CONTEXT + ) + self.assertEqual(len(retrieved), 0) async def test_delete_info_non_existent_config(self) -> None: task_id = 'task_delete_non_exist' # Ensure it doesn't raise an error try: - await self.config_store.delete_info(task_id) + await self.config_store.delete_info( + task_id, context=MINIMAL_CALL_CONTEXT + ) except Exception as e: self.fail( f'delete_info raised {e} unexpectedly for nonexistent task_id' ) - self.assertNotIn( - task_id, self.config_store._push_notification_infos - ) # Should still not be there + retrieved = await self.config_store.get_info( + task_id, MINIMAL_CALL_CONTEXT + ) + self.assertEqual(len(retrieved), 0) async def test_send_notification_success(self) -> None: task_id = 'task_send_success' - task_data = create_sample_task(task_id=task_id) - config = create_sample_push_config(url='http://notify.me/here') - await self.config_store.set_info(task_id, config) + task_data = _create_sample_task(task_id=task_id) + config = _create_sample_push_config(url='http://notify.me/here') + await self.config_store.set_info(task_id, config, MINIMAL_CALL_CONTEXT) # Mock the post call to simulate success mock_response = AsyncMock(spec=httpx.Response) @@ -172,11 +215,11 @@ async def test_send_notification_success(self) -> None: async def test_send_notification_with_token_success(self) -> None: task_id = 'task_send_success' - task_data = create_sample_task(task_id=task_id) - config = create_sample_push_config( + task_data = _create_sample_task(task_id=task_id) + config = _create_sample_push_config( url='http://notify.me/here', token='unique_token' ) - await self.config_store.set_info(task_id, config) + await self.config_store.set_info(task_id, config, MINIMAL_CALL_CONTEXT) # Mock the post call to simulate success mock_response = AsyncMock(spec=httpx.Response) @@ -203,7 +246,7 @@ async def test_send_notification_with_token_success(self) -> None: async def test_send_notification_no_config(self) -> None: task_id = 'task_send_no_config' - task_data = create_sample_task(task_id=task_id) + task_data = _create_sample_task(task_id=task_id) await self.notifier.send_notification(task_data) # Pass only task_data @@ -214,9 +257,9 @@ async def test_send_notification_http_status_error( self, mock_logger: MagicMock ) -> None: task_id = 'task_send_http_err' - task_data = create_sample_task(task_id=task_id) - config = create_sample_push_config(url='http://notify.me/http_error') - await self.config_store.set_info(task_id, config) + task_data = _create_sample_task(task_id=task_id) + config = _create_sample_push_config(url='http://notify.me/http_error') + await self.config_store.set_info(task_id, config, MINIMAL_CALL_CONTEXT) mock_response = MagicMock( spec=httpx.Response @@ -244,9 +287,9 @@ async def test_send_notification_request_error( self, mock_logger: MagicMock ) -> None: task_id = 'task_send_req_err' - task_data = create_sample_task(task_id=task_id) - config = create_sample_push_config(url='http://notify.me/req_error') - await self.config_store.set_info(task_id, config) + task_data = _create_sample_task(task_id=task_id) + config = _create_sample_push_config(url='http://notify.me/req_error') + await self.config_store.set_info(task_id, config, MINIMAL_CALL_CONTEXT) request_error = httpx.RequestError('Network issue', request=MagicMock()) self.mock_httpx_client.post.side_effect = request_error @@ -271,11 +314,11 @@ async def test_send_notification_with_auth( still works even if the config has an authentication field set. """ task_id = 'task_send_auth' - task_data = create_sample_task(task_id=task_id) - config = create_sample_push_config(url='http://notify.me/auth') + task_data = _create_sample_task(task_id=task_id) + config = _create_sample_push_config(url='http://notify.me/auth') # The current implementation doesn't use the authentication field # It only supports token-based auth via the token field - await self.config_store.set_info(task_id, config) + await self.config_store.set_info(task_id, config, MINIMAL_CALL_CONTEXT) mock_response = AsyncMock(spec=httpx.Response) mock_response.status_code = 200 @@ -295,6 +338,95 @@ async def test_send_notification_with_auth( ) # auth is not passed by current implementation mock_response.raise_for_status.assert_called_once() + async def test_owner_resource_scoping(self) -> None: + """Test that operations are scoped to the correct owner.""" + context_user1 = ServerCallContext(user=SampleUser(user_name='user1')) + context_user2 = ServerCallContext(user=SampleUser(user_name='user2')) + + # Create configs for different owners + task1_u1_config1 = PushNotificationConfig( + id='t1-u1-c1', url='http://u1.com/1' + ) + task1_u1_config2 = PushNotificationConfig( + id='t1-u1-c2', url='http://u1.com/2' + ) + task1_u2_config1 = PushNotificationConfig( + id='t1-u2-c1', url='http://u2.com/1' + ) + task2_u1_config1 = PushNotificationConfig( + id='t2-u1-c1', url='http://u1.com/3' + ) + + await self.config_store.set_info( + 'task1', task1_u1_config1, context_user1 + ) + await self.config_store.set_info( + 'task1', task1_u1_config2, context_user1 + ) + await self.config_store.set_info( + 'task1', task1_u2_config1, context_user2 + ) + await self.config_store.set_info( + 'task2', task2_u1_config1, context_user1 + ) + + # Test GET_INFO + # User 1 should get only their configs for task1 + u1_task1_configs = await self.config_store.get_info( + 'task1', context_user1 + ) + self.assertEqual(len(u1_task1_configs), 2) + self.assertEqual( + {c.id for c in u1_task1_configs}, {'t1-u1-c1', 't1-u1-c2'} + ) + + # User 2 should get only their configs for task1 + u2_task1_configs = await self.config_store.get_info( + 'task1', context_user2 + ) + self.assertEqual(len(u2_task1_configs), 1) + self.assertEqual(u2_task1_configs[0].id, 't1-u2-c1') + + # User 2 should get no configs for task2 + u2_task2_configs = await self.config_store.get_info( + 'task2', context_user2 + ) + self.assertEqual(len(u2_task2_configs), 0) + + # User 1 should get their config for task2 + u1_task2_configs = await self.config_store.get_info( + 'task2', context_user1 + ) + self.assertEqual(len(u1_task2_configs), 1) + self.assertEqual(u1_task2_configs[0].id, 't2-u1-c1') + + # Test DELETE_INFO + # User 2 deleting User 1's config should not work + await self.config_store.delete_info('task1', context_user2, 't1-u1-c1') + u1_task1_configs = await self.config_store.get_info( + 'task1', context_user1 + ) + self.assertEqual(len(u1_task1_configs), 2) + + # User 1 deleting their own config + await self.config_store.delete_info('task1', context_user1, 't1-u1-c1') + u1_task1_configs = await self.config_store.get_info( + 'task1', context_user1 + ) + self.assertEqual(len(u1_task1_configs), 1) + self.assertEqual(u1_task1_configs[0].id, 't1-u1-c2') + + # User 1 deleting all configs for task2 + await self.config_store.delete_info('task2', context=context_user1) + u1_task2_configs = await self.config_store.get_info( + 'task2', context_user1 + ) + self.assertEqual(len(u1_task2_configs), 0) + + # Cleanup remaining + await self.config_store.delete_info('task1', context=context_user1) + await self.config_store.delete_info('task1', context=context_user2) + if __name__ == '__main__': unittest.main() diff --git a/tests/server/tasks/test_inmemory_task_store.py b/tests/server/tasks/test_inmemory_task_store.py index d6ebc5919..6aa1bb7e5 100644 --- a/tests/server/tasks/test_inmemory_task_store.py +++ b/tests/server/tasks/test_inmemory_task_store.py @@ -1,3 +1,4 @@ +from a2a.server.context import ServerCallContext import pytest from datetime import datetime, timezone @@ -5,6 +6,23 @@ from a2a.types.a2a_pb2 import Task, TaskState, TaskStatus, ListTasksRequest from a2a.utils.constants import DEFAULT_LIST_TASKS_PAGE_SIZE +from a2a.auth.user import User + + +class SampleUser(User): + """A test implementation of the User interface.""" + + def __init__(self, user_name: str): + self._user_name = user_name + + @property + def is_authenticated(self) -> bool: + return True + + @property + def user_name(self) -> str: + return self._user_name + def create_minimal_task( task_id: str = 'task-abc', context_id: str = 'session-xyz' @@ -247,3 +265,67 @@ async def test_in_memory_task_store_delete_nonexistent() -> None: """Test deleting a nonexistent task.""" store = InMemoryTaskStore() await store.delete('nonexistent') + + +@pytest.mark.asyncio +async def test_owner_resource_scoping() -> None: + """Test that operations are scoped to the correct owner.""" + store = InMemoryTaskStore() + task = create_minimal_task() + + context_user1 = ServerCallContext(user=SampleUser(user_name='user1')) + context_user2 = ServerCallContext(user=SampleUser(user_name='user2')) + context_user3 = ServerCallContext( + user=SampleUser(user_name='user3') + ) # For testing non-existent user + + # Create tasks for different owners + task1_user1 = Task() + task1_user1.CopyFrom(task) + task1_user1.id = 'u1-task1' + + task2_user1 = Task() + task2_user1.CopyFrom(task) + task2_user1.id = 'u1-task2' + + task1_user2 = Task() + task1_user2.CopyFrom(task) + task1_user2.id = 'u2-task1' + + await store.save(task1_user1, context_user1) + await store.save(task2_user1, context_user1) + await store.save(task1_user2, context_user2) + + # Test GET + assert await store.get('u1-task1', context_user1) is not None + assert await store.get('u1-task1', context_user2) is None + assert await store.get('u2-task1', context_user1) is None + assert await store.get('u2-task1', context_user2) is not None + assert await store.get('u2-task1', context_user3) is None + + # Test LIST + params = ListTasksRequest() + page_user1 = await store.list(params, context_user1) + assert len(page_user1.tasks) == 2 + assert {t.id for t in page_user1.tasks} == {'u1-task1', 'u1-task2'} + assert page_user1.total_size == 2 + + page_user2 = await store.list(params, context_user2) + assert len(page_user2.tasks) == 1 + assert {t.id for t in page_user2.tasks} == {'u2-task1'} + assert page_user2.total_size == 1 + + page_user3 = await store.list(params, context_user3) + assert len(page_user3.tasks) == 0 + assert page_user3.total_size == 0 + + # Test DELETE + await store.delete('u1-task1', context_user2) # Should not delete + assert await store.get('u1-task1', context_user1) is not None + + await store.delete('u1-task1', context_user1) # Should delete + assert await store.get('u1-task1', context_user1) is None + + # Cleanup remaining tasks + await store.delete('u1-task2', context_user1) + await store.delete('u2-task1', context_user2) diff --git a/tests/server/tasks/test_push_notification_sender.py b/tests/server/tasks/test_push_notification_sender.py index a7b5f7603..985ae6b7a 100644 --- a/tests/server/tasks/test_push_notification_sender.py +++ b/tests/server/tasks/test_push_notification_sender.py @@ -5,6 +5,8 @@ import httpx from google.protobuf.json_format import MessageToDict +from a2a.auth.user import User +from a2a.server.context import ServerCallContext from a2a.server.tasks.base_push_notification_sender import ( BasePushNotificationSender, ) @@ -17,7 +19,22 @@ ) -def create_sample_task( +class SampleUser(User): + """A test implementation of the User interface.""" + + def __init__(self, user_name: str): + self._user_name = user_name + + @property + def is_authenticated(self) -> bool: + return True + + @property + def user_name(self) -> str: + return self._user_name + + +def _create_sample_task( task_id: str = 'task123', status_state: TaskState = TaskState.TASK_STATE_COMPLETED, ) -> Task: @@ -28,7 +45,7 @@ def create_sample_task( ) -def create_sample_push_config( +def _create_sample_push_config( url: str = 'http://example.com/callback', config_id: str = 'cfg1', token: str | None = None, @@ -36,6 +53,9 @@ def create_sample_push_config( return PushNotificationConfig(id=config_id, url=url, token=token) +MINIMAL_CALL_CONTEXT = ServerCallContext(user=SampleUser(user_name='user')) + + class TestBasePushNotificationSender(unittest.IsolatedAsyncioTestCase): def setUp(self) -> None: self.mock_httpx_client = AsyncMock(spec=httpx.AsyncClient) @@ -43,6 +63,7 @@ def setUp(self) -> None: self.sender = BasePushNotificationSender( httpx_client=self.mock_httpx_client, config_store=self.mock_config_store, + context=MINIMAL_CALL_CONTEXT, ) def test_constructor_stores_client_and_config_store(self) -> None: @@ -51,8 +72,8 @@ def test_constructor_stores_client_and_config_store(self) -> None: async def test_send_notification_success(self) -> None: task_id = 'task_send_success' - task_data = create_sample_task(task_id=task_id) - config = create_sample_push_config(url='http://notify.me/here') + task_data = _create_sample_task(task_id=task_id) + config = _create_sample_push_config(url='http://notify.me/here') self.mock_config_store.get_info.return_value = [config] mock_response = AsyncMock(spec=httpx.Response) @@ -61,7 +82,9 @@ async def test_send_notification_success(self) -> None: await self.sender.send_notification(task_data) - self.mock_config_store.get_info.assert_awaited_once_with + self.mock_config_store.get_info.assert_awaited_once_with( + task_id, MINIMAL_CALL_CONTEXT + ) # assert httpx_client post method got invoked with right parameters self.mock_httpx_client.post.assert_awaited_once_with( @@ -73,8 +96,8 @@ async def test_send_notification_success(self) -> None: async def test_send_notification_with_token_success(self) -> None: task_id = 'task_send_success' - task_data = create_sample_task(task_id=task_id) - config = create_sample_push_config( + task_data = _create_sample_task(task_id=task_id) + config = _create_sample_push_config( url='http://notify.me/here', token='unique_token' ) self.mock_config_store.get_info.return_value = [config] @@ -85,7 +108,9 @@ async def test_send_notification_with_token_success(self) -> None: await self.sender.send_notification(task_data) - self.mock_config_store.get_info.assert_awaited_once_with + self.mock_config_store.get_info.assert_awaited_once_with( + task_id, MINIMAL_CALL_CONTEXT + ) # assert httpx_client post method got invoked with right parameters self.mock_httpx_client.post.assert_awaited_once_with( @@ -97,12 +122,14 @@ async def test_send_notification_with_token_success(self) -> None: async def test_send_notification_no_config(self) -> None: task_id = 'task_send_no_config' - task_data = create_sample_task(task_id=task_id) + task_data = _create_sample_task(task_id=task_id) self.mock_config_store.get_info.return_value = [] await self.sender.send_notification(task_data) - self.mock_config_store.get_info.assert_awaited_once_with(task_id) + self.mock_config_store.get_info.assert_awaited_once_with( + task_id, MINIMAL_CALL_CONTEXT + ) self.mock_httpx_client.post.assert_not_called() @patch('a2a.server.tasks.base_push_notification_sender.logger') @@ -110,8 +137,8 @@ async def test_send_notification_http_status_error( self, mock_logger: MagicMock ) -> None: task_id = 'task_send_http_err' - task_data = create_sample_task(task_id=task_id) - config = create_sample_push_config(url='http://notify.me/http_error') + task_data = _create_sample_task(task_id=task_id) + config = _create_sample_push_config(url='http://notify.me/http_error') self.mock_config_store.get_info.return_value = [config] mock_response = MagicMock(spec=httpx.Response) @@ -124,7 +151,9 @@ async def test_send_notification_http_status_error( await self.sender.send_notification(task_data) - self.mock_config_store.get_info.assert_awaited_once_with(task_id) + self.mock_config_store.get_info.assert_awaited_once_with( + task_id, MINIMAL_CALL_CONTEXT + ) self.mock_httpx_client.post.assert_awaited_once_with( config.url, json=MessageToDict(StreamResponse(task=task_data)), @@ -134,11 +163,11 @@ async def test_send_notification_http_status_error( async def test_send_notification_multiple_configs(self) -> None: task_id = 'task_multiple_configs' - task_data = create_sample_task(task_id=task_id) - config1 = create_sample_push_config( + task_data = _create_sample_task(task_id=task_id) + config1 = _create_sample_push_config( url='http://notify.me/cfg1', config_id='cfg1' ) - config2 = create_sample_push_config( + config2 = _create_sample_push_config( url='http://notify.me/cfg2', config_id='cfg2' ) self.mock_config_store.get_info.return_value = [config1, config2] @@ -149,7 +178,9 @@ async def test_send_notification_multiple_configs(self) -> None: await self.sender.send_notification(task_data) - self.mock_config_store.get_info.assert_awaited_once_with(task_id) + self.mock_config_store.get_info.assert_awaited_once_with( + task_id, MINIMAL_CALL_CONTEXT + ) self.assertEqual(self.mock_httpx_client.post.call_count, 2) # Check calls for config1 diff --git a/tests/server/test_owner_resolver.py b/tests/server/test_owner_resolver.py new file mode 100644 index 000000000..8a0686865 --- /dev/null +++ b/tests/server/test_owner_resolver.py @@ -0,0 +1,31 @@ +from a2a.auth.user import User + +from a2a.server.context import ServerCallContext +from a2a.server.owner_resolver import resolve_user_scope + + +class TestUser(User): + """A test implementation of the User interface.""" + + def __init__(self, user_name: str): + self._user_name = user_name + + @property + def is_authenticated(self) -> bool: + return True + + @property + def user_name(self) -> str: + return self._user_name + + +def test_resolve_user_scope_valid_user(): + """Test resolve_user_scope with a valid user in the context.""" + user = TestUser(user_name='testuser') + context = ServerCallContext(user=user) + assert resolve_user_scope(context) == 'testuser' + + +def test_resolve_user_scope_no_context(): + """Test resolve_user_scope when the context is None.""" + assert resolve_user_scope(None) == 'unknown' diff --git a/uv.lock b/uv.lock index 2cecfc177..748ef3ee6 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.10" resolution-markers = [ "python_full_version >= '3.14'", @@ -70,6 +70,7 @@ telemetry = [ [package.dev-dependencies] dev = [ { name = "a2a-sdk", extra = ["all"] }, + { name = "alembic" }, { name = "autoflake" }, { name = "mypy" }, { name = "no-implicit-optional" }, @@ -135,6 +136,7 @@ provides-extras = ["all", "encryption", "grpc", "http-server", "mysql", "postgre [package.metadata.requires-dev] dev = [ { name = "a2a-sdk", extras = ["all"], editable = "." }, + { name = "alembic", specifier = ">=1.14.0" }, { name = "autoflake" }, { name = "mypy", specifier = ">=1.15.0" }, { name = "no-implicit-optional" }, @@ -177,6 +179,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/00/b7/e3bf5133d697a08128598c8d0abc5e16377b51465a33756de24fa7dee953/aiosqlite-0.22.1-py3-none-any.whl", hash = "sha256:21c002eb13823fad740196c5a2e9d8e62f6243bd9e7e4a1f87fb5e44ecb4fceb", size = 17405, upload-time = "2025-12-23T19:25:42.139Z" }, ] +[[package]] +name = "alembic" +version = "1.18.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mako" }, + { name = "sqlalchemy" }, + { name = "tomli", marker = "python_full_version < '3.11'" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/94/13/8b084e0f2efb0275a1d534838844926f798bd766566b1375174e2448cd31/alembic-1.18.4.tar.gz", hash = "sha256:cb6e1fd84b6174ab8dbb2329f86d631ba9559dd78df550b57804d607672cedbc", size = 2056725, upload-time = "2026-02-10T16:00:47.195Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d2/29/6533c317b74f707ea28f8d633734dbda2119bbadfc61b2f3640ba835d0f7/alembic-1.18.4-py3-none-any.whl", hash = "sha256:a5ed4adcf6d8a4cb575f3d759f071b03cd6e5c7618eb796cb52497be25bfe19a", size = 263893, upload-time = "2026-02-10T16:00:49.997Z" }, +] + [[package]] name = "annotated-doc" version = "0.0.4" @@ -1277,6 +1294,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/94/d1/433b3c06e78f23486fe4fdd19bc134657eb30997d2054b0dbf52bbf3382e/librt-0.8.0-cp314-cp314t-win_arm64.whl", hash = "sha256:92249938ab744a5890580d3cb2b22042f0dce71cdaa7c1369823df62bedf7cbc", size = 48753, upload-time = "2026-02-12T14:53:38.539Z" }, ] +[[package]] +name = "mako" +version = "1.3.10" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markupsafe" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9e/38/bd5b78a920a64d708fe6bc8e0a2c075e1389d53bef8413725c63ba041535/mako-1.3.10.tar.gz", hash = "sha256:99579a6f39583fa7e5630a28c3c1f440e4e97a414b80372649c0ce338da2ea28", size = 392474, upload-time = "2025-04-10T12:44:31.16Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/87/fb/99f81ac72ae23375f22b7afdb7642aba97c00a713c217124420147681a2f/mako-1.3.10-py3-none-any.whl", hash = "sha256:baef24a52fc4fc514a0887ac600f9f1cff3d82c61d4d700a1fa84d597b88db59", size = 78509, upload-time = "2025-04-10T12:50:53.297Z" }, +] + [[package]] name = "markupsafe" version = "3.0.3" @@ -2323,7 +2352,7 @@ wheels = [ [[package]] name = "virtualenv" -version = "20.37.0" +version = "20.38.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "distlib" }, @@ -2331,9 +2360,9 @@ dependencies = [ { name = "platformdirs" }, { name = "typing-extensions", marker = "python_full_version < '3.11'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c1/ef/d9d4ce633df789bf3430bd81fb0d8b9d9465dfc1d1f0deb3fb62cd80f5c2/virtualenv-20.37.0.tar.gz", hash = "sha256:6f7e2064ed470aa7418874e70b6369d53b66bcd9e9fd5389763e96b6c94ccb7c", size = 5864710, upload-time = "2026-02-16T16:17:59.42Z" } +sdist = { url = "https://files.pythonhosted.org/packages/d2/03/a94d404ca09a89a7301a7008467aed525d4cdeb9186d262154dd23208709/virtualenv-20.38.0.tar.gz", hash = "sha256:94f39b1abaea5185bf7ea5a46702b56f1d0c9aa2f41a6c2b8b0af4ddc74c10a7", size = 5864558, upload-time = "2026-02-19T07:48:02.385Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/42/4b/6cf85b485be7ec29db837ec2a1d8cd68bc1147b1abf23d8636c5bd65b3cc/virtualenv-20.37.0-py3-none-any.whl", hash = "sha256:5d3951c32d57232ae3569d4de4cc256c439e045135ebf43518131175d9be435d", size = 5837480, upload-time = "2026-02-16T16:17:57.341Z" }, + { url = "https://files.pythonhosted.org/packages/42/d7/394801755d4c8684b655d35c665aea7836ec68320304f62ab3c94395b442/virtualenv-20.38.0-py3-none-any.whl", hash = "sha256:d6e78e5889de3a4742df2d3d44e779366325a90cf356f15621fddace82431794", size = 5837778, upload-time = "2026-02-19T07:47:59.778Z" }, ] [[package]]