diff --git a/composer/tools/composer_migrate.py b/composer/tools/composer_migrate.py index c4ef2fbb5f9..a2ef5ae9be5 100644 --- a/composer/tools/composer_migrate.py +++ b/composer/tools/composer_migrate.py @@ -19,9 +19,13 @@ import json import math import pprint -import subprocess +import time from typing import Any, Dict, List +import google.auth +from google.auth.transport.requests import AuthorizedSession +import requests + import logging @@ -32,62 +36,82 @@ class ComposerClient: """Client for interacting with Composer API. - The client uses gcloud under the hood. + The client uses Google Auth and Requests under the hood. """ def __init__(self, project: str, location: str, sdk_endpoint: str) -> None: self.project = project self.location = location - self.sdk_endpoint = sdk_endpoint + self.sdk_endpoint = sdk_endpoint.rstrip("/") + self.credentials, _ = google.auth.default() + self.session = AuthorizedSession(self.credentials) + self._airflow_uris = {} + + def _get_airflow_uri(self, environment_name: str) -> str: + """Returns the Airflow URI for a given environment, caching the result.""" + if environment_name not in self._airflow_uris: + environment = self.get_environment(environment_name) + self._airflow_uris[environment_name] = environment["config"]["airflowUri"] + return self._airflow_uris[environment_name] def get_environment(self, environment_name: str) -> Any: """Returns an environment json for a given Composer environment.""" - command = ( - f"CLOUDSDK_API_ENDPOINT_OVERRIDES_COMPOSER={self.sdk_endpoint} gcloud" - " composer environments describe" - f" {environment_name} --project={self.project} --location={self.location} --format" - " json" + url = ( + f"{self.sdk_endpoint}/v1/projects/{self.project}/locations/" + f"{self.location}/environments/{environment_name}" ) - output = run_shell_command(command) - return json.loads(output) + response = self.session.get(url) + if response.status_code != 200: + raise RuntimeError( + f"Failed to get environment {environment_name}: {response.text}" + ) + return response.json() def create_environment_from_config(self, config: Any) -> Any: """Creates a Composer environment based on the given json config.""" - # Obtain access token through gcloud - access_token = run_shell_command("gcloud auth print-access-token") - - # gcloud does not support creating composer environments from json, so we - # need to use the API directly. - create_environment_command = ( - f"curl -s -X POST -H 'Authorization: Bearer {access_token}'" - " -H 'Content-Type: application/json'" - f" -d '{json.dumps(config)}'" - f" {self.sdk_endpoint}/v1/projects/{self.project}/locations/{self.location}/environments" - ) - output = run_shell_command(create_environment_command) - logging.info("Create environment operation: %s", output) - - # Poll create operation using gcloud. - operation_id = json.loads(output)["name"].split("/")[-1] - poll_operation_command = ( - f"CLOUDSDK_API_ENDPOINT_OVERRIDES_COMPOSER={self.sdk_endpoint} gcloud" - " composer operations wait" - f" {operation_id} --project={self.project} --location={self.location}" + url = ( + f"{self.sdk_endpoint}/v1/projects/{self.project}/locations/" + f"{self.location}/environments" ) - run_shell_command(poll_operation_command) - - def list_dags(self, environment_name: str) -> List[str]: + # Verify that the environment name is present in the config. + # The API expects the resource name in the format: + # projects/{project}/locations/{location}/environments/{environment_name} + if "name" not in config: + raise ValueError("Environment name is missing in the config.") + + # Extract environment ID from the full name if needed as query param, + # but the original code didn't use it, so we trust the body 'name' field. + # However, usually for Create, we might need environmentId query param if we want to specify it explicitly + # and it's not inferred. + # The original code did: POST .../environments with body. + + response = self.session.post(url, json=config) + if response.status_code == 409: + logger.info("Environment already exists, skipping creation.") + return + + if response.status_code != 200: + raise RuntimeError( + f"Failed to create environment: {response.text}" + ) + + operation = response.json() + logger.info("Create environment operation: %s", operation["name"]) + self._wait_for_operation(operation["name"]) + + + def list_dags(self, environment_name: str) -> List[Dict[str, Any]]: """Returns a list of DAGs in a given Composer environment.""" - command = ( - f"CLOUDSDK_API_ENDPOINT_OVERRIDES_COMPOSER={self.sdk_endpoint} gcloud" - " composer environments run" - f" {environment_name} --project={self.project} --location={self.location} dags" - " list -- -o json" - ) - output = run_shell_command(command) - # Output may contain text from top level print statements. - # The last line of the output is always a json array of DAGs. - return json.loads(output.splitlines()[-1]) + airflow_uri = self._get_airflow_uri(environment_name) + + url = f"{airflow_uri}/api/v1/dags" + response = self.session.get(url) + if response.status_code != 200: + raise RuntimeError( + f"Failed to list DAGs: {response.text}" + ) + return response.json()["dags"] + def pause_dag( self, @@ -95,13 +119,30 @@ def pause_dag( environment_name: str, ) -> Any: """Pauses a DAG in a Composer environment.""" - command = ( - f"CLOUDSDK_API_ENDPOINT_OVERRIDES_COMPOSER={self.sdk_endpoint} gcloud" - " composer environments run" - f" {environment_name} --project={self.project} --location={self.location} dags" - f" pause -- {dag_id}" - ) - run_shell_command(command) + airflow_uri = self._get_airflow_uri(environment_name) + + url = f"{airflow_uri}/api/v1/dags/{dag_id}" + response = self.session.patch(url, json={"is_paused": True}) + if response.status_code != 200: + raise RuntimeError( + f"Failed to pause DAG {dag_id}: {response.text}" + ) + + + def pause_all_dags( + self, + environment_name: str, + ) -> Any: + """Pauses all DAGs in a Composer environment.""" + airflow_uri = self._get_airflow_uri(environment_name) + + url = f"{airflow_uri}/api/v1/dags?dag_id_pattern=%" # Pause all DAGs using % as a wildcard + response = self.session.patch(url, json={"is_paused": True}) + if response.status_code != 200: + raise RuntimeError( + f"Failed to pause all DAGs: {response.text}" + ) + def unpause_dag( self, @@ -109,25 +150,45 @@ def unpause_dag( environment_name: str, ) -> Any: """Unpauses a DAG in a Composer environment.""" - command = ( - f"CLOUDSDK_API_ENDPOINT_OVERRIDES_COMPOSER={self.sdk_endpoint} gcloud" - " composer environments run" - f" {environment_name} --project={self.project} --location={self.location} dags" - f" unpause -- {dag_id}" - ) - run_shell_command(command) + airflow_uri = self._get_airflow_uri(environment_name) + + url = f"{airflow_uri}/api/v1/dags/{dag_id}" + response = self.session.patch(url, json={"is_paused": False}) + if response.status_code != 200: + raise RuntimeError( + f"Failed to unpause DAG {dag_id}: {response.text}" + ) + + def unpause_all_dags( + self, + environment_name: str, + ) -> Any: + """Unpauses all DAGs in a Composer environment.""" + airflow_uri = self._get_airflow_uri(environment_name) + + url = f"{airflow_uri}/api/v1/dags?dag_id_pattern=%" # Pause all DAGs using % as a wildcard + response = self.session.patch(url, json={"is_paused": False}) + if response.status_code != 200: + raise RuntimeError( + f"Failed to unpause all DAGs: {response.text}" + ) def save_snapshot(self, environment_name: str) -> str: """Saves a snapshot of a Composer environment.""" - command = ( - f"CLOUDSDK_API_ENDPOINT_OVERRIDES_COMPOSER={self.sdk_endpoint} gcloud" - " composer" - " environments snapshots save" - f" {environment_name} --project={self.project}" - f" --location={self.location} --format=json" + url = ( + f"{self.sdk_endpoint}/v1/projects/{self.project}/locations/" + f"{self.location}/environments/{environment_name}:saveSnapshot" ) - output = run_shell_command(command) - return json.loads(output)["snapshotPath"] + response = self.session.post(url, json={}) + if response.status_code != 200: + raise RuntimeError( + f"Failed to initiate snapshot save: {response.text}" + ) + + operation = response.json() + logging.info("Save snapshot operation: %s", operation["name"]) + completed_operation = self._wait_for_operation(operation["name"]) + return completed_operation["response"]["snapshotPath"] def load_snapshot( self, @@ -135,25 +196,44 @@ def load_snapshot( snapshot_path: str, ) -> Any: """Loads a snapshot to a Composer environment.""" - command = ( - f"CLOUDSDK_API_ENDPOINT_OVERRIDES_COMPOSER={self.sdk_endpoint} gcloud" - " composer" - f" environments snapshots load {environment_name}" - f" --snapshot-path={snapshot_path} --project={self.project}" - f" --location={self.location} --format=json" + url = ( + f"{self.sdk_endpoint}/v1/projects/{self.project}/locations/" + f"{self.location}/environments/{environment_name}:loadSnapshot" ) - run_shell_command(command) - - -def run_shell_command(command: str, command_input: str = None) -> str: - """Executes shell command and returns its output.""" - p = subprocess.Popen(command, stdout=subprocess.PIPE, shell=True) - (res, _) = p.communicate(input=command_input) - output = str(res.decode().strip("\n")) - - if p.returncode: - raise RuntimeError(f"Failed to run shell command: {command}, details: {output}") - return output + response = self.session.post(url, json={"snapshotPath": snapshot_path}) + if response.status_code != 200: + raise RuntimeError( + f"Failed to initiate snapshot load: {response.text}" + ) + + operation = response.json() + logging.info("Load snapshot operation: %s", operation["name"]) + self._wait_for_operation(operation["name"]) + + + def _wait_for_operation(self, operation_name: str) -> Any: + """Waits for a long-running operation to complete.""" + # operation_name is distinct from operation_id. + # It is a full resource name: projects/.../locations/.../operations/... + + # We need to poll the operation status. + url = f"{self.sdk_endpoint}/v1/{operation_name}" + + while True: + response = self.session.get(url) + if response.status_code != 200: + raise RuntimeError( + f"Failed to get operation status: {response.text}" + ) + operation = response.json() + if "done" in operation and operation["done"]: + if "error" in operation: + raise RuntimeError(f"Operation failed: {operation['error']}") + logging.info("Operation completed successfully.") + return operation + + logging.info("Waiting for operation to complete...") + time.sleep(10) def get_target_cpu(source_cpu: float, max_cpu: float) -> float: @@ -392,15 +472,7 @@ def main( len(source_env_dags), source_env_dag_ids, ) - for dag in source_env_dags: - if dag["dag_id"] == "airflow_monitoring": - continue - if dag["is_paused"] == "True": - logger.info("DAG %s is already paused.", dag["dag_id"]) - continue - logger.info("Pausing DAG %s in the source environment.", dag["dag_id"]) - client.pause_dag(dag["dag_id"], source_environment_name) - logger.info("DAG %s paused.", dag["dag_id"]) + client.pause_all_dags(source_environment_name) logger.info("All DAGs in the source environment paused.") # 4. Save snapshot of the source environment @@ -420,18 +492,26 @@ def main( while not all_dags_present: target_env_dags = client.list_dags(target_environment_name) target_env_dag_ids = [dag["dag_id"] for dag in target_env_dags] - all_dags_present = set(source_env_dag_ids) == set(target_env_dag_ids) - logger.info("List of DAGs in the target environment: %s", target_env_dag_ids) + missing_dags = set(source_env_dag_ids) - set(target_env_dag_ids) + all_dags_present = not missing_dags + if missing_dags: + logger.info("Waiting for DAGs to appear in target: %s", missing_dags) + else: + logger.info("All DAGs present in target environment.") + time.sleep(10) # Unpause only DAGs that were not paused in the source environment. - for dag in source_env_dags: - if dag["dag_id"] == "airflow_monitoring": - continue - if dag["is_paused"] == "True": - logger.info("DAG %s was paused in the source environment.", dag["dag_id"]) - continue - logger.info("Unpausing DAG %s in the target environment.", dag["dag_id"]) - client.unpause_dag(dag["dag_id"], target_environment_name) - logger.info("DAG %s unpaused.", dag["dag_id"]) + # Optimization: if all DAGs were unpaused in source, use bulk unpause. + if not any(d["is_paused"] for d in source_env_dags): + logger.info("All DAGs were unpaused in source. Unpausing all DAGs in target.") + client.unpause_all_dags(target_environment_name) + else: + for dag in source_env_dags: + if dag["is_paused"]: + logger.info("DAG %s was paused in the source environment.", dag["dag_id"]) + continue + logger.info("Unpausing DAG %s in the target environment.", dag["dag_id"]) + client.unpause_dag(dag["dag_id"], target_environment_name) + logger.info("DAG %s unpaused.", dag["dag_id"]) logger.info("DAGs in the target environment unpaused.") logger.info("Migration complete.")