From 4e67f1cd5db49bf551cd78db94cf731c8bfa320b Mon Sep 17 00:00:00 2001 From: ddeleo Date: Wed, 18 Feb 2026 14:15:38 -0500 Subject: [PATCH 1/7] refactor: Migrate ComposerClient from gcloud shell commands to direct API calls using Google Auth and Requests. --- composer/tools/composer_migrate.py | 211 +++++++++++++++++------------ 1 file changed, 126 insertions(+), 85 deletions(-) diff --git a/composer/tools/composer_migrate.py b/composer/tools/composer_migrate.py index c4ef2fbb5f9..d1f0847b620 100644 --- a/composer/tools/composer_migrate.py +++ b/composer/tools/composer_migrate.py @@ -19,9 +19,13 @@ import json import math import pprint -import subprocess +import time from typing import Any, Dict, List +import google.auth +from google.auth.transport.requests import AuthorizedSession +import requests + import logging @@ -32,62 +36,71 @@ class ComposerClient: """Client for interacting with Composer API. - The client uses gcloud under the hood. + The client uses Google Auth and Requests under the hood. """ def __init__(self, project: str, location: str, sdk_endpoint: str) -> None: self.project = project self.location = location - self.sdk_endpoint = sdk_endpoint + self.sdk_endpoint = sdk_endpoint.rstrip("/") + self.credentials, _ = google.auth.default() + self.session = AuthorizedSession(self.credentials) def get_environment(self, environment_name: str) -> Any: """Returns an environment json for a given Composer environment.""" - command = ( - f"CLOUDSDK_API_ENDPOINT_OVERRIDES_COMPOSER={self.sdk_endpoint} gcloud" - " composer environments describe" - f" {environment_name} --project={self.project} --location={self.location} --format" - " json" + url = ( + f"{self.sdk_endpoint}/v1/projects/{self.project}/locations/" + f"{self.location}/environments/{environment_name}" ) - output = run_shell_command(command) - return json.loads(output) + response = self.session.get(url) + if response.status_code != 200: + raise RuntimeError( + f"Failed to get environment {environment_name}: {response.text}" + ) + return response.json() def create_environment_from_config(self, config: Any) -> Any: """Creates a Composer environment based on the given json config.""" - # Obtain access token through gcloud - access_token = run_shell_command("gcloud auth print-access-token") - - # gcloud does not support creating composer environments from json, so we - # need to use the API directly. - create_environment_command = ( - f"curl -s -X POST -H 'Authorization: Bearer {access_token}'" - " -H 'Content-Type: application/json'" - f" -d '{json.dumps(config)}'" - f" {self.sdk_endpoint}/v1/projects/{self.project}/locations/{self.location}/environments" - ) - output = run_shell_command(create_environment_command) - logging.info("Create environment operation: %s", output) - - # Poll create operation using gcloud. - operation_id = json.loads(output)["name"].split("/")[-1] - poll_operation_command = ( - f"CLOUDSDK_API_ENDPOINT_OVERRIDES_COMPOSER={self.sdk_endpoint} gcloud" - " composer operations wait" - f" {operation_id} --project={self.project} --location={self.location}" + url = ( + f"{self.sdk_endpoint}/v1/projects/{self.project}/locations/" + f"{self.location}/environments" ) - run_shell_command(poll_operation_command) - - def list_dags(self, environment_name: str) -> List[str]: + # Verify that the environment name is present in the config. + # The API expects the resource name in the format: + # projects/{project}/locations/{location}/environments/{environment_name} + if "name" not in config: + raise ValueError("Environment name is missing in the config.") + + # Extract environment ID from the full name if needed as query param, + # but the original code didn't use it, so we trust the body 'name' field. + # However, usually for Create, we might need environmentId query param if we want to specify it explicitly + # and it's not inferred. + # The original code did: POST .../environments with body. + + response = self.session.post(url, json=config) + if response.status_code != 200: + raise RuntimeError( + f"Failed to create environment: {response.text}" + ) + + operation = response.json() + logging.info("Create environment operation: %s", operation["name"]) + self._wait_for_operation(operation["name"]) + + + def list_dags(self, environment_name: str) -> List[Dict[str, Any]]: """Returns a list of DAGs in a given Composer environment.""" - command = ( - f"CLOUDSDK_API_ENDPOINT_OVERRIDES_COMPOSER={self.sdk_endpoint} gcloud" - " composer environments run" - f" {environment_name} --project={self.project} --location={self.location} dags" - " list -- -o json" - ) - output = run_shell_command(command) - # Output may contain text from top level print statements. - # The last line of the output is always a json array of DAGs. - return json.loads(output.splitlines()[-1]) + # Get authentication context and Airflow URI + environment = self.get_environment(environment_name) + airflow_uri = environment["config"]["airflowUri"] + + url = f"{airflow_uri}/api/v1/dags" + response = self.session.get(url) + if response.status_code != 200: + raise RuntimeError( + f"Failed to list DAGs: {response.text}" + ) + return response.json()["dags"] def pause_dag( self, @@ -95,13 +108,15 @@ def pause_dag( environment_name: str, ) -> Any: """Pauses a DAG in a Composer environment.""" - command = ( - f"CLOUDSDK_API_ENDPOINT_OVERRIDES_COMPOSER={self.sdk_endpoint} gcloud" - " composer environments run" - f" {environment_name} --project={self.project} --location={self.location} dags" - f" pause -- {dag_id}" - ) - run_shell_command(command) + environment = self.get_environment(environment_name) + airflow_uri = environment["config"]["airflowUri"] + + url = f"{airflow_uri}/api/v1/dags/{dag_id}" + response = self.session.patch(url, json={"is_paused": True}) + if response.status_code != 200: + raise RuntimeError( + f"Failed to pause DAG {dag_id}: {response.text}" + ) def unpause_dag( self, @@ -109,25 +124,32 @@ def unpause_dag( environment_name: str, ) -> Any: """Unpauses a DAG in a Composer environment.""" - command = ( - f"CLOUDSDK_API_ENDPOINT_OVERRIDES_COMPOSER={self.sdk_endpoint} gcloud" - " composer environments run" - f" {environment_name} --project={self.project} --location={self.location} dags" - f" unpause -- {dag_id}" - ) - run_shell_command(command) + environment = self.get_environment(environment_name) + airflow_uri = environment["config"]["airflowUri"] + + url = f"{airflow_uri}/api/v1/dags/{dag_id}" + response = self.session.patch(url, json={"is_paused": False}) + if response.status_code != 200: + raise RuntimeError( + f"Failed to unpause DAG {dag_id}: {response.text}" + ) def save_snapshot(self, environment_name: str) -> str: """Saves a snapshot of a Composer environment.""" - command = ( - f"CLOUDSDK_API_ENDPOINT_OVERRIDES_COMPOSER={self.sdk_endpoint} gcloud" - " composer" - " environments snapshots save" - f" {environment_name} --project={self.project}" - f" --location={self.location} --format=json" + url = ( + f"{self.sdk_endpoint}/v1/projects/{self.project}/locations/" + f"{self.location}/environments/{environment_name}:saveSnapshot" ) - output = run_shell_command(command) - return json.loads(output)["snapshotPath"] + response = self.session.post(url, json={}) + if response.status_code != 200: + raise RuntimeError( + f"Failed to initiate snapshot save: {response.text}" + ) + + operation = response.json() + logging.info("Save snapshot operation: %s", operation["name"]) + completed_operation = self._wait_for_operation(operation["name"]) + return completed_operation["response"]["snapshotPath"] def load_snapshot( self, @@ -135,25 +157,44 @@ def load_snapshot( snapshot_path: str, ) -> Any: """Loads a snapshot to a Composer environment.""" - command = ( - f"CLOUDSDK_API_ENDPOINT_OVERRIDES_COMPOSER={self.sdk_endpoint} gcloud" - " composer" - f" environments snapshots load {environment_name}" - f" --snapshot-path={snapshot_path} --project={self.project}" - f" --location={self.location} --format=json" + url = ( + f"{self.sdk_endpoint}/v1/projects/{self.project}/locations/" + f"{self.location}/environments/{environment_name}:loadSnapshot" ) - run_shell_command(command) - - -def run_shell_command(command: str, command_input: str = None) -> str: - """Executes shell command and returns its output.""" - p = subprocess.Popen(command, stdout=subprocess.PIPE, shell=True) - (res, _) = p.communicate(input=command_input) - output = str(res.decode().strip("\n")) - - if p.returncode: - raise RuntimeError(f"Failed to run shell command: {command}, details: {output}") - return output + response = self.session.post(url, json={"snapshotPath": snapshot_path}) + if response.status_code != 200: + raise RuntimeError( + f"Failed to initiate snapshot load: {response.text}" + ) + + operation = response.json() + logging.info("Load snapshot operation: %s", operation["name"]) + self._wait_for_operation(operation["name"]) + + + def _wait_for_operation(self, operation_name: str) -> Any: + """Waits for a long-running operation to complete.""" + # operation_name is distinct from operation_id. + # It is a full resource name: projects/.../locations/.../operations/... + + # We need to poll the operation status. + url = f"{self.sdk_endpoint}/v1/{operation_name}" + + while True: + response = self.session.get(url) + if response.status_code != 200: + raise RuntimeError( + f"Failed to get operation status: {response.text}" + ) + operation = response.json() + if "done" in operation and operation["done"]: + if "error" in operation: + raise RuntimeError(f"Operation failed: {operation['error']}") + logging.info("Operation completed successfully.") + return operation + + logging.info("Waiting for operation to complete...") + time.sleep(10) def get_target_cpu(source_cpu: float, max_cpu: float) -> float: @@ -395,7 +436,7 @@ def main( for dag in source_env_dags: if dag["dag_id"] == "airflow_monitoring": continue - if dag["is_paused"] == "True": + if dag["is_paused"]: logger.info("DAG %s is already paused.", dag["dag_id"]) continue logger.info("Pausing DAG %s in the source environment.", dag["dag_id"]) @@ -426,7 +467,7 @@ def main( for dag in source_env_dags: if dag["dag_id"] == "airflow_monitoring": continue - if dag["is_paused"] == "True": + if dag["is_paused"]: logger.info("DAG %s was paused in the source environment.", dag["dag_id"]) continue logger.info("Unpausing DAG %s in the target environment.", dag["dag_id"]) From fc2be091616878d3aec2d84c64f3941794c64bb6 Mon Sep 17 00:00:00 2001 From: ddeleo Date: Wed, 18 Feb 2026 15:04:29 -0500 Subject: [PATCH 2/7] feat: Add `pause_all_dags` method and update `main` to call it --- composer/tools/composer_migrate.py | 43 +++++++++++++++++++++++------- 1 file changed, 34 insertions(+), 9 deletions(-) diff --git a/composer/tools/composer_migrate.py b/composer/tools/composer_migrate.py index d1f0847b620..a081da0dc61 100644 --- a/composer/tools/composer_migrate.py +++ b/composer/tools/composer_migrate.py @@ -102,6 +102,7 @@ def list_dags(self, environment_name: str) -> List[Dict[str, Any]]: ) return response.json()["dags"] + def pause_dag( self, dag_id: str, @@ -117,6 +118,23 @@ def pause_dag( raise RuntimeError( f"Failed to pause DAG {dag_id}: {response.text}" ) + + + def pause_all_dags( + self, + environment_name: str, + ) -> Any: + """Pauses a DAG in a Composer environment.""" + environment = self.get_environment(environment_name) + airflow_uri = environment["config"]["airflowUri"] + + url = f"{airflow_uri}/api/v1/dags?dag_id_pattern=%" # Pause all DAGs using % as a wildcard + response = self.session.patch(url, json={"is_paused": True}) + if response.status_code != 200: + raise RuntimeError( + f"Failed to pause DAG {dag_id}: {response.text}" + ) + def unpause_dag( self, @@ -134,6 +152,21 @@ def unpause_dag( f"Failed to unpause DAG {dag_id}: {response.text}" ) + def unpause_all_dags( + self, + environment_name: str, + ) -> Any: + """Pauses a DAG in a Composer environment.""" + environment = self.get_environment(environment_name) + airflow_uri = environment["config"]["airflowUri"] + + url = f"{airflow_uri}/api/v1/dags?dag_id_pattern=%" # Pause all DAGs using % as a wildcard + response = self.session.patch(url, json={"is_paused": False}) + if response.status_code != 200: + raise RuntimeError( + f"Failed to pause DAG {dag_id}: {response.text}" + ) + def save_snapshot(self, environment_name: str) -> str: """Saves a snapshot of a Composer environment.""" url = ( @@ -433,15 +466,7 @@ def main( len(source_env_dags), source_env_dag_ids, ) - for dag in source_env_dags: - if dag["dag_id"] == "airflow_monitoring": - continue - if dag["is_paused"]: - logger.info("DAG %s is already paused.", dag["dag_id"]) - continue - logger.info("Pausing DAG %s in the source environment.", dag["dag_id"]) - client.pause_dag(dag["dag_id"], source_environment_name) - logger.info("DAG %s paused.", dag["dag_id"]) + client.pause_all_dags(source_environment_name) logger.info("All DAGs in the source environment paused.") # 4. Save snapshot of the source environment From 73ac8c5066f50286297248427855c15bff4ba173 Mon Sep 17 00:00:00 2001 From: ddeleo Date: Wed, 18 Feb 2026 15:35:06 -0500 Subject: [PATCH 3/7] feat: modify `pause_dag` and `unpause_dag` to operate on all DAGs and optimize `restore_dags_state` to use bulk unpause when possible. --- composer/tools/composer_migrate.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/composer/tools/composer_migrate.py b/composer/tools/composer_migrate.py index a081da0dc61..5b044c3d072 100644 --- a/composer/tools/composer_migrate.py +++ b/composer/tools/composer_migrate.py @@ -124,7 +124,7 @@ def pause_all_dags( self, environment_name: str, ) -> Any: - """Pauses a DAG in a Composer environment.""" + """Pauses all DAGs in a Composer environment.""" environment = self.get_environment(environment_name) airflow_uri = environment["config"]["airflowUri"] @@ -132,7 +132,7 @@ def pause_all_dags( response = self.session.patch(url, json={"is_paused": True}) if response.status_code != 200: raise RuntimeError( - f"Failed to pause DAG {dag_id}: {response.text}" + f"Failed to pause all DAGs: {response.text}" ) @@ -156,7 +156,7 @@ def unpause_all_dags( self, environment_name: str, ) -> Any: - """Pauses a DAG in a Composer environment.""" + """Unpauses all DAGs in a Composer environment.""" environment = self.get_environment(environment_name) airflow_uri = environment["config"]["airflowUri"] @@ -164,7 +164,7 @@ def unpause_all_dags( response = self.session.patch(url, json={"is_paused": False}) if response.status_code != 200: raise RuntimeError( - f"Failed to pause DAG {dag_id}: {response.text}" + f"Failed to unpause all DAGs: {response.text}" ) def save_snapshot(self, environment_name: str) -> str: @@ -489,15 +489,18 @@ def main( all_dags_present = set(source_env_dag_ids) == set(target_env_dag_ids) logger.info("List of DAGs in the target environment: %s", target_env_dag_ids) # Unpause only DAGs that were not paused in the source environment. - for dag in source_env_dags: - if dag["dag_id"] == "airflow_monitoring": - continue - if dag["is_paused"]: - logger.info("DAG %s was paused in the source environment.", dag["dag_id"]) - continue - logger.info("Unpausing DAG %s in the target environment.", dag["dag_id"]) - client.unpause_dag(dag["dag_id"], target_environment_name) - logger.info("DAG %s unpaused.", dag["dag_id"]) + # Optimization: if all DAGs were unpaused in source, use bulk unpause. + if not any(d["is_paused"] for d in source_env_dags): + logger.info("All DAGs were unpaused in source. Unpausing all DAGs in target.") + client.unpause_all_dags(target_environment_name) + else: + for dag in source_env_dags: + if dag["is_paused"]: + logger.info("DAG %s was paused in the source environment.", dag["dag_id"]) + continue + logger.info("Unpausing DAG %s in the target environment.", dag["dag_id"]) + client.unpause_dag(dag["dag_id"], target_environment_name) + logger.info("DAG %s unpaused.", dag["dag_id"]) logger.info("DAGs in the target environment unpaused.") logger.info("Migration complete.") From 24a6d83597563f43af068ab1b0e6b432d69753cb Mon Sep 17 00:00:00 2001 From: ddeleo Date: Wed, 18 Feb 2026 15:40:29 -0500 Subject: [PATCH 4/7] feat: gracefully handle existing environments during creation and add a 10-second delay after listing target DAGs. --- composer/tools/composer_migrate.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/composer/tools/composer_migrate.py b/composer/tools/composer_migrate.py index 5b044c3d072..1f305bc3a6b 100644 --- a/composer/tools/composer_migrate.py +++ b/composer/tools/composer_migrate.py @@ -78,6 +78,10 @@ def create_environment_from_config(self, config: Any) -> Any: # The original code did: POST .../environments with body. response = self.session.post(url, json=config) + if response.status_code == 409: + logger.info("Environment already exists, skipping creation.") + return + if response.status_code != 200: raise RuntimeError( f"Failed to create environment: {response.text}" @@ -488,6 +492,7 @@ def main( target_env_dag_ids = [dag["dag_id"] for dag in target_env_dags] all_dags_present = set(source_env_dag_ids) == set(target_env_dag_ids) logger.info("List of DAGs in the target environment: %s", target_env_dag_ids) + time.sleep(10) # Unpause only DAGs that were not paused in the source environment. # Optimization: if all DAGs were unpaused in source, use bulk unpause. if not any(d["is_paused"] for d in source_env_dags): From 7a8ee7030b6b9f125f42018db7b34a04c62abd41 Mon Sep 17 00:00:00 2001 From: ddeleo Date: Wed, 18 Feb 2026 16:38:01 -0500 Subject: [PATCH 5/7] feat: Log specific missing DAGs while waiting for all DAGs to appear in the target environment. --- composer/tools/composer_migrate.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/composer/tools/composer_migrate.py b/composer/tools/composer_migrate.py index 1f305bc3a6b..9dc285dd2a5 100644 --- a/composer/tools/composer_migrate.py +++ b/composer/tools/composer_migrate.py @@ -490,8 +490,12 @@ def main( while not all_dags_present: target_env_dags = client.list_dags(target_environment_name) target_env_dag_ids = [dag["dag_id"] for dag in target_env_dags] - all_dags_present = set(source_env_dag_ids) == set(target_env_dag_ids) - logger.info("List of DAGs in the target environment: %s", target_env_dag_ids) + missing_dags = set(source_env_dag_ids) - set(target_env_dag_ids) + all_dags_present = not missing_dags + if missing_dags: + logger.info("Waiting for DAGs to appear in target: %s", missing_dags) + else: + logger.info("All DAGs present in target environment.") time.sleep(10) # Unpause only DAGs that were not paused in the source environment. # Optimization: if all DAGs were unpaused in source, use bulk unpause. From eb73dd4fdf254c0b4c16abdcb07029727144e491 Mon Sep 17 00:00:00 2001 From: ddeleo Date: Wed, 18 Feb 2026 22:26:09 -0500 Subject: [PATCH 6/7] refactor: Cache Airflow URIs to optimize environment lookups and reduce redundant API calls. --- composer/tools/composer_migrate.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/composer/tools/composer_migrate.py b/composer/tools/composer_migrate.py index 9dc285dd2a5..8e84c424879 100644 --- a/composer/tools/composer_migrate.py +++ b/composer/tools/composer_migrate.py @@ -45,6 +45,14 @@ def __init__(self, project: str, location: str, sdk_endpoint: str) -> None: self.sdk_endpoint = sdk_endpoint.rstrip("/") self.credentials, _ = google.auth.default() self.session = AuthorizedSession(self.credentials) + self._airflow_uris = {} + + def _get_airflow_uri(self, environment_name: str) -> str: + """Returns the Airflow URI for a given environment, caching the result.""" + if environment_name not in self._airflow_uris: + environment = self.get_environment(environment_name) + self._airflow_uris[environment_name] = environment["config"]["airflowUri"] + return self._airflow_uris[environment_name] def get_environment(self, environment_name: str) -> Any: """Returns an environment json for a given Composer environment.""" @@ -94,9 +102,7 @@ def create_environment_from_config(self, config: Any) -> Any: def list_dags(self, environment_name: str) -> List[Dict[str, Any]]: """Returns a list of DAGs in a given Composer environment.""" - # Get authentication context and Airflow URI - environment = self.get_environment(environment_name) - airflow_uri = environment["config"]["airflowUri"] + airflow_uri = self._get_airflow_uri(environment_name) url = f"{airflow_uri}/api/v1/dags" response = self.session.get(url) @@ -113,8 +119,7 @@ def pause_dag( environment_name: str, ) -> Any: """Pauses a DAG in a Composer environment.""" - environment = self.get_environment(environment_name) - airflow_uri = environment["config"]["airflowUri"] + airflow_uri = self._get_airflow_uri(environment_name) url = f"{airflow_uri}/api/v1/dags/{dag_id}" response = self.session.patch(url, json={"is_paused": True}) @@ -129,8 +134,7 @@ def pause_all_dags( environment_name: str, ) -> Any: """Pauses all DAGs in a Composer environment.""" - environment = self.get_environment(environment_name) - airflow_uri = environment["config"]["airflowUri"] + airflow_uri = self._get_airflow_uri(environment_name) url = f"{airflow_uri}/api/v1/dags?dag_id_pattern=%" # Pause all DAGs using % as a wildcard response = self.session.patch(url, json={"is_paused": True}) @@ -146,8 +150,7 @@ def unpause_dag( environment_name: str, ) -> Any: """Unpauses a DAG in a Composer environment.""" - environment = self.get_environment(environment_name) - airflow_uri = environment["config"]["airflowUri"] + airflow_uri = self._get_airflow_uri(environment_name) url = f"{airflow_uri}/api/v1/dags/{dag_id}" response = self.session.patch(url, json={"is_paused": False}) @@ -161,8 +164,7 @@ def unpause_all_dags( environment_name: str, ) -> Any: """Unpauses all DAGs in a Composer environment.""" - environment = self.get_environment(environment_name) - airflow_uri = environment["config"]["airflowUri"] + airflow_uri = self._get_airflow_uri(environment_name) url = f"{airflow_uri}/api/v1/dags?dag_id_pattern=%" # Pause all DAGs using % as a wildcard response = self.session.patch(url, json={"is_paused": False}) From e2bca55417dfc80d9d644ffa95145b7c5f4466a2 Mon Sep 17 00:00:00 2001 From: Daniel De Leo Date: Wed, 18 Feb 2026 22:29:36 -0500 Subject: [PATCH 7/7] Apply suggestion from @gemini-code-assist[bot] Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- composer/tools/composer_migrate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/composer/tools/composer_migrate.py b/composer/tools/composer_migrate.py index 8e84c424879..a2ef5ae9be5 100644 --- a/composer/tools/composer_migrate.py +++ b/composer/tools/composer_migrate.py @@ -96,7 +96,7 @@ def create_environment_from_config(self, config: Any) -> Any: ) operation = response.json() - logging.info("Create environment operation: %s", operation["name"]) + logger.info("Create environment operation: %s", operation["name"]) self._wait_for_operation(operation["name"])