diff --git a/src/strands/models/mistral.py b/src/strands/models/mistral.py index f44a11d30..7c9f6fbbd 100644 --- a/src/strands/models/mistral.py +++ b/src/strands/models/mistral.py @@ -9,8 +9,14 @@ from collections.abc import AsyncGenerator, Iterable from typing import Any, TypeVar -import mistralai from pydantic import BaseModel + +# Support both mistralai v1.x and v2.x import paths. +# v2.0 moved the client class from mistralai.Mistral to mistralai.client.Mistral. +try: + from mistralai.client import Mistral as MistralClient # type: ignore[attr-defined] +except ImportError: + from mistralai import Mistral as MistralClient from typing_extensions import TypedDict, Unpack, override from ..types.content import ContentBlock, Messages @@ -434,7 +440,7 @@ async def stream( logger.debug("got response from model") if not self.config.get("stream", True): # Use non-streaming API - async with mistralai.Mistral(**self.client_args) as client: + async with MistralClient(**self.client_args) as client: response = await client.chat.complete_async(**request) for event in self._handle_non_streaming_response(response): yield self.format_chunk(event) @@ -442,7 +448,7 @@ async def stream( return # Use the streaming API - async with mistralai.Mistral(**self.client_args) as client: + async with MistralClient(**self.client_args) as client: stream_response = await client.chat.stream_async(**request) yield self.format_chunk({"chunk_type": "message_start"}) @@ -535,7 +541,7 @@ async def structured_output( formatted_request["tool_choice"] = "any" formatted_request["parallel_tool_calls"] = False - async with mistralai.Mistral(**self.client_args) as client: + async with MistralClient(**self.client_args) as client: response = await client.chat.complete_async(**formatted_request) if response.choices and response.choices[0].message.tool_calls: diff --git a/tests/strands/models/test_mistral.py b/tests/strands/models/test_mistral.py index 57189748e..f2220835a 100644 --- a/tests/strands/models/test_mistral.py +++ b/tests/strands/models/test_mistral.py @@ -1,17 +1,18 @@ +import importlib import logging import unittest.mock import pydantic import pytest -import strands +import strands.models.mistral as mistral_module from strands.models.mistral import MistralModel from strands.types.exceptions import ModelThrottledException @pytest.fixture def mistral_client(): - with unittest.mock.patch.object(strands.models.mistral.mistralai, "Mistral") as mock_client_cls: + with unittest.mock.patch("strands.models.mistral.MistralClient") as mock_client_cls: mock_client = unittest.mock.AsyncMock() mock_client_cls.return_value.__aenter__.return_value = mock_client yield mock_client @@ -679,3 +680,42 @@ def test_format_request_filters_location_source_document(model, caplog): user_content = formatted_messages[0]["content"] assert user_content == "analyze this document" assert "Location sources are not supported by Mistral" in caplog.text + + +def test_mistral_client_import_v2(): + """Test that MistralClient resolves from mistralai.client.Mistral (v2.x import path).""" + mock_client_cls = unittest.mock.MagicMock(name="MistralClientV2") + mock_client_module = unittest.mock.MagicMock() + mock_client_module.Mistral = mock_client_cls + + with unittest.mock.patch.dict("sys.modules", {"mistralai.client": mock_client_module}): + importlib.reload(mistral_module) + + actual_client = mistral_module.MistralClient + exp_client = mock_client_cls + + assert actual_client is exp_client + assert actual_client is not None + + # Restore original module state + importlib.reload(mistral_module) + + +def test_mistral_client_import_v1_fallback(): + """Test that MistralClient falls back to mistralai.Mistral when mistralai.client is unavailable (v1.x path).""" + mock_client_cls = unittest.mock.MagicMock(name="MistralClientV1") + mock_mistralai = unittest.mock.MagicMock() + mock_mistralai.Mistral = mock_client_cls + + # Setting a module to None in sys.modules causes ImportError on import + with unittest.mock.patch.dict("sys.modules", {"mistralai.client": None, "mistralai": mock_mistralai}): + importlib.reload(mistral_module) + + actual_client = mistral_module.MistralClient + exp_client = mock_client_cls + + assert actual_client is exp_client + assert actual_client is not None + + # Restore original module state + importlib.reload(mistral_module)