Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions src/strands/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,14 @@
from collections.abc import AsyncGenerator, Iterable
from typing import Any, TypeVar

import mistralai
from pydantic import BaseModel

# Support both mistralai v1.x and v2.x import paths.
# v2.0 moved the client class from mistralai.Mistral to mistralai.client.Mistral.
try:
from mistralai.client import Mistral as MistralClient # type: ignore[attr-defined]
except ImportError:
from mistralai import Mistral as MistralClient
from typing_extensions import TypedDict, Unpack, override

from ..types.content import ContentBlock, Messages
Expand Down Expand Up @@ -434,15 +440,15 @@ async def stream(
logger.debug("got response from model")
if not self.config.get("stream", True):
# Use non-streaming API
async with mistralai.Mistral(**self.client_args) as client:
async with MistralClient(**self.client_args) as client:
response = await client.chat.complete_async(**request)
for event in self._handle_non_streaming_response(response):
yield self.format_chunk(event)

return

# Use the streaming API
async with mistralai.Mistral(**self.client_args) as client:
async with MistralClient(**self.client_args) as client:
stream_response = await client.chat.stream_async(**request)

yield self.format_chunk({"chunk_type": "message_start"})
Expand Down Expand Up @@ -535,7 +541,7 @@ async def structured_output(
formatted_request["tool_choice"] = "any"
formatted_request["parallel_tool_calls"] = False

async with mistralai.Mistral(**self.client_args) as client:
async with MistralClient(**self.client_args) as client:
response = await client.chat.complete_async(**formatted_request)

if response.choices and response.choices[0].message.tool_calls:
Expand Down
44 changes: 42 additions & 2 deletions tests/strands/models/test_mistral.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import importlib
import logging
import unittest.mock

import pydantic
import pytest

import strands
import strands.models.mistral as mistral_module
from strands.models.mistral import MistralModel
from strands.types.exceptions import ModelThrottledException


@pytest.fixture
def mistral_client():
with unittest.mock.patch.object(strands.models.mistral.mistralai, "Mistral") as mock_client_cls:
with unittest.mock.patch("strands.models.mistral.MistralClient") as mock_client_cls:
mock_client = unittest.mock.AsyncMock()
mock_client_cls.return_value.__aenter__.return_value = mock_client
yield mock_client
Expand Down Expand Up @@ -679,3 +680,42 @@ def test_format_request_filters_location_source_document(model, caplog):
user_content = formatted_messages[0]["content"]
assert user_content == "analyze this document"
assert "Location sources are not supported by Mistral" in caplog.text


def test_mistral_client_import_v2():
"""Test that MistralClient resolves from mistralai.client.Mistral (v2.x import path)."""
mock_client_cls = unittest.mock.MagicMock(name="MistralClientV2")
mock_client_module = unittest.mock.MagicMock()
mock_client_module.Mistral = mock_client_cls

with unittest.mock.patch.dict("sys.modules", {"mistralai.client": mock_client_module}):
importlib.reload(mistral_module)

actual_client = mistral_module.MistralClient
exp_client = mock_client_cls

assert actual_client is exp_client
assert actual_client is not None

# Restore original module state
importlib.reload(mistral_module)


def test_mistral_client_import_v1_fallback():
"""Test that MistralClient falls back to mistralai.Mistral when mistralai.client is unavailable (v1.x path)."""
mock_client_cls = unittest.mock.MagicMock(name="MistralClientV1")
mock_mistralai = unittest.mock.MagicMock()
mock_mistralai.Mistral = mock_client_cls

# Setting a module to None in sys.modules causes ImportError on import
with unittest.mock.patch.dict("sys.modules", {"mistralai.client": None, "mistralai": mock_mistralai}):
importlib.reload(mistral_module)

actual_client = mistral_module.MistralClient
exp_client = mock_client_cls

assert actual_client is exp_client
assert actual_client is not None

# Restore original module state
importlib.reload(mistral_module)