Skip to content
1 change: 1 addition & 0 deletions aredis_om/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from . import redis


URL = os.environ.get("REDIS_OM_URL", None)


Expand Down
1 change: 1 addition & 0 deletions aredis_om/model/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

from pydantic import BaseModel


try:
from pydantic.deprecated.json import ENCODERS_BY_TYPE
from pydantic_core import PydanticUndefined
Expand Down
1 change: 1 addition & 0 deletions aredis_om/model/migrations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
SchemaMigrator,
)


__all__ = [
# Data migrations
"BaseMigration",
Expand Down
1 change: 1 addition & 0 deletions aredis_om/model/migrations/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@
from .base import BaseMigration, DataMigrationError
from .migrator import DataMigrator


__all__ = ["BaseMigration", "DataMigrationError", "DataMigrator"]
1 change: 1 addition & 0 deletions aredis_om/model/migrations/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import time
from typing import Any, Dict, List


try:
import psutil
except ImportError:
Expand Down
1 change: 1 addition & 0 deletions aredis_om/model/migrations/data/builtin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@
DatetimeFieldMigration,
)


__all__ = ["DatetimeFieldMigration", "DatetimeFieldDetector", "ConversionFailureMode"]
11 changes: 7 additions & 4 deletions aredis_om/model/migrations/data/builtin/datetime_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from ..base import BaseMigration, DataMigrationError


log = logging.getLogger(__name__)


Expand Down Expand Up @@ -180,9 +181,9 @@ def __init__(self):
self.converted_fields = 0
self.skipped_fields = 0
self.failed_conversions = 0
self.errors: List[Tuple[str, str, str, Exception]] = (
[]
) # (key, field, value, error)
self.errors: List[
Tuple[str, str, str, Exception]
] = [] # (key, field, value, error)

def add_conversion_error(self, key: str, field: str, value: Any, error: Exception):
"""Record a conversion error."""
Expand Down Expand Up @@ -393,7 +394,9 @@ async def save_progress(
}

await self.redis.set(
self.state_key, json.dumps(state_data), ex=86400 # Expire after 24 hours
self.state_key,
json.dumps(state_data),
ex=86400, # Expire after 24 hours
)

async def load_progress(self) -> Dict[str, Any]:
Expand Down
1 change: 1 addition & 0 deletions aredis_om/model/migrations/schema/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .legacy_migrator import MigrationAction, MigrationError, Migrator, SchemaDetector
from .migrator import SchemaMigrator


__all__ = [
# Primary API
"BaseSchemaMigration",
Expand Down
4 changes: 3 additions & 1 deletion aredis_om/model/migrations/schema/legacy_migrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import redis


log = logging.getLogger(__name__)


Expand Down Expand Up @@ -52,7 +53,8 @@ def import_submodules(root_module_name: str):
)

for loader, module_name, is_pkg in pkgutil.walk_packages(
root_module.__path__, root_module.__name__ + "." # type: ignore
root_module.__path__,
root_module.__name__ + ".", # type: ignore
):
importlib.import_module(module_name)

Expand Down
61 changes: 36 additions & 25 deletions aredis_om/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,14 @@
Type,
TypeVar,
Union,
)
from typing import get_args as typing_get_args
from typing import (
no_type_check,
)
from typing import get_args as typing_get_args

from more_itertools import ichunked
from pydantic import BaseModel


try:
from pydantic import ConfigDict, TypeAdapter, field_validator

Expand Down Expand Up @@ -73,6 +72,7 @@
from .token_escaper import TokenEscaper
from .types import Coordinates, CoordinateType, GeoFilter


