diff --git a/benchmarks/autosp/.gitignore b/benchmarks/autosp/.gitignore new file mode 100644 index 000000000..1d27669c4 --- /dev/null +++ b/benchmarks/autosp/.gitignore @@ -0,0 +1,5 @@ +*.log +*.pyc +logs +*. +*.pt diff --git a/benchmarks/autosp/README.md b/benchmarks/autosp/README.md new file mode 100644 index 000000000..e892bdd8f --- /dev/null +++ b/benchmarks/autosp/README.md @@ -0,0 +1,87 @@ +# AutoSP Benchmarking Examples + +This directory contains AutoSP benchmarking examples that demonstrate model compilation and optimization techniques using DeepSpeed and HuggingFace Accelerate. The example script show four compilation modes (AutoSP and baselines) for training large language models: + +| Mode | Parallelism Strategy | Execution Backend | +|------|----------------------|-------------------| +| **eager** | Ulysses DistributedAttention | PyTorch Eager | +| **compile** | Ulysses DistributedAttention | PyTorch Inductor | +| **autosp** | Automatic Sequence Parallelism | AutoSP Compiler | +| **ringattn** | RingAttention-style Sequence Parallelism | PyTorch Inductor | + +## Files in this Directory + +- **run.py**: Benchmarking script with an option to choose either of the 4 compilation modes listed above +- **run_autosp.sh**: Launcher script that configures training runs across multiple GPUs using Hugging Face Accelerate +- **sp_dp_registry.py**: Sequence Parallel and Data Parallel mesh management utilities +- **distributed_attention.py**: Ulysses-styled sequence paralllelism which can be plugged in as an attention backend for HuggingFace +- **ring_attention.py**: Ring Attention algorithm implementation which can be plugged in as an attention backend for HuggingFace +- **configs/**: Training configuration templates for different model sizes and scenarios +- **correctness/**: Correctness validation suite for AutoSP + - **correctness_run.py**: Runs training for a specific configuration (compile mode, sequence parallel size, ZeRO stage) and saves per-rank losses to a JSON file for comparison + - **correctness.sh**: Launcher script that orchestrates correctness testing across multiple configurations, running both baseline (compiled Ulysses) and AutoSP modes + - **validator.py**: Compares per-rank losses between AutoSP and baseline to verify numerical correctness within a configurable threshold + +## Setup Guide + +Quick start guide to clone and set up the AutoSP repository. + + +### Install dependencies + +```bash +pip install torch==2.7.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128 +``` + +```bash +pip install \ + transformers==4.50.3 \ + tokenizers \ + huggingface-hub \ + safetensors \ + datasets \ + accelerate \ + scipy \ + tqdm \ + pyyaml +``` + +## Benchmarking + +The `benchmarks/autosp/` directory contains for benchmarking scripts: + +```bash +cd benchmarks/autosp +``` + +#### Run autosp on 2 GPUs +```bash +./run_autosp.sh --compile autosp --batch-size 1 --seq-length 64 --sp-size 2 --num-layers 1 --steps 1 --deterministic +``` + +#### Run eager mode ulysses on 2 GPUs +```bash +./run_autosp.sh --compile eager --batch-size 1 --seq-length 64 --sp-size 2 --num-layers 1 --steps 1 --deterministic +``` + +#### Run torch.compile'd ulysses on 2 GPUs +```bash +./run_autosp.sh --compile compile --batch-size 1 --seq-length 64 --sp-size 2 --num-layers 1 --steps 1 --deterministic +``` + +#### Run torch.compile'd ring attention on 2 GPUs +```bash +./run_autosp.sh --compile ringattn --batch-size 1 --seq-length 64 --sp-size 2 --num-layers 1 --steps 1 --deterministic +``` + +## Correctness Testing + +To validate that AutoSP produces numerically correct results matching the baseline, use the correctness test suite: + +```bash +cd correctness +./correctness.sh # Test default sp-sizes: 1, 2, 4, 8 +./correctness.sh 2 4 # Test only sp-sizes 2 and 4 +``` + +This runs training for each configuration with both baseline (compiled Ulysses) and AutoSP modes, then compares per-rank losses to verify correctness. diff --git a/benchmarks/autosp/configs/autosp_config.json b/benchmarks/autosp/configs/autosp_config.json new file mode 100644 index 000000000..d448ae893 --- /dev/null +++ b/benchmarks/autosp/configs/autosp_config.json @@ -0,0 +1,21 @@ +{ + + "bf16": { + "enabled": true + }, + + "zero_optimization": { + "stage": 0 + }, + "compile": { + "deepcompile": true, + "passes": ["autosp"] + }, + "gradient_accumulation_steps": 1, + "gradient_clipping": "auto", + "steps_per_print": 2000, + "train_batch_size": 1, + "train_micro_batch_size_per_gpu": 1, + "wall_clock_breakdown": false, + "sequence_parallel_size": 2 +} diff --git a/benchmarks/autosp/configs/autosp_config.yaml b/benchmarks/autosp/configs/autosp_config.yaml new file mode 100644 index 000000000..5ba20b9a6 --- /dev/null +++ b/benchmarks/autosp/configs/autosp_config.yaml @@ -0,0 +1,16 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + deepspeed_multinode_launcher: standard + deepspeed_config_file: configs/autosp_config.json +distributed_type: DEEPSPEED +machine_rank: 0 +main_training_function: main +num_machines: 1 +num_processes: 2 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/benchmarks/autosp/configs/torchcompile_config.json b/benchmarks/autosp/configs/torchcompile_config.json new file mode 100644 index 000000000..d61b17b9f --- /dev/null +++ b/benchmarks/autosp/configs/torchcompile_config.json @@ -0,0 +1,14 @@ +{ + "bf16": { + "enabled": true + }, + "zero_optimization":{ + "stage": 0 + }, + "gradient_accumulation_steps": 1, + "gradient_clipping": "auto", + "steps_per_print": 2000, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false +} diff --git a/benchmarks/autosp/configs/torchcompile_config.yaml b/benchmarks/autosp/configs/torchcompile_config.yaml new file mode 100644 index 000000000..2e35b1185 --- /dev/null +++ b/benchmarks/autosp/configs/torchcompile_config.yaml @@ -0,0 +1,16 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + deepspeed_multinode_launcher: standard + deepspeed_config_file: configs/torchcompile_config.json +distributed_type: DEEPSPEED +machine_rank: 0 +main_training_function: main +num_machines: 1 +num_processes: 2 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/benchmarks/autosp/correctness/correctness.sh b/benchmarks/autosp/correctness/correctness.sh new file mode 100755 index 000000000..68e80f3f5 --- /dev/null +++ b/benchmarks/autosp/correctness/correctness.sh @@ -0,0 +1,137 @@ +#!/bin/bash + +# Correctness test suite for autosp vs baseline compiled DS-Ulysses. +# +# For each (sp_size, dp_size) x zero_stage configuration: +# 1. Runs baseline (--compile compile) for N steps +# 2. Runs autosp (--compile autosp) for N steps +# 3. Compares per-rank losses with validator.py +# +# Usage: +# ./correctness.sh # Default configs +# ./correctness.sh 2,1 2,2 4,1 # Custom sp,dp pairs + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +OUTPUT_DIR="${SCRIPT_DIR}/output" +STEPS=5 + +cleanup() { + rm -rf "${OUTPUT_DIR}" +} +trap cleanup EXIT + +# Parse sp,dp pairs from positional args (e.g. 2,1 2,2 4,1) +declare -a CONFIGS=() + +if [ $# -gt 0 ]; then + for arg in "$@"; do + if [[ "$arg" =~ ^([0-9]+),([0-9]+)$ ]]; then + CONFIGS+=("$arg") + else + echo "Error: invalid config '${arg}'. Expected format: sp,dp (e.g. 2,1)" + exit 1 + fi + done +else + CONFIGS=("1,1" "2,1" "4,1" "8,1") +fi + +ZERO_STAGES=(0 1) + +PASS_COUNT=0 +FAIL_COUNT=0 +TOTAL_COUNT=0 +declare -a RESULTS=() + +echo "" +echo "================================================================" +echo " AutoSP Correctness Test Suite" +echo "================================================================" +echo " Configs (sp,dp): ${CONFIGS[*]}" +echo " Zero stages: ${ZERO_STAGES[*]}" +echo " Steps: ${STEPS}" +echo " Output dir: ${OUTPUT_DIR}" +echo "================================================================" +echo "" + +for config in "${CONFIGS[@]}"; do + sp_size="${config%%,*}" + dp_size="${config##*,}" + + for zero_stage in "${ZERO_STAGES[@]}"; do + TEST_NAME="sp${sp_size}_dp${dp_size}_zero${zero_stage}" + TEST_DIR="${OUTPUT_DIR}/${TEST_NAME}" + mkdir -p "${TEST_DIR}" + + ((TOTAL_COUNT++)) + + echo "----------------------------------------------------------------" + echo " Test: sp_size=${sp_size}, dp_size=${dp_size}, zero_stage=${zero_stage}" + echo "----------------------------------------------------------------" + + # --- Baseline (compiled DS-Ulysses) --- + echo " [1/3] Running baseline (--compile compile) ..." + if ! python3 "${SCRIPT_DIR}/correctness_run.py" \ + --compile compile \ + --sp-size "${sp_size}" \ + --dp-size "${dp_size}" \ + --zero-stage "${zero_stage}" \ + --steps "${STEPS}" \ + --output-file "${TEST_DIR}/baseline.json"; then + + echo " FAIL: Baseline training failed" + RESULTS+=(" ${TEST_NAME}: FAIL (baseline training error)") + ((FAIL_COUNT++)) + echo "" + continue + fi + + # --- AutoSP --- + echo " [2/3] Running autosp (--compile autosp) ..." + if ! python3 "${SCRIPT_DIR}/correctness_run.py" \ + --compile autosp \ + --sp-size "${sp_size}" \ + --dp-size "${dp_size}" \ + --zero-stage "${zero_stage}" \ + --steps "${STEPS}" \ + --output-file "${TEST_DIR}/autosp.json"; then + + echo " FAIL: AutoSP training failed" + RESULTS+=(" ${TEST_NAME}: FAIL (autosp training error)") + ((FAIL_COUNT++)) + echo "" + continue + fi + + # --- Validate --- + echo " [3/3] Validating per-rank losses ..." + if python3 "${SCRIPT_DIR}/validator.py" \ + --baseline "${TEST_DIR}/baseline.json" \ + --autosp "${TEST_DIR}/autosp.json"; then + + RESULTS+=(" ${TEST_NAME}: PASS") + ((PASS_COUNT++)) + else + RESULTS+=(" ${TEST_NAME}: FAIL") + ((FAIL_COUNT++)) + fi + + echo "" + done +done + +# ---- Summary ---- +echo "================================================================" +echo " SUMMARY" +echo "================================================================" +for result in "${RESULTS[@]}"; do + echo "${result}" +done +echo "" +echo " Passed: ${PASS_COUNT}/${TOTAL_COUNT} Failed: ${FAIL_COUNT}/${TOTAL_COUNT}" +echo "================================================================" + +if [ "${FAIL_COUNT}" -gt 0 ]; then + exit 1 +fi +exit 0 diff --git a/benchmarks/autosp/correctness/correctness_run.py b/benchmarks/autosp/correctness/correctness_run.py new file mode 100644 index 000000000..89f1b6311 --- /dev/null +++ b/benchmarks/autosp/correctness/correctness_run.py @@ -0,0 +1,269 @@ +""" +Runs training for a specific configuration (compile mode, sp_size, dp_size, zero_stage) +and saves per-rank losses to a JSON file. + +Reuses the existing run.py training script with temporary config files, +launching via accelerate in the same way as run_autosp.sh. +""" + +import argparse +import csv +import json +import os +import re +import socket +import subprocess +import sys +import tempfile + + +def get_free_port(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def get_host_ip(): + try: + result = subprocess.run( + ["hostname", "-i"], capture_output=True, text=True, check=True + ) + return result.stdout.strip().split()[0] + except Exception: + return "127.0.0.1" + + +def create_ds_config(compile_mode, sp_size, dp_size, zero_stage, batch_size, config_path): + """Create a DeepSpeed JSON config for the given configuration.""" + total_devices = sp_size * dp_size + train_batch_size = total_devices // sp_size + + config = { + "bf16": {"enabled": True}, + "zero_optimization": {"stage": zero_stage}, + "gradient_accumulation_steps": 1, + "gradient_clipping": "auto", + "steps_per_print": 2000, + "train_batch_size": train_batch_size, + "train_micro_batch_size_per_gpu": 1, + "wall_clock_breakdown": False, + "sequence_parallel_size" : sp_size + } + if compile_mode == "autosp": + config["compile"] = { + "deepcompile": True, + "passes": ["autosp"], + } + with open(config_path, "w") as f: + json.dump(config, f, indent=4) + + +def create_accelerate_config(ds_config_path, num_processes, config_path): + """Create an accelerate YAML config pointing to the DS JSON config.""" + content = ( + "compute_environment: LOCAL_MACHINE\n" + "debug: false\n" + "deepspeed_config:\n" + " deepspeed_multinode_launcher: standard\n" + f" deepspeed_config_file: {ds_config_path}\n" + "distributed_type: DEEPSPEED\n" + "machine_rank: 0\n" + "main_training_function: main\n" + "num_machines: 1\n" + f"num_processes: {num_processes}\n" + "rdzv_backend: static\n" + "same_network: true\n" + "tpu_env: []\n" + "tpu_use_cluster: false\n" + "tpu_use_sudo: false\n" + "use_cpu: false\n" + ) + with open(config_path, "w") as f: + f.write(content) + + +def parse_losses_from_csv(logs_dir, compile_mode, seq_length, num_processes): + """Read per-rank loss CSV files written by run.py (full precision).""" + losses = {} + for rank in range(num_processes): + csv_path = os.path.join( + logs_dir, f"loss_{compile_mode}_seq{seq_length}_rank{rank}.csv" + ) + if not os.path.exists(csv_path): + continue + rank_losses = {} + with open(csv_path, "r") as f: + reader = csv.DictReader(f) + for row in reader: + rank_losses[str(row["step"])] = float(row["loss"]) + losses[str(rank)] = rank_losses + return losses + + +def parse_losses_from_stdout(output): + """Fallback: parse loss values from the printed training output.""" + losses = {} + for line in output.split("\n"): + match = re.search(r"\[Rank (\d+)\].*Step (\d+), Loss: ([\d.]+)", line) + if match: + rank, step = match.group(1), match.group(2) + loss = float(match.group(3)) + losses.setdefault(rank, {})[step] = loss + return losses + + +def cleanup_csv_files(logs_dir, compile_mode, seq_length, num_processes): + """Remove loss CSV files created by run.py during training.""" + for rank in range(num_processes): + csv_path = os.path.join( + logs_dir, f"loss_{compile_mode}_seq{seq_length}_rank{rank}.csv" + ) + try: + os.remove(csv_path) + except FileNotFoundError: + pass + + +def main(): + parser = argparse.ArgumentParser( + description="Run training and capture per-rank losses" + ) + parser.add_argument("--compile", choices=["compile", "autosp"], required=True) + parser.add_argument("--sp-size", type=int, required=True) + parser.add_argument("--dp-size", type=int, default=1) + parser.add_argument("--zero-stage", type=int, choices=[0, 1], required=True) + parser.add_argument("--steps", type=int, default=5) + parser.add_argument("--output-file", type=str, required=True) + parser.add_argument("--seq-length", type=int, default=64) + parser.add_argument("--batch-size", type=int, default=1) + parser.add_argument("--num-layers", type=int, default=1) + parser.add_argument("--verbose", action="store_true") + args = parser.parse_args() + + num_processes = args.sp_size * args.dp_size + + script_dir = os.path.dirname(os.path.abspath(__file__)) + autosp_dir = os.path.abspath(os.path.join(script_dir, "..")) + run_py = os.path.join(autosp_dir, "run.py") + logs_dir = os.path.join(autosp_dir, "logs") + + output_dir = os.path.dirname(os.path.abspath(args.output_file)) + os.makedirs(output_dir, exist_ok=True) + + with tempfile.TemporaryDirectory() as tmpdir: + ds_config_path = os.path.join(tmpdir, "ds_config.json") + accel_config_path = os.path.join(tmpdir, "accelerate_config.yaml") + + create_ds_config( + args.compile, args.sp_size, args.dp_size, + args.zero_stage, args.batch_size, ds_config_path, + ) + create_accelerate_config(ds_config_path, num_processes, accel_config_path) + + host_ip = get_host_ip() + port = get_free_port() + + cmd = [ + "accelerate", "launch", + "--main_process_ip", host_ip, + "--main_process_port", str(port), + "--num_machines", "1", + "--num_processes", str(num_processes), + "--machine_rank", "0", + "--config_file", accel_config_path, + run_py, + "--model_name", "meta-llama/Llama-2-7b-chat-hf", + "--batch_size", str(args.batch_size), + "--seq_length", str(args.seq_length), + "--sp_size", str(args.sp_size), + "--dp_size", str(args.dp_size), + "--backend", "inductor", + "--compile", args.compile, + "--num_layers", str(args.num_layers), + "--steps", str(args.steps), + "--deterministic", + ] + + env = os.environ.copy() + env["NCCL_DEBUG"] = "WARN" + + output = "" + stderr_output = "" + + if args.verbose: + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + cwd=autosp_dir, + env=env, + ) + for line in process.stdout: + output += line + sys.stdout.write(line) + sys.stdout.flush() + process.wait() + return_code = process.returncode + else: + result = subprocess.run( + cmd, capture_output=True, text=True, cwd=autosp_dir, env=env + ) + output = result.stdout + stderr_output = result.stderr + return_code = result.returncode + + # Save training log for debugging + log_path = args.output_file.replace(".json", ".log") + with open(log_path, "w") as f: + f.write(f"Command: {' '.join(cmd)}\n") + f.write(f"Return code: {return_code}\n") + f.write("=" * 60 + "\n") + f.write(output) + if stderr_output: + f.write("\n--- STDERR ---\n") + f.write(stderr_output) + + if return_code != 0: + print(f" Training failed (exit code {return_code}). See: {log_path}") + if not args.verbose: + lines = (output + stderr_output).strip().split("\n") + for line in lines[-30:]: + print(f" {line}") + cleanup_csv_files(logs_dir, args.compile, args.seq_length, num_processes) + sys.exit(1) + + losses = parse_losses_from_csv( + logs_dir, args.compile, args.seq_length, num_processes + ) + cleanup_csv_files(logs_dir, args.compile, args.seq_length, num_processes) + + if not losses: + print(" Warning: CSV loss files not found, falling back to stdout parsing") + losses = parse_losses_from_stdout(output) + + if not losses: + print(" Error: No losses found in training output") + sys.exit(1) + + result_data = { + "config": { + "compile": args.compile, + "sp_size": args.sp_size, + "dp_size": args.dp_size, + "zero_stage": args.zero_stage, + "steps": args.steps, + }, + "losses": losses, + } + + with open(args.output_file, "w") as f: + json.dump(result_data, f, indent=2) + + num_ranks = len(losses) + num_steps = max(len(v) for v in losses.values()) + print(f" Losses saved: {num_ranks} rank(s), {num_steps} step(s) -> {args.output_file}") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/autosp/correctness/validator.py b/benchmarks/autosp/correctness/validator.py new file mode 100644 index 000000000..b7fff2b57 --- /dev/null +++ b/benchmarks/autosp/correctness/validator.py @@ -0,0 +1,94 @@ +""" +Validates that per-rank losses from autosp match the baseline (compiled DS-Ulysses) +within a configurable threshold. + +Reads two JSON files produced by correctness_run.py and compares them element-wise. +""" + +import argparse +import json +import sys + + +def validate(baseline_path, autosp_path, threshold): + with open(baseline_path) as f: + baseline_data = json.load(f) + with open(autosp_path) as f: + autosp_data = json.load(f) + + baseline_losses = baseline_data["losses"] + autosp_losses = autosp_data["losses"] + + baseline_ranks = sorted(baseline_losses.keys(), key=int) + autosp_ranks = sorted(autosp_losses.keys(), key=int) + + if baseline_ranks != autosp_ranks: + print( + f" FAIL: Rank mismatch — " + f"baseline has ranks {baseline_ranks}, autosp has ranks {autosp_ranks}" + ) + return False + + all_pass = True + max_diff = 0.0 + mismatches = [] + + for rank in baseline_ranks: + bl_steps = baseline_losses[rank] + asp_steps = autosp_losses[rank] + + all_steps = sorted(set(bl_steps.keys()) | set(asp_steps.keys()), key=int) + for step in all_steps: + if step not in bl_steps: + mismatches.append(f" Rank {rank}, Step {step}: missing in baseline") + all_pass = False + continue + if step not in asp_steps: + mismatches.append(f" Rank {rank}, Step {step}: missing in autosp") + all_pass = False + continue + + bl_val = bl_steps[step] + asp_val = asp_steps[step] + diff = abs(bl_val - asp_val) + max_diff = max(max_diff, diff) + + if diff > threshold: + mismatches.append( + f" Rank {rank}, Step {step}: " + f"baseline={bl_val:.6f}, autosp={asp_val:.6f}, diff={diff:.6e}" + ) + all_pass = False + + if all_pass: + print(f" PASS (max diff: {max_diff:.6e}, threshold: {threshold:.6e})") + else: + print(f" FAIL (max diff: {max_diff:.6e}, threshold: {threshold:.6e})") + for m in mismatches: + print(m) + + return all_pass + + +def main(): + parser = argparse.ArgumentParser( + description="Validate autosp losses against baseline" + ) + parser.add_argument( + "--baseline", required=True, help="Path to baseline losses JSON" + ) + parser.add_argument("--autosp", required=True, help="Path to autosp losses JSON") + parser.add_argument( + "--threshold", + type=float, + default=1e-2, + help="Maximum allowed absolute difference per loss value (default: 1e-2)", + ) + args = parser.parse_args() + + passed = validate(args.baseline, args.autosp, args.threshold) + sys.exit(0 if passed else 1) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/autosp/distributed_attention.py b/benchmarks/autosp/distributed_attention.py new file mode 100644 index 000000000..b9f9667e1 --- /dev/null +++ b/benchmarks/autosp/distributed_attention.py @@ -0,0 +1,93 @@ +import os +import torch +import torch.distributed as dist +from deepspeed.sequence.layer import DistributedAttention +from sp_dp_registry import get_group, is_setup, sp_size + +#TODO: See if there is a better way to pass the mask +_padding_mask_context = None + +def set_padding_mask(mask): + global _padding_mask_context + _padding_mask_context = mask + +def get_padding_mask(): + global _padding_mask_context + return _padding_mask_context + +def ulysses_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + scaling=None, + dropout=0.0, + is_causal=False, + **kwargs, +): + assert is_setup(), 'Incorrectly setup SP/DP Groups.' + + gid = dist.get_rank() // sp_size() + group = get_group(gid) + + # Ulysses expects (batch, seq, heads, dim) + # HF standard provides (batch, heads, seq, dim) + q = query_states.transpose(1, 2).contiguous() + k = key_states.transpose(1, 2).contiguous() + v = value_states.transpose(1, 2).contiguous() + + if not hasattr(self, "ulysses_engine"): + self.ulysses_engine = DistributedAttention( + sdpa_wrapper, + group, + scatter_idx=2, # Shard heads + gather_idx=1 # Gather sequences + ) + + attn_output = self.ulysses_engine( + q, k, v, + batch_dim_idx=0, + attn_mask=None, + dropout_p=dropout, + is_causal=is_causal, + scale=scaling + ) + + return attn_output, None + +def sdpa_wrapper(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=True, scale=None): + # Permute from [b, s, n, h] to [b, n, s, h] for SDPA + q = query.permute(0, 2, 1, 3).contiguous() + k = key.permute(0, 2, 1, 3).contiguous() + v = value.permute(0, 2, 1, 3).contiguous() + + # Create the attention mask from padding mask + causal mask + padding_mask = get_padding_mask() + combined_mask = None + + if padding_mask is not None: + B, S = padding_mask.shape # [B, S] + device = padding_mask.device + + causal_mask = torch.tril(torch.ones(S, S, device=device, dtype=torch.bool)) + padding_mask_bool = (padding_mask != 0).unsqueeze(1) + causal_expanded = causal_mask.unsqueeze(0) + combined_mask = causal_expanded & padding_mask_bool + combined_mask = combined_mask.unsqueeze(1) + + elif is_causal: + pass + + output = torch.nn.functional.scaled_dot_product_attention( + q, k, v, + attn_mask=combined_mask, + dropout_p=dropout_p, + is_causal=(combined_mask is None and is_causal), + scale=scale, + enable_gqa=False + ) + + # Permute back from [b, n, s, h] to [b, s, n, h] for all-to-all on output + output = output.permute(0, 2, 1, 3).contiguous() + return output diff --git a/benchmarks/autosp/ring_attention.py b/benchmarks/autosp/ring_attention.py new file mode 100644 index 000000000..7b01da7b9 --- /dev/null +++ b/benchmarks/autosp/ring_attention.py @@ -0,0 +1,530 @@ +## Code is taken directly from the RingFlashAttention +## repository: https://github.com/zhuzilin/ring-flash-attention +from typing import Optional, Tuple + +import torch +import torch.distributed as dist +import torch.nn.functional as F +import inspect +from functools import cache + +from sp_dp_registry import get_group, is_setup, sp_size +from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward + +__all__ = ["update_out_and_lse", "RingComm", "get_default_args"] + +## Utility communication files. ## +@cache +def _get_default_args(func): + spec = inspect.getfullargspec(func) + defaults = spec.defaults if spec.defaults is not None else () + padded_defaults = (None,) * (len(spec.args) - len(defaults)) + defaults + args = dict(zip(spec.args, padded_defaults)) + if "softcap" in args: + args["softcap"] = 0.0 + return args + + +def get_default_args(func): + if inspect.isfunction(func): + return _get_default_args(func) + else: + # Use the origin _init_fn in CustomOpDef + return _get_default_args(func._init_fn) + + +@torch.jit.script +def _update_out_and_lse( + out: torch.Tensor, + lse: torch.Tensor, + block_out: torch.Tensor, + block_lse: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + + block_out = block_out.to(torch.float32) + block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) + + # new_lse = lse + torch.log(1 + torch.exp(block_lse - lse)) + # torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out + # For additional context and discussion, please refer to: + # https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 + out = out - F.sigmoid(block_lse - lse) * (out - block_out) + lse = lse - F.logsigmoid(lse - block_lse) + + return out, lse + + +def update_out_and_lse( + out: Optional[torch.Tensor], + lse: Optional[torch.Tensor], + block_out: torch.Tensor, + block_lse: torch.Tensor, + slice_=None, +) -> Tuple[torch.Tensor, torch.Tensor]: + if out is None: + if slice_ is not None: + raise RuntimeError("first update_out_and_lse should not pass slice_ args") + out = block_out.to(torch.float32) + lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) + elif slice_ is not None: + slice_out, slice_lse = out[slice_], lse[slice_] + slice_out, slice_lse = _update_out_and_lse( + slice_out, slice_lse, block_out, block_lse + ) + out[slice_], lse[slice_] = slice_out, slice_lse + else: + out, lse = _update_out_and_lse(out, lse, block_out, block_lse) + return out, lse + + +@torch.jit.script +def flatten_varlen_lse(lse, cu_seqlens): + new_lse = [] + for i in range(len(cu_seqlens) - 1): + start, end = cu_seqlens[i], cu_seqlens[i + 1] + new_lse.append(lse[i, :, : end - start]) + return torch.cat(new_lse, dim=1) + + +@torch.jit.script +def unflatten_varlen_lse(lse, cu_seqlens, max_seqlen: int): + num_seq = len(cu_seqlens) - 1 + num_head = lse.shape[-2] + new_lse = torch.empty( + (num_seq, max_seqlen, num_head, 1), dtype=torch.float32, device=lse.device + ) + for i in range(num_seq): + start, end = cu_seqlens[i], cu_seqlens[i + 1] + new_lse[i, : end - start] = lse[start:end] + return new_lse.squeeze(dim=-1).transpose(1, 2).contiguous() + + +class RingComm: + def __init__(self, process_group: dist.ProcessGroup): + self._process_group = process_group + self._ops = [] + self.rank = dist.get_rank(self._process_group) + self.world_size = dist.get_world_size(self._process_group) + self._reqs = None + + self.send_rank = (self.rank + 1) % self.world_size + self.recv_rank = (self.rank - 1) % self.world_size + + if process_group is not None: + self.send_rank = dist.get_global_rank(self._process_group, self.send_rank) + self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank) + + def send_recv( + self, to_send: torch.Tensor, recv_tensor: Optional[torch.Tensor] = None + ) -> torch.Tensor: + if recv_tensor is None: + res = torch.empty_like(to_send) + else: + res = recv_tensor + + send_op = dist.P2POp( + dist.isend, to_send, self.send_rank, group=self._process_group + ) + recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group) + self._ops.append(send_op) + self._ops.append(recv_op) + return res + + def commit(self): + if self._reqs is not None: + raise RuntimeError("commit called twice") + self._reqs = dist.batch_isend_irecv(self._ops) + + def wait(self): + if self._reqs is None: + raise RuntimeError("wait called before commit") + for req in self._reqs: + req.wait() + self._reqs = None + self._ops = [] + + def send_recv_kv( + self, + k: torch.Tensor, + v: torch.Tensor, + k_buffer: Optional[torch.Tensor] = None, + v_buffer: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + next_k, next_v = self.send_recv(k, k_buffer), self.send_recv(v, v_buffer) + self.commit() + return next_k, next_v + + +class AllGatherComm: + def __init__(self, group=None) -> None: + self.group = group + self.handles = [] + + def all_gather(self, output_tensor: torch.Tensor, input_tensor: torch.Tensor): + handle = dist.all_gather_into_tensor( + output_tensor, input_tensor, group=self.group, async_op=True + ) + self.handles.append(handle) + + def wait(self): + for handle in self.handles: + handle.wait() + self.handles = [] + + +def ring_flash_attn_forward( + process_group, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + comm = RingComm(process_group) + + out = None + lse = None + + next_k, next_v = None, None + + for step in range(comm.world_size): + if step + 1 != comm.world_size: + next_k, next_v = comm.send_recv_kv(k, v) + + if not causal or step <= comm.rank: + params = get_default_args(_flash_attn_forward).copy() + params.update( + { + "q": q, + "k": k, + "v": v, + "dropout_p": dropout_p, + "softmax_scale": softmax_scale, + "causal": causal and step == 0, + "alibi_slopes": alibi_slopes, + "return_softmax": True and dropout_p > 0, + } + ) + if "window_size" in params: + params.update({"window_size": window_size}) + else: + params.update( + { + "window_size_left": window_size[0], + "window_size_right": window_size[1], + } + ) + outputs = _flash_attn_forward(**params) + if len(outputs) == 8: + block_out, _, _, _, _, block_lse, _, _ = outputs + else: + assert len(outputs) == 4 + block_out, block_lse, _, _ = outputs + out, lse = update_out_and_lse(out, lse, block_out, block_lse) + + if step + 1 != comm.world_size: + comm.wait() + k, v = next_k, next_v + + out = out.to(q.dtype) + lse = lse.squeeze(dim=-1).transpose(1, 2) + return out, lse + + +def ring_flash_attn_backward( + process_group, + dout, + q, + k, + v, + out, + softmax_lse, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + kv_comm = RingComm(process_group) + d_kv_comm = RingComm(process_group) + dq, dk, dv = None, None, None + next_dk, next_dv = None, None + + block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) + block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) + block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) + + next_dk, next_dv = None, None + next_k, next_v = None, None + + for step in range(kv_comm.world_size): + if step + 1 != kv_comm.world_size: + next_k, next_v = kv_comm.send_recv_kv(k, v) + + if step <= kv_comm.rank or not causal: + bwd_causal = causal and step == 0 + params = get_default_args(_flash_attn_backward).copy() + params.update( + { + "dout": dout, + "q": q, + "k": k, + "v": v, + "out": out, + "softmax_lse": softmax_lse, + "dq": block_dq_buffer, + "dk": block_dk_buffer, + "dv": block_dv_buffer, + "dropout_p": dropout_p, + "softmax_scale": softmax_scale, + "causal": bwd_causal, + "alibi_slopes": alibi_slopes, + "deterministic": deterministic, + } + ) + if "window_size" in params: + params.update({"window_size": window_size}) + else: + params.update( + { + "window_size_left": window_size[0], + "window_size_right": window_size[1], + } + ) + _flash_attn_backward(**params) + + if dq is None: + dq = block_dq_buffer.to(torch.float32) + dk = block_dk_buffer.to(torch.float32) + dv = block_dv_buffer.to(torch.float32) + else: + dq += block_dq_buffer + d_kv_comm.wait() + dk = block_dk_buffer + next_dk + dv = block_dv_buffer + next_dv + elif step != 0: + d_kv_comm.wait() + dk, dv = next_dk, next_dv + + if step + 1 != kv_comm.world_size: + kv_comm.wait() + k, v = next_k, next_v + + next_dk, next_dv = d_kv_comm.send_recv_kv(dk, dv) + + d_kv_comm.wait() + + return dq.to(torch.bfloat16), next_dk.to(q.dtype), next_dv.to(q.dtype) + + +class RingFlashAttnFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_softmax, + group, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + assert alibi_slopes is None + k = k.contiguous() + v = v.contiguous() + out, softmax_lse = ring_flash_attn_forward( + group, + q, + k, + v, + softmax_scale=softmax_scale, + dropout_p=dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=False, + ) + # this should be out_padded + ctx.save_for_backward(q, k, v, out, softmax_lse) + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + ctx.group = group + return out if not return_softmax else (out, softmax_lse, None) + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse = ctx.saved_tensors + dq, dk, dv = ring_flash_attn_backward( + ctx.group, + dout, + q, + k, + v, + out, + softmax_lse, + softmax_scale=ctx.softmax_scale, + dropout_p=ctx.dropout_p, + causal=ctx.causal, + window_size=ctx.window_size, + alibi_slopes=ctx.alibi_slopes, + deterministic=ctx.deterministic, + ) + return dq, dk, dv, None, None, None, None, None, None, None, None + + +def ring_flash_attn_qkvpacked_func( + qkv, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return RingFlashAttnFunc.apply( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def ring_flash_attn_kvpacked_func( + q, + kv, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return RingFlashAttnFunc.apply( + q, + kv[:, :, 0], + kv[:, :, 1], + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def ring_flash_attn_func( + q, + k, + v, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return RingFlashAttnFunc.apply( + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +# HuggingFace-compatible wrapper for ring attention +# This follows the same pattern as ulysses_attention_forward in distributed_attention.py +def ring_attention_forward( + self, # This will be the LlamaAttention instance + query_states, + key_states, + value_states, + attention_mask=None, + scaling=None, + dropout=0.0, + is_causal=True, + **kwargs, +): + """ + Ring attention forward pass compatible with HuggingFace's attention interface. + + Args: + self: The LlamaAttention module instance + query_states: (batch, heads, seq, dim) - HuggingFace format + key_states: (batch, heads, seq, dim) - HuggingFace format + value_states: (batch, heads, seq, dim) - HuggingFace format + attention_mask: Not used (ring attention handles masking internally) + scaling: Softmax scaling factor + dropout: Dropout probability + is_causal: Whether to use causal masking + **kwargs: Additional arguments (ignored) + + Returns: + tuple: (attn_output, None) where attn_output is (batch, seq, heads, dim) + """ + # Convert from HF format (batch, heads, seq, dim) to flash_attn format (batch, seq, heads, dim) + assert is_setup(), 'Incorrectly setup SP/DP Groups.' + + gid = dist.get_rank() // sp_size() + group = get_group(gid) + + q = query_states.transpose(1, 2).contiguous() + k = key_states.transpose(1, 2).contiguous() + v = value_states.transpose(1, 2).contiguous() + + # Ring attention expects (batch, seq, heads, dim) + # Call the ring flash attention function + attn_output = ring_flash_attn_func( + q, + k, + v, + dropout_p=dropout, + softmax_scale=scaling, + causal=is_causal, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=group, + ) + + # Output is already in (batch, seq, heads, dim) format, which HF expects after attention + # Note: Llama's forward handles the reshape internally + return attn_output, None diff --git a/benchmarks/autosp/run.py b/benchmarks/autosp/run.py new file mode 100644 index 000000000..35aac1e5d --- /dev/null +++ b/benchmarks/autosp/run.py @@ -0,0 +1,358 @@ +import os + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +import argparse +import random +import time +from datetime import datetime + +import numpy as np +import torch +import torch.distributed as dist +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from accelerate import Accelerator +from datasets import load_dataset +from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, enable_full_determinism + +from deepspeed.compile.passes.sp_compile import prepare_autosp_inputs + +from distributed_attention import ulysses_attention_forward, set_padding_mask +# from ring_attention import ring_attention_forward +from sp_dp_registry import get_group, populate_registry, get_registry + +torch.set_float32_matmul_precision("high") + +def set_seed(seed: int = 42): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + +def seed_worker(worker_id): + worker_seed = 12 + worker_id + np.random.seed(worker_seed) + random.seed(worker_seed) + +def get_args(): + parser = argparse.ArgumentParser( + description="AutoSP benchmark script for distributed sequence parallel training", + formatter_class=argparse.RawDescriptionHelpFormatter + ) + + parser.add_argument( + "--model_name", + type=str, + default="meta-llama/Llama-2-7b-hf", + help="HuggingFace model name or path" + ) + parser.add_argument( + "--batch_size", + type=int, + default=1, + help="Batch size per GPU" + ) + parser.add_argument( + "--num_epochs", + type=int, + default=1, + help="Number of training epochs" + ) + parser.add_argument( + "--seq_length", + type=int, + default=512, + help="Sequence length for training" + ) + parser.add_argument( + "--steps", + type=int, + default=1, + help="Total training steps" + ) + parser.add_argument( + "--learning_rate", + type=float, + default=2e-5, + help="Learning rate for optimizer" + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Gradient accumulation steps" + ) + parser.add_argument( + "--activation_checkpointing", + action="store_true", + help="Enable gradient checkpointing" + ) + parser.add_argument( + "--dataset_name", + type=str, + default="timdettmers/openassistant-guanaco", + help="HuggingFace dataset name" + ) + parser.add_argument( + "--num_layers", + type=int, + default=None, + help="Number of transformer layers (None means use full model)" + ) + + # Compilation arguments + parser.add_argument( + "--compile", + type=str, + default="autosp", + choices=["eager", "compile", "autosp", "ringattn"], + help="Compilation mode: eager (no compilation), compile (torch.compile), autosp (AutoSP), ringattn (ring attention)" + ) + parser.add_argument( + "--backend", + type=str, + default="inductor", + help="Backend compiler (e.g., inductor, cudagraph)" + ) + + parser.add_argument( + "--deterministic", + action="store_true", + help="Enable deterministic mode for reproducibility" + ) + + parser.add_argument( + "--print_interval", + type=int, + default=1, + help="Interval for printing metrics" + ) + + parser.add_argument( + "--sp_size", + type=int, + default=2, + help="Sequence parallel size" + ) + parser.add_argument( + "--dp_size", + type=int, + default=1, + help="Data parallel size" + ) + + return parser.parse_args() + +def validate_args(args): + valid_compile_modes = ["eager", "compile", "autosp", "ringattn"] + if args.compile not in valid_compile_modes: + raise ValueError( + f"Invalid compile mode: {args.compile}. " + f"Must be one of {valid_compile_modes}" + ) + + if args.sp_size <= 0 or args.dp_size <= 0: + raise ValueError("sp_size and dp_size must be positive integers") + + if args.seq_length <= 0: + raise ValueError("seq_length must be positive") + + +def print_rank_0(accelerator, *args, **kwargs): + """Print only on main process (rank 0).""" + if accelerator.is_main_process: + print(*args, **kwargs) + + +def main(): + args = get_args() + validate_args(args) + set_seed(12) + + if args.deterministic: + enable_full_determinism(12) + from torch._inductor import config + config.fallback_random = True + torch.use_deterministic_algorithms(True) + + accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps) + device = accelerator.device + assert accelerator.num_processes == args.sp_size * args.dp_size, 'Incorrect dp/sp sizing' + + print_rank_0(accelerator, "\n" + "="*60) + print_rank_0(accelerator, "AutoSP Benchmark Configuration") + print_rank_0(accelerator, "="*60) + print_rank_0(accelerator, f"Model: {args.model_name}") + print_rank_0(accelerator, f"Compile Mode: {args.compile}") + print_rank_0(accelerator, f"Backend: {args.backend}") + print_rank_0(accelerator, f"Sequence Parallel Size: {args.sp_size}") + print_rank_0(accelerator, f"Data Parallel Size: {args.dp_size}") + print_rank_0(accelerator, f"Total Processes: {accelerator.num_processes}") + print_rank_0(accelerator, f"Batch Size: {args.batch_size}") + print_rank_0(accelerator, f"Sequence Length: {args.seq_length}") + print_rank_0(accelerator, f"Num Layers: {args.num_layers if args.num_layers else 'Full model'}") + print_rank_0(accelerator, f"Deterministic: {args.deterministic}") + print_rank_0(accelerator, f"Activation Checkpointing: {args.activation_checkpointing}") + print_rank_0(accelerator, f"Learning Rate: {args.learning_rate}") + print_rank_0(accelerator, f"Gradient Accumulation Steps: {args.gradient_accumulation_steps}") + print_rank_0(accelerator, "="*60 + "\n") + + ## Set sp/dp groups accordingly. + if args.compile in ['compile', 'eager', 'ringattn']: + populate_registry(args.sp_size, args.dp_size) + + print_rank_0(accelerator, "Loading model and tokenizer...") + + model_name = args.model_name + if args.compile == "autosp": + attention_backend = "sdpa" + else: + if args.compile == "eager" or args.compile == "compile": + from transformers.models.llama import modeling_llama + attention_backend = "ulyssess" + modeling_llama.ALL_ATTENTION_FUNCTIONS["ulyssess"] = ulysses_attention_forward + elif args.compile == "ringattn": + from transformers.models.llama import modeling_llama + attention_backend = "ringattn" + modeling_llama.ALL_ATTENTION_FUNCTIONS["ringattn"] = ring_attention_forward + + if args.num_layers is not None: + model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + print_rank_0(accelerator, f"num_hidden_layers: {model_config.num_hidden_layers} -> {args.num_layers}") + model_config.num_hidden_layers = args.num_layers + model_config._attn_implementation = attention_backend + model = AutoModelForCausalLM.from_config(model_config, trust_remote_code=True) + else: + model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + model_config._attn_implementation = attention_backend + model = AutoModelForCausalLM.from_pretrained(model_name, config=model_config, trust_remote_code=True) + + if args.activation_checkpointing: + model.gradient_checkpointing_enable() + + print_rank_0(accelerator, "Loading tokenizer...") + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + print_rank_0(accelerator, "Loading dataset...") + + g = torch.Generator() + g.manual_seed(12) + dataset = load_dataset('ag_news', split='train[:1%]') + + def tokenize_function(examples): + return tokenizer(examples['text'], padding='max_length', max_length=args.seq_length, truncation=True) + + tokenized_dataset = dataset.map(tokenize_function, batched=True) + tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask']) + + num_replicas_ = args.dp_size + rank_ = accelerator.process_index // args.sp_size + + sampler = DistributedSampler(tokenized_dataset, num_replicas=num_replicas_, rank=rank_, seed=12, shuffle=False) + data_loader = DataLoader(tokenized_dataset, batch_size=args.batch_size, sampler=sampler, num_workers=4, worker_init_fn=seed_worker, generator=g) + + optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate) + + model, optimizer, data_loader = accelerator.prepare(model, optimizer, data_loader) + print_rank_0(accelerator, f"Model prepared: {model.__class__}") + + if args.compile == "autosp": + print_rank_0(accelerator, f"Running autosp with backend={args.backend}") + torch._dynamo.config.capture_dynamic_output_shape_ops = True + torch._dynamo.config.capture_scalar_outputs = True + model.compile(backend=args.backend) + elif args.compile in ["compile", "ringattn"]: + print_rank_0(accelerator, f"Running torch.compile with backend={args.backend}") + torch._dynamo.config.capture_dynamic_output_shape_ops = True + torch._dynamo.config.capture_scalar_outputs = True + model = torch.compile(model, backend=args.backend) + else: + print_rank_0(accelerator, f"Running eager") + + timestamp = datetime.now().strftime("%Y%m%d%H%M%S") + model_name = args.model_name.split("/")[-1] + exp_name = f"{model_name}_np{accelerator.num_processes}_{args.compile}_" \ + f"B{args.backend}_" \ + f"L{0 if args.num_layers is None else args.num_layers}_" \ + f"bs{args.batch_size}_seq{args.seq_length}_" \ + f"T{timestamp}" + + model.train() + global_step = 0 + print_rank_0(accelerator, f"Using global sequence length: {args.seq_length}") + + os.makedirs("logs", exist_ok=True) + loss_log_file = open(f"logs/loss_{args.compile}_seq{args.seq_length}_rank{accelerator.process_index}.csv", "w") + loss_log_file.write("step,loss\n") + + sp_rank = dist.get_rank() % args.sp_size + for epoch in range(args.num_epochs): + start_iter = time.time() + + for step, batch in enumerate(data_loader): + input_ids = batch['input_ids'].to(device) + B, S = input_ids.shape + + label_ids = input_ids.clone() + position_ids = torch.arange(S, device=device).unsqueeze(0) + attention_mask = batch['attention_mask'].to(device) + + if args.compile == 'autosp': + # prepare inputs for autosp + input_ids, label_ids, position_ids, attention_mask = prepare_autosp_inputs( + input_ids, label_ids, position_ids, attention_mask, seq_dim=1 + ) + else: + chunk_size = S // args.sp_size + start = sp_rank * chunk_size + end = start + chunk_size + input_ids = input_ids[:, start:end] + label_ids = label_ids[:, start:end] + position_ids = position_ids[:, start:end] + + # Store the padding mask to be accessed directly in local attention + set_padding_mask(attention_mask) + + outputs = model( + input_ids=input_ids, + labels=label_ids, + position_ids=position_ids, + attention_mask=attention_mask + ) + loss = outputs.loss + + elapsed_time = time.time() - start_iter + alloc_mem_gb = torch.cuda.memory_allocated() / (1024 ** 3) + peak_mem_gb = torch.cuda.max_memory_allocated() / (1024 ** 3) + + if global_step % args.print_interval == 0: + print( + f"[Rank {accelerator.process_index}] Epoch {epoch+1}, Step {global_step}, Loss: {loss.item():.4f}, " + f"Time: {elapsed_time:.2f}s, " + f"Alloc Mem: {alloc_mem_gb:.2f} GB, " + f"Peak Mem: {peak_mem_gb:.2f} GB" + ) + + accelerator.backward(loss) + + loss_log_file.write(f"{global_step},{loss.item()}\n") + loss_log_file.flush() + + global_step += 1 + if global_step > args.steps: + break + +if __name__ == "__main__": + torch._dynamo.config.accumulated_cache_size_limit = 256 + torch._dynamo.config.cache_size_limit = 128 + torch._dynamo.config.optimize_ddp = False + + main() + diff --git a/benchmarks/autosp/run_autosp.sh b/benchmarks/autosp/run_autosp.sh new file mode 100755 index 000000000..51697004b --- /dev/null +++ b/benchmarks/autosp/run_autosp.sh @@ -0,0 +1,120 @@ +#!/bin/bash + +# Default parameters +MODEL="meta-llama/Llama-2-7b-chat-hf" +COMPILE="eager" +BACKEND="inductor" +SP_SIZE=2 +DP_SIZE=1 +BATCH_SIZE=1 +SEQ_LENGTH=64 +EXTRA_OPTS="" + +while [[ $# -gt 0 ]]; do + case $1 in + --host-ip) + HOST_IP="$2" + shift 2 + ;; + --model) + MODEL="$2" + shift 2 + ;; + --compile) + COMPILE="$2" + shift 2 + ;; + --backend) + BACKEND="$2" + shift 2 + ;; + --batch-size) + BATCH_SIZE="$2" + shift 2 + ;; + --seq-length) + SEQ_LENGTH="$2" + shift 2 + ;; + --sp-size) + SP_SIZE="$2" + shift 2 + ;; + --dp-size) + DP_SIZE="$2" + shift 2 + ;; + --num-layers) + EXTRA_OPTS="${EXTRA_OPTS} --num_layers $2" + shift 2 + ;; + *) + EXTRA_OPTS="${EXTRA_OPTS} $1" + shift + ;; + esac +done + +if [[ "$COMPILE" != "eager" && "$COMPILE" != "compile" && "$COMPILE" != "autosp" && "$COMPILE" != "ringattn" ]]; then + echo "Invalid compile mode: $COMPILE. Choose from eager, compile, autosp, ringattn." + exit 1 +fi + +if [[ -z "${HOST_IP}" ]]; then + HOST_IP=$(hostname -i | awk '{print $1}') +fi + +PORT=$(python3 -c "import socket; s = socket.socket(); s.bind(('', 0)); print(s.getsockname()[1]); s.close()") + +NUM_PROCESSES=$((SP_SIZE * DP_SIZE)) + +CONFIG_FILE="configs/torchcompile_config.yaml" +if [ "${COMPILE}" == "autosp" ]; then + CONFIG_FILE="configs/autosp_config.yaml" +fi + +mkdir -p logs + +# Generate timestamp for log file +TIMESTAMP=$(date +"%Y%m%d_%H%M%S") +LOG_FILE=logs/log_${COMPILE}_sp${SP_SIZE}_dp${DP_SIZE}_seq${SEQ_LENGTH}_${TIMESTAMP}.log + +# Print configuration +echo "" +echo "================================================================" +echo "Configuration" +echo "================================================================" +echo "HOST_IP: ${HOST_IP}" +echo "PORT: ${PORT}" +echo "NUM_PROCESSES: ${NUM_PROCESSES}" +echo "MODEL: ${MODEL}" +echo "COMPILE: ${COMPILE}" +echo "BACKEND: ${BACKEND}" +echo "SP_SIZE: ${SP_SIZE}" +echo "DP_SIZE: ${DP_SIZE}" +echo "BATCH_SIZE: ${BATCH_SIZE}" +echo "SEQ_LENGTH: ${SEQ_LENGTH}" +echo "LOG_FILE: ${LOG_FILE}" +echo "================================================================" +echo "" + +export NCCL_DEBUG=WARN + +# Launch training +accelerate launch \ + --main_process_ip ${HOST_IP} \ + --main_process_port ${PORT} \ + --num_machines 1 \ + --num_processes ${NUM_PROCESSES} \ + --machine_rank 0 \ + --config_file ${CONFIG_FILE} \ + run.py \ + --model_name "${MODEL}" \ + --batch_size ${BATCH_SIZE} \ + --seq_length ${SEQ_LENGTH} \ + --sp_size ${SP_SIZE} \ + --dp_size ${DP_SIZE} \ + --backend ${BACKEND} \ + --compile ${COMPILE} \ + ${EXTRA_OPTS} \ + 2>&1 | tee ${LOG_FILE} diff --git a/benchmarks/autosp/sp_dp_registry.py b/benchmarks/autosp/sp_dp_registry.py new file mode 100644 index 000000000..ebb29d91a --- /dev/null +++ b/benchmarks/autosp/sp_dp_registry.py @@ -0,0 +1,43 @@ +import torch +import torch.distributed as dist + +GROUP_REGISTRY = {} # int -> dist.ProcessGroup + +def register_groups(groups): + """groups: List[List[int]], e.g. [[0,1],[2,3]]""" + for gid, ranks in enumerate(groups): + if gid not in GROUP_REGISTRY: + GROUP_REGISTRY[gid] = dist.new_group(ranks) + +def get_group(gid: int): + return GROUP_REGISTRY[gid] if gid is not None else dist.group.WORLD + +def get_registry(): + return GROUP_REGISTRY + +def is_setup(): + return GROUP_REGISTRY['is_reg'] if 'is_reg' in GROUP_REGISTRY else False + +def sp_size(): + assert 'SP_SIZE' in GROUP_REGISTRY, 'SP_SIZE not init properly.' + + return GROUP_REGISTRY['SP_SIZE'] + +def dp_size(): + assert 'DP_SIZE' in GROUP_REGISTRY, 'DP_SIZE not init properly' + + return GROUP_REGISTRY['DP_SIZE'] + +def populate_registry(SP_SIZE, DP_SIZE): + group_listing = [] + offset = 0 + for _ in range(DP_SIZE): + group_listing.append([i + offset for i in range(SP_SIZE)]) + offset += SP_SIZE + + register_groups(group_listing) + + ## Extraneous metadata required for proper instatiation. ## + GROUP_REGISTRY['SP_SIZE'] = SP_SIZE + GROUP_REGISTRY['DP_SIZE'] = DP_SIZE + GROUP_REGISTRY['is_reg'] = True