Skip to content

Commit 5c194b3

Browse files
Minitron pruning refactor [2/2]: De-couple importance estimator from Dynamic Module (#693)
## What does this PR do? - Code refactor, no logic change - De-couple Minitron pruning importance estimator from Dynamic Module so its easy to configure different importance logic for pruning ## Testing <!-- Mention how have you tested your change if applicable. --> - [x] CI/CD tests passing - [x] Compare mmlu on pruned Qwen3-8B with previous and current implementation - [x] Compare mmlu on pruned Qwen3-30B-A3B with previous and current implementation (lot of variance in results, some pruning configs better some worse) - [x] Compare mmlu on pruned Nemotron-Nano-v2-9B with previous and current implementation Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
1 parent 7233616 commit 5c194b3

7 files changed

Lines changed: 909 additions & 806 deletions

File tree

modelopt/torch/nas/plugins/megatron.py

Lines changed: 6 additions & 494 deletions
Large diffs are not rendered by default.

modelopt/torch/prune/plugins/mcore_minitron.py

Lines changed: 589 additions & 14 deletions
Large diffs are not rendered by default.

tests/gpu/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py

Lines changed: 1 addition & 219 deletions
Original file line numberDiff line numberDiff line change
@@ -18,29 +18,20 @@
1818
import pytest
1919
import torch
2020
from _test_utils.import_helper import skip_if_no_megatron
21-
from _test_utils.torch.misc import compare_outputs
2221

2322
skip_if_no_megatron(apex_or_te_required=True)
2423

2524
from _test_utils.torch.distributed.utils import spawn_multiprocess_job
2625
from _test_utils.torch.megatron.models import get_mcore_gpt_model
27-
from _test_utils.torch.megatron.utils import (
28-
run_mcore_inference,
29-
run_mcore_inference_with_dummy_input,
30-
)
31-
from _test_utils.torch.misc import set_seed
26+
from _test_utils.torch.megatron.utils import run_mcore_inference
3227
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
33-
from megatron.core.parallel_state import destroy_model_parallel
3428
from megatron.core.transformer.attention import SelfAttention
35-
from megatron.core.transformer.identity_op import IdentityOp
3629
from megatron.core.transformer.mlp import MLP
3730
from megatron.core.transformer.transformer_layer import TransformerLayer
3831

3932
import modelopt.torch.nas as mtn
40-
from modelopt.torch.nas.conversion import export_searchspace
4133
from modelopt.torch.nas.modules import DynamicModuleList
4234
from modelopt.torch.nas.plugins.megatron import (
43-
NumAttentionHeadsHp,
4435
_DynamicColumnParallelLinear,
4536
_DynamicEmbedding,
4637
_DynamicLanguageModelEmbedding,
@@ -57,7 +48,6 @@
5748
expand_head_indices,
5849
)
5950
from modelopt.torch.opt.utils import named_dynamic_modules, search_space_size
60-
from modelopt.torch.prune.plugins.mcore_minitron import _convert_model_to_dynamic_space
6151
from modelopt.torch.utils.random import centroid
6252

6353
SEED = 1234
@@ -156,147 +146,12 @@ def test_gpt_search_space(num_attention_heads, num_query_groups, activation_func
156146
)
157147

158148

159-
def _test_gpt_parameter_sorting(activation_func, rank, size):
160-
num_layers = size
161-
hidden_size = 128
162-
num_attention_heads = 8
163-
num_query_groups = 4
164-
ffn_hidden_size = 64
165-
max_sequence_length = 32
166-
vocab_size = 128
167-
batch_size = 2
168-
169-
model = get_mcore_gpt_model(
170-
tensor_model_parallel_size=1,
171-
pipeline_model_parallel_size=size,
172-
initialize_megatron=True,
173-
num_layers=num_layers,
174-
hidden_size=hidden_size,
175-
num_attention_heads=num_attention_heads,
176-
num_query_groups=num_query_groups,
177-
ffn_hidden_size=ffn_hidden_size,
178-
max_sequence_length=max_sequence_length,
179-
vocab_size=vocab_size,
180-
activation_func=activation_func,
181-
bf16=False,
182-
).cuda()
183-
184-
# Randomize layernorm weights instead of all zeros or ones
185-
for n, m in model.named_modules():
186-
if "layernorm" in n and not isinstance(m, IdentityOp):
187-
m.weight.data = torch.randn_like(m.weight)
188-
189-
model.eval()
190-
dynamic_space = _convert_model_to_dynamic_space(model)
191-
192-
# Compute activations for sorting
193-
for _ in range(5):
194-
run_mcore_inference_with_dummy_input(model, batch_size)
195-
196-
# Get the output of the original model
197-
prompt_tokens = torch.randint(0, vocab_size, (batch_size, max_sequence_length)).cuda()
198-
y1 = run_mcore_inference(model, prompt_tokens)
199-
200-
mtn.utils.sort_parameters(model)
201-
202-
# check if all ffn_hidden_size, num_attention_heads, hidden_size have been sorted
203-
sortable_per_pp = [
204-
n for n, hp in dynamic_space.named_hparams(configurable=True) if hp.importance is not None
205-
]
206-
# 2 hps per layer (num_attention_heads, ffn_hidden_size) + 1 for hidden_size (num_layers is not sorted!)
207-
assert len(sortable_per_pp) == 2 * num_layers // size + 1
208-
209-
# sanity check if the model functionality is preserved after sorting
210-
y2 = run_mcore_inference(model, prompt_tokens)
211-
212-
# check if the inference results after sorting is the same
213-
compare_outputs(y1, y2, rtol=1e-5, atol=1e-3)
214-
215-
216-
@pytest.mark.parametrize("activation_func", ["swiglu"])
217-
def test_gpt_parameter_sorting(activation_func, need_2_gpus):
218-
set_seed(SEED)
219-
spawn_multiprocess_job(
220-
size=torch.cuda.device_count(),
221-
job=partial(_test_gpt_parameter_sorting, activation_func),
222-
backend="nccl",
223-
)
224-
225-
226149
def test_expand_head_indices():
227150
heads = torch.LongTensor([1, 3, 2, 0])
228151
hidden_size_per_head = 2
229152
assert expand_head_indices(heads, hidden_size_per_head).tolist() == [2, 3, 6, 7, 4, 5, 0, 1]
230153

231154

232-
def test_self_attention_head_sorting(distributed_setup_size_1):
233-
model = get_mcore_gpt_model(
234-
tensor_model_parallel_size=1,
235-
pipeline_model_parallel_size=1,
236-
initialize_megatron=True,
237-
num_layers=1,
238-
hidden_size=16,
239-
num_attention_heads=8,
240-
num_query_groups=2,
241-
ffn_hidden_size=16,
242-
activation_func="squared_relu",
243-
).cuda()
244-
245-
model = mtn.convert(model, "mcore_minitron")
246-
247-
self_attn = model.decoder.layers[0].self_attention
248-
assert isinstance(self_attn, _DynamicSelfAttention)
249-
assert isinstance(self_attn.linear_qkv, _DynamicQKVColumnParallelLinear)
250-
assert isinstance(self_attn.linear_proj, _DynamicProjRowParallelLinear)
251-
252-
hp_num_attention_heads = self_attn.get_hparam("num_attention_heads")
253-
assert isinstance(hp_num_attention_heads, NumAttentionHeadsHp)
254-
255-
# Choices are multiples of num_query_groups (2): [2, 4, 6, 8]
256-
assert hp_num_attention_heads.choices == [2, 4, 6, 8]
257-
assert hp_num_attention_heads._num_query_groups == 2
258-
259-
# Set importance and slice order
260-
# Importance per head (group-aware): [2.2, 0.1, 1.1, 2.1, 3.0, 2.0, 0.0, 1.0]
261-
# Group 0 (heads 0-3): [2.2, 0.1, 1.1, 2.1] → sorted: [0, 3, 2, 1]
262-
# Group 1 (heads 4-7): [3.0, 2.0, 0.0, 1.0] → sorted: [4, 5, 7, 6]
263-
# Global ranking (group-aware, flattened): [0, 3, 2, 1, 4, 5, 7, 6]
264-
hp_num_attention_heads._get_importance = lambda: torch.tensor(
265-
[2.2, 0.1, 1.1, 2.1, 3.0, 2.0, 0.0, 1.0]
266-
)
267-
# _estimate_head_ranking returns ranking as 1D tensor
268-
expected_ranking = torch.tensor([0, 3, 2, 1, 4, 5, 7, 6])
269-
hp_num_attention_heads.enforce_order(expected_ranking)
270-
271-
assert hp_num_attention_heads.active_slice.tolist() == [0, 3, 2, 1, 4, 5, 7, 6]
272-
273-
# check if we get correct selection of sorted + pruned heads after setting active values
274-
hp_num_attention_heads.active = 4 # top 2 heads per group (2 groups * 2 heads = 4 total)
275-
276-
# Expected: Top 2 heads from each group: [0, 3] from group 0, [4, 5] from group 1
277-
expected_q_heads = [0, 3, 4, 5]
278-
# In QKV layout (4 heads/group → 6 QKV heads/group):
279-
# Group 0: Q=[0, 3], K=4, V=5 → QKV indices [0, 3, 4, 5]
280-
# Group 1: Q=[4, 5], K=10, V=11 → QKV indices [6, 7, 10, 11]
281-
expected_qkv_heads = [0, 3, 4, 5, 6, 7, 10, 11]
282-
283-
assert (
284-
self_attn.linear_qkv._get_output_size_indices().tolist()
285-
== expand_head_indices(
286-
torch.LongTensor(expected_qkv_heads), model.config.kv_channels
287-
).tolist()
288-
)
289-
assert (
290-
self_attn.linear_proj._get_input_size_indices().tolist()
291-
== expand_head_indices(
292-
torch.LongTensor(expected_q_heads), model.config.kv_channels
293-
).tolist()
294-
)
295-
296-
# Clean up since this is not a spawned process
297-
destroy_model_parallel()
298-
299-
300155
def _test_gpt_moe_search_space(rank, size):
301156
channel_divisor = 64
302157

@@ -374,76 +229,3 @@ def test_gpt_moe_search_space():
374229
spawn_multiprocess_job(
375230
size=torch.cuda.device_count(), job=_test_gpt_moe_search_space, backend="nccl"
376231
)
377-
378-
379-
def _test_gpt_moe_parameter_sorting(rank, size):
380-
num_layers = min(size * 2, 8)
381-
hidden_size = 256
382-
num_attention_heads = 8
383-
num_query_groups = 4
384-
moe_ffn_hidden_size = 128
385-
num_moe_experts = 4
386-
moe_shared_expert_intermediate_size = 256
387-
max_sequence_length = 16
388-
vocab_size = 64
389-
batch_size = 2
390-
391-
model = get_mcore_gpt_model(
392-
tensor_model_parallel_size=1,
393-
pipeline_model_parallel_size=size,
394-
initialize_megatron=True,
395-
num_layers=num_layers,
396-
hidden_size=hidden_size,
397-
num_attention_heads=num_attention_heads,
398-
num_query_groups=num_query_groups,
399-
max_sequence_length=max_sequence_length,
400-
vocab_size=vocab_size,
401-
activation_func="squared_relu",
402-
num_moe_experts=num_moe_experts,
403-
moe_ffn_hidden_size=moe_ffn_hidden_size,
404-
moe_shared_expert_intermediate_size=moe_shared_expert_intermediate_size,
405-
bf16=False,
406-
).cuda()
407-
408-
# Randomize layernorm weights instead of all zeros or ones
409-
for n, m in model.named_modules():
410-
if "layernorm" in n and not isinstance(m, IdentityOp):
411-
m.weight.data = torch.randn_like(m.weight)
412-
413-
model.eval()
414-
dynamic_space = _convert_model_to_dynamic_space(model)
415-
416-
# Compute activations for sorting
417-
for _ in range(10):
418-
run_mcore_inference_with_dummy_input(model, batch_size)
419-
420-
# Get the output of the original model
421-
prompt_tokens = torch.randint(0, vocab_size, (batch_size, max_sequence_length)).cuda()
422-
y1 = run_mcore_inference(model, prompt_tokens)
423-
424-
mtn.utils.sort_parameters(model)
425-
426-
# check if all num_moe_experts, moe_ffn, moe_shared_ffn, num_attention_heads, hidden_size
427-
# have been sorted
428-
sortable_per_pp = [
429-
n for n, hp in dynamic_space.named_hparams(configurable=True) if hp.importance is not None
430-
]
431-
# (num_moe_experts + 3) hps per layer + 1 for hidden_size (num_layers is not sorted!)
432-
# Per layer: num_attention_heads, num_moe_experts, moe_ffn (per expert), moe_shared_ffn
433-
assert len(sortable_per_pp) == (num_moe_experts + 3) * num_layers // size + 1
434-
435-
# sanity check if the model functionality is preserved after sorting
436-
export_searchspace(model, mtn.get_subnet_config(model))
437-
y2 = run_mcore_inference(model, prompt_tokens)
438-
439-
# check if the inference results after sorting is the same
440-
compare_outputs(y1, y2, rtol=1e-5, atol=1e-3)
441-
442-
443-
def test_gpt_moe_parameter_sorting(need_2_gpus):
444-
set_seed(SEED)
445-
spawn_multiprocess_job(
446-
size=torch.cuda.device_count(),
447-
job=_test_gpt_moe_parameter_sorting,
448-
backend="nccl",
449-
)

tests/gpu/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py

Lines changed: 1 addition & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,13 @@
1616

1717
import torch
1818
from _test_utils.import_helper import skip_if_no_megatron
19-
from _test_utils.torch.misc import compare_outputs
2019

2120
skip_if_no_megatron(apex_or_te_required=True, mamba_required=True)
2221

2322
from _test_utils.torch.distributed.utils import spawn_multiprocess_job
2423
from _test_utils.torch.megatron.models import get_mcore_mamba_hybrid_model
25-
from _test_utils.torch.megatron.utils import (
26-
run_mcore_inference,
27-
run_mcore_inference_with_dummy_input,
28-
)
29-
from _test_utils.torch.misc import set_seed
24+
from _test_utils.torch.megatron.utils import run_mcore_inference
3025
from megatron.core.parallel_state import is_pipeline_first_stage, is_pipeline_last_stage
31-
from megatron.core.transformer.identity_op import IdentityOp
3226

3327
import modelopt.torch.nas as mtn
3428
from modelopt.torch.nas.modules.conv import _DynamicConvNd
@@ -46,7 +40,6 @@
4640
)
4741
from modelopt.torch.nas.traced_hp import TracedHp
4842
from modelopt.torch.opt.utils import named_dynamic_modules, search_space_size
49-
from modelopt.torch.prune.plugins.mcore_minitron import _convert_model_to_dynamic_space
5043
from modelopt.torch.utils.random import centroid
5144

