Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ test-lmdb:
@echo "⚡ Running LMDB tests..."
$(PYTHON) pytest tests/ -m "lmdb" -v --log-cli-level=ERROR

test-kafka:
@echo "Running Kafka tests..."
$(PYTHON) pytest tests/ -m "kafka" -v --log-cli-level=ERROR

# Parallel streaming integration tests
test-parallel-streaming:
@echo "⚡ Running parallel streaming integration tests..."
Expand Down Expand Up @@ -132,6 +136,7 @@ help:
@echo " make test-postgresql - Run PostgreSQL tests"
@echo " make test-redis - Run Redis tests"
@echo " make test-snowflake - Run Snowflake tests"
@echo " make test-kafka - Run Kafka tests"
@echo " make test-performance - Run performance tests"
@echo " make lint - Lint code with ruff"
@echo " make format - Format code with ruff"
Expand Down
2 changes: 1 addition & 1 deletion apps/test_kafka_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from amp.loaders.types import LabelJoinConfig

# Connect to Amp server
server_url = os.getenv('AMP_SERVER_URL', 'grpc://34.27.238.174:80')
server_url = os.getenv('AMP_SERVER_URL', 'grpc://127.0.0.1:1602')
print(f'Connecting to {server_url}...')
client = Client(server_url)
print('✅ Connected!')
Expand Down
32 changes: 19 additions & 13 deletions src/amp/loaders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,8 +840,21 @@ def _rewind_to_watermark(self, table_name: str, connection_name: Optional[str] =
for range_obj in resume_pos.ranges:
from_block = range_obj.end + 1

# Check if there are actually uncommitted batches beyond the watermark
uncommitted = self.state_store.invalidate_from_block(
connection_name, table_name, range_obj.network, from_block
)

if not uncommitted:
self.logger.debug(
f'No uncommitted batches for {range_obj.network} beyond block {from_block}, '
f'skipping crash recovery cleanup'
)
continue

self.logger.info(
f'Crash recovery: Cleaning up {table_name} data for {range_obj.network} from block {from_block} onwards'
f'Crash recovery: Cleaning up {len(uncommitted)} uncommitted batches '
f'for {range_obj.network} from block {from_block} onwards in {table_name}'
)

invalidation_ranges = [
Expand All @@ -859,20 +872,13 @@ def _rewind_to_watermark(self, table_name: str, connection_name: Optional[str] =
self.logger.info(f'Crash recovery completed for {range_obj.network} in {table_name}')

except NotImplementedError:
invalidated = self.state_store.invalidate_from_block(
connection_name, table_name, range_obj.network, from_block
self.logger.warning(
f'Crash recovery: Cleared {len(uncommitted)} batches from state '
f'for {range_obj.network} but cannot delete data from {table_name}. '
f'{self.__class__.__name__} does not support data deletion. '
f'Duplicates may occur on resume.'
)

if invalidated:
self.logger.warning(
f'Crash recovery: Cleared {len(invalidated)} batches from state '
f'for {range_obj.network} but cannot delete data from {table_name}. '
f'{self.__class__.__name__} does not support data deletion. '
f'Duplicates may occur on resume.'
)
else:
self.logger.debug(f'No uncommitted batches found for {range_obj.network}')

def _add_metadata_columns(self, data: pa.RecordBatch, block_ranges: List[BlockRange]) -> pa.RecordBatch:
"""
Add metadata columns for streaming data with compact batch identification.
Expand Down
39 changes: 30 additions & 9 deletions src/amp/loaders/implementations/kafka_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class KafkaConfig:
class KafkaLoader(DataLoader[KafkaConfig]):
SUPPORTED_MODES = {LoadMode.APPEND}
REQUIRES_SCHEMA_MATCH = False
SUPPORTS_TRANSACTIONS = True
SUPPORTS_TRANSACTIONS = False

def __init__(self, config: Dict[str, Any], label_manager=None) -> None:
self._extra_producer_config = {
Expand All @@ -34,6 +34,14 @@ def __init__(self, config: Dict[str, Any], label_manager=None) -> None:
super().__init__(config, label_manager)
self._producer = None

# Replace in-memory state store with LMDB if configured (before connect, consistent with other loaders)
if self.state_enabled and self.state_storage == 'lmdb':
self.state_store = LMDBStreamStateStore(
connection_name=self.config.client_id,
data_dir=self.state_data_dir,
)
self.logger.info(f'Initialized LMDB state store at {self.state_store.data_dir}')

def _get_required_config_fields(self) -> list[str]:
return ['bootstrap_servers']

Expand All @@ -57,13 +65,6 @@ def connect(self) -> None:
self.logger.info(f'Connected to Kafka at {self.config.bootstrap_servers}')
self.logger.info(f'Client ID: {self.config.client_id}')

if self.state_enabled and self.state_storage == 'lmdb':
self.state_store = LMDBStreamStateStore(
connection_name=self.config.client_id,
data_dir=self.state_data_dir,
)
self.logger.info(f'Initialized LMDB state store at {self.state_store.data_dir}')

self._is_connected = True

except Exception as e:
Expand All @@ -84,9 +85,27 @@ def disconnect(self) -> None:
self._is_connected = False
self.logger.info('Disconnected from Kafka')

def health_check(self) -> Dict[str, Any]:
"""Check Kafka broker connectivity."""
base = {
'healthy': False,
'loader_type': 'kafka',
'bootstrap_servers': self.config.bootstrap_servers,
'client_id': self.config.client_id,
}
if not self._is_connected or not self._producer:
base['error'] = 'Not connected'
return base
try:
healthy = hasattr(self._producer, '_sender') and self._producer._sender.is_alive()
base['healthy'] = healthy
return base
except Exception as e:
base['error'] = str(e)
return base

def _create_table_from_schema(self, schema: pa.Schema, table_name: str) -> None:
self.logger.info(f'Kafka topic {table_name} will be auto-created on first message send')
pass

def _load_batch_impl(self, batch: pa.RecordBatch, table_name: str, **kwargs) -> int:
if not self._producer:
Expand All @@ -108,6 +127,7 @@ def _load_batch_impl(self, batch: pa.RecordBatch, table_name: str, **kwargs) ->

self._producer.send(topic=table_name, key=key, value=row)

self._producer.flush()
self._producer.commit_transaction()
self.logger.debug(f'Committed transaction with {num_rows} messages to topic {table_name}')

Expand Down Expand Up @@ -169,6 +189,7 @@ def _handle_reorg(self, invalidation_ranges: List[BlockRange], table_name: str,
f'{invalidation_range.network} blocks {invalidation_range.start}-{invalidation_range.end}'
)

self._producer.flush()
self._producer.commit_transaction()
self.logger.info(f'Committed {len(invalidation_ranges)} reorg events to {reorg_topic}')

Expand Down
4 changes: 3 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,9 @@ def kafka_container():
container.with_env('KAFKA_OFFSETS_TOPIC_REPLICATION_FACTOR', '1')
container.start()

time.sleep(10)
# KafkaContainer.start() already waits for "[KafkaServer id=N] started" log.
# Brief additional wait for transaction coordinator to be fully ready.
time.sleep(5)

yield container

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,18 @@ def test_loader_connection(self, kafka_test_config):
assert loader._is_connected == False
assert loader._producer is None

def test_health_check(self, kafka_test_config):
loader = KafkaLoader(kafka_test_config)

health = loader.health_check()
assert health['healthy'] == False
assert 'error' in health

with loader:
health = loader.health_check()
assert health['healthy'] == True
assert health['bootstrap_servers'] == kafka_test_config['bootstrap_servers']

def test_context_manager(self, kafka_test_config):
loader = KafkaLoader(kafka_test_config)

Expand Down
16 changes: 15 additions & 1 deletion tests/unit/test_crash_recovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def test_rewind_calls_handle_reorg(self, mock_loader):
"""Should call _handle_reorg with correct invalidation ranges"""
watermark = ResumeWatermark(ranges=[BlockRange(network='ethereum', start=1000, end=1010, hash='0xabc')])
mock_loader.state_store.get_resume_position = Mock(return_value=watermark)
mock_loader.state_store.invalidate_from_block = Mock(return_value=['batch1'])
mock_loader._handle_reorg = Mock()

mock_loader._rewind_to_watermark('test_table', 'test_conn')
Expand All @@ -61,12 +62,23 @@ def test_rewind_calls_handle_reorg(self, mock_loader):
assert call_args[0][1] == 'test_table'
assert call_args[0][2] == 'test_conn'

def test_rewind_skips_reorg_when_no_uncommitted_batches(self, mock_loader):
"""Should skip _handle_reorg when there are no uncommitted batches (e.g., clean restart)"""
watermark = ResumeWatermark(ranges=[BlockRange(network='ethereum', start=1000, end=1010, hash='0xabc')])
mock_loader.state_store.get_resume_position = Mock(return_value=watermark)
mock_loader.state_store.invalidate_from_block = Mock(return_value=[])
mock_loader._handle_reorg = Mock()

mock_loader._rewind_to_watermark('test_table', 'test_conn')

mock_loader._handle_reorg.assert_not_called()

def test_rewind_handles_not_implemented(self, mock_loader):
"""Should gracefully handle loaders without _handle_reorg"""
watermark = ResumeWatermark(ranges=[BlockRange(network='ethereum', start=1000, end=1010, hash='0xabc')])
mock_loader.state_store.get_resume_position = Mock(return_value=watermark)
mock_loader._handle_reorg = Mock(side_effect=NotImplementedError())
mock_loader.state_store.invalidate_from_block = Mock(return_value=[])
mock_loader.state_store.invalidate_from_block = Mock(return_value=['batch1'])

mock_loader._rewind_to_watermark('test_table', 'test_conn')

Expand All @@ -83,6 +95,7 @@ def test_rewind_with_multiple_networks(self, mock_loader):
]
)
mock_loader.state_store.get_resume_position = Mock(return_value=watermark)
mock_loader.state_store.invalidate_from_block = Mock(return_value=['batch1'])
mock_loader._handle_reorg = Mock()

mock_loader._rewind_to_watermark('test_table', 'test_conn')
Expand All @@ -101,6 +114,7 @@ def test_rewind_uses_default_connection_name(self, mock_loader):
"""Should use default connection name from loader class"""
watermark = ResumeWatermark(ranges=[BlockRange(network='ethereum', start=1000, end=1010, hash='0xabc')])
mock_loader.state_store.get_resume_position = Mock(return_value=watermark)
mock_loader.state_store.invalidate_from_block = Mock(return_value=['batch1'])
mock_loader._handle_reorg = Mock()

mock_loader._rewind_to_watermark('test_table', connection_name=None)
Expand Down