Skip to content

Commit bdd10c2

Browse files
h-guo18yeyu-nvidia
andauthored
Feat: MLA eagle (#689)
## What does this PR do? **Type of change:** New Feature <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> **Overview:** - Add MLA Eagle support - Add new argument "eagle_decoder_type" to switch between llama and kimik2 eagle; - Add patches to load from kimik2 model implementations dynamically; - new default config for kimi k2; - Refactor eagle export to support multilayer/multitype eagle export concisely; - Rename some modules for simplified export logic; - Other minor improvements; ## Usage <!-- You can potentially add a usage example below. --> ```python # Add a code snippet demonstrating how to use this ``` ## Testing <!-- Mention how have you tested your change if applicable. --> - Tested that kimi k2 thinking works with eagle_type=kimik2: <img width="1068" height="636" alt="image" src="https://github.com/user-attachments/assets/5557ef87-c719-4fb1-be18-30435f6b3885" /> - Tested that llama 3.2 1b works with eagle_type=llama: <img width="1066" height="634" alt="image" src="https://github.com/user-attachments/assets/633c575c-cc79-43af-aed3-0378a303ebc7" /> ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> --------- Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> Signed-off-by: yeyu-nvidia <yeyu@nvidia.com> Co-authored-by: yeyu-nvidia <yeyu@nvidia.com>
1 parent b286165 commit bdd10c2

13 files changed

Lines changed: 500 additions & 182 deletions

File tree

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
11
{
2-
"_attn_implementation": "sdpa"
32
}

examples/speculative_decoding/eagle_utils.py

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -259,35 +259,24 @@ def __len__(self):
259259
def __getitem__(self, i) -> dict[str, torch.Tensor]:
260260
# Load the conversational data, using the cache
261261
raw_data, offline_file_path = self.data_entries[i]
262-
if i in self.cached_data_dict:
263-
preprocessed_base = self.cached_data_dict[i]
264-
else:
265-
ret = self.preprocess_fn(
266-
[raw_data], self.tokenizer, processor=self.vlm_processor, img_dir=self.img_dir
267-
)
268-
preprocessed_base = {k: ret[k][0] for k in ret}
269-
self.cached_data_dict[i] = preprocessed_base
270-
271262
# Extend the data sample with the hidden states from the .pt file
272263
max_length = self.tokenizer.model_max_length
273264
offline_data = torch.load(offline_file_path)
274265
offline_data["input_ids"] = offline_data["input_ids"][:max_length]
275266
offline_data["hidden_states"] = offline_data["hidden_states"][:max_length, :]
276267
offline_data["aux_hidden_states"] = offline_data["aux_hidden_states"][:max_length, :]
277268

278-
# Make sure the input_ids have the same shape
279-
if preprocessed_base["input_ids"].shape != offline_data["input_ids"].shape:
280-
msg = f"""Input IDs from offline data do not match the preprocessed input IDs
281-
for offline data sample at {offline_file_path}."""
282-
raise ValueError(msg)
283-
284-
ret = {**preprocessed_base} # Shallow copy so we don't accidentally modify the cache
285-
ret["input_ids"] = offline_data["input_ids"]
286-
ret["kwargs"] = {
287-
"base_model_outputs": {
288-
"base_model_hidden_states": offline_data["hidden_states"],
289-
"aux_hidden_states": offline_data["aux_hidden_states"],
290-
}
269+
ret = {
270+
"input_ids": offline_data["input_ids"],
271+
"attention_mask": torch.ones_like(offline_data["input_ids"]),
272+
"loss_mask": torch.ones_like(offline_data["input_ids"]),
273+
"labels": torch.full_like(offline_data["input_ids"], IGNORE_TOKEN_ID),
274+
"kwargs": {
275+
"base_model_outputs": {
276+
"base_model_hidden_states": offline_data["hidden_states"],
277+
"aux_hidden_states": offline_data["aux_hidden_states"],
278+
}
279+
},
291280
}
292281
return ret
293282

@@ -338,12 +327,24 @@ def make_eagle_supervised_data_module(
338327
"offline_data_path must be provided for offline training."
339328
)
340329
offline_data_path = Path(data_args.offline_data_path)
330+
# Collect all pt file paths
341331
all_files = {str(p) for p in offline_data_path.glob("*.pt")}
332+
all_files |= {str(p) for p in offline_data_path.glob("**/*.pt")}
342333
if not all_files:
343334
raise ValueError(f"No .pt files found in {data_args.offline_data_path}")
344335

345-
# Filter to conversations that exist in the offline data and in the provided json
336+
# Build a map from conv_id to file_path for fast lookup
337+
print("building conv_id_to_file map...")
338+
conv_id_to_file = {}
339+
for pt_path in all_files:
340+
pt_name = Path(pt_path).name
341+
# Expect conv_id.pt
342+
if pt_name.endswith(".pt"):
343+
conv_id = pt_name[:-3]
344+
conv_id_to_file[conv_id] = pt_path
345+
346346
valid_entries = []
347+
print("filtering valid entries...")
347348
for entry in data_json:
348349
conv_id = entry.get("conversation_id")
349350
if conv_id is None:
@@ -352,9 +353,11 @@ def make_eagle_supervised_data_module(
352353
conv_id = entry.get("id")
353354
if conv_id is None:
354355
raise ValueError(f"Conversation ID required but not found for entry {entry}")
355-
file_path = str(offline_data_path / f"{conv_id}.pt")
356-
if file_path in all_files:
357-
valid_entries.append((entry, file_path))
356+
357+
file_path = conv_id_to_file.get(str(conv_id))
358+
if file_path is None:
359+
continue
360+
valid_entries.append((entry, file_path))
358361

359362
if len(valid_entries) == 0:
360363
msg = """No valid files found in the offline data path that match the conversation IDs

examples/speculative_decoding/launch_train.sh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ while [ $# -gt 0 ]; do
3838
if [[ "$1" != *=* ]]; then shift; fi
3939
MODE="${1#*=}"
4040
;;
41+
--eagle_decoder_type*)
42+
if [[ "$1" != *=* ]]; then shift; fi
43+
EAGLE_DECODER_TYPE="${1#*=}"
44+
;;
4145
--output_dir*)
4246
if [[ "$1" != *=* ]]; then shift; fi
4347
OUTPUT_DIR="${1#*=}"
@@ -115,6 +119,7 @@ DEFAULT_SAVE_STEPS=$((8192 / GPU_COUNT))
115119

116120
MODEL=${MODEL:-"TinyLlama/TinyLlama-1.1B-Chat-v1.0"}
117121
MODE=${MODE:-"eagle3"}
122+
EAGLE_DECODER_TYPE=${EAGLE_DECODER_TYPE:-"llama"}
118123
# Set default OUTPUT_DIR to ckpts/{modelname}, where {modelname} is the last part of the model path
119124
MODEL_BASENAME=$(basename "$MODEL")
120125
OUTPUT_DIR=${OUTPUT_DIR:-"ckpts/${MODEL_BASENAME}-$(date +%Y%m%d_%H%M)"}
@@ -174,6 +179,7 @@ fi
174179
export TOKENIZERS_PARALLELISM=False
175180
CMD="accelerate launch $MULTI_GPU --mixed_precision bf16 main.py \
176181
--mode $MODE \
182+
--eagle_decoder_type $EAGLE_DECODER_TYPE \
177183
--model_name_or_path $MODEL \
178184
--training_seq_len $TRAINING_SEQ_LEN \
179185
--dataloader_drop_last True \

examples/speculative_decoding/main.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,10 @@ class MedusaArguments:
111111
@dataclass
112112
class EagleArguments:
113113
eagle_config: str = field(default=None, metadata={"help": "Path to eagle_config.json"})
114+
eagle_decoder_type: str = field(
115+
default="llama",
116+
metadata={"help": "The class of eagle decoder to use. Available options: llama, kimik2"},
117+
)
114118

115119

116120
def train():
@@ -144,24 +148,29 @@ def train():
144148

145149
if checkpoint:
146150
model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype="auto")
147-
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
151+
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
148152
else:
149153
# To avoid OOM for large models, we load and convert model on CPU first.
150154
# Model will be moved to GPU during HF trainer.init().
155+
offline_kwargs = {"num_hidden_layers": 0} if use_offline_training else {}
151156
model = transformers.AutoModelForCausalLM.from_pretrained(
152157
model_args.model_name_or_path,
153158
torch_dtype="auto",
154159
device_map="cpu",
155160
trust_remote_code=True,
161+
**offline_kwargs,
156162
)
157163
if use_offline_training:
158164
# When doing offline training, we need to set num_hidden_layers
159165
# since we override it when loading the model for space savings
160-
model_config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path)
166+
model_config = transformers.AutoConfig.from_pretrained(
167+
model_args.model_name_or_path, trust_remote_code=True
168+
)
161169
model.config.num_orig_hidden_layers = model_config.num_hidden_layers
162170
tokenizer = transformers.AutoTokenizer.from_pretrained(
163171
model_args.model_name_or_path,
164172
model_max_length=training_args.training_seq_len,
173+
trust_remote_code=True,
165174
)
166175
if tokenizer.chat_template is None:
167176
tokenizer.chat_template = (
@@ -179,22 +188,30 @@ def train():
179188
}
180189
mtsp.convert(model, [("medusa", config)])
181190
elif training_args.mode in ["eagle1", "eagle3"]:
182-
from modelopt.torch.speculative.config import EAGLE1_DEFAULT_CFG, EAGLE3_DEFAULT_CFG
183-
184-
# Load default config
185-
config = {
186-
"eagle1": EAGLE1_DEFAULT_CFG,
187-
"eagle3": EAGLE3_DEFAULT_CFG,
188-
}[training_args.mode]["config"]
191+
from modelopt.torch.speculative.config import (
192+
default_eagle_config,
193+
eagle3_default_config,
194+
kimik2_eagle_default_config,
195+
)
189196

190-
# overwrite config with custom config
191-
if use_offline_training:
192-
config["eagle_offline"] = True
197+
if eagle_args.eagle_decoder_type == "kimik2":
198+
eagle_architecture_config = kimik2_eagle_default_config
199+
else:
200+
eagle_architecture_config = {
201+
"eagle1": default_eagle_config,
202+
"eagle3": eagle3_default_config,
203+
}[training_args.mode]
193204

194205
if eagle_args.eagle_config:
195206
with open(eagle_args.eagle_config) as f:
196207
custom_config = json.load(f)
197-
config["eagle_architecture_config"].update(custom_config)
208+
eagle_architecture_config.update(custom_config)
209+
210+
config = {
211+
"eagle_decoder_type": eagle_args.eagle_decoder_type,
212+
"eagle_offline": use_offline_training,
213+
"eagle_architecture_config": eagle_architecture_config,
214+
}
198215

199216
mtsp.convert(model, [("eagle", config)])
200217

0 commit comments

Comments
 (0)