Commit 4eb1835
authored
Support for KV cache quantization for MLA Attention vLLM fakequant (#714)
## What does this PR do?
**Type of change:** Feature extention
**Overview:**
Added support to quantize KV cache in vLLM fakequant by adding
quantization support for
[MLAAttention](https://github.com/vllm-project/vllm/blob/v0.11.1/vllm/attention/layer.py#L641)
## Usage
Please refer to
[Readme](https://github.com/NVIDIA/Model-Optimizer/tree/kinjal/vllm_att_quant/examples/vllm_serve#calibrate-and-serve-fake-quant-model-in-vllm)
```shell
KV_QUANT_CFG=NVFP4_KV_CFG QUANT_CFG=NVFP4_DEFAULT_CFG python vllm_serve_fakequant.py deepseek-ai/DeepSeek-V2 --served-model-name deepseek-ai/DeepSeek-V2 --host 0.0.0.0 --port 8001 --trust-remote-code --enforce-eager --gpu-memory-utilization 0.8
```
## Testing
Locally tested KV Cache quantization
```
�(rotary_emb): DeepseekScalingRotaryEmbedding()
�(mla_attn): MultiHeadLatentAttentionWrapper(
� (fused_qkv_a_proj): QuantMergedColumnParallelLinear(
� in_features=5120, output_features=2112, bias=False, tp_size=1, gather_output=False
� (input_quantizer): TensorQuantizer((2, 1) bit fake block_sizes={-1: 16, 'type': 'dynamic', 'scale_bits': (4, 3)}, amax=141.0000 calibrator=MaxCalibrator quant)
� (weight_quantizer): TensorQuantizer((2, 1) bit fake block_sizes={-1: 16, 'type': 'dynamic', 'scale_bits': (4, 3)}, amax=1.4297 calibrator=MaxCalibrator quant)
� (output_quantizer): TensorQuantizer(disabled)
� )
� (q_a_layernorm): RMSNorm(hidden_size=1536, eps=1e-06)
� (q_b_proj): QuantColumnParallelLinear(
� in_features=1536, output_features=3072, bias=False, tp_size=8, gather_output=False
� (input_quantizer): TensorQuantizer((2, 1) bit fake block_sizes={-1: 16, 'type': 'dynamic', 'scale_bits': (4, 3)}, amax=32.0000 calibrator=MaxCalibrator quant)
� (weight_quantizer): TensorQuantizer((2, 1) bit fake block_sizes={-1: 16, 'type': 'dynamic', 'scale_bits': (4, 3)}, amax=0.1670 calibrator=MaxCalibrator quant)
� (output_quantizer): TensorQuantizer(disabled)
� )
� (kv_a_layernorm): RMSNorm(hidden_size=512, eps=1e-06)
� (kv_b_proj): QuantColumnParallelLinear(
� in_features=512, output_features=4096, bias=False, tp_size=8, gather_output=False
� (input_quantizer): TensorQuantizer((2, 1) bit fake block_sizes={-1: 16, 'type': 'dynamic', 'scale_bits': (4, 3)}, amax=7.5312 calibrator=MaxCalibrator quant)
� (weight_quantizer): TensorQuantizer((2, 1) bit fake block_sizes={-1: 16, 'type': 'dynamic', 'scale_bits': (4, 3)}, amax=0.2773 calibrator=MaxCalibrator quant)
� (output_quantizer): TensorQuantizer(disabled)
� )
� (rotary_emb): DeepseekScalingRotaryEmbedding()
� (o_proj): QuantRowParallelLinear(
� in_features=2048, output_features=5120, bias=False, tp_size=8, reduce_results=True
� (input_quantizer): TensorQuantizer((2, 1) bit fake block_sizes={-1: 16, 'type': 'dynamic', 'scale_bits': (4, 3)}, amax=1.7188 calibrator=MaxCalibrator quant)
� (weight_quantizer): TensorQuantizer((2, 1) bit fake block_sizes={-1: 16, 'type': 'dynamic', 'scale_bits': (4, 3)}, amax=0.4336 calibrator=MaxCalibrator quant)
� (output_quantizer): TensorQuantizer(disabled)
� )
� (mla_attn): QuantMLAAttention(
� (q_bmm_quantizer): TensorQuantizer(disabled)
� (kv_c_bmm_quantizer): TensorQuantizer((2, 1) bit fake block_sizes={-1: 16, 'type': 'dynamic', 'scale_bits': (4, 3)}, amax=7.5312 calibrator=MaxCalibrator quant)
� )
�)
```
## Before your PR is "*Ready for review*"
<!-- If you haven't finished some of the above items you can still open
`Draft` PR. -->
- **Make sure you read and follow [Contributor
guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)**
and your commits are signed.
- **Is this change backward compatible?**: Yes
- **Did you write any new necessary tests?**:No
- **Did you add or update any necessary documentation?**:NA
- **Did you update
[Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**:
NA
## Additional Information
<!-- E.g. related issue. -->
---------
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>1 parent 8426c36 commit 4eb1835
3 files changed
Lines changed: 67 additions & 8 deletions
File tree
- examples/vllm_serve
- modelopt/torch/quantization/plugins
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
149 | 149 | | |
150 | 150 | | |
151 | 151 | | |
152 | | - | |
| 152 | + | |
153 | 153 | | |
154 | 154 | | |
155 | 155 | | |
156 | 156 | | |
157 | 157 | | |
| 158 | + | |
| 159 | + | |
| 160 | + | |
| 161 | + | |
| 162 | + | |
| 163 | + | |
| 164 | + | |
| 165 | + | |
| 166 | + | |
| 167 | + | |
| 168 | + | |
| 169 | + | |
| 170 | + | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
| 174 | + | |
| 175 | + | |
| 176 | + | |
| 177 | + | |
| 178 | + | |
| 179 | + | |
| 180 | + | |
158 | 181 | | |
159 | 182 | | |
160 | 183 | | |
| |||
236 | 259 | | |
237 | 260 | | |
238 | 261 | | |
239 | | - | |
240 | | - | |
241 | | - | |
242 | | - | |
243 | | - | |
| 262 | + | |
| 263 | + | |
| 264 | + | |
| 265 | + | |
244 | 266 | | |
245 | 267 | | |
246 | 268 | | |
247 | 269 | | |
248 | 270 | | |
| 271 | + | |
| 272 | + | |
| 273 | + | |
| 274 | + | |
| 275 | + | |
| 276 | + | |
| 277 | + | |
| 278 | + | |
| 279 | + | |
249 | 280 | | |
250 | 281 | | |
251 | 282 | | |
| |||
314 | 345 | | |
315 | 346 | | |
316 | 347 | | |
317 | | - | |
| 348 | + | |
318 | 349 | | |
319 | 350 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
70 | 70 | | |
71 | 71 | | |
72 | 72 | | |
73 | | - | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
74 | 80 | | |
75 | 81 | | |
76 | 82 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
40 | 40 | | |
41 | 41 | | |
42 | 42 | | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
43 | 48 | | |
44 | 49 | | |
45 | 50 | | |
| |||
281 | 286 | | |
282 | 287 | | |
283 | 288 | | |
| 289 | + | |
| 290 | + | |
| 291 | + | |
| 292 | + | |
| 293 | + | |
| 294 | + | |
| 295 | + | |
| 296 | + | |
| 297 | + | |
| 298 | + | |
| 299 | + | |
| 300 | + | |
| 301 | + | |
| 302 | + | |
| 303 | + | |
| 304 | + | |
| 305 | + | |
0 commit comments