5245
SEED = 1234
@@ -131,73 +124,6 @@ def test_mamba_search_space():
131124
)
132125

133126

134-
def _test_mamba_parameter_sorting(rank, size):
135-
num_layers = size
136-
hybrid_override_pattern = "M" * size
137-
hidden_size = 256
138-
mamba_state_dim = 64
139-
mamba_head_dim = 16
140-
mamba_num_groups = 2
141-
max_sequence_length = 32
142-
vocab_size = 64
143-
batch_size = 2
144-
145-
model = get_mcore_mamba_hybrid_model(
146-
tensor_model_parallel_size=1,
147-
pipeline_model_parallel_size=size,
148-
initialize_megatron=True,
149-
num_layers=num_layers,
150-
hybrid_override_pattern=hybrid_override_pattern,
151-
hidden_size=hidden_size,
152-
mamba_state_dim=mamba_state_dim,
153-
mamba_head_dim=mamba_head_dim,
154-
mamba_num_groups=mamba_num_groups,
155-
max_sequence_length=max_sequence_length,
156-
vocab_size=vocab_size,
157-
bf16=False,
158-
).cuda()
159-
160-
# Randomize norm weights instead of all zeros or ones
161-
for n, m in model.named_modules():
162-
if "norm" in n and not isinstance(m, IdentityOp):
163-
m.weight.data = torch.randn_like(m.weight)
164-
165-
model.eval()
166-
dynamic_space = _convert_model_to_dynamic_space(model)
167-
168-
# Compute activations for sorting
169-
for _ in range(5):
170-
run_mcore_inference_with_dummy_input(model, batch_size)
171-
172-
# Get the output of the original model
173-
prompt_tokens = torch.randint(0, vocab_size, (batch_size, max_sequence_length)).cuda()
174-
y1 = run_mcore_inference(model, prompt_tokens)
175-
176-
mtn.utils.sort_parameters(model)
177-
178-
# check if all mamba_num_heads, mamba_head_dim, hidden_size have been sorted
179-
sortable_per_pp = [
180-
n for n, hp in dynamic_space.named_hparams(configurable=True) if hp.importance is not None
181-
]
182-
# 2 mamba hps per layer + 1 for hidden_size (num_layers is not sorted!)
183-
assert len(sortable_per_pp) == 2 * num_layers // size + 1
184-
185-
# sanity check if the model functionality is preserved after sorting
186-
y2 = run_mcore_inference(model, prompt_tokens)
187-
188-
# check if the inference results after sorting is the same
189-
compare_outputs(y1, y2, rtol=1e-5, atol=1e-3)
190-
191-
192-
def test_mamba_parameter_sorting(need_2_gpus):
193-
set_seed(SEED)
194-
spawn_multiprocess_job(
195-
size=torch.cuda.device_count(),
196-
job=_test_mamba_parameter_sorting,
197-
backend="nccl",
198-
)
199-
200-
201127
def test_mamba_num_heads_hp():
202128
num_heads = MambaNumHeadsHp([2, 4, 6, 8], ngroups=2) # 4 heads per group
203129
assert num_heads.choices == [2, 4, 6, 8]

0 commit comments

Comments
 (0)