model_registry = {}
_T = TypeVar("_T")
Model = TypeVar("Model", bound="RedisModel")
Expand Down Expand Up @@ -115,8 +115,11 @@ def convert_datetime_to_timestamp(obj):
elif isinstance(obj, datetime.datetime):
return obj.timestamp()
elif isinstance(obj, datetime.date):
# Convert date to datetime at midnight and get timestamp
dt = datetime.datetime.combine(obj, datetime.time.min)
# Date values represent calendar days, so normalize to UTC midnight
# to avoid timezone-dependent day shifts on round-trip conversion.
dt = datetime.datetime.combine(
obj, datetime.time.min, tzinfo=datetime.timezone.utc
)
return dt.timestamp()
else:
return obj
Expand All @@ -138,7 +141,9 @@ def convert_timestamp_to_datetime(obj, model_fields):
# For Optional[T] which is Union[T, None], get the non-None type
args = getattr(field_type, "__args__", ())
non_none_types = [
arg for arg in args if arg is not type(None) # noqa: E721
arg
for arg in args
if arg is not type(None) # noqa: E721
]
if len(non_none_types) == 1:
field_type = non_none_types[0]
Expand All @@ -150,8 +155,13 @@ def convert_timestamp_to_datetime(obj, model_fields):
try:
if isinstance(value, str):
value = float(value)
# Use fromtimestamp to preserve local timezone behavior
dt = datetime.datetime.fromtimestamp(value)
# Return UTC-aware datetime for consistency.
# Timestamps are always UTC-referenced, so we return
# UTC-aware datetimes. Users can convert to their
# preferred timezone with dt.astimezone(tz).
dt = datetime.datetime.fromtimestamp(
value, datetime.timezone.utc
)
# If the field is specifically a date, convert to date
if field_type is datetime.date:
result[key] = dt.date()
Expand Down Expand Up @@ -255,7 +265,9 @@ def convert_base64_to_bytes(obj, model_fields):
# For Optional[T] which is Union[T, None], get the non-None type
args = getattr(field_type, "__args__", ())
non_none_types = [
arg for arg in args if arg is not type(None) # noqa: E721
arg
for arg in args
if arg is not type(None) # noqa: E721
]
if len(non_none_types) == 1:
field_type = non_none_types[0]
Expand Down Expand Up @@ -636,10 +648,10 @@ def embedded(cls):

def is_supported_container_type(typ: Optional[type]) -> bool:
# TODO: Wait, why don't we support indexing sets?
if typ == list or typ == tuple or typ == Literal:
if typ is list or typ is tuple or typ is Literal:
return True
unwrapped = get_origin(typ)
return unwrapped == list or unwrapped == tuple or unwrapped == Literal
return unwrapped is list or unwrapped is tuple or unwrapped is Literal


def validate_model_fields(model: Type["RedisModel"], field_values: Dict[str, Any]):
Expand Down Expand Up @@ -1056,7 +1068,7 @@ def _validate_deep_field_path(self, field_path: str):
field_type, RedisModel
):
current_model = field_type
elif field_type == dict:
elif field_type is dict:
# Dict fields - we can't validate nested paths, just accept them
return
else:
Expand Down Expand Up @@ -1089,7 +1101,7 @@ def _validate_deep_field_path(self, field_path: str):
field_type, RedisModel
):
current_model = field_type
elif field_type == dict:
elif field_type is dict:
return # Can't validate further into dict
else:
raise QueryNotSupportedError(
Expand Down Expand Up @@ -1174,18 +1186,18 @@ def _convert_projected_fields(self, raw_data: Dict[str, str]) -> Dict[str, Any]:
field_type = getattr(field_info, "type_", str)

# Handle common type conversions directly for efficiency
if field_type == int:
if field_type is int:
converted_data[field_name] = int(raw_value)
elif field_type == float:
elif field_type is float:
converted_data[field_name] = float(raw_value)
elif field_type == bool:
elif field_type is bool:
# Redis may store bool as "True"/"False" or "1"/"0"
converted_data[field_name] = raw_value.lower() in (
"true",
"1",
"yes",
)
elif field_type == str:
elif field_type is str:
converted_data[field_name] = raw_value
else:
# For complex types, keep as string (could be enhanced later)
Expand Down Expand Up @@ -1231,7 +1243,7 @@ def _has_complex_projected_fields(self) -> bool:
field_type = getattr(field_info, "annotation", None)

# Check for dict fields
if field_type == dict:
if field_type is dict:
return True

# Check for embedded models (subclasses of RedisModel)
Expand Down Expand Up @@ -1524,8 +1536,7 @@ def expand_tag_value(value):
return "|".join([escaper.escape(str(v)) for v in value])
except TypeError:
log.debug(
"Escaping single non-iterable value used for an IN or "
"NOT_IN query: %s",
"Escaping single non-iterable value used for an IN or NOT_IN query: %s",
value,
)
return escaper.escape(str(value))
Expand Down Expand Up @@ -1571,8 +1582,10 @@ def convert_numeric_value(v):
if isinstance(v, datetime.date) and not isinstance(
v, datetime.datetime
):
# Convert date to datetime at midnight
v = datetime.datetime.combine(v, datetime.time.min)
# Use UTC midnight so query conversion matches storage conversion.
v = datetime.datetime.combine(
v, datetime.time.min, tzinfo=datetime.timezone.utc
)
v = v.timestamp()
return v

Expand Down Expand Up @@ -3352,9 +3365,7 @@ def schema_for_type(cls, name, typ: Any, field_info: PydanticFieldInfo):
field_info, "separator", SINGLE_VALUE_TAG_FIELD_SEPARATOR
)
if getattr(field_info, "full_text_search", False) is True:
schema = (
f"{name} TAG SEPARATOR {separator} " f"{name} AS {name}_fts TEXT"
)
schema = f"{name} TAG SEPARATOR {separator} {name} AS {name}_fts TEXT"
else:
schema = f"{name} TAG SEPARATOR {separator}"
elif issubclass(typ, RedisModel):
Expand Down
1 change: 1 addition & 0 deletions aredis_om/model/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Annotated, Any, Literal, Tuple, Union


