WARNING: THIS SITE IS A MIRROR OF GITHUB.COM / IT CANNOT LOGIN OR REGISTER ACCOUNTS / THE CONTENTS ARE PROVIDED AS-IS / THIS SITE ASSUMES NO RESPONSIBILITY FOR ANY DISPLAYED CONTENT OR LINKS / IF YOU FOUND SOMETHING MAY NOT GOOD FOR EVERYONE, CONTACT ADMIN AT ilovescratch@foxmail.com
Skip to content

Commit 6a3d817

Browse files
authored
Tohtana/bf16 master weights examples (#994)
* add bf16 master weights example Signed-off-by: Masahiro Tanaka <[email protected]> * add results Signed-off-by: Masahiro Tanaka <[email protected]> * update estimation Signed-off-by: Masahiro Tanaka <[email protected]> --------- Signed-off-by: Masahiro Tanaka <[email protected]>
1 parent e676aa3 commit 6a3d817

File tree

10 files changed

+2880
-0
lines changed

10 files changed

+2880
-0
lines changed
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# BF16 Low-Precision Master Weights and Optimizer States
2+
3+
This example demonstrates DeepSpeed's [new low-precision training options](https://github.com/deepspeedai/DeepSpeed/pull/7700) that can significantly reduce memory usage:
4+
5+
- `bf16_master_weights_and_grads`: Keep master parameters and gradients in BF16 instead of FP32
6+
- `bf16_optimizer_states`: Keep optimizer states (e.g., Adam moments) in BF16
7+
8+
9+
### Running an Example
10+
11+
The following commands run training for 1000 steps on the Wikitext-103 dataset using both the baseline and BF16 low-precision configurations, then generates a loss comparison plot.
12+
The model has approximately 6.86 billion parameters (hidden=4096, layers=32, heads=32, batch=1, seq=512).
13+
For BF16 low-precision training, we use `torch.autocast`. ZeRO stage is set to 3 for both.
14+
15+
```bash
16+
# Run 1000 steps with wikitext dataset
17+
deepspeed --num_gpus=4 train.py --deepspeed_config configs/baseline.json \
18+
--num_layers 32 --hidden_dim 4096 --num_heads 32 --batch_size 1 \
19+
--num_steps 1000 --activation_checkpointing \
20+
--loss_log_file logs/baseline_loss.csv --use_real_data --seed 42
21+
22+
deepspeed --num_gpus=4 train.py --deepspeed_config configs/bf16_full.json \
23+
--num_layers 32 --hidden_dim 4096 --num_heads 32 --batch_size 1 \
24+
--num_steps 1000 --activation_checkpointing \
25+
--loss_log_file logs/bf16_full_loss.csv --use_real_data --seed 42
26+
27+
# Generate comparison plot
28+
python plot_loss.py --baseline logs/baseline_loss.csv --bf16 logs/bf16_full_loss.csv \
29+
--output loss_comparison.png
30+
```
31+
32+
Here is a summary of the memory usage and training time results using 4xH100.
33+
This shows a significant memory reduction: Memory reduction: 9.57 GB allocated (37%), 12.45 GB peak (39.7%)**.
34+
35+
| Configuration | Allocated Memory | Peak Memory | Avg Step Time |
36+
|---------------|------------------|-------------|---------------|
37+
| Baseline (fp32 master) | 25.74 GB | 31.38 GB | 0.6016s |
38+
| BF16 low-precision (master + opt states) | **16.17 GB** | **18.93 GB** | 0.6427s |
39+
40+
41+
## Loss Curve Comparison
42+
43+
To verify that BF16 low-precision training maintains numerical stability, we trained for 1000 steps on the Wikitext-103 dataset:
44+
45+
![Loss Comparison](logs/7b_loss_run/loss_comparison.png)
46+
47+
| Configuration | Final Loss | Mean Loss | Loss Std |
48+
|---------------|------------|-----------|----------|
49+
| Baseline (fp32 master) | 3.09 | 2.78 | 1.56 |
50+
| BF16 Low-Precision | 3.12 | 2.90 | 2.37 |
51+
52+
The loss curves show that both configurations converge similarly, demonstrating that the reduced precision does not significantly impact training quality while providing substantial memory savings.
53+
54+
### Memory Breakdown
55+
56+
For a model with N parameters:
57+
58+
| Component | Baseline | BF16 Low-Precision |
59+
|-----------|----------|-------------------|
60+
| Model params | 2N bytes (BF16) | 2N bytes (BF16) |
61+
| Gradients | 2N bytes (BF16) | 2N bytes (BF16) |
62+
| Master weights | 4N bytes (FP32) | 2N bytes (BF16) |
63+
| Master Gradients | 4N bytes (FP32) | 2N bytes (BF16) |
64+
| Adam momentum | 4N bytes (FP32) | 2N bytes (BF16) |
65+
| Adam variance | 4N bytes (FP32) | 2N bytes (BF16) |
66+
| **Total** | **20 bytes** | **12N bytes** |
67+
68+
Note that DeepSpeed ZeRO partitions model states across multiple GPUs. ZeRO Stage 1 partitions master parameters, gradients, and Adam’s momentum and variance. ZeRO Stage 2 additionally partitions gradients. With ZeRO Stage 3, all of these model states are partitioned.
69+
70+
With ZeRO-3, BF16 low-precision configurations provide a theoretical ~40% reduction in optimizer-related memory. Actual savings depend on activation memory and other factors, but our results show a close match to the theoretical estimate.
71+
72+
## Related Resources
73+
74+
- [DeepSpeed BF16 Documentation](https://www.deepspeed.ai/docs/config-json/#bf16-training-options)
75+
- [Low-precision master params PR](https://github.com/deepspeedai/DeepSpeed/pull/7700)
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
{
2+
"train_micro_batch_size_per_gpu": 4,
3+
"gradient_accumulation_steps": 1,
4+
"steps_per_print": 100,
5+
6+
"optimizer": {
7+
"type": "AdamW",
8+
"params": {
9+
"lr": 1e-4,
10+
"betas": [0.9, 0.999],
11+
"eps": 1e-8,
12+
"weight_decay": 0.01
13+
}
14+
},
15+
16+
"bf16": {
17+
"enabled": true
18+
},
19+
20+
"zero_optimization": {
21+
"stage": 3,
22+
"overlap_comm": true,
23+
"contiguous_gradients": true,
24+
"reduce_bucket_size": 5e7,
25+
"stage3_param_persistence_threshold": 0
26+
},
27+
28+
"torch_autocast": {
29+
"enabled": false
30+
}
31+
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
{
2+
"train_micro_batch_size_per_gpu": 4,
3+
"gradient_accumulation_steps": 1,
4+
"steps_per_print": 100,
5+
6+
"optimizer": {
7+
"type": "AdamW",
8+
"params": {
9+
"lr": 1e-4,
10+
"betas": [0.9, 0.999],
11+
"eps": 1e-8,
12+
"weight_decay": 0.01
13+
}
14+
},
15+
16+
"bf16": {
17+
"enabled": true,
18+
"bf16_master_weights_and_grads": true,
19+
"bf16_optimizer_states": true
20+
},
21+
22+
"zero_optimization": {
23+
"stage": 3,
24+
"overlap_comm": true,
25+
"contiguous_gradients": true,
26+
"reduce_bucket_size": 5e7,
27+
"stage3_param_persistence_threshold": 0
28+
},
29+
30+
"torch_autocast": {
31+
"enabled": true,
32+
"dtype": "torch.bfloat16"
33+
}
34+
}
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
#!/usr/bin/env python
2+
# Copyright (c) Microsoft Corporation.
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""
6+
Script to gather and compare memory usage from training logs.
7+
8+
Usage:
9+
python gather_memory.py --log_dir logs/20231201_120000
10+
"""
11+
12+
import argparse
13+
import os
14+
import re
15+
from pathlib import Path
16+
17+
18+
def parse_summary_line(line):
19+
"""Parse the SUMMARY line from log output."""
20+
pattern = r"SUMMARY: config=(\S+) params=(\d+) peak_mem_bytes=(\d+) alloc_mem_bytes=(\d+) avg_step_time=(\S+)"
21+
match = re.search(pattern, line)
22+
if match:
23+
return {
24+
"config": match.group(1),
25+
"params": int(match.group(2)),
26+
"peak_mem_bytes": int(match.group(3)),
27+
"alloc_mem_bytes": int(match.group(4)),
28+
"avg_step_time": float(match.group(5)),
29+
}
30+
return None
31+
32+
33+
def format_bytes(bytes_val):
34+
"""Format bytes to human-readable string."""
35+
gb = bytes_val / (1024 ** 3)
36+
return f"{gb:.2f} GB"
37+
38+
39+
def format_bytes_mb(bytes_val):
40+
"""Format bytes to MB."""
41+
mb = bytes_val / (1024 ** 2)
42+
return f"{mb:.1f} MB"
43+
44+
45+
def get_config_name(config_path):
46+
"""Extract clean config name from path."""
47+
name = Path(config_path).stem
48+
if name == "baseline":
49+
return "Baseline (fp32 master)"
50+
elif name == "bf16_master_wg":
51+
return "bf16_master_weights_and_grads"
52+
elif name == "bf16_full":
53+
return "bf16_full (master + opt states)"
54+
return name
55+
56+
57+
def main():
58+
parser = argparse.ArgumentParser(description="Gather memory usage from training logs")
59+
parser.add_argument("--log_dir", type=str, required=True, help="Directory containing log files")
60+
parser.add_argument("--output", type=str, default=None, help="Output file for summary")
61+
args = parser.parse_args()
62+
63+
log_dir = Path(args.log_dir)
64+
if not log_dir.exists():
65+
print(f"Error: Log directory '{log_dir}' does not exist")
66+
return 1
67+
68+
# Find and parse all log files
69+
results = []
70+
log_files = ["baseline.log", "bf16_full.log"]
71+
72+
for log_file in log_files:
73+
log_path = log_dir / log_file
74+
if not log_path.exists():
75+
print(f"Warning: Log file '{log_path}' not found, skipping")
76+
continue
77+
78+
with open(log_path, "r") as f:
79+
for line in f:
80+
summary = parse_summary_line(line)
81+
if summary:
82+
results.append(summary)
83+
break
84+
85+
if not results:
86+
print("No results found in log files")
87+
return 1
88+
89+
# Calculate baseline for comparison
90+
baseline_peak = None
91+
for r in results:
92+
if "baseline" in r["config"]:
93+
baseline_peak = r["peak_mem_bytes"]
94+
break
95+
96+
# Generate summary
97+
output_lines = []
98+
output_lines.append("=" * 80)
99+
output_lines.append("BF16 Low-Precision Master Weights - Memory Usage Comparison")
100+
output_lines.append("=" * 80)
101+
output_lines.append("")
102+
103+
# Table header
104+
output_lines.append(f"{'Configuration':<40} {'Peak Memory':<15} {'Reduction':<15} {'Step Time':<12}")
105+
output_lines.append("-" * 80)
106+
107+
for r in results:
108+
config_name = get_config_name(r["config"])
109+
peak_mem = format_bytes(r["peak_mem_bytes"])
110+
step_time = f"{r['avg_step_time']:.4f}s"
111+
112+
if baseline_peak and baseline_peak > 0:
113+
reduction = ((baseline_peak - r["peak_mem_bytes"]) / baseline_peak) * 100
114+
reduction_str = f"{reduction:+.1f}%" if reduction != 0 else "-"
115+
else:
116+
reduction_str = "-"
117+
118+
output_lines.append(f"{config_name:<40} {peak_mem:<15} {reduction_str:<15} {step_time:<12}")
119+
120+
output_lines.append("-" * 80)
121+
output_lines.append("")
122+
123+
# Detailed breakdown
124+
output_lines.append("Detailed Results:")
125+
output_lines.append("-" * 40)
126+
for r in results:
127+
config_name = get_config_name(r["config"])
128+
output_lines.append(f"\n{config_name}:")
129+
output_lines.append(f" Parameters: {r['params']:,}")
130+
output_lines.append(f" Peak Memory: {format_bytes(r['peak_mem_bytes'])} ({r['peak_mem_bytes']:,} bytes)")
131+
output_lines.append(f" Allocated Memory: {format_bytes(r['alloc_mem_bytes'])} ({r['alloc_mem_bytes']:,} bytes)")
132+
output_lines.append(f" Avg Step Time: {r['avg_step_time']:.4f}s")
133+
134+
output_lines.append("")
135+
output_lines.append("=" * 80)
136+
137+
# Generate markdown table
138+
output_lines.append("")
139+
output_lines.append("Markdown Table (for README):")
140+
output_lines.append("-" * 40)
141+
output_lines.append("")
142+
output_lines.append("| Configuration | Peak Memory | Memory Reduction | Avg Step Time |")
143+
output_lines.append("|---------------|-------------|------------------|---------------|")
144+
145+
for r in results:
146+
config_name = get_config_name(r["config"])
147+
peak_mem = format_bytes(r["peak_mem_bytes"])
148+
step_time = f"{r['avg_step_time']:.4f}s"
149+
150+
if baseline_peak and baseline_peak > 0:
151+
reduction = ((baseline_peak - r["peak_mem_bytes"]) / baseline_peak) * 100
152+
reduction_str = f"{reduction:+.1f}%" if reduction != 0 else "-"
153+
else:
154+
reduction_str = "-"
155+
156+
output_lines.append(f"| {config_name} | {peak_mem} | {reduction_str} | {step_time} |")
157+
158+
output_lines.append("")
159+
160+
# Print to stdout
161+
summary_text = "\n".join(output_lines)
162+
print(summary_text)
163+
164+
# Save to file
165+
output_path = args.output or (log_dir / "summary.txt")
166+
with open(output_path, "w") as f:
167+
f.write(summary_text)
168+
169+
print(f"\nSummary saved to: {output_path}")
170+
171+
return 0
172+
173+
174+
if __name__ == "__main__":
175+
exit(main())

0 commit comments

Comments
 (0)