Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,23 @@
from nncf.tensor import functions as fns


def process_stats(stats: WCTensorStatistic, subset_size: int, act_ch_axis: int = -1) -> tuple[Tensor, Tensor]:
def process_stats(
stats: WCTensorStatistic,
subset_size: int,
act_ch_axis: int = -1,
transpose_a: bool = False,
) -> tuple[Tensor, Tensor]:
"""
A function for processing activations. Shared between AWQ, Scale Estimation and LoRA Correction algorithms.

:param stats: An object containing statistics for the layer.
:param subset_size: The number of samples for AWQ. If subset_size <= 0, all samples are used.
:param act_ch_axis: The activation channel axis.
:param transpose_a: When True, returns X in [SampleSize, HiddenDim] layout instead of the default
[HiddenDim, SampleSize]. Used by LoRA Correction which requires samples as rows.
:return: tuple of the following tensors:
s - maximum channel magnitude across samples [HiddenDim]
X - average channel magnitude across tokens in the sequence [HiddenDim, min(SampleSize, ~subset_size)]
s - maximum channel magnitude across samples, shape [HiddenDim]
X - activation matrix, shape [HiddenDim, SampleSize] normally or [SampleSize, HiddenDim] if transpose_a=True
"""
X = fns.stack(
stats.mean_values
Expand All @@ -37,8 +44,13 @@ def process_stats(stats: WCTensorStatistic, subset_size: int, act_ch_axis: int =
axes = list(range(1, len(X.shape))) + [0]
X_full = fns.transpose(X, axes=axes)

# The sample dimension is always the last axis after transpose
sample_axis = -1
if transpose_a:
axes = list(range(len(X_full.shape)))
axes[-1], axes[-2] = axes[-2], axes[-1]
X_full = fns.transpose(X_full, axes=axes)

# The sample dimension is axis -1 by default, but moves to -2 if transpose_a is True
sample_axis = -2 if transpose_a else -1

# Prevent high memory and time consumption by sampling
if X_full.shape[sample_axis] > subset_size and subset_size > 0:
Expand All @@ -47,11 +59,13 @@ def process_stats(stats: WCTensorStatistic, subset_size: int, act_ch_axis: int =
]
step = X_full.shape[sample_axis] // subset_size
idxs = [i[0] for i in sorted(enumerate(lens), key=lambda x: -x[1])][::step]
X = X_full[..., idxs]
if transpose_a:
X = X_full[..., idxs, :]
else:
X = X_full[..., idxs]
else:
X = X_full

# Compute max magnitude along the sample axis (last axis)
# Result: [HiddenDim] or [No. of Experts, HiddenDim]
# Compute max magnitude along the sample axis
s = fns.max(fns.abs(X_full), axis=sample_axis)
return s, X
Original file line number Diff line number Diff line change
Expand Up @@ -1181,11 +1181,6 @@ def apply_with_parameters(
)

if self._lora_correction:
for wc_params in all_weight_params:
if self._backend_entity.matmul_has_transposed_activations(wc_params.node_with_weight, graph):
msg = "Transposed activations are not supported yet for the LoRa correction algorithm"
raise nncf.UnsupportedModelError(msg)

lora_correction_params = self._advanced_parameters.lora_correction_params
lora_correction_algo = LoraCorrectionAlgorithm(statistics, lora_correction_params)
description += " with correction of low-rank adapters"
Expand Down Expand Up @@ -1399,7 +1394,7 @@ def _get_statistics_for_weights_compression(
# Where mean_value is a 1D tensor representing an activation reduced over batch and sequence length dimensions,
# shape is an original shape of an activation before reduction, n is the size of the dataset (or subset_size).
statistics = {}
for (act_node, output_port_id, _), matmul_nodes in matmul_input_to_output_nodes_map.items():
for (act_node, output_port_id, _act_channel_axis), matmul_nodes in matmul_input_to_output_nodes_map.items():
tensor_collectors = list(
statistic_points.get_algo_statistics_for_node(
act_node.node_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,19 @@ def is_applicable(self, wc_params: WeightCompressionParameters):
return wc_params.compression_config.num_bits == 4

def calculate_adapters(
self, weight: Tensor, compressed_weight: CompressedWeight, wc_params: WeightCompressionParameters
self,
weight: Tensor,
compressed_weight: CompressedWeight,
wc_params: WeightCompressionParameters,
act_ch_axis: int,
) -> tuple[Tensor, Tensor, list[float]]:
"""
Calculates low rank matrices for a given original and compressed weights.

:param weight: original floating-point weight matrix.
:param compressed_weight: compressed weight matrix.
:param wc_params: parameters of weight compression.
:param act_ch_axis: axis number of the activation tensor which correspond to it channel.
:return: two low rank matrices in the order of execution of corresponding linear layers.
"""
layer_name = wc_params.node_with_weight.node_name
Expand All @@ -126,6 +131,7 @@ def calculate_adapters(
wc_params.reduction_axes,
self._lora_correction_params,
layer_statistics,
act_ch_axis,
is_debug,
)
if is_debug:
Expand All @@ -140,6 +146,7 @@ def calculate_low_rank_matrices(
reduction_axes: tuple[int, ...],
lora_correction_params: AdvancedLoraCorrectionParameters,
layer_statistics: WCTensorStatistic,
act_ch_axis: int,
is_debug: bool | None = False,
):
"""
Expand All @@ -155,6 +162,7 @@ def calculate_low_rank_matrices(
:param reduction_axes: axes along which different statistics reduced.
:param lora_correction_params: parameters to configure the algorithm.
:param layer_statistics: an object containing statistics for the layer.
:param act_ch_axis: axis number of the activation tensor which correspond to it channel.
:param is_debug: whether to collect debug information, defaults to False.
:return: two low rank matrices in the order of execution of corresponding linear layers and list of mean noises.
Noises are collected from each step of the algorithm if debug was enabled.
Expand All @@ -168,7 +176,12 @@ def calculate_low_rank_matrices(
)
mode = compression_config.mode
assert len(reduction_axes) == 1, "Assumed a single reduction axis"
reduction_axis = reduction_axes[0] if compression_config.group_size != -1 else -1

if compression_config.group_size != -1:
reduction_axis = reduction_axes[0]
else:
reduction_axis = -1

if mode in (CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM):
fq_weights = do_integer_dequantization(
compressed_weight,
Expand All @@ -190,8 +203,8 @@ def calculate_low_rank_matrices(
svd_residual = fns.transpose(svd_residual)
residual = svd_residual.clone() # [H, O]

s, X = process_stats(layer_statistics, subset_size) # [H], [H, SS]
X = fns.transpose(X) # [SS, H]
# Pass it to process_stats with transpose_a=True to get [SS, H] layout
s, X = process_stats(layer_statistics, subset_size, act_ch_axis, transpose_a=True)
if compression_config.group_size > 0:
# Multiply residual of weights by maximum channel magnitude of activations normalized per quantization
# group. As a consequence, weights corresponding to a "noisy" activations has a higher error to correct.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,8 @@ def insert_adapters(
A_W = opset.constant(lora_A.data)
B_W = opset.constant(lora_B.data)

A_MM = opset.matmul(input_node, A_W, transpose_a=False, transpose_b=True)
transpose_a = wc_params.node_with_weight.layer_attributes.input_attributes["transpose"]
A_MM = opset.matmul(input_node, A_W, transpose_a=transpose_a, transpose_b=True)
B_MM = opset.matmul(A_MM, B_W, transpose_a=False, transpose_b=True)

node_output_port = mm_node.output(0)
Expand Down Expand Up @@ -361,7 +362,15 @@ def transform_model(
compressed_weight.tensor = compressed_weight.tensor.as_numpy_tensor()
if compressed_weight.zero_point is not None:
compressed_weight.zero_point = compressed_weight.zero_point.as_numpy_tensor()
adapters = lora_correction_algo.calculate_adapters(weight, compressed_weight, wc_params)

activation_port_id = self.get_activation_port_id(wc_params.node_with_weight, graph)
activation_edge = graph.get_input_edge_by_port_id(wc_params.node_with_weight, activation_port_id)
activation_shape = activation_edge.tensor_shape
act_ch_axis = self.get_activation_channel_axis(
wc_params.node_with_weight, activation_port_id, activation_shape
)

adapters = lora_correction_algo.calculate_adapters(weight, compressed_weight, wc_params, act_ch_axis)
self.insert_adapters(wc_params, *adapters, int8_lora=lora_correction_algo.use_int8_adapters)
self.name_to_node_mapping = None

Expand Down
Loading