From cc3b3e6100f6244edbd475bb5acef04037bf1f76 Mon Sep 17 00:00:00 2001 From: Amos Date: Thu, 26 Feb 2026 15:00:43 +0100 Subject: [PATCH] add support for ProgressTracker --- taskiq_nats/result_backend.py | 21 +++++++++++++++++++++ tests/test_result_backend.py | 21 +++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/taskiq_nats/result_backend.py b/taskiq_nats/result_backend.py index a9ea934..07b44e8 100644 --- a/taskiq_nats/result_backend.py +++ b/taskiq_nats/result_backend.py @@ -8,6 +8,7 @@ from taskiq import AsyncResultBackend, ResultGetError from taskiq.abc.serializer import TaskiqSerializer from taskiq.compat import model_dump, model_validate +from taskiq.depends.progress_tracker import TaskProgress from taskiq.result import TaskiqResult from taskiq.serializers import PickleSerializer @@ -123,3 +124,23 @@ async def get_result( taskiq_result.log = None return taskiq_result + + async def set_progress( + self, + task_id: str, + progress: TaskProgress[Any], + ) -> None: + await self.object_store.put( + name=f"progress:{task_id}", + data=self.serializer.dumpb(model_dump(progress)), + ) + + async def get_progress(self, task_id: str) -> TaskProgress[Any] | None: + try: + result = await self.object_store.get(name=f"progress:{task_id}") + except ObjectNotFoundError: + return None + return model_validate( + TaskProgress[Any], + self.serializer.loadb(result.data), + ) diff --git a/tests/test_result_backend.py b/tests/test_result_backend.py index fcc940d..50a9efd 100644 --- a/tests/test_result_backend.py +++ b/tests/test_result_backend.py @@ -3,6 +3,7 @@ import pytest from taskiq import ResultGetError, TaskiqResult +from taskiq.depends.progress_tracker import TaskProgress from taskiq_nats import NATSObjectStoreResultBackend @@ -146,3 +147,23 @@ async def test_success_backend_is_result_ready( ) assert await nats_result_backend.is_result_ready(task_id=task_id) + + +async def test_set_and_get_progress( + nats_result_backend: NATSObjectStoreResultBackend[_ReturnType], + task_id: str, +) -> None: + progress = TaskProgress(state="PROGRESS", meta={"current": 5, "total": 10}) + await nats_result_backend.set_progress(task_id=task_id, progress=progress) + result = await nats_result_backend.get_progress(task_id=task_id) + assert result is not None + assert result.state == "PROGRESS" + assert result.meta == {"current": 5, "total": 10} + + +async def test_get_progress_not_found( + nats_result_backend: NATSObjectStoreResultBackend[_ReturnType], + task_id: str, +) -> None: + result = await nats_result_backend.get_progress(task_id=task_id) + assert result is None