diff --git a/openevolve/config.py b/openevolve/config.py index bef193da21..c6bb2f8069 100644 --- a/openevolve/config.py +++ b/openevolve/config.py @@ -350,6 +350,7 @@ class DatabaseConfig: novelty_llm: Optional["LLMInterface"] = None embedding_model: Optional[str] = None + embedding_base_url: Optional[str] = None similarity_threshold: float = 0.99 diff --git a/openevolve/database.py b/openevolve/database.py index eca5eab0bb..d6f85f2f6b 100644 --- a/openevolve/database.py +++ b/openevolve/database.py @@ -204,7 +204,7 @@ def __init__(self, config: DatabaseConfig): self.novelty_llm = config.novelty_llm self.embedding_client = ( - EmbeddingClient(config.embedding_model) if config.embedding_model else None + EmbeddingClient(config.embedding_model, base_url=config.embedding_base_url) if config.embedding_model else None ) self.similarity_threshold = config.similarity_threshold diff --git a/openevolve/embedding.py b/openevolve/embedding.py index 302d4513f6..7c3bd3484a 100644 --- a/openevolve/embedding.py +++ b/openevolve/embedding.py @@ -10,51 +10,40 @@ logger = logging.getLogger(__name__) -M = 1_000_000 - -OPENAI_EMBEDDING_MODELS = [ - "text-embedding-3-small", - "text-embedding-3-large", -] - AZURE_EMBEDDING_MODELS = [ "azure-text-embedding-3-small", "azure-text-embedding-3-large", ] -OPENAI_EMBEDDING_COSTS = { - "text-embedding-3-small": 0.02 / M, - "text-embedding-3-large": 0.13 / M, -} - class EmbeddingClient: - def __init__(self, model_name: str = "text-embedding-3-small"): + def __init__(self, model_name: str = "text-embedding-3-small", base_url: str | None = None): """ Initialize the EmbeddingClient. Args: - model (str): The OpenAI embedding model name to use. + model_name: The embedding model name to use. + base_url: Optional base URL for the embedding API endpoint. """ - self.client, self.model = self._get_client_model(model_name) + self.client, self.model = self._get_client_model(model_name, base_url) - def _get_client_model(self, model_name: str) -> tuple[openai.OpenAI, str]: - if model_name in OPENAI_EMBEDDING_MODELS: - # Use OPENAI_EMBEDDING_API_KEY if set, otherwise fall back to OPENAI_API_KEY - # This allows users to use OpenRouter for LLMs while using OpenAI for embeddings - embedding_api_key = os.getenv("OPENAI_EMBEDDING_API_KEY") or os.getenv("OPENAI_API_KEY") - client = openai.OpenAI(api_key=embedding_api_key) - model_to_use = model_name - elif model_name in AZURE_EMBEDDING_MODELS: + def _get_client_model( + self, model_name: str, base_url: str | None = None + ) -> tuple[openai.OpenAI, str]: + if model_name in AZURE_EMBEDDING_MODELS: # get rid of the azure- prefix model_to_use = model_name.split("azure-")[-1] client = openai.AzureOpenAI( api_key=os.getenv("AZURE_OPENAI_API_KEY"), api_version=os.getenv("AZURE_API_VERSION"), - azure_endpoint=os.getenv("AZURE_API_ENDPOINT"), + azure_endpoint=os.environ["AZURE_API_ENDPOINT"], ) else: - raise ValueError(f"Invalid embedding model: {model_name}") + # Use OPENAI_EMBEDDING_API_KEY if set, otherwise fall back to OPENAI_API_KEY + # This allows users to use OpenRouter for LLMs while using OpenAI for embeddings + embedding_api_key = os.getenv("OPENAI_EMBEDDING_API_KEY") or os.getenv("OPENAI_API_KEY") + client = openai.OpenAI(api_key=embedding_api_key, base_url=base_url) + model_to_use = model_name return client, model_to_use diff --git a/openevolve/process_parallel.py b/openevolve/process_parallel.py index a2fd6592a9..352c52bc84 100644 --- a/openevolve/process_parallel.py +++ b/openevolve/process_parallel.py @@ -5,14 +5,11 @@ import asyncio import logging import multiprocessing as mp -import pickle -import signal import time -from concurrent.futures import Future, ProcessPoolExecutor +from concurrent.futures import BrokenExecutor, Future, ProcessPoolExecutor from concurrent.futures import TimeoutError as FutureTimeoutError from dataclasses import asdict, dataclass -from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional from openevolve.config import Config from openevolve.database import Program, ProgramDatabase @@ -357,6 +354,10 @@ def __init__( self.num_workers = config.evaluator.parallel_evaluations self.num_islands = config.database.num_islands + # Recovery tracking for process pool crashes + self.recovery_attempts = 0 + self.max_recovery_attempts = 3 + logger.info(f"Initialized process parallel controller with {self.num_workers} workers") def _serialize_config(self, config: Config) -> dict: @@ -434,6 +435,38 @@ def stop(self) -> None: logger.info("Stopped process pool") + def _recover_process_pool(self, failed_iterations: list[int] | None = None) -> None: + """Recover from a crashed process pool by recreating it. + + Args: + failed_iterations: List of iteration numbers that failed and need re-queuing + """ + import gc + + logger.warning("Process pool crashed, attempting recovery...") + + # Shutdown broken executor without waiting (it's already broken) + if self.executor: + try: + self.executor.shutdown(wait=False, cancel_futures=True) + except Exception: + pass # Executor may already be in bad state + self.executor = None + + # Force garbage collection to free memory before restarting + gc.collect() + + # Brief delay to let system stabilize (memory freed, processes cleaned up) + time.sleep(2.0) + + # Recreate the pool + self.start() + + if failed_iterations: + logger.info(f"Pool recovered. {len(failed_iterations)} iterations will be re-queued.") + else: + logger.info("Pool recovered successfully.") + def request_shutdown(self) -> None: """Request graceful shutdown""" logger.info("Graceful shutdown requested...") @@ -559,6 +592,14 @@ async def run_evolution( # Reconstruct program from dict child_program = Program(**result.child_program_dict) + # Reset recovery counter on successful iteration + if self.recovery_attempts > 0: + logger.info( + f"Pool stable after recovery, resetting recovery counter " + f"(was {self.recovery_attempts})" + ) + self.recovery_attempts = 0 + # Add to database with explicit target_island to ensure proper island placement # This fixes issue #391: children should go to the target island, not inherit # from the parent (which may be from a different island due to fallback sampling) @@ -752,6 +793,38 @@ async def run_evolution( ) # Cancel the future to clean up the process future.cancel() + except BrokenExecutor as e: + logger.error(f"Process pool crashed during iteration {completed_iteration}: {e}") + + # Collect all failed iterations from pending futures + failed_iterations = [completed_iteration] + list(pending_futures.keys()) + + # Clear pending futures (they're all invalid now) + pending_futures.clear() + for island_id in island_pending: + island_pending[island_id].clear() + + # Attempt recovery + self.recovery_attempts += 1 + if self.recovery_attempts > self.max_recovery_attempts: + logger.error( + f"Max recovery attempts ({self.max_recovery_attempts}) exceeded. " + f"Stopping evolution." + ) + break + + self._recover_process_pool(failed_iterations) + + # Re-queue failed iterations (distribute across islands) + for i, failed_iter in enumerate(failed_iterations): + if failed_iter < total_iterations: + island_id = i % self.num_islands + future = self._submit_iteration(failed_iter, island_id) + if future: + pending_futures[failed_iter] = future + island_pending[island_id].append(failed_iter) + + continue except Exception as e: logger.error(f"Error processing result from iteration {completed_iteration}: {e}") @@ -822,6 +895,9 @@ def _submit_iteration( return future + except BrokenExecutor: + # Let this propagate up to run_evolution for recovery + raise except Exception as e: logger.error(f"Error submitting iteration {iteration}: {e}") return None diff --git a/openevolve/prompt/sampler.py b/openevolve/prompt/sampler.py index 61a5b98ba0..b793957784 100644 --- a/openevolve/prompt/sampler.py +++ b/openevolve/prompt/sampler.py @@ -110,11 +110,17 @@ def build_prompt( if self.config.programs_as_changes_description: if self.config.system_message_changes_description: - system_message_changes_description = self.config.system_message_changes_description.strip() + system_message_changes_description = ( + self.config.system_message_changes_description.strip() + ) else: - system_message_changes_description = self.template_manager.get_template("system_message_changes_description") + system_message_changes_description = self.template_manager.get_template( + "system_message_changes_description" + ) - system_message = self.template_manager.get_template("system_message_with_changes_description").format( + system_message = self.template_manager.get_template( + "system_message_with_changes_description" + ).format( system_message=system_message, system_message_changes_description=system_message_changes_description, ) @@ -160,8 +166,10 @@ def build_prompt( **kwargs, ) - if self.config.programs_as_changes_description: - user_message = self.template_manager.get_template("user_message_with_changes_description").format( + if self.config.programs_as_changes_description and current_changes_description is not None: + user_message = self.template_manager.get_template( + "user_message_with_changes_description" + ).format( user_message=user_message, changes_description=current_changes_description.rstrip(), ) @@ -265,11 +273,8 @@ def _format_evolution_history( for i, program in enumerate(reversed(selected_previous)): attempt_number = len(previous_programs) - i - changes = ( - program.get("changes_description") - or program.get("metadata", {}).get( - "changes", self.template_manager.get_fragment("attempt_unknown_changes") - ) + changes = program.get("changes_description") or program.get("metadata", {}).get( + "changes", self.template_manager.get_fragment("attempt_unknown_changes") ) # Format performance metrics using safe formatting @@ -334,9 +339,7 @@ def _format_evolution_history( for i, program in enumerate(selected_top): use_changes = self.config.programs_as_changes_description program_code = ( - program.get("changes_description", "") - if use_changes - else program.get("code", "") + program.get("changes_description", "") if use_changes else program.get("code", "") ) if not program_code: program_code = "" if use_changes else "" @@ -351,11 +354,20 @@ def _format_evolution_history( for name, value in program.get("metrics", {}).items(): if isinstance(value, (int, float)): try: - key_features.append(self.template_manager.get_fragment("top_program_metrics_prefix") + f" {name} ({value:.4f})") + key_features.append( + self.template_manager.get_fragment("top_program_metrics_prefix") + + f" {name} ({value:.4f})" + ) except (ValueError, TypeError): - key_features.append(self.template_manager.get_fragment("top_program_metrics_prefix") + f" {name} ({value})") + key_features.append( + self.template_manager.get_fragment("top_program_metrics_prefix") + + f" {name} ({value})" + ) else: - key_features.append(self.template_manager.get_fragment("top_program_metrics_prefix") + f" {name} ({value})") + key_features.append( + self.template_manager.get_fragment("top_program_metrics_prefix") + + f" {name} ({value})" + ) key_features_str = ", ".join(key_features) @@ -385,7 +397,11 @@ def _format_evolution_history( # Use random sampling to get diverse programs diverse_programs = random.sample(remaining_programs, num_diverse) - diverse_programs_str += "\n\n## " + self.template_manager.get_fragment("diverse_programs_title") + "\n\n" + diverse_programs_str += ( + "\n\n## " + + self.template_manager.get_fragment("diverse_programs_title") + + "\n\n" + ) for i, program in enumerate(diverse_programs): use_changes = self.config.programs_as_changes_description @@ -404,7 +420,8 @@ def _format_evolution_history( key_features = program.get("key_features", []) if not key_features: key_features = [ - self.template_manager.get_fragment("diverse_program_metrics_prefix") + f" {name}" + self.template_manager.get_fragment("diverse_program_metrics_prefix") + + f" {name}" for name in list(program.get("metrics", {}).keys())[ :2 ] # Just first 2 metrics @@ -416,7 +433,9 @@ def _format_evolution_history( top_program_template.format( program_number=f"D{i + 1}", score=f"{score:.4f}", - language=("text" if self.config.programs_as_changes_description else language), + language=( + "text" if self.config.programs_as_changes_description else language + ), program_snippet=program_code, key_features=key_features_str, ) @@ -466,9 +485,7 @@ def _format_inspirations_section( for i, program in enumerate(inspirations): use_changes = self.config.programs_as_changes_description program_code = ( - program.get("changes_description", "") - if use_changes - else program.get("code", "") + program.get("changes_description", "") if use_changes else program.get("code", "") ) if not program_code: program_code = "" if use_changes else "" @@ -551,16 +568,24 @@ def _extract_unique_features(self, program: Dict[str, Any]) -> str: and self.config.include_changes_under_chars and len(changes) < self.config.include_changes_under_chars ): - features.append(self.template_manager.get_fragment("inspiration_changes_prefix").format(changes=changes)) + features.append( + self.template_manager.get_fragment("inspiration_changes_prefix").format( + changes=changes + ) + ) # Analyze metrics for standout characteristics metrics = program.get("metrics", {}) for metric_name, value in metrics.items(): if isinstance(value, (int, float)): if value >= 0.9: - features.append(f"{self.template_manager.get_fragment('inspiration_metrics_excellent').format(metric_name=metric_name, value=value)}") + features.append( + f"{self.template_manager.get_fragment('inspiration_metrics_excellent').format(metric_name=metric_name, value=value)}" + ) elif value <= 0.3: - features.append(f"{self.template_manager.get_fragment('inspiration_metrics_alternative').format(metric_name=metric_name)}") + features.append( + f"{self.template_manager.get_fragment('inspiration_metrics_alternative').format(metric_name=metric_name)}" + ) # Code-based features (simple heuristics) code = program.get("code", "") @@ -571,22 +596,32 @@ def _extract_unique_features(self, program: Dict[str, Any]) -> str: if "numpy" in code_lower or "np." in code_lower: features.append(self.template_manager.get_fragment("inspiration_code_with_numpy")) if "for" in code_lower and "while" in code_lower: - features.append(self.template_manager.get_fragment("inspiration_code_with_mixed_iteration")) + features.append( + self.template_manager.get_fragment("inspiration_code_with_mixed_iteration") + ) if ( self.config.concise_implementation_max_lines and len(code.split("\n")) <= self.config.concise_implementation_max_lines ): - features.append(self.template_manager.get_fragment("inspiration_code_with_concise_line")) + features.append( + self.template_manager.get_fragment("inspiration_code_with_concise_line") + ) elif ( self.config.comprehensive_implementation_min_lines and len(code.split("\n")) >= self.config.comprehensive_implementation_min_lines ): - features.append(self.template_manager.get_fragment("inspiration_code_with_comprehensive_line")) + features.append( + self.template_manager.get_fragment("inspiration_code_with_comprehensive_line") + ) # Default if no specific features found if not features: program_type = self._determine_program_type(program) - features.append(self.template_manager.get_fragment("inspiration_no_features_postfix").format(program_type=program_type)) + features.append( + self.template_manager.get_fragment("inspiration_no_features_postfix").format( + program_type=program_type + ) + ) # Use num_top_programs as limit for features (similar to how we limit programs) feature_limit = self.config.num_top_programs @@ -629,7 +664,12 @@ def _render_artifacts(self, artifacts: Dict[str, Union[str, bytes]]) -> str: sections.append(f"### {key}\n```\n{content}\n```") if sections: - return "## " + self.template_manager.get_fragment("artifact_title") + "\n\n" + "\n\n".join(sections) + return ( + "## " + + self.template_manager.get_fragment("artifact_title") + + "\n\n" + + "\n\n".join(sections) + ) else: return "" diff --git a/tests/test_process_pool_recovery.py b/tests/test_process_pool_recovery.py new file mode 100644 index 0000000000..d79ccb21c7 --- /dev/null +++ b/tests/test_process_pool_recovery.py @@ -0,0 +1,157 @@ +""" +Tests for process pool crash recovery +""" + +import asyncio +import os +import tempfile +import unittest +from unittest.mock import MagicMock, patch +from concurrent.futures import BrokenExecutor, Future + +# Set dummy API key for testing +os.environ["OPENAI_API_KEY"] = "test" + +from openevolve.config import Config +from openevolve.database import Program, ProgramDatabase +from openevolve.process_parallel import ProcessParallelController, SerializableResult + + +class TestProcessPoolRecovery(unittest.TestCase): + """Tests for process pool crash recovery""" + + def setUp(self): + """Set up test environment""" + self.test_dir = tempfile.mkdtemp() + + # Create test config + self.config = Config() + self.config.max_iterations = 10 + self.config.evaluator.parallel_evaluations = 2 + self.config.evaluator.timeout = 10 + self.config.database.num_islands = 2 + self.config.database.in_memory = True + self.config.checkpoint_interval = 5 + + # Create test evaluation file + self.eval_content = """ +def evaluate(program_path): + return {"score": 0.5} +""" + self.eval_file = os.path.join(self.test_dir, "evaluator.py") + with open(self.eval_file, "w") as f: + f.write(self.eval_content) + + # Create test database + self.database = ProgramDatabase(self.config.database) + + # Add some test programs + for i in range(2): + program = Program( + id=f"test_{i}", + code=f"def func_{i}(): return {i}", + language="python", + metrics={"score": 0.5}, + iteration_found=0, + ) + self.database.add(program) + + def tearDown(self): + """Clean up test environment""" + import shutil + + shutil.rmtree(self.test_dir, ignore_errors=True) + + def test_controller_has_recovery_tracking(self): + """Test that controller initializes with recovery tracking attributes""" + controller = ProcessParallelController(self.config, self.eval_file, self.database) + + self.assertEqual(controller.recovery_attempts, 0) + self.assertEqual(controller.max_recovery_attempts, 3) + + def test_recover_process_pool_recreates_executor(self): + """Test that _recover_process_pool recreates the executor""" + controller = ProcessParallelController(self.config, self.eval_file, self.database) + + # Start the controller to create initial executor + controller.start() + self.assertIsNotNone(controller.executor) + original_executor = controller.executor + + # Simulate recovery + with patch("time.sleep"): + controller._recover_process_pool() + + # Verify executor was recreated + self.assertIsNotNone(controller.executor) + self.assertIsNot(controller.executor, original_executor) + + # Clean up + controller.stop() + + def test_broken_executor_triggers_recovery_and_resets_on_success(self): + """Test that BrokenExecutor triggers recovery and counter resets on success""" + + async def run_test(): + controller = ProcessParallelController(self.config, self.eval_file, self.database) + + # Track recovery calls + recovery_called = [] + + def mock_recover(failed_iterations=None): + recovery_called.append(failed_iterations) + + controller._recover_process_pool = mock_recover + + # First call raises BrokenExecutor, subsequent calls succeed + call_count = [0] + + def mock_submit(iteration, island_id): + call_count[0] += 1 + mock_future = MagicMock(spec=Future) + + if call_count[0] == 1: + # First future raises BrokenExecutor when result() is called + mock_future.done.return_value = True + mock_future.result.side_effect = BrokenExecutor("Pool crashed") + else: + # Subsequent calls succeed + mock_result = SerializableResult( + child_program_dict={ + "id": f"child_{call_count[0]}", + "code": "def evolved(): return 1", + "language": "python", + "parent_id": "test_0", + "generation": 1, + "metrics": {"score": 0.7}, + "iteration_found": iteration, + "metadata": {"island": island_id}, + }, + parent_id="test_0", + iteration_time=0.1, + iteration=iteration, + ) + mock_future.done.return_value = True + mock_future.result.return_value = mock_result + mock_future.cancel.return_value = True + + return mock_future + + with patch.object(controller, "_submit_iteration", side_effect=mock_submit): + controller.start() + + # Run evolution - should recover from crash and reset counter on success + await controller.run_evolution( + start_iteration=1, max_iterations=2, target_score=None + ) + + # Verify recovery was triggered + self.assertEqual(len(recovery_called), 1) + # Verify counter was reset after successful iteration + self.assertEqual(controller.recovery_attempts, 0) + + asyncio.run(run_test()) + + +if __name__ == "__main__": + unittest.main()