Skip to content

Commit 1619421

Browse files
Added support for MoE for vllm >= 0.14.0rc1 (#1162)
### What does this PR do? Type of change: Bug fix `_QuantFusedMoEBase.forward()` previously replaced `vllm_fused_moe_package.invoke_fused_moe_kernel`, which was replaced starting in vLLM v0.14.0rc1, There are two paths for FusedMoE forward: ``` Path 1 (Modular — standard CUDA path): FusedMoE.forward() → self.runner.forward() → TritonExperts.apply() → invoke_fused_moe_triton_kernel() ← called twice (w1, w2) Path 2 (legacy): inplace_fused_experts / outplace_fused_experts → fused_experts_impl() → dispatch_fused_moe_kernel() → invoke_fused_moe_triton_kernel() or invoke_fused_moe_wna16_triton_kernel() or invoke_fused_moe_wna16_cuda_kernel() ``` This caused an `AttributeError` / assertion failure for any MoE model quantized with vLLM ≥ v0.14.0rc1. The fix refactors the kernel-patching logic into a `_patch_moe_kernel()` context manager that probes for both attribute names (the two names are mutually exclusive across vLLM versions — confirmed by inspecting every release from v0.10.0 to v0.19.1). ### Usage NA ### Testing ``` docker run --gpus all -it --shm-size=160GB --network host --rm -v <modelopt path>:/home/modelopt \ vllm/vllm-openai:v0.15.0 bash -c "cd /home/modelopt && pip install . && pip install datasets && \ QUANT_CFG=NVFP4_DEFAULT_CFG python3 /home/modelopt/examples/vllm_serve/vllm_serve_fakequant.py \ nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 -tp 1 --served-model-name NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 \ --host 0.0.0.0 --port 8001 --trust-remote-code --disable-custom-all-reduce \ --gpu-memory-utilization 0.8" ``` ### Before your PR is "*Ready for review*" Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md) and your commits are signed (`git commit -s -S`). Make sure you read and follow the [Security Best Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors) (e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(..., weights_only=False)`, `pickle`, etc.). - Is this change backward compatible?: ✅ - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: N/A - Did you write any new necessary tests?: N/A - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: N/A ### Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactor** * Ensures quantized expert weights are correctly used by the fused-MoE execution path so inference uses the intended quantized tensors. * Replaces fragile manual swapping of the runtime kernel with a safer, context-managed swap that reliably caches and restores the original. * Adds runtime detection and selection among available fused-MoE kernel entrypoints to support multiple variants. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
1 parent 3131195 commit 1619421

4 files changed

Lines changed: 100 additions & 28 deletions

File tree

examples/vllm_serve/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ This is a simple example to demonstrate calibrating and serving ModelOpt fakequa
44

55
Compared with realquant, fakequant is 2-5x slower, but doesn't require dedicated kernel support and facilitates research.
66

7-
This example is tested with vllm 0.9.0 and 0.11.2
7+
This example is tested with vllm 0.9.0 and 0.19.1
88

99
## Prepare environment
1010

examples/vllm_serve/fakequant_worker.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,12 +134,15 @@ def determine_available_memory(self) -> int:
134134
with disable_compilation(model):
135135
return super().determine_available_memory()
136136

137-
def compile_or_warm_up_model(self) -> None:
137+
def compile_or_warm_up_model(self) -> float:
138138
if (
139139
quant_config["quant_cfg"]
140140
or quant_config["kv_quant_cfg"]
141141
or quant_config["modelopt_state_path"]
142142
or quant_config["recipe_path"]
143143
):
144144
_fakequant_run_prolog_worker(self)
145-
super().compile_or_warm_up_model()
145+
# Must return the base worker's compilation time (seconds). Returning None
146+
# breaks vLLM V1 executor: initialize_from_config does max(compilation_times)
147+
# across TP workers.
148+
return super().compile_or_warm_up_model()

examples/vllm_serve/vllm_serve_fakequant.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,12 @@
6262

6363
vllm_version = version.parse(vllm.__version__)
6464
if vllm_version <= version.parse("0.11.0"):
65-
from vllm.executor.ray_distributed_executor import RayDistributedExecutor
6665
from vllm.utils import FlexibleArgumentParser
6766
else:
6867
from vllm.utils.argparse_utils import FlexibleArgumentParser
69-
from vllm.v1.executor.ray_executor import RayDistributedExecutor
7068

7169

72-
# Adding the envs you want to pass to the workers
70+
# Env vars to copy from the driver to Ray workers (must match fakequant_worker / vllm_ptq_utils).
7371
additional_env_vars = {
7472
"QUANT_DATASET",
7573
"QUANT_CALIB_SIZE",
@@ -82,7 +80,17 @@
8280
"TRUST_REMOTE_CODE",
8381
}
8482

85-
RayDistributedExecutor.ADDITIONAL_ENV_VARS.update(additional_env_vars)
83+
try:
84+
from vllm.executor.ray_distributed_executor import RayDistributedExecutor
85+
86+
RayDistributedExecutor.ADDITIONAL_ENV_VARS.update(additional_env_vars)
87+
except (ImportError, AttributeError):
88+
# vLLM v1 Ray: vllm/ray/ray_env.py (get_env_vars_to_copy); merge with any user-set list.
89+
extra_env_var = "VLLM_RAY_EXTRA_ENV_VARS_TO_COPY"
90+
merged_env_vars = {
91+
t.strip() for t in os.environ.get(extra_env_var, "").split(",") if t.strip()
92+
} | additional_env_vars
93+
os.environ[extra_env_var] = ",".join(sorted(merged_env_vars))
8694

8795

8896
def main():

modelopt/torch/quantization/plugins/vllm.py

Lines changed: 82 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,11 @@
1515

1616
"""Support quantization for VLLM layers."""
1717

18+
import contextvars
1819
import importlib
20+
from collections.abc import Callable
1921
from contextlib import contextmanager
22+
from functools import partial
2023
from itertools import chain
2124

2225
import torch
@@ -85,6 +88,21 @@
8588
)
8689