try:
from pydantic import BeforeValidator, PlainSerializer
from pydantic_extra_types.coordinate import Coordinate
Expand Down
33 changes: 27 additions & 6 deletions make_sync.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import re
import shutil
import subprocess
from pathlib import Path

import unasync
Expand Down Expand Up @@ -116,25 +117,45 @@ def remove_run_async_call(match):
# Post-process model.py to fix async imports for sync version
model_file = Path(__file__).absolute().parent / "redis_om/model/model.py"
if model_file.exists():
with open(model_file, 'r') as f:
with open(model_file, "r") as f:
content = f.read()

# Fix supports_hash_field_expiration to check sync Redis class
# The unasync replacement doesn't work for dotted attribute access
content = content.replace(
'redis_lib.asyncio.Redis',
'redis_lib.Redis'
"redis_lib.asyncio.Redis",
"redis_lib.Redis",
)

# Fix Pipeline import: redis.asyncio.client -> redis.client
content = content.replace(
'from redis.asyncio.client import Pipeline',
'from redis.client import Pipeline'
"from redis.asyncio.client import Pipeline",
"from redis.client import Pipeline",
)

with open(model_file, 'w') as f:
with open(model_file, "w") as f:
f.write(content)

# Fix duplicated import introduced by Async->sync class replacement.
redisvl_file = Path(__file__).absolute().parent / "redis_om/redisvl.py"
if redisvl_file.exists():
with open(redisvl_file, "r") as f:
content = f.read()

content = content.replace(
"from redisvl.index import SearchIndex, SearchIndex",
"from redisvl.index import SearchIndex",
)

with open(redisvl_file, "w") as f:
f.write(content)

# Ensure generated sync code is formatter-clean for CI lint checks.
subprocess.run(
["ruff", "format", str(redis_om_dir), str(tests_sync_dir)],
check=True,
)


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from aredis_om import get_redis_connection


TEST_PREFIX = "redis-om:testing"


Expand Down
1 change: 1 addition & 0 deletions tests/test_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from aredis_om import EmbeddedJsonModel, Field, HashModel, JsonModel, Migrator


# Skip if pytest-benchmark is not installed
pytest.importorskip("pytest_benchmark")

Expand Down
1 change: 1 addition & 0 deletions tests/test_bug_fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from .conftest import py_test_mark_asyncio


if not has_redisearch():
pytestmark = pytest.mark.skip

Expand Down
Loading
Loading