Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 43 additions & 21 deletions src/datajoint/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def __init__(
port: int | None = None,
use_tls: bool | dict | None = None,
*,
database_name: str | None = None,
backend: str | None = None,
config_override: "Config | None" = None,
) -> None:
Expand All @@ -180,7 +181,9 @@ def __init__(
port = int(port)
elif port is None:
port = self._config["database.port"]
self.conn_info = dict(host=host, port=port, user=user, passwd=password)
if database_name is None:
database_name = self._config.get("database.name")
self.conn_info = dict(host=host, port=port, user=user, passwd=password, database_name=database_name)
if use_tls is not False:
# use_tls can be: None (auto-detect), True (enable), False (disable), or dict (custom config)
if isinstance(use_tls, dict):
Expand All @@ -201,12 +204,27 @@ def __init__(
backend = self._config["database.backend"]
self.adapter = get_adapter(backend)

if database_name and self.adapter.backend == "mysql":
warnings.warn(
"database.name is set but the MySQL backend does not support database selection. "
"This setting only applies to PostgreSQL connections.",
UserWarning,
stacklevel=2,
)

self.connect()
if self.is_connected:
logger.info("DataJoint {version} connected to {user}@{host}:{port}".format(version=__version__, **self.conn_info))
db = self.conn_info.get("database_name")
db_str = f"/{db}" if db else ""
logger.info(
f"DataJoint {__version__} connected to "
f"{self.conn_info['user']}@{self.conn_info['host']}:{self.conn_info['port']}{db_str}"
)
self.connection_id = self.adapter.get_connection_id(self._conn)
else:
raise errors.LostConnectionError("Connection failed {user}@{host}:{port}".format(**self.conn_info))
raise errors.LostConnectionError(
f"Connection failed {self.conn_info['user']}@{self.conn_info['host']}:{self.conn_info['port']}"
)
self._in_transaction = False
self.schemas = dict()
self.dependencies = Dependencies(self)
Expand All @@ -216,22 +234,33 @@ def __eq__(self, other):

def __repr__(self):
connected = "connected" if self.is_connected else "disconnected"
return "DataJoint connection ({connected}) {user}@{host}:{port}".format(connected=connected, **self.conn_info)
user = self.conn_info["user"]
host = self.conn_info["host"]
port = self.conn_info["port"]
db = self.conn_info.get("database_name")
db_str = f"/{db}" if db else ""
return f"DataJoint connection ({connected}) {user}@{host}:{port}{db_str}"

def _build_connect_kwargs(self, use_tls=None):
"""Build kwargs dict for adapter.connect()."""
kwargs = dict(
host=self.conn_info["host"],
port=self.conn_info["port"],
user=self.conn_info["user"],
password=self.conn_info["passwd"],
charset=self._config["connection.charset"],
use_tls=use_tls if use_tls is not None else self.conn_info.get("ssl"),
)
if self.conn_info.get("database_name"):
kwargs["dbname"] = self.conn_info["database_name"]
return kwargs

def connect(self) -> None:
"""Establish or re-establish connection to the database server."""
with warnings.catch_warnings():
warnings.filterwarnings("ignore", ".*deprecated.*")
try:
# Use adapter to create connection
self._conn = self.adapter.connect(
host=self.conn_info["host"],
port=self.conn_info["port"],
user=self.conn_info["user"],
password=self.conn_info["passwd"],
charset=self._config["connection.charset"],
use_tls=self.conn_info.get("ssl"),
)
self._conn = self.adapter.connect(**self._build_connect_kwargs())
except Exception as ssl_error:
# If SSL fails, retry without SSL (if it was auto-detected)
if self.conn_info.get("ssl_input") is None:
Expand All @@ -240,14 +269,7 @@ def connect(self) -> None:
"To require SSL, set use_tls=True explicitly.",
ssl_error,
)
self._conn = self.adapter.connect(
host=self.conn_info["host"],
port=self.conn_info["port"],
user=self.conn_info["user"],
password=self.conn_info["passwd"],
charset=self._config["connection.charset"],
use_tls=False, # Explicitly disable SSL for fallback
)
self._conn = self.adapter.connect(**self._build_connect_kwargs(use_tls=False))
else:
raise
self._is_closed = False # Mark as connected after successful connection
Expand Down
7 changes: 7 additions & 0 deletions src/datajoint/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,13 @@ def activate(
self.connection = connection
if self.connection is None:
self.connection = _get_singleton_connection()
if self.connection._config.get("database.database_prefix"):
warnings.warn(
"database_prefix is deprecated and will be removed in DataJoint 2.3. "
"Use database.name to select a PostgreSQL database instead.",
DeprecationWarning,
stacklevel=2,
)
self.database = schema_name
if create_schema is not None:
self.create_schema = create_schema
Expand Down
9 changes: 7 additions & 2 deletions src/datajoint/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
"database.password": "DJ_PASS",
"database.backend": "DJ_BACKEND",
"database.port": "DJ_PORT",
"database.name": "DJ_DATABASE_NAME",
"database.database_prefix": "DJ_DATABASE_PREFIX",
"database.create_tables": "DJ_CREATE_TABLES",
"loglevel": "DJ_LOG_LEVEL",
Expand Down Expand Up @@ -196,13 +197,17 @@ class DatabaseSettings(BaseSettings):
description="Database backend: 'mysql' or 'postgresql'",
)
port: int | None = Field(default=None, validation_alias="DJ_PORT")
name: str | None = Field(
default=None,
validation_alias="DJ_DATABASE_NAME",
description="Database name for PostgreSQL connections. Defaults to 'postgres' if not set.",
)
reconnect: bool = True
use_tls: bool | None = Field(default=None, validation_alias="DJ_USE_TLS")
database_prefix: str = Field(
default="",
validation_alias="DJ_DATABASE_PREFIX",
description="Prefix for database/schema names. "
"Not automatically applied; use dj.config.database.database_prefix when creating schemas.",
description="Deprecated. Use database.name instead.",
)
create_tables: bool = Field(
default=True,
Expand Down
65 changes: 65 additions & 0 deletions tests/unit/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,71 @@ def test_similar_prefix_names_allowed(self):
dj.config.stores.update(original_stores)


class TestDatabaseNameConfiguration:
"""Test database.name configuration."""

def test_database_name_default_is_none(self):
"""Database name defaults to None when not configured."""
from datajoint.settings import DatabaseSettings

s = DatabaseSettings()
assert s.name is None

def test_database_name_env_var(self, monkeypatch):
"""DJ_DATABASE_NAME environment variable sets database name."""
from datajoint.settings import DatabaseSettings

monkeypatch.setenv("DJ_DATABASE_NAME", "my_database")
s = DatabaseSettings()
assert s.name == "my_database"

def test_database_name_from_config_file(self, tmp_path, monkeypatch):
"""Load database name from config file."""
import json

from datajoint.settings import Config

config_file = tmp_path / "test_config.json"
config_file.write_text(json.dumps({"database": {"name": "custom_db", "host": "localhost"}}))

monkeypatch.delenv("DJ_DATABASE_NAME", raising=False)
monkeypatch.delenv("DJ_HOST", raising=False)

cfg = Config()
cfg.load(config_file)
assert cfg.database.name == "custom_db"

def test_database_name_dict_access(self):
"""Dict-style access reads and writes database name."""
original = dj.config.database.name
try:
dj.config.database.name = "test_db"
assert dj.config["database.name"] == "test_db"
finally:
dj.config.database.name = original

def test_database_name_override_context_manager(self):
"""Override context manager temporarily sets database name."""
original = dj.config.database.name
with dj.config.override(database__name="override_db"):
assert dj.config.database.name == "override_db"
assert dj.config.database.name == original

def test_database_prefix_empty_no_warning(self):
"""Empty database_prefix does not emit DeprecationWarning at config load."""
import warnings

from datajoint.settings import DatabaseSettings

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
DatabaseSettings()
deprecation_warnings = [
x for x in w if issubclass(x.category, DeprecationWarning) and "database_prefix" in str(x.message)
]
assert len(deprecation_warnings) == 0


class TestBackendConfiguration:
"""Test database backend configuration and port auto-detection."""

Expand Down
Loading