Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion tests/pytorch/debug/test_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
"underflows%",
"scale_inv_min",
"scale_inv_max",
"scale_inv_std",
"mse",
]

Expand Down Expand Up @@ -248,6 +249,10 @@ def test_log_quantized_stats_numerics(fp8_recipe, feature_dirs):
debug_api.step()

dequantized_tensor = quantized_tensor.dequantize()
if hasattr(quantized_tensor, "_scale_inv"):
scale_inv_rowwise = quantized_tensor._scale_inv.float()
else:
scale_inv_rowwise = quantized_tensor._rowwise_scale_inv.float()
output = read_log(log_dir)

for line in output.splitlines():
Expand All @@ -267,6 +272,17 @@ def test_log_quantized_stats_numerics(fp8_recipe, feature_dirs):
(abs(dequantized_tensor) > abs(tensor)).sum() / dequantized_tensor.numel() * 100
)
assert overflows == pytest.approx(expected.cpu(), abs=1e-4)
# Rowwise scale_inv stats only; logger formats with {:.4f} so abs<1e-4.
if "scale_inv_min" in line and "_columnwise" not in line:
value = float(line.split("value=")[1])
assert value == pytest.approx(scale_inv_rowwise.min().cpu().item(), abs=1e-4)
if "scale_inv_max" in line and "_columnwise" not in line:
value = float(line.split("value=")[1])
assert value == pytest.approx(scale_inv_rowwise.max().cpu().item(), abs=1e-4)
if "scale_inv_std" in line and "_columnwise" not in line:
value = float(line.split("value=")[1])
expected = torch.std(scale_inv_rowwise, unbiased=False).cpu().item()
assert value == pytest.approx(expected, abs=1e-4)


