-
Notifications
You must be signed in to change notification settings - Fork 6.7k
Improve composer 2 to composer 3 migration script with various optimizations #13840
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
4e67f1c
fc2be09
73ac8c5
24a6d83
7a8ee70
eb73dd4
e2bca55
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -19,9 +19,13 @@ | |||||
| import json | ||||||
| import math | ||||||
| import pprint | ||||||
| import subprocess | ||||||
| import time | ||||||
| from typing import Any, Dict, List | ||||||
|
|
||||||
| import google.auth | ||||||
| from google.auth.transport.requests import AuthorizedSession | ||||||
| import requests | ||||||
|
|
||||||
| import logging | ||||||
|
|
||||||
|
|
||||||
|
|
@@ -32,128 +36,204 @@ | |||||
| class ComposerClient: | ||||||
| """Client for interacting with Composer API. | ||||||
|
|
||||||
| The client uses gcloud under the hood. | ||||||
| The client uses Google Auth and Requests under the hood. | ||||||
| """ | ||||||
|
|
||||||
| def __init__(self, project: str, location: str, sdk_endpoint: str) -> None: | ||||||
| self.project = project | ||||||
| self.location = location | ||||||
| self.sdk_endpoint = sdk_endpoint | ||||||
| self.sdk_endpoint = sdk_endpoint.rstrip("/") | ||||||
| self.credentials, _ = google.auth.default() | ||||||
| self.session = AuthorizedSession(self.credentials) | ||||||
| self._airflow_uris = {} | ||||||
|
|
||||||
| def _get_airflow_uri(self, environment_name: str) -> str: | ||||||
| """Returns the Airflow URI for a given environment, caching the result.""" | ||||||
| if environment_name not in self._airflow_uris: | ||||||
| environment = self.get_environment(environment_name) | ||||||
| self._airflow_uris[environment_name] = environment["config"]["airflowUri"] | ||||||
| return self._airflow_uris[environment_name] | ||||||
|
|
||||||
| def get_environment(self, environment_name: str) -> Any: | ||||||
| """Returns an environment json for a given Composer environment.""" | ||||||
| command = ( | ||||||
| f"CLOUDSDK_API_ENDPOINT_OVERRIDES_COMPOSER={self.sdk_endpoint} gcloud" | ||||||
| " composer environments describe" | ||||||
| f" {environment_name} --project={self.project} --location={self.location} --format" | ||||||
| " json" | ||||||
| url = ( | ||||||
| f"{self.sdk_endpoint}/v1/projects/{self.project}/locations/" | ||||||
| f"{self.location}/environments/{environment_name}" | ||||||
| ) | ||||||
| output = run_shell_command(command) | ||||||
| return json.loads(output) | ||||||
| response = self.session.get(url) | ||||||
| if response.status_code != 200: | ||||||
| raise RuntimeError( | ||||||
| f"Failed to get environment {environment_name}: {response.text}" | ||||||
| ) | ||||||
| return response.json() | ||||||
|
|
||||||
| def create_environment_from_config(self, config: Any) -> Any: | ||||||
| """Creates a Composer environment based on the given json config.""" | ||||||
| # Obtain access token through gcloud | ||||||
| access_token = run_shell_command("gcloud auth print-access-token") | ||||||
|
|
||||||
| # gcloud does not support creating composer environments from json, so we | ||||||
| # need to use the API directly. | ||||||
| create_environment_command = ( | ||||||
| f"curl -s -X POST -H 'Authorization: Bearer {access_token}'" | ||||||
| " -H 'Content-Type: application/json'" | ||||||
| f" -d '{json.dumps(config)}'" | ||||||
| f" {self.sdk_endpoint}/v1/projects/{self.project}/locations/{self.location}/environments" | ||||||
| ) | ||||||
| output = run_shell_command(create_environment_command) | ||||||
| logging.info("Create environment operation: %s", output) | ||||||
|
|
||||||
| # Poll create operation using gcloud. | ||||||
| operation_id = json.loads(output)["name"].split("/")[-1] | ||||||
| poll_operation_command = ( | ||||||
| f"CLOUDSDK_API_ENDPOINT_OVERRIDES_COMPOSER={self.sdk_endpoint} gcloud" | ||||||
| " composer operations wait" | ||||||
| f" {operation_id} --project={self.project} --location={self.location}" | ||||||
| url = ( | ||||||
| f"{self.sdk_endpoint}/v1/projects/{self.project}/locations/" | ||||||
| f"{self.location}/environments" | ||||||
| ) | ||||||
| run_shell_command(poll_operation_command) | ||||||
|
|
||||||
| def list_dags(self, environment_name: str) -> List[str]: | ||||||
| # Verify that the environment name is present in the config. | ||||||
| # The API expects the resource name in the format: | ||||||
| # projects/{project}/locations/{location}/environments/{environment_name} | ||||||
| if "name" not in config: | ||||||
| raise ValueError("Environment name is missing in the config.") | ||||||
|
|
||||||
| # Extract environment ID from the full name if needed as query param, | ||||||
| # but the original code didn't use it, so we trust the body 'name' field. | ||||||
| # However, usually for Create, we might need environmentId query param if we want to specify it explicitly | ||||||
| # and it's not inferred. | ||||||
| # The original code did: POST .../environments with body. | ||||||
|
|
||||||
| response = self.session.post(url, json=config) | ||||||
| if response.status_code == 409: | ||||||
| logger.info("Environment already exists, skipping creation.") | ||||||
| return | ||||||
|
|
||||||
| if response.status_code != 200: | ||||||
| raise RuntimeError( | ||||||
| f"Failed to create environment: {response.text}" | ||||||
| ) | ||||||
|
|
||||||
| operation = response.json() | ||||||
| logger.info("Create environment operation: %s", operation["name"]) | ||||||
| self._wait_for_operation(operation["name"]) | ||||||
|
|
||||||
|
|
||||||
| def list_dags(self, environment_name: str) -> List[Dict[str, Any]]: | ||||||
| """Returns a list of DAGs in a given Composer environment.""" | ||||||
| command = ( | ||||||
| f"CLOUDSDK_API_ENDPOINT_OVERRIDES_COMPOSER={self.sdk_endpoint} gcloud" | ||||||
| " composer environments run" | ||||||
| f" {environment_name} --project={self.project} --location={self.location} dags" | ||||||
| " list -- -o json" | ||||||
| ) | ||||||
| output = run_shell_command(command) | ||||||
| # Output may contain text from top level print statements. | ||||||
| # The last line of the output is always a json array of DAGs. | ||||||
| return json.loads(output.splitlines()[-1]) | ||||||
| airflow_uri = self._get_airflow_uri(environment_name) | ||||||
|
|
||||||
| url = f"{airflow_uri}/api/v1/dags" | ||||||
| response = self.session.get(url) | ||||||
| if response.status_code != 200: | ||||||
| raise RuntimeError( | ||||||
| f"Failed to list DAGs: {response.text}" | ||||||
| ) | ||||||
| return response.json()["dags"] | ||||||
|
|
||||||
|
|
||||||
| def pause_dag( | ||||||
| self, | ||||||
| dag_id: str, | ||||||
| environment_name: str, | ||||||
| ) -> Any: | ||||||
| """Pauses a DAG in a Composer environment.""" | ||||||
| command = ( | ||||||
| f"CLOUDSDK_API_ENDPOINT_OVERRIDES_COMPOSER={self.sdk_endpoint} gcloud" | ||||||
| " composer environments run" | ||||||
| f" {environment_name} --project={self.project} --location={self.location} dags" | ||||||
| f" pause -- {dag_id}" | ||||||
| ) | ||||||
| run_shell_command(command) | ||||||
| airflow_uri = self._get_airflow_uri(environment_name) | ||||||
|
|
||||||
| url = f"{airflow_uri}/api/v1/dags/{dag_id}" | ||||||
| response = self.session.patch(url, json={"is_paused": True}) | ||||||
| if response.status_code != 200: | ||||||
| raise RuntimeError( | ||||||
| f"Failed to pause DAG {dag_id}: {response.text}" | ||||||
| ) | ||||||
|
|
||||||
|
|
||||||
| def pause_all_dags( | ||||||
| self, | ||||||
| environment_name: str, | ||||||
| ) -> Any: | ||||||
| """Pauses all DAGs in a Composer environment.""" | ||||||
| airflow_uri = self._get_airflow_uri(environment_name) | ||||||
|
|
||||||
| url = f"{airflow_uri}/api/v1/dags?dag_id_pattern=%" # Pause all DAGs using % as a wildcard | ||||||
| response = self.session.patch(url, json={"is_paused": True}) | ||||||
| if response.status_code != 200: | ||||||
| raise RuntimeError( | ||||||
| f"Failed to pause all DAGs: {response.text}" | ||||||
| ) | ||||||
|
|
||||||
|
|
||||||
| def unpause_dag( | ||||||
| self, | ||||||
| dag_id: str, | ||||||
| environment_name: str, | ||||||
| ) -> Any: | ||||||
| """Unpauses a DAG in a Composer environment.""" | ||||||
| command = ( | ||||||
| f"CLOUDSDK_API_ENDPOINT_OVERRIDES_COMPOSER={self.sdk_endpoint} gcloud" | ||||||
| " composer environments run" | ||||||
| f" {environment_name} --project={self.project} --location={self.location} dags" | ||||||
| f" unpause -- {dag_id}" | ||||||
| ) | ||||||
| run_shell_command(command) | ||||||
| airflow_uri = self._get_airflow_uri(environment_name) | ||||||
|
|
||||||
| url = f"{airflow_uri}/api/v1/dags/{dag_id}" | ||||||
| response = self.session.patch(url, json={"is_paused": False}) | ||||||
| if response.status_code != 200: | ||||||
| raise RuntimeError( | ||||||
| f"Failed to unpause DAG {dag_id}: {response.text}" | ||||||
| ) | ||||||
|
|
||||||
| def unpause_all_dags( | ||||||
| self, | ||||||
| environment_name: str, | ||||||
| ) -> Any: | ||||||
| """Unpauses all DAGs in a Composer environment.""" | ||||||
| airflow_uri = self._get_airflow_uri(environment_name) | ||||||
|
|
||||||
| url = f"{airflow_uri}/api/v1/dags?dag_id_pattern=%" # Pause all DAGs using % as a wildcard | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to
Suggested change
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again this suggestion is incorrect https://github.com/apache/airflow/blob/4404bc05b3e77bf1c50219ba2ec1da5ef560a684/airflow-core/src/airflow/api_fastapi/common/parameters.py#L253 Airflow clearly states regular expressions are not supported and that you should use % and _ wildcards |
||||||
| response = self.session.patch(url, json={"is_paused": False}) | ||||||
| if response.status_code != 200: | ||||||
| raise RuntimeError( | ||||||
| f"Failed to unpause all DAGs: {response.text}" | ||||||
| ) | ||||||
|
|
||||||
| def save_snapshot(self, environment_name: str) -> str: | ||||||
| """Saves a snapshot of a Composer environment.""" | ||||||
| command = ( | ||||||
| f"CLOUDSDK_API_ENDPOINT_OVERRIDES_COMPOSER={self.sdk_endpoint} gcloud" | ||||||
| " composer" | ||||||
| " environments snapshots save" | ||||||
| f" {environment_name} --project={self.project}" | ||||||
| f" --location={self.location} --format=json" | ||||||
| url = ( | ||||||
| f"{self.sdk_endpoint}/v1/projects/{self.project}/locations/" | ||||||
| f"{self.location}/environments/{environment_name}:saveSnapshot" | ||||||
| ) | ||||||
| output = run_shell_command(command) | ||||||
| return json.loads(output)["snapshotPath"] | ||||||
| response = self.session.post(url, json={}) | ||||||
| if response.status_code != 200: | ||||||
| raise RuntimeError( | ||||||
| f"Failed to initiate snapshot save: {response.text}" | ||||||
| ) | ||||||
|
|
||||||
| operation = response.json() | ||||||
| logging.info("Save snapshot operation: %s", operation["name"]) | ||||||
| completed_operation = self._wait_for_operation(operation["name"]) | ||||||
| return completed_operation["response"]["snapshotPath"] | ||||||
|
|
||||||
| def load_snapshot( | ||||||
| self, | ||||||
| environment_name: str, | ||||||
| snapshot_path: str, | ||||||
| ) -> Any: | ||||||
| """Loads a snapshot to a Composer environment.""" | ||||||
| command = ( | ||||||
| f"CLOUDSDK_API_ENDPOINT_OVERRIDES_COMPOSER={self.sdk_endpoint} gcloud" | ||||||
| " composer" | ||||||
| f" environments snapshots load {environment_name}" | ||||||
| f" --snapshot-path={snapshot_path} --project={self.project}" | ||||||
| f" --location={self.location} --format=json" | ||||||
| url = ( | ||||||
| f"{self.sdk_endpoint}/v1/projects/{self.project}/locations/" | ||||||
| f"{self.location}/environments/{environment_name}:loadSnapshot" | ||||||
| ) | ||||||
| run_shell_command(command) | ||||||
|
|
||||||
|
|
||||||
| def run_shell_command(command: str, command_input: str = None) -> str: | ||||||
| """Executes shell command and returns its output.""" | ||||||
| p = subprocess.Popen(command, stdout=subprocess.PIPE, shell=True) | ||||||
| (res, _) = p.communicate(input=command_input) | ||||||
| output = str(res.decode().strip("\n")) | ||||||
|
|
||||||
| if p.returncode: | ||||||
| raise RuntimeError(f"Failed to run shell command: {command}, details: {output}") | ||||||
| return output | ||||||
| response = self.session.post(url, json={"snapshotPath": snapshot_path}) | ||||||
| if response.status_code != 200: | ||||||
| raise RuntimeError( | ||||||
| f"Failed to initiate snapshot load: {response.text}" | ||||||
| ) | ||||||
|
|
||||||
| operation = response.json() | ||||||
| logging.info("Load snapshot operation: %s", operation["name"]) | ||||||
| self._wait_for_operation(operation["name"]) | ||||||
|
|
||||||
|
|
||||||
| def _wait_for_operation(self, operation_name: str) -> Any: | ||||||
| """Waits for a long-running operation to complete.""" | ||||||
| # operation_name is distinct from operation_id. | ||||||
| # It is a full resource name: projects/.../locations/.../operations/... | ||||||
|
|
||||||
| # We need to poll the operation status. | ||||||
| url = f"{self.sdk_endpoint}/v1/{operation_name}" | ||||||
|
|
||||||
| while True: | ||||||
| response = self.session.get(url) | ||||||
| if response.status_code != 200: | ||||||
| raise RuntimeError( | ||||||
| f"Failed to get operation status: {response.text}" | ||||||
| ) | ||||||
| operation = response.json() | ||||||
| if "done" in operation and operation["done"]: | ||||||
| if "error" in operation: | ||||||
| raise RuntimeError(f"Operation failed: {operation['error']}") | ||||||
| logging.info("Operation completed successfully.") | ||||||
| return operation | ||||||
|
|
||||||
| logging.info("Waiting for operation to complete...") | ||||||
| time.sleep(10) | ||||||
|
|
||||||
|
|
||||||
| def get_target_cpu(source_cpu: float, max_cpu: float) -> float: | ||||||
|
|
@@ -392,15 +472,7 @@ def main( | |||||
| len(source_env_dags), | ||||||
| source_env_dag_ids, | ||||||
| ) | ||||||
| for dag in source_env_dags: | ||||||
| if dag["dag_id"] == "airflow_monitoring": | ||||||
| continue | ||||||
| if dag["is_paused"] == "True": | ||||||
| logger.info("DAG %s is already paused.", dag["dag_id"]) | ||||||
| continue | ||||||
| logger.info("Pausing DAG %s in the source environment.", dag["dag_id"]) | ||||||
| client.pause_dag(dag["dag_id"], source_environment_name) | ||||||
| logger.info("DAG %s paused.", dag["dag_id"]) | ||||||
| client.pause_all_dags(source_environment_name) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The previous implementation explicitly skipped the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. airflow_monitoring is safe to pause because Composer automatically unpauses it shortly afterwards |
||||||
| logger.info("All DAGs in the source environment paused.") | ||||||
|
|
||||||
| # 4. Save snapshot of the source environment | ||||||
|
|
@@ -420,18 +492,26 @@ def main( | |||||
| while not all_dags_present: | ||||||
| target_env_dags = client.list_dags(target_environment_name) | ||||||
| target_env_dag_ids = [dag["dag_id"] for dag in target_env_dags] | ||||||
| all_dags_present = set(source_env_dag_ids) == set(target_env_dag_ids) | ||||||
| logger.info("List of DAGs in the target environment: %s", target_env_dag_ids) | ||||||
| missing_dags = set(source_env_dag_ids) - set(target_env_dag_ids) | ||||||
| all_dags_present = not missing_dags | ||||||
| if missing_dags: | ||||||
| logger.info("Waiting for DAGs to appear in target: %s", missing_dags) | ||||||
| else: | ||||||
| logger.info("All DAGs present in target environment.") | ||||||
| time.sleep(10) | ||||||
| # Unpause only DAGs that were not paused in the source environment. | ||||||
| for dag in source_env_dags: | ||||||
| if dag["dag_id"] == "airflow_monitoring": | ||||||
| continue | ||||||
| if dag["is_paused"] == "True": | ||||||
| logger.info("DAG %s was paused in the source environment.", dag["dag_id"]) | ||||||
| continue | ||||||
| logger.info("Unpausing DAG %s in the target environment.", dag["dag_id"]) | ||||||
| client.unpause_dag(dag["dag_id"], target_environment_name) | ||||||
| logger.info("DAG %s unpaused.", dag["dag_id"]) | ||||||
| # Optimization: if all DAGs were unpaused in source, use bulk unpause. | ||||||
| if not any(d["is_paused"] for d in source_env_dags): | ||||||
| logger.info("All DAGs were unpaused in source. Unpausing all DAGs in target.") | ||||||
| client.unpause_all_dags(target_environment_name) | ||||||
| else: | ||||||
| for dag in source_env_dags: | ||||||
| if dag["is_paused"]: | ||||||
| logger.info("DAG %s was paused in the source environment.", dag["dag_id"]) | ||||||
| continue | ||||||
| logger.info("Unpausing DAG %s in the target environment.", dag["dag_id"]) | ||||||
| client.unpause_dag(dag["dag_id"], target_environment_name) | ||||||
| logger.info("DAG %s unpaused.", dag["dag_id"]) | ||||||
| logger.info("DAGs in the target environment unpaused.") | ||||||
|
|
||||||
| logger.info("Migration complete.") | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Airflow REST API's
dag_id_patternparameter expects a glob expression. The%character is not a standard glob wildcard;*should be used to match all DAGs. Using%will likely result in no DAGs being matched, causing this function to fail silently.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry gemini, this suggestion is just incorrect https://github.com/apache/airflow/blob/4404bc05b3e77bf1c50219ba2ec1da5ef560a684/airflow-core/src/airflow/api_fastapi/common/parameters.py#L253 Airflow clearly states regular expressions are not supported and that you should use
%and_wildcards