Skip to content

Commit dd16a96

Browse files
authored
example demonstrating how to train CosmosReason2 Eagle3 (#965)
## What does this PR do? **Type of change:** New example **Overview:** Adds examples/speculative_decoding/guides/train_eagle_head_cosmos_reason2.ipynb, a step-by-step Jupyter notebook that walks through the full EAGLE3 draft-head training workflow for nvidia/Cosmos-Reason2-8B. The notebook covers: 1. Installing dependencies 2. Authenticating with Hugging Face 3. Preparing training data from the Nemotron-Post-Training-Dataset-v2 (chat split) using a curated row-selection mapping (guides/nemotron_mapping.csv) 4. Inspecting the bundled EAGLE3 config (guides/CR2_eagle_config.json) tuned for Cosmos-Reason2 (YaRN RoPE, FlexAttention, reduced draft vocabulary) 5. (Optional) Calibrating the draft vocabulary to 32k tokens for faster training and inference 6. Launching training via launch_train.sh with FSDP2 multi-GPU support 7. Exporting the checkpoint to HF format and serving with vLLM Also includes guides/nemotron_mapping.csv and guides/CR2_eagle_config.json as companion files. ## Usage Open and run examples/speculative_decoding/guides/train_eagle_head_cosmos_reason2.ipynb cell by cell. After training, serve the exported checkpoint with: ``` vllm serve nvidia/Cosmos-Reason2-8B \ --host 0.0.0.0 \ --port 8000 \ --speculative-model export/cosmos-reason2-8b-eagle3 \ --num-speculative-tokens 3 \ --dtype bfloat16 ``` ## Testing Tested end-to-end on a 4xB100 GPUs. The exported checkpoint was validated with specdec_bench ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes - **Did you write any new necessary tests?**: No - **Did you add or update any necessary documentation?**: Yes (the notebook is self-documenting) - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: No ## Additional Information Cosmos-Reason2-8B requires at least one 80 GB GPU (H100/A100) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added a speculative-decoding configuration for draft model and rotary/attention behavior. * Added an end-to-end training notebook demonstrating training, export, and deployment of a speculative-decoding draft head on Cosmos-Reason2. * Added a data-preparation tool to download, normalize, and convert Nemotron chat conversations into a standardized conversation format. * **Documentation** * Notebook documents environment setup, data prep, training/validation cadence, export, and deployment steps. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Slawek Kierat <skierat@nvidia.com>
1 parent 31f0783 commit dd16a96

4 files changed

Lines changed: 516 additions & 0 deletions

File tree

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
{
2+
"draft_vocab_size": 32000,
3+
"initializer_range": 0.02,
4+
"rms_norm_eps": 1e-06,
5+
"_attn_implementation": "flex_attention",
6+
"rope_scaling": {
7+
"beta_fast": 32.0,
8+
"beta_slow": 1.0,
9+
"factor": 32.0,
10+
"original_max_position_embeddings": 8192,
11+
"rope_type": "yarn",
12+
"truncate": false
13+
},
14+
"rope_theta": 150000
15+
}
350 KB
Binary file not shown.
Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# Training an EAGLE3 Draft Head for Cosmos-Reason2\n",
8+
"\n",
9+
"This notebook walks through the full workflow for training an EAGLE3 speculative-decoding draft head on top of [nvidia/Cosmos-Reason2-8B](https://huggingface.co/nvidia/Cosmos-Reason2-8B).\n",
10+
"\n",
11+
"**Workflow overview**\n",
12+
"\n",
13+
"| Step | Description |\n",
14+
"| :---: | :--- |\n",
15+
"| 1 | Install dependencies |\n",
16+
"| 2 | Authenticate with Hugging Face |\n",
17+
"| 3 | Prepare training data from the Nemotron dataset |\n",
18+
"| 4 | Calibrate the draft vocabulary |\n",
19+
"| 5 | Launch training |\n",
20+
"| 6 | Export checkpoint for deployment |\n",
21+
"\n",
22+
"> **Hardware requirement** – Cosmos-Reason2-8B requires at least one 80 GB GPU (e.g. H100/A100).\n",
23+
"> Multi-GPU training is supported automatically via FSDP2 when more than one GPU is available."
24+
],
25+
"id": "efe23925"
26+
},
27+
{
28+
"cell_type": "markdown",
29+
"metadata": {},
30+
"source": [
31+
"## Step 1 – Install Dependencies"
32+
],
33+
"id": "e64d39b5"
34+
},
35+
{
36+
"cell_type": "code",
37+
"metadata": {},
38+
"source": [
39+
"%%bash\n",
40+
"pip install -U nvidia-modelopt[hf]\n",
41+
"pip install -r ../requirements.txt"
42+
],
43+
"execution_count": null,
44+
"outputs": [],
45+
"id": "f0049171"
46+
},
47+
{
48+
"cell_type": "markdown",
49+
"metadata": {},
50+
"source": [
51+
"## Step 2 – Authenticate with Hugging Face\n",
52+
"\n",
53+
"Both `nvidia/Cosmos-Reason2-8B` and `nvidia/Nemotron-Post-Training-Dataset-v2` require accepting\n",
54+
"their licence agreements on the Hub. Run the cell below and follow the interactive prompt to log in:"
55+
],
56+
"id": "fe68982a"
57+
},
58+
{
59+
"cell_type": "code",
60+
"metadata": {},
61+
"source": [
62+
"%%bash\n",
63+
"hf auth login"
64+
],
65+
"execution_count": null,
66+
"outputs": [],
67+
"id": "b62417b6"
68+
},
69+
{
70+
"cell_type": "markdown",
71+
"metadata": {},
72+
"source": [
73+
"## Step 3 – Prepare Training Data\n",
74+
"\n",
75+
"We use a curated subset of [nvidia/Nemotron-Post-Training-Dataset-v2](https://huggingface.co/datasets/nvidia/Nemotron-Post-Training-Dataset-v2)\n",
76+
"(chat split) for training. The `nemotron_mapping.bin` file (bundled alongside this notebook) selects the specific rows to use.\n",
77+
"It stores 0-based dataset row indices as packed `int32` values (little-endian, produced by `numpy.ndarray.tofile`).\n",
78+
"\n",
79+
"The script streams only the required parquet shards and writes a conversation file in the\n",
80+
"standard `jsonl` format expected by `launch_train.sh`."
81+
],
82+
"id": "cdd4d470"
83+
},
84+
{
85+
"cell_type": "code",
86+
"metadata": {},
87+
"source": [
88+
"%%bash\n",
89+
"python ../prepare_input_conversations/add_nemotron_chat.py \\\n",
90+
" --mapping-file nemotron_mapping.bin"
91+
],
92+
"execution_count": null,
93+
"outputs": [],
94+
"id": "32259e23"
95+
},
96+
{
97+
"cell_type": "code",
98+
"metadata": {},
99+
"source": [
100+
"%%bash\n",
101+
"# Expect exactly 89511 conversations.\n",
102+
"count=$(wc -l < input_conversations/nemotron-chat.jsonl)\n",
103+
"echo \"${count} conversations in ../input_conversations/nemotron-chat.jsonl\"\n",
104+
"[ \"$count\" -eq 89511 ] || { echo \"ERROR: expected 89511, got ${count}\"; exit 1; }"
105+
],
106+
"execution_count": null,
107+
"outputs": [],
108+
"id": "d05b97d3"
109+
},
110+
{
111+
"cell_type": "markdown",
112+
"metadata": {},
113+
"source": [
114+
"## Step 4 – Calibrate the Draft Vocabulary\n",
115+
"\n",
116+
"`CR2_eagle_config.json` sets `\"draft_vocab_size\": 32000`. Using a compressed vocabulary\n",
117+
"speeds up training and inference, but requires a one-time calibration step that produces a\n",
118+
"token-mapping file (`d2t.pt`)."
119+
],
120+
"id": "09717fcc"
121+
},
122+
{
123+
"cell_type": "code",
124+
"metadata": {},
125+
"source": [
126+
"%%bash\n",
127+
"python ../scripts/calibrate_draft_vocab.py \\\n",
128+
" --model nvidia/Cosmos-Reason2-8B \\\n",
129+
" --data input_conversations/nemotron-chat.jsonl \\\n",
130+
" --draft_vocab_size 32000 \\\n",
131+
" --save_dir draft_vocab_cache"
132+
],
133+
"execution_count": null,
134+
"outputs": [],
135+
"id": "388f6897"
136+
},
137+
{
138+
"cell_type": "markdown",
139+
"metadata": {},
140+
"source": [
141+
"## Step 5 – Train the EAGLE3 Draft Head\n",
142+
"\n",
143+
"Training is launched via `launch_train.sh`, which internally calls `accelerate launch main.py`\n",
144+
"and sets up FSDP2 automatically when multiple GPUs are available.\n",
145+
"\n",
146+
"Key arguments used for Cosmos-Reason2:\n",
147+
"\n",
148+
"| Argument | Value | Notes |\n",
149+
"| :--- | :--- | :--- |\n",
150+
"| `--model` | `nvidia/Cosmos-Reason2-8B` | Target VLM |\n",
151+
"| `--data` | `guides/input_conversations/nemotron-chat.jsonl` | Training conversations |\n",
152+
"| `--eagle_config` | `guides/CR2_eagle_config.json` | Draft-head architecture |\n",
153+
"| `--draft_vocab_cache` | `guides/draft_vocab_cache/Cosmos-Reason2-8B/d2t.pt` | Token-mapping from Step 4 |\n",
154+
"| `--vlm_processor` | `nvidia/Cosmos-Reason2-8B` | VLM image processor |\n",
155+
"| `--vlm_img_dir` | `data/` | Directory containing referenced images |\n",
156+
"| `--training_seq_len` | `16384` | Max token length per sample (lower to save GPU memory or speed up training) |\n",
157+
"| `--lr` | `1.5e-4` | Learning rate |\n",
158+
"| `--num_epochs` | `20` | Training epochs |\n",
159+
"| `--train_bs` | `1` | Per-device batch size |\n",
160+
"| `--save_steps` | `1000` | Checkpoint frequency |\n",
161+
"| `--ar_validate_steps` | `1000000` | Effectively disables in-training AR validation |\n",
162+
"\n",
163+
"> **Tip** – Set `--ar_validate_steps` to a smaller value (e.g. `500`) to periodically measure\n",
164+
"> acceptance rate on MT-Bench during training."
165+
],
166+
"id": "336c43b9"
167+
},
168+
{
169+
"cell_type": "code",
170+
"metadata": {},
171+
"source": [
172+
"%%bash\n",
173+
"export WANDB_MODE=disabled\n",
174+
"OUTPUT_DIR=ckpts/cosmos-reason2-8b-eagle3\n",
175+
"EAGLE_CONFIG=guides/CR2_eagle_config.json\n",
176+
"DRAFT_VOCAB_CACHE=guides/draft_vocab_cache/Cosmos-Reason2-8B/d2t.pt\n",
177+
"\n",
178+
"\n",
179+
"# 20 epochs on 89k samples (4xB100): ~24 hours.\n",
180+
"cd ..; OUTPUT_DIR=$OUTPUT_DIR ./launch_train.sh \\\n",
181+
" --model nvidia/Cosmos-Reason2-8B \\\n",
182+
" --output_dir $OUTPUT_DIR \\\n",
183+
" --data guides/input_conversations/nemotron-chat.jsonl \\\n",
184+
" --lr 1.5e-4 \\\n",
185+
" --num_epochs 20 \\\n",
186+
" --train_bs 1 \\\n",
187+
" --eagle_config $EAGLE_CONFIG \\\n",
188+
" --draft_vocab_cache $DRAFT_VOCAB_CACHE \\\n",
189+
" --training_seq_len 16384 \\\n",
190+
" --save_steps 1000 \\\n",
191+
" --ar_validate_steps 1000000 \\\n",
192+
" --vlm_processor nvidia/Cosmos-Reason2-8B \\\n",
193+
" --vlm_img_dir data/"
194+
],
195+
"execution_count": null,
196+
"outputs": [],
197+
"id": "0380f773"
198+
},
199+
{
200+
"cell_type": "markdown",
201+
"metadata": {},
202+
"source": [
203+
"## Step 6 – Export Checkpoint for Deployment\n",
204+
"\n",
205+
"After training completes, convert the ModelOpt checkpoint to the Hugging Face–compatible\n",
206+
"format expected by vLLM. Point `--model_path` to the desired checkpoint subdirectory\n",
207+
"(e.g. `checkpoint-110000`)."
208+
],
209+
"id": "98e0f8c4"
210+
},
211+
{
212+
"cell_type": "code",
213+
"metadata": {},
214+
"source": [
215+
"%%bash\n",
216+
"CKPT_DIR=ckpts/cosmos-reason2-8b-eagle3/checkpoint-110000\n",
217+
"EXPORT_PATH=export/cosmos-reason2-8b-eagle3\n",
218+
"\n",
219+
"python scripts/export_hf_checkpoint.py \\\n",
220+
" --model_path $CKPT_DIR \\\n",
221+
" --export_path $EXPORT_PATH"
222+
],
223+
"execution_count": null,
224+
"outputs": [],
225+
"id": "63880f67"
226+
},
227+
{
228+
"cell_type": "markdown",
229+
"metadata": {},
230+
"source": [
231+
"## Deployment\n",
232+
"\n",
233+
"The exported checkpoint can be served directly with **vLLM**:\n",
234+
"\n",
235+
"```bash\n",
236+
"vllm serve nvidia/Cosmos-Reason2-8B \\\n",
237+
" --host 0.0.0.0 \\\n",
238+
" --port 8000 \\\n",
239+
" --speculative_config '{\"method\": \"eagle3\", \"model\": \"export/cosmos-reason2-8b-eagle3\", \"num_speculative_tokens\": 3}'\n",
240+
"```\n",
241+
"\n",
242+
"Refer to the [vLLM speculative decoding docs](https://docs.vllm.ai/en/latest/features/spec_decode/) for the full list of options."
243+
],
244+
"id": "413c4275"
245+
}
246+
],
247+
"metadata": {
248+
"kernelspec": {
249+
"display_name": "Python 3",
250+
"language": "python",
251+
"name": "python3"
252+
},
253+
"language_info": {
254+
"name": "python",
255+
"version": "3.10.0"
256+
}
257+
},
258+
"nbformat": 4,
259+
"nbformat_minor": 5
260+
}

0 commit comments

Comments
 (0)