LOG_HIGH_PRECISION_CONFIG = """
Expand Down Expand Up @@ -403,7 +419,8 @@ def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs):

with open(
os.path.join(
temp_dir, "nvdlfw_inspect_statistics_logs/nvdlfw_inspect_globalrank-0.log"
temp_dir,
"nvdlfw_inspect_statistics_logs/nvdlfw_inspect_globalrank-0.log",
),
"r",
) as f:
Expand Down
82 changes: 63 additions & 19 deletions transformer_engine/debug/features/log_fp8_tensor_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,26 @@

import torch
import nvdlfw_inspect.api as debug_api
from nvdlfw_inspect.debug_features.log_tensor_stats import LogTensorStats as BaseLogTensorStats
from nvdlfw_inspect.debug_features.log_tensor_stats import (
LogTensorStats as BaseLogTensorStats,
)
from nvdlfw_inspect.registry import Registry, api_method
import transformer_engine_torch as tex

from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS
from transformer_engine.debug.features.utils import get_reduction_params, next_enabled_iter
from transformer_engine.debug.features.utils import (
get_reduction_params,
next_enabled_iter,
)
from transformer_engine.pytorch.tensor import Quantizer, QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer,
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockQuantizer
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
Float8BlockQuantizer,
)

try:
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
Expand All @@ -33,7 +40,12 @@
NVFP4Quantizer = None


ALL_RECIPE_NAMES = ["fp8_delayed_scaling", "fp8_current_scaling", "mxfp8", "fp8_block_scaling"]
ALL_RECIPE_NAMES = [
"fp8_delayed_scaling",
"fp8_current_scaling",
"mxfp8",
"fp8_block_scaling",
]


def _get_recipe_name(quantizer: Optional[Quantizer]):
Expand All @@ -57,7 +69,10 @@ def _get_new_quantizer(recipe_name, fp8_dtype):
return Float8BlockQuantizer(fp8_dtype=fp8_dtype, rowwise=True, columnwise=True)
if recipe_name == "fp8_current_scaling":
return Float8CurrentScalingQuantizer(
fp8_dtype=fp8_dtype, device=torch.device("cuda"), rowwise=True, columnwise=True
fp8_dtype=fp8_dtype,
device=torch.device("cuda"),
rowwise=True,
columnwise=True,
)
if recipe_name == "mxfp8":
return MXFP8Quantizer(fp8_dtype=fp8_dtype, rowwise=True, columnwise=True)
Expand Down Expand Up @@ -119,10 +134,13 @@ class LogFp8TensorStats(BaseLogTensorStats):
- overflows% - percentage of elements of tensor that were clipped to the max/min value of the FP8 range - supported only for fp8_delayed_scaling,
- scale_inv_min - minimum of the inverse of the scaling factors,
- scale_inv_max - maximum of the inverse of the scaling factors,
- scale_inv_std - population standard deviation of the inverse of the scaling factors;
useful for spotting clipping that min/max alone can miss (degenerate to 0 for
fp8_delayed_scaling / fp8_current_scaling since those use a single scalar scale).
- mse - mean squared error of the quantized tensor and the original tensor = sum((quantized_tensor - original_tensor)**2) / num_elements,

When collecting stats for the weight tensor with FP8 model parameters enabled,
only "scale_inv_min" and "scale_inv_max" are available.
only "scale_inv_min", "scale_inv_max" and "scale_inv_std" are available.
All other statistics require access to the high precision tensor.

tensors/tensors_struct: List[str]
Expand Down Expand Up @@ -191,15 +209,8 @@ def check_if_stat_is_supported(
if recipe_from_stat != "" and recipe_from_stat not in ALL_RECIPE_NAMES:
raise ValueError(f"Stat {stat} contains an unsupported recipe name: {recipe_from_stat}")

# Block any NVFP4 stats in LogFp8TensorStats (FP8-specific logic won't work)
# But allow recipe-prefixed FP8 stats like "mxfp8_underflows%" even with NVFP4 quantizer
if recipe_from_stat == "nvfp4":
raise ValueError(
f"[NVTORCH INSPECT ERROR] Cannot compute NVFP4 stats '{stat}' in LogFp8TensorStats."
" FP8-specific statistics do not work with NVFP4. Use LogNvfp4TensorStats for"
" NVFP4-specific stats, or use FP8 recipe-prefixed stats (e.g.,"
" 'mxfp8_underflows%', 'fp8_block_scaling_mse') for what-if FP8 comparisons."
)
# NVFP4-resolved stats are filtered out before this point in inspect_tensor().
assert recipe_from_stat != "nvfp4"
Comment on lines +212 to +213
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Replacing a user-facing raise ValueError with a bare assert weakens the defensive guard. Python's -O flag silently disables all assert statements, so if this path is ever reached in an optimised build, execution would continue silently and produce a confusing failure deep in the quantization path instead of a clear error message.

Suggested change
# NVFP4-resolved stats are filtered out before this point in inspect_tensor().
assert recipe_from_stat != "nvfp4"
# NVFP4-resolved stats are filtered out before this point in inspect_tensor().
if recipe_from_stat == "nvfp4":
raise ValueError(
f"[NVTORCH INSPECT ERROR] Cannot compute NVFP4 stats '{stat}' in "
"LogFp8TensorStats. This is an internal error: bare NVFP4 stats should "
"have been filtered in inspect_tensor() before reaching this point."
)


if recipe_from_stat in ["fp8_delayed_scaling", "fp8_current_scaling"] and columnwise:
raise ValueError(
Expand All @@ -216,7 +227,13 @@ def check_if_stat_is_supported(
if recipe_from_stat == "mxfp8" and torch.cuda.get_device_capability()[0] < 10:
raise ValueError(f"Stat {stat} needs Blackwell or later GPU.")

supported_stats = ["underflows%", "scale_inv_min", "scale_inv_max", "mse"]
supported_stats = [
"underflows%",
"scale_inv_min",
"scale_inv_max",
"scale_inv_std",
"mse",
]
if stat_without_recipe not in supported_stats:
raise ValueError(
f"Stat {stat} contains an unsupported stat name: {stat_without_recipe}"
Expand Down Expand Up @@ -252,9 +269,14 @@ def update_aux_dict(
Needs to clean after usage, because it possibly change the usage of the quantized tensor.
"""
fp8_dtype = tex.DType.kFloat8E4M3
if recipe_name in ["fp8_delayed_scaling", "fp8_current_scaling", "fp8_block_scaling"]:
if recipe_name in [
"fp8_delayed_scaling",
"fp8_current_scaling",
"fp8_block_scaling",
]:
assert isinstance(
quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer, Float8BlockQuantizer)
quantizer,
(Float8Quantizer, Float8CurrentScalingQuantizer, Float8BlockQuantizer),
)
fp8_dtype = quantizer.dtype

