Skip to content
286 changes: 183 additions & 103 deletions composer/tools/composer_migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,13 @@
import json
import math
import pprint
import subprocess
import time
from typing import Any, Dict, List

import google.auth
from google.auth.transport.requests import AuthorizedSession
import requests

import logging


Expand All @@ -32,128 +36,204 @@
class ComposerClient:
"""Client for interacting with Composer API.

The client uses gcloud under the hood.
The client uses Google Auth and Requests under the hood.
"""

def __init__(self, project: str, location: str, sdk_endpoint: str) -> None:
self.project = project
self.location = location
self.sdk_endpoint = sdk_endpoint
self.sdk_endpoint = sdk_endpoint.rstrip("/")
self.credentials, _ = google.auth.default()
self.session = AuthorizedSession(self.credentials)
self._airflow_uris = {}

def _get_airflow_uri(self, environment_name: str) -> str:
"""Returns the Airflow URI for a given environment, caching the result."""
if environment_name not in self._airflow_uris:
environment = self.get_environment(environment_name)
self._airflow_uris[environment_name] = environment["config"]["airflowUri"]
return self._airflow_uris[environment_name]

def get_environment(self, environment_name: str) -> Any:
"""Returns an environment json for a given Composer environment."""
command = (
f"CLOUDSDK_API_ENDPOINT_OVERRIDES_COMPOSER={self.sdk_endpoint} gcloud"
" composer environments describe"
f" {environment_name} --project={self.project} --location={self.location} --format"
" json"
url = (
f"{self.sdk_endpoint}/v1/projects/{self.project}/locations/"
f"{self.location}/environments/{environment_name}"
)
output = run_shell_command(command)
return json.loads(output)
response = self.session.get(url)
if response.status_code != 200:
raise RuntimeError(
f"Failed to get environment {environment_name}: {response.text}"
)
return response.json()

def create_environment_from_config(self, config: Any) -> Any:
"""Creates a Composer environment based on the given json config."""
# Obtain access token through gcloud
access_token = run_shell_command("gcloud auth print-access-token")

# gcloud does not support creating composer environments from json, so we
# need to use the API directly.
create_environment_command = (
f"curl -s -X POST -H 'Authorization: Bearer {access_token}'"
" -H 'Content-Type: application/json'"
f" -d '{json.dumps(config)}'"
f" {self.sdk_endpoint}/v1/projects/{self.project}/locations/{self.location}/environments"
)
output = run_shell_command(create_environment_command)
logging.info("Create environment operation: %s", output)

# Poll create operation using gcloud.
operation_id = json.loads(output)["name"].split("/")[-1]
poll_operation_command = (
f"CLOUDSDK_API_ENDPOINT_OVERRIDES_COMPOSER={self.sdk_endpoint} gcloud"
" composer operations wait"
f" {operation_id} --project={self.project} --location={self.location}"
url = (
f"{self.sdk_endpoint}/v1/projects/{self.project}/locations/"
f"{self.location}/environments"
)
run_shell_command(poll_operation_command)

def list_dags(self, environment_name: str) -> List[str]:
# Verify that the environment name is present in the config.
# The API expects the resource name in the format:
# projects/{project}/locations/{location}/environments/{environment_name}
if "name" not in config:
raise ValueError("Environment name is missing in the config.")

# Extract environment ID from the full name if needed as query param,
# but the original code didn't use it, so we trust the body 'name' field.
# However, usually for Create, we might need environmentId query param if we want to specify it explicitly
# and it's not inferred.
# The original code did: POST .../environments with body.

response = self.session.post(url, json=config)
if response.status_code == 409:
logger.info("Environment already exists, skipping creation.")
return

if response.status_code != 200:
raise RuntimeError(
f"Failed to create environment: {response.text}"
)

operation = response.json()
logger.info("Create environment operation: %s", operation["name"])
self._wait_for_operation(operation["name"])


def list_dags(self, environment_name: str) -> List[Dict[str, Any]]:
"""Returns a list of DAGs in a given Composer environment."""
command = (
f"CLOUDSDK_API_ENDPOINT_OVERRIDES_COMPOSER={self.sdk_endpoint} gcloud"
" composer environments run"
f" {environment_name} --project={self.project} --location={self.location} dags"
" list -- -o json"
)
output = run_shell_command(command)
# Output may contain text from top level print statements.
# The last line of the output is always a json array of DAGs.
return json.loads(output.splitlines()[-1])
airflow_uri = self._get_airflow_uri(environment_name)

url = f"{airflow_uri}/api/v1/dags"
response = self.session.get(url)
if response.status_code != 200:
raise RuntimeError(
f"Failed to list DAGs: {response.text}"
)
return response.json()["dags"]


def pause_dag(
self,
dag_id: str,
environment_name: str,
) -> Any:
"""Pauses a DAG in a Composer environment."""
command = (
f"CLOUDSDK_API_ENDPOINT_OVERRIDES_COMPOSER={self.sdk_endpoint} gcloud"
" composer environments run"
f" {environment_name} --project={self.project} --location={self.location} dags"
f" pause -- {dag_id}"
)
run_shell_command(command)
airflow_uri = self._get_airflow_uri(environment_name)

url = f"{airflow_uri}/api/v1/dags/{dag_id}"
response = self.session.patch(url, json={"is_paused": True})
if response.status_code != 200:
raise RuntimeError(
f"Failed to pause DAG {dag_id}: {response.text}"
)


def pause_all_dags(
self,
environment_name: str,
) -> Any:
"""Pauses all DAGs in a Composer environment."""
airflow_uri = self._get_airflow_uri(environment_name)

url = f"{airflow_uri}/api/v1/dags?dag_id_pattern=%" # Pause all DAGs using % as a wildcard
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The Airflow REST API's dag_id_pattern parameter expects a glob expression. The % character is not a standard glob wildcard; * should be used to match all DAGs. Using % will likely result in no DAGs being matched, causing this function to fail silently.

Suggested change
url = f"{airflow_uri}/api/v1/dags?dag_id_pattern=%" # Pause all DAGs using % as a wildcard
url = f"{airflow_uri}/api/v1/dags?dag_id_pattern=*" # Pause all DAGs using * as a wildcard

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry gemini, this suggestion is just incorrect https://github.com/apache/airflow/blob/4404bc05b3e77bf1c50219ba2ec1da5ef560a684/airflow-core/src/airflow/api_fastapi/common/parameters.py#L253 Airflow clearly states regular expressions are not supported and that you should use % and _ wildcards

response = self.session.patch(url, json={"is_paused": True})
if response.status_code != 200:
raise RuntimeError(
f"Failed to pause all DAGs: {response.text}"
)


def unpause_dag(
self,
dag_id: str,
environment_name: str,
) -> Any:
"""Unpauses a DAG in a Composer environment."""
command = (
f"CLOUDSDK_API_ENDPOINT_OVERRIDES_COMPOSER={self.sdk_endpoint} gcloud"
" composer environments run"
f" {environment_name} --project={self.project} --location={self.location} dags"
f" unpause -- {dag_id}"
)
run_shell_command(command)
airflow_uri = self._get_airflow_uri(environment_name)

url = f"{airflow_uri}/api/v1/dags/{dag_id}"
response = self.session.patch(url, json={"is_paused": False})
if response.status_code != 200:
raise RuntimeError(
f"Failed to unpause DAG {dag_id}: {response.text}"
)

def unpause_all_dags(
self,
environment_name: str,
) -> Any:
"""Unpauses all DAGs in a Composer environment."""
airflow_uri = self._get_airflow_uri(environment_name)

url = f"{airflow_uri}/api/v1/dags?dag_id_pattern=%" # Pause all DAGs using % as a wildcard
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to pause_all_dags, the dag_id_pattern parameter for the Airflow REST API expects a glob expression. Please use * instead of % to correctly match all DAGs.

Suggested change
url = f"{airflow_uri}/api/v1/dags?dag_id_pattern=%" # Pause all DAGs using % as a wildcard
url = f"{airflow_uri}/api/v1/dags?dag_id_pattern=*" # Unpause all DAGs using * as a wildcard

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again this suggestion is incorrect https://github.com/apache/airflow/blob/4404bc05b3e77bf1c50219ba2ec1da5ef560a684/airflow-core/src/airflow/api_fastapi/common/parameters.py#L253 Airflow clearly states regular expressions are not supported and that you should use % and _ wildcards

response = self.session.patch(url, json={"is_paused": False})
if response.status_code != 200:
raise RuntimeError(
f"Failed to unpause all DAGs: {response.text}"
)

def save_snapshot(self, environment_name: str) -> str:
"""Saves a snapshot of a Composer environment."""
command = (
f"CLOUDSDK_API_ENDPOINT_OVERRIDES_COMPOSER={self.sdk_endpoint} gcloud"
" composer"
" environments snapshots save"
f" {environment_name} --project={self.project}"
f" --location={self.location} --format=json"
url = (
f"{self.sdk_endpoint}/v1/projects/{self.project}/locations/"
f"{self.location}/environments/{environment_name}:saveSnapshot"
)
output = run_shell_command(command)
return json.loads(output)["snapshotPath"]
response = self.session.post(url, json={})
if response.status_code != 200:
raise RuntimeError(
f"Failed to initiate snapshot save: {response.text}"
)

operation = response.json()
logging.info("Save snapshot operation: %s", operation["name"])
completed_operation = self._wait_for_operation(operation["name"])
return completed_operation["response"]["snapshotPath"]

def load_snapshot(
self,
environment_name: str,
snapshot_path: str,
) -> Any:
"""Loads a snapshot to a Composer environment."""
command = (
f"CLOUDSDK_API_ENDPOINT_OVERRIDES_COMPOSER={self.sdk_endpoint} gcloud"
" composer"
f" environments snapshots load {environment_name}"
f" --snapshot-path={snapshot_path} --project={self.project}"
f" --location={self.location} --format=json"
url = (
f"{self.sdk_endpoint}/v1/projects/{self.project}/locations/"
f"{self.location}/environments/{environment_name}:loadSnapshot"
)
run_shell_command(command)


def run_shell_command(command: str, command_input: str = None) -> str:
"""Executes shell command and returns its output."""
p = subprocess.Popen(command, stdout=subprocess.PIPE, shell=True)
(res, _) = p.communicate(input=command_input)
output = str(res.decode().strip("\n"))

if p.returncode:
raise RuntimeError(f"Failed to run shell command: {command}, details: {output}")
return output
response = self.session.post(url, json={"snapshotPath": snapshot_path})
if response.status_code != 200:
raise RuntimeError(
f"Failed to initiate snapshot load: {response.text}"
)

operation = response.json()
logging.info("Load snapshot operation: %s", operation["name"])
self._wait_for_operation(operation["name"])


def _wait_for_operation(self, operation_name: str) -> Any:
"""Waits for a long-running operation to complete."""
# operation_name is distinct from operation_id.
# It is a full resource name: projects/.../locations/.../operations/...

# We need to poll the operation status.
url = f"{self.sdk_endpoint}/v1/{operation_name}"

while True:
response = self.session.get(url)
if response.status_code != 200:
raise RuntimeError(
f"Failed to get operation status: {response.text}"
)
operation = response.json()
if "done" in operation and operation["done"]:
if "error" in operation:
raise RuntimeError(f"Operation failed: {operation['error']}")
logging.info("Operation completed successfully.")
return operation

logging.info("Waiting for operation to complete...")
time.sleep(10)


def get_target_cpu(source_cpu: float, max_cpu: float) -> float:
Expand Down Expand Up @@ -392,15 +472,7 @@ def main(
len(source_env_dags),
source_env_dag_ids,
)
for dag in source_env_dags:
if dag["dag_id"] == "airflow_monitoring":
continue
if dag["is_paused"] == "True":
logger.info("DAG %s is already paused.", dag["dag_id"])
continue
logger.info("Pausing DAG %s in the source environment.", dag["dag_id"])
client.pause_dag(dag["dag_id"], source_environment_name)
logger.info("DAG %s paused.", dag["dag_id"])
client.pause_all_dags(source_environment_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The previous implementation explicitly skipped the airflow_monitoring DAG when pausing DAGs. The new bulk operation will attempt to pause all DAGs, including this protected system DAG. This could lead to errors if the API call fails for this specific DAG and aborts the entire batch operation. Please ensure the bulk operation handles this case gracefully. This concern also applies to the unpausing logic.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

airflow_monitoring is safe to pause because Composer automatically unpauses it shortly afterwards

logger.info("All DAGs in the source environment paused.")

# 4. Save snapshot of the source environment
Expand All @@ -420,18 +492,26 @@ def main(
while not all_dags_present:
target_env_dags = client.list_dags(target_environment_name)
target_env_dag_ids = [dag["dag_id"] for dag in target_env_dags]
all_dags_present = set(source_env_dag_ids) == set(target_env_dag_ids)
logger.info("List of DAGs in the target environment: %s", target_env_dag_ids)
missing_dags = set(source_env_dag_ids) - set(target_env_dag_ids)
all_dags_present = not missing_dags
if missing_dags:
logger.info("Waiting for DAGs to appear in target: %s", missing_dags)
else:
logger.info("All DAGs present in target environment.")
time.sleep(10)
# Unpause only DAGs that were not paused in the source environment.
for dag in source_env_dags:
if dag["dag_id"] == "airflow_monitoring":
continue
if dag["is_paused"] == "True":
logger.info("DAG %s was paused in the source environment.", dag["dag_id"])
continue
logger.info("Unpausing DAG %s in the target environment.", dag["dag_id"])
client.unpause_dag(dag["dag_id"], target_environment_name)
logger.info("DAG %s unpaused.", dag["dag_id"])
# Optimization: if all DAGs were unpaused in source, use bulk unpause.
if not any(d["is_paused"] for d in source_env_dags):
logger.info("All DAGs were unpaused in source. Unpausing all DAGs in target.")
client.unpause_all_dags(target_environment_name)
else:
for dag in source_env_dags:
if dag["is_paused"]:
logger.info("DAG %s was paused in the source environment.", dag["dag_id"])
continue
logger.info("Unpausing DAG %s in the target environment.", dag["dag_id"])
client.unpause_dag(dag["dag_id"], target_environment_name)
logger.info("DAG %s unpaused.", dag["dag_id"])
logger.info("DAGs in the target environment unpaused.")

logger.info("Migration complete.")
Expand Down