Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions taskiq_nats/result_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from taskiq import AsyncResultBackend, ResultGetError
from taskiq.abc.serializer import TaskiqSerializer
from taskiq.compat import model_dump, model_validate
from taskiq.depends.progress_tracker import TaskProgress
from taskiq.result import TaskiqResult
from taskiq.serializers import PickleSerializer

Expand Down Expand Up @@ -123,3 +124,23 @@ async def get_result(
taskiq_result.log = None

return taskiq_result

async def set_progress(
self,
task_id: str,
progress: TaskProgress[Any],
) -> None:
await self.object_store.put(
name=f"progress:{task_id}",
data=self.serializer.dumpb(model_dump(progress)),
)

async def get_progress(self, task_id: str) -> TaskProgress[Any] | None:
try:
result = await self.object_store.get(name=f"progress:{task_id}")
except ObjectNotFoundError:
return None
return model_validate(
TaskProgress[Any],
self.serializer.loadb(result.data),
)
21 changes: 21 additions & 0 deletions tests/test_result_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest
from taskiq import ResultGetError, TaskiqResult
from taskiq.depends.progress_tracker import TaskProgress

from taskiq_nats import NATSObjectStoreResultBackend

Expand Down Expand Up @@ -146,3 +147,23 @@ async def test_success_backend_is_result_ready(
)

assert await nats_result_backend.is_result_ready(task_id=task_id)


async def test_set_and_get_progress(
nats_result_backend: NATSObjectStoreResultBackend[_ReturnType],
task_id: str,
) -> None:
progress = TaskProgress(state="PROGRESS", meta={"current": 5, "total": 10})
await nats_result_backend.set_progress(task_id=task_id, progress=progress)
result = await nats_result_backend.get_progress(task_id=task_id)
assert result is not None
assert result.state == "PROGRESS"
assert result.meta == {"current": 5, "total": 10}


async def test_get_progress_not_found(
nats_result_backend: NATSObjectStoreResultBackend[_ReturnType],
task_id: str,
) -> None:
result = await nats_result_backend.get_progress(task_id=task_id)
assert result is None