Expand All @@ -280,7 +302,8 @@ def update_aux_dict(
finally:
if isinstance(quantized_tensor, QuantizedTensor):
quantized_tensor.update_usage(
rowwise_usage=old_rowwise_usage, columnwise_usage=old_columnwise_usage
rowwise_usage=old_rowwise_usage,
columnwise_usage=old_columnwise_usage,
)

@api_method
Expand Down Expand Up @@ -338,6 +361,27 @@ def inspect_tensor(

recipe_name = _get_recipe_name(quantizer)

# If the layer uses NVFP4, drop bare stats (which would target the NVFP4
# recipe that LogFp8TensorStats can't handle) but keep stats explicitly
# prefixed with an FP8 recipe (e.g. "mxfp8_mse") for what-if FP8 comparison.
if _nvfp4_available and isinstance(quantizer, NVFP4Quantizer):
kept_stats, dropped_stats = [], []
for stat in config["stats"]:
if any(r in stat for r in ALL_RECIPE_NAMES):
kept_stats.append(stat)
else:
dropped_stats.append(stat)
if dropped_stats:
warnings.warn(
f"[LogFp8TensorStats] Skipping stats {dropped_stats} for layer "
f"'{layer_name}', tensor '{tensor_name}': layer uses NVFP4. Use "
"LogNvfp4TensorStats for NVFP4 stats, or prefix stats with an FP8 "
"recipe name (e.g. 'mxfp8_mse') for what-if FP8 comparisons."
)
if not kept_stats:
return
config = {**config, "stats": kept_stats}

for stat in config["stats"]:
self.check_if_stat_is_supported(
stat, recipe_name, high_precision_tensor_provided=tensor is not None
Expand Down
24 changes: 20 additions & 4 deletions transformer_engine/debug/features/log_nvfp4_tensor_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,21 @@
import torch
import nvdlfw_inspect.api as debug_api

from nvdlfw_inspect.debug_features.log_tensor_stats import LogTensorStats as BaseLogTensorStats
from nvdlfw_inspect.debug_features.log_tensor_stats import (
LogTensorStats as BaseLogTensorStats,
)
from nvdlfw_inspect.registry import Registry, api_method

from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS
from transformer_engine.pytorch.tensor import Quantizer, QuantizedTensor
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
from transformer_engine.debug.features.utils import get_reduction_params, next_enabled_iter
from transformer_engine.pytorch.tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage
from transformer_engine.debug.features.utils import (
get_reduction_params,
next_enabled_iter,
)
from transformer_engine.pytorch.tensor.storage.nvfp4_tensor_storage import (
NVFP4TensorStorage,
)


@Registry.register_feature(namespace="transformer_engine")
Expand All @@ -45,6 +52,10 @@ class LogNvfp4TensorStats(BaseLogTensorStats):
List of statistics to collect. Available stats:
- underflows% - percentage of non-zero elements clipped to 0 (from packed FP4 data)
- mse - mean squared error = sum((quantized_tensor - original_tensor)**2) / num_elements
- scale_inv_min - minimum of the inverse of the scaling factors
- scale_inv_max - maximum of the inverse of the scaling factors
- scale_inv_std - population standard deviation of the inverse of the scaling factors;
useful for spotting clipping that min/max alone can miss

tensors/tensors_struct: List[str]
list of tensors to log
Expand Down Expand Up @@ -85,13 +96,18 @@ class LogNvfp4TensorStats(BaseLogTensorStats):

def check_if_stat_is_supported(self, stat: str):
"""Returns True if stat is supported, raises ValueError otherwise."""
bare = stat[: -len("_columnwise")] if stat.endswith("_columnwise") else stat
supported_stats = [
"underflows%",
"mse",
"scale_inv_min",
"scale_inv_max",
"scale_inv_std",
]
if stat not in supported_stats:
if bare not in supported_stats:
raise ValueError(
f"Stat {stat} is not supported for NVFP4. Supported stats: {supported_stats}"
" (any of these may take an optional '_columnwise' suffix)"
)
return True

Expand Down
Loading
Loading