8790
vllm_fused_moe_package = importlib.import_module("vllm.model_executor.layers.fused_moe.fused_moe")
91+
# vLLM may call one entry (e.g. ``dispatch_fused_moe_kernel``) which then calls another on the same
92+
# module (e.g. ``invoke_fused_moe_triton_kernel``). Patching every name would otherwise apply fakequant
93+
# twice; see ``_moe_fakequant_active`` in ``invoke_fused_moe_quantized``.
94+
_FUSED_MOE_KERNEL_CANDIDATES = (
95+
"invoke_fused_moe_kernel",
96+
"invoke_fused_moe_triton_kernel",
97+
"dispatch_fused_moe_kernel",
98+
)
99+
_FUSED_MOE_KERNEL_FUNCS = tuple(
100+
n for n in _FUSED_MOE_KERNEL_CANDIDATES if hasattr(vllm_fused_moe_package, n)
101+
)
102+
103+
_moe_fakequant_active: contextvars.ContextVar[bool] = contextvars.ContextVar(
104+
"moe_fakequant_active", default=False
105+
)
88106

89107

90108
@contextmanager
@@ -340,29 +358,64 @@ def invoke_fused_moe_quantized(
340358
B: torch.Tensor, # noqa: N803
341359
C: torch.Tensor, # noqa: N803
342360
*args,
361+
original_kernel: Callable,
362+
**kwargs,
363+
):
364+
# Nested module-level entry (e.g. dispatch -> triton): call the real kernel once, no second quant.
365+
if _moe_fakequant_active.get():
366+
return original_kernel(A, B, C, *args, **kwargs)
367+
token = _moe_fakequant_active.set(True)
368+
try:
369+
return self._invoke_fused_moe_quantized_function(
370+
A, B, C, *args, original_kernel=original_kernel, **kwargs
371+
)
372+
finally:
373+
_moe_fakequant_active.reset(token)
374+
375+
def _invoke_fused_moe_quantized_function(
376+
self,
377+
A: torch.Tensor, # noqa: N803
378+
B: torch.Tensor, # noqa: N803
379+
C: torch.Tensor, # noqa: N803
380+
*args,
381+
original_kernel: Callable,
343382
**kwargs,
344383
):
345384
if B is self.w13_weight:
346385
# First layer of expert
347386
A = self.w13_input_quantizer(A) # noqa: N806
348-
if self.w13_weight_quantizer.is_enabled:
349-
original_weight = self.w13_weight
350-
self.w13_weight = self.w13_weight_quantizer(self.w13_weight)
351-
vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
352-
self.w13_weight = original_weight
387+
if self.w13_weight_quantizer.is_enabled: # pragma: no cover
388+
original_weight, self.w13_weight = (
389+
self.w13_weight,
390+
self.w13_weight_quantizer(self.w13_weight),
391+
)
392+
# In case the weight quantizer isn't folded yet in vllm_serve_fakequant, pass the
393+
# quantized weight to the kernel.
394+
B = self.w13_weight # noqa: N806
395+
try:
396+
original_kernel(A, B, C, *args, **kwargs)
397+
finally:
398+
self.w13_weight = original_weight
353399
else:
354-
vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
400+
original_kernel(A, B, C, *args, **kwargs)
355401
if self.w13_output_quantizer.is_enabled:
356402
C[:] = self.w13_output_quantizer(C)
357403
elif B is self.w2_weight:
358404
A = self.w2_input_quantizer(A) # noqa: N806
359-
if self.w2_weight_quantizer.is_enabled:
360-
original_weight = self.w2_weight
361-
self.w2_weight = self.w2_weight_quantizer(self.w2_weight)
362-
vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
363-
self.w2_weight = original_weight
405+
if self.w2_weight_quantizer.is_enabled: # pragma: no cover
406+
original_weight, self.w2_weight = (
407+
self.w2_weight,
408+
self.w2_weight_quantizer(self.w2_weight),
409+
)
410+
# In case the weight quantizer isn't folded yet in vllm_serve_fakequant, pass the
411+
# quantized weight to the kernel.
412+
B = self.w2_weight # noqa: N806
413+
try:
414+
original_kernel(A, B, C, *args, **kwargs)
415+
finally:
416+
self.w2_weight = original_weight
364417
else:
365-
vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
418+
original_kernel(A, B, C, *args, **kwargs)
366419
if self.w2_output_quantizer.is_enabled:
367420
C[:] = self.w2_output_quantizer(C)
368421
else:
@@ -372,24 +425,31 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
372425
# This is again due to the bad coding of vLLM
373426
# fused_moe submodule is overwritten by the fused_moe function
374427
# so we need to import the fused_moe module explicitly
375-
assert vllm_fused_moe_package.invoke_fused_moe_kernel is not None
428+
assert _FUSED_MOE_KERNEL_FUNCS and all(
429+
getattr(vllm_fused_moe_package, n, None) is not None for n in _FUSED_MOE_KERNEL_FUNCS
430+
)
376431
# This context manager will conflict with torch.compile
377432
# with replace_function(
378433
# vllm_fused_moe_package,
379434
# "invoke_fused_moe_kernel",
380435
# self.invoke_fused_moe_quantized,
381436
# ):
437+
originals = {n: getattr(vllm_fused_moe_package, n) for n in _FUSED_MOE_KERNEL_FUNCS}
382438
try:
383-
vllm_fused_moe_package._invoke_fused_moe_kernel = ( # type: ignore[attr-defined]
384-
vllm_fused_moe_package.invoke_fused_moe_kernel
385-
)
386-
vllm_fused_moe_package.invoke_fused_moe_kernel = self.invoke_fused_moe_quantized # type: ignore[attr-defined]
439+
for n in _FUSED_MOE_KERNEL_FUNCS:
440+
setattr(
441+
vllm_fused_moe_package,
442+
n,
443+
partial(
444+
self.invoke_fused_moe_quantized,
445+
original_kernel=originals[n],
446+
),
447+
)
387448
output = super().forward(hidden_states, router_logits)
388449
return output
389450
finally:
390-
vllm_fused_moe_package.invoke_fused_moe_kernel = ( # type: ignore[attr-defined]
391-
vllm_fused_moe_package._invoke_fused_moe_kernel
392-
)
451+
for n in _FUSED_MOE_KERNEL_FUNCS:
452+
setattr(vllm_fused_moe_package, n, originals[n])
393453

394454
@torch.no_grad()
395455
def fold_weight(self, keep_attrs: bool = False):
@@ -409,7 +469,8 @@ def fold_weight(self, keep_attrs: bool = False):
409469
)
410470
self.w2_weight_quantizer.disable()
411471

412-
torch.cuda.empty_cache()
472+
if torch.cuda.is_available():
473+
torch.cuda.empty_cache()
413474

414475

415476
@QuantModuleRegistry.register({vllm_fused_moe_layer.FusedMoE: "vllm_FusedMoE"})

0 commit comments

Comments
 (0)