[feat] Add ada cache for flux2 ppt#1154
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces adaptive feature caching (AdaCache) support for Flux2 models, including new configuration files, shell scripts, and the implementation of caching-enabled transformer inference classes and schedulers. The review feedback highlights several critical and medium-severity issues: removing redundant and error-prone motion regulation (moreg) logic that causes division-by-zero errors in 2D image models, eliminating inefficient GPU-to-CPU tensor transfers before clearing references, reducing code duplication in the caching inference logic, handling null values for feature_caching to prevent unexpected NotImplementedError exceptions, and defaulting the path variable in shell scripts to prevent execution failures.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| def _update_spatial_dim(self, ada_args, residual): | ||
| if ada_args.spatial_dim <= 0: | ||
| ada_args.spatial_dim = residual.shape[0] | ||
|
|
||
| def _calculate_skip_step_length_for_args(self, ada_args): | ||
| if ada_args.previous_residual_tiny is None: | ||
| ada_args.previous_residual_tiny = ada_args.now_residual_tiny | ||
| return 1 | ||
|
|
||
| cache = ada_args.previous_residual_tiny | ||
| res = ada_args.now_residual_tiny | ||
| self._update_spatial_dim(ada_args, res) | ||
| norm_ord = ada_args.norm_ord | ||
| cache_diff = (cache - res).norm(dim=(0, 1), p=norm_ord) / cache.norm(dim=(0, 1), p=norm_ord) | ||
| cache_diff = cache_diff / ada_args.skipped_step_length | ||
|
|
||
| if ada_args.moreg_steps[0] <= self.scheduler.step_index <= ada_args.moreg_steps[1]: | ||
| moreg = 0 | ||
| for i in ada_args.moreg_strides: | ||
| moreg_i = (res[i * ada_args.spatial_dim :, :] - res[: -i * ada_args.spatial_dim, :]).norm(p=norm_ord) | ||
| moreg_i /= res[i * ada_args.spatial_dim :, :].norm(p=norm_ord) + res[: -i * ada_args.spatial_dim, :].norm(p=norm_ord) | ||
| moreg += moreg_i | ||
| moreg = moreg / len(ada_args.moreg_strides) | ||
| moreg = ((1 / ada_args.moreg_hyp[0] * moreg) ** ada_args.moreg_hyp[1]) / ada_args.moreg_hyp[2] | ||
| else: | ||
| moreg = 1.0 | ||
|
|
||
| mograd = ada_args.mograd_mul * (moreg - ada_args.previous_moreg) / ada_args.skipped_step_length | ||
| ada_args.previous_moreg = moreg | ||
| moreg = moreg + abs(mograd) | ||
| cache_diff = cache_diff * moreg | ||
|
|
||
| metric_thres, cache_rates = list(self.codebook.keys()), list(self.codebook.values()) | ||
| if cache_diff < metric_thres[0]: | ||
| new_rate = cache_rates[0] | ||
| elif cache_diff < metric_thres[1]: | ||
| new_rate = cache_rates[1] | ||
| elif cache_diff < metric_thres[2]: | ||
| new_rate = cache_rates[2] | ||
| elif cache_diff < metric_thres[3]: | ||
| new_rate = cache_rates[3] | ||
| elif cache_diff < metric_thres[4]: | ||
| new_rate = cache_rates[4] | ||
| else: | ||
| new_rate = cache_rates[-1] | ||
|
|
||
| ada_args.previous_residual_tiny = ada_args.now_residual_tiny | ||
| return new_rate |
There was a problem hiding this comment.
The motion regulation (moreg) calculation is designed for video models to measure temporal motion across frames. Since Flux2 is a 2D image model, there is no temporal dimension, and moreg is completely redundant. Furthermore, because ada_args.spatial_dim defaults to 0 and is updated to residual.shape[0] (the full image sequence length), the slicing res[i * ada_args.spatial_dim :, :] results in empty tensors, leading to division-by-zero (NaN) errors during norm calculations. Removing the moreg logic entirely avoids these issues, simplifies the code, and improves performance.
def _calculate_skip_step_length_for_args(self, ada_args):
if ada_args.previous_residual_tiny is None:
ada_args.previous_residual_tiny = ada_args.now_residual_tiny
return 1
cache = ada_args.previous_residual_tiny
res = ada_args.now_residual_tiny
norm_ord = ada_args.norm_ord
cache_diff = (cache - res).norm(dim=(0, 1), p=norm_ord) / cache.norm(dim=(0, 1), p=norm_ord)
cache_diff = cache_diff / ada_args.skipped_step_length
new_rate = list(self.codebook.values())[-1]
for thres, rate in self.codebook.items():
if cache_diff < thres:
new_rate = rate
break
ada_args.previous_residual_tiny = ada_args.now_residual_tiny
return new_rate| if ada_args.previous_residual is not None: | ||
| ada_args.previous_residual = ada_args.previous_residual.cpu() | ||
| if ada_args.previous_residual_tiny is not None: | ||
| ada_args.previous_residual_tiny = ada_args.previous_residual_tiny.cpu() | ||
| if ada_args.now_residual_tiny is not None: | ||
| ada_args.now_residual_tiny = ada_args.now_residual_tiny.cpu() | ||
|
|
||
| ada_args.previous_residual = None | ||
| ada_args.previous_residual_tiny = None | ||
| ada_args.now_residual_tiny = None |
There was a problem hiding this comment.
Moving the residual tensors to CPU immediately before setting them to None is redundant and inefficient. Setting them to None releases the references, allowing PyTorch to free the GPU memory directly. The .cpu() calls waste GPU-to-CPU bandwidth and CPU memory allocation overhead.
ada_args.previous_residual = None
ada_args.previous_residual_tiny = None
ada_args.now_residual_tiny = None| def infer(self, block_weights, pre_infer_out): | ||
| if self.scheduler.infer_condition: | ||
| index = self.scheduler.step_index | ||
| caching_records = self.scheduler.caching_records | ||
|
|
||
| if caching_records[index] or self.must_calc(index): | ||
| hidden_states = self.infer_calculating(block_weights, pre_infer_out) | ||
|
|
||
| if index <= self.scheduler.infer_steps - 2: | ||
| self.args_even.skipped_step_length = self.calculate_skip_step_length() | ||
| for i in range(1, self.args_even.skipped_step_length): | ||
| if (index + i) <= self.scheduler.infer_steps - 1: | ||
| self.scheduler.caching_records[index + i] = False | ||
| else: | ||
| hidden_states = self.infer_using_cache(pre_infer_out) | ||
| else: | ||
| index = self.scheduler.step_index | ||
| caching_records = self.scheduler.caching_records_2 | ||
|
|
||
| if caching_records[index] or self.must_calc(index): | ||
| hidden_states = self.infer_calculating(block_weights, pre_infer_out) | ||
|
|
||
| if index <= self.scheduler.infer_steps - 2: | ||
| self.args_odd.skipped_step_length = self.calculate_skip_step_length() | ||
| for i in range(1, self.args_odd.skipped_step_length): | ||
| if (index + i) <= self.scheduler.infer_steps - 1: | ||
| self.scheduler.caching_records_2[index + i] = False | ||
| else: | ||
| hidden_states = self.infer_using_cache(pre_infer_out) | ||
|
|
||
| return hidden_states |
There was a problem hiding this comment.
There is significant code duplication between the if self.scheduler.infer_condition and else branches in the infer method. We can dynamically select the caching records and arguments to simplify the logic and improve maintainability.
def infer(self, block_weights, pre_infer_out):
index = self.scheduler.step_index
if self.scheduler.infer_condition:
caching_records = self.scheduler.caching_records
ada_args = self.args_even
else:
caching_records = self.scheduler.caching_records_2
ada_args = self.args_odd
if caching_records[index] or self.must_calc(index):
hidden_states = self.infer_calculating(block_weights, pre_infer_out)
if index <= self.scheduler.infer_steps - 2:
ada_args.skipped_step_length = self.calculate_skip_step_length()
for i in range(1, ada_args.skipped_step_length):
if (index + i) <= self.scheduler.infer_steps - 1:
caching_records[index + i] = False
else:
hidden_states = self.infer_using_cache(pre_infer_out)
return hidden_states| feature_caching = self.config.get("feature_caching", "NoCaching") | ||
| if feature_caching in ("NoCaching", "None"): | ||
| if self.cpu_offload and self.offload_granularity == "block": | ||
| self.transformer_infer_class = Flux2OffloadTransformerInfer | ||
| else: | ||
| self.transformer_infer_class = Flux2TransformerInfer | ||
| elif feature_caching == "Ada": | ||
| if self.cpu_offload and self.offload_granularity == "block": | ||
| raise NotImplementedError("Flux2 AdaCache does not support block-level cpu_offload yet") | ||
| self.transformer_infer_class = Flux2TransformerInferAdaCaching | ||
| else: | ||
| self.transformer_infer_class = Flux2TransformerInfer | ||
| raise NotImplementedError(f"Unsupported feature_caching type: {feature_caching}") |
There was a problem hiding this comment.
If feature_caching is set to null in the JSON configuration (which parses to Python None), the check feature_caching in ("NoCaching", "None") will evaluate to False, leading to an unexpected NotImplementedError. Adding None to the tuple ensures robust handling of null values.
| feature_caching = self.config.get("feature_caching", "NoCaching") | |
| if feature_caching in ("NoCaching", "None"): | |
| if self.cpu_offload and self.offload_granularity == "block": | |
| self.transformer_infer_class = Flux2OffloadTransformerInfer | |
| else: | |
| self.transformer_infer_class = Flux2TransformerInfer | |
| elif feature_caching == "Ada": | |
| if self.cpu_offload and self.offload_granularity == "block": | |
| raise NotImplementedError("Flux2 AdaCache does not support block-level cpu_offload yet") | |
| self.transformer_infer_class = Flux2TransformerInferAdaCaching | |
| else: | |
| self.transformer_infer_class = Flux2TransformerInfer | |
| raise NotImplementedError(f"Unsupported feature_caching type: {feature_caching}") | |
| feature_caching = self.config.get("feature_caching", "NoCaching") | |
| if feature_caching in ("NoCaching", "None", None): | |
| if self.cpu_offload and self.offload_granularity == "block": | |
| self.transformer_infer_class = Flux2OffloadTransformerInfer | |
| else: | |
| self.transformer_infer_class = Flux2TransformerInfer | |
| elif feature_caching == "Ada": | |
| if self.cpu_offload and self.offload_granularity == "block": | |
| raise NotImplementedError("Flux2 AdaCache does not support block-level cpu_offload yet") | |
| self.transformer_infer_class = Flux2TransformerInferAdaCaching | |
| else: | |
| raise NotImplementedError(f"Unsupported feature_caching type: {feature_caching}") |
| feature_caching = self.config.get("feature_caching", "NoCaching") | ||
| if feature_caching in ("NoCaching", "None"): | ||
| if self.cpu_offload and self.offload_granularity == "block": | ||
| self.transformer_infer_class = Flux2OffloadTransformerInfer | ||
| else: | ||
| self.transformer_infer_class = Flux2TransformerInfer | ||
| elif feature_caching == "Ada": | ||
| if self.cpu_offload and self.offload_granularity == "block": | ||
| raise NotImplementedError("Flux2 AdaCache does not support block-level cpu_offload yet") | ||
| self.transformer_infer_class = Flux2TransformerInferAdaCaching | ||
| else: | ||
| self.transformer_infer_class = Flux2TransformerInfer | ||
| raise NotImplementedError(f"Unsupported feature_caching type: {feature_caching}") |
There was a problem hiding this comment.
If feature_caching is set to null in the JSON configuration (which parses to Python None), the check feature_caching in ("NoCaching", "None") will evaluate to False, leading to an unexpected NotImplementedError. Adding None to the tuple ensures robust handling of null values.
| feature_caching = self.config.get("feature_caching", "NoCaching") | |
| if feature_caching in ("NoCaching", "None"): | |
| if self.cpu_offload and self.offload_granularity == "block": | |
| self.transformer_infer_class = Flux2OffloadTransformerInfer | |
| else: | |
| self.transformer_infer_class = Flux2TransformerInfer | |
| elif feature_caching == "Ada": | |
| if self.cpu_offload and self.offload_granularity == "block": | |
| raise NotImplementedError("Flux2 AdaCache does not support block-level cpu_offload yet") | |
| self.transformer_infer_class = Flux2TransformerInferAdaCaching | |
| else: | |
| self.transformer_infer_class = Flux2TransformerInfer | |
| raise NotImplementedError(f"Unsupported feature_caching type: {feature_caching}") | |
| feature_caching = self.config.get("feature_caching", "NoCaching") | |
| if feature_caching in ("NoCaching", "None", None): | |
| if self.cpu_offload and self.offload_granularity == "block": | |
| self.transformer_infer_class = Flux2OffloadTransformerInfer | |
| else: | |
| self.transformer_infer_class = Flux2TransformerInfer | |
| elif feature_caching == "Ada": | |
| if self.cpu_offload and self.offload_granularity == "block": | |
| raise NotImplementedError("Flux2 AdaCache does not support block-level cpu_offload yet") | |
| self.transformer_infer_class = Flux2TransformerInferAdaCaching | |
| else: | |
| raise NotImplementedError(f"Unsupported feature_caching type: {feature_caching}") |
| def _get_scheduler_class(self): | ||
| if self.config.get("feature_caching", "NoCaching") in ("NoCaching", "None"): | ||
| return None | ||
| if self.config.get("feature_caching") == "Ada": | ||
| return Flux2SchedulerCaching | ||
| raise NotImplementedError(f"Unsupported feature_caching type: {self.config.get('feature_caching')}") |
There was a problem hiding this comment.
If feature_caching is set to null in the JSON configuration (which parses to Python None), the check self.config.get("feature_caching", "NoCaching") in ("NoCaching", "None") will evaluate to False, leading to an unexpected NotImplementedError. Adding None to the tuple ensures robust handling of null values.
| def _get_scheduler_class(self): | |
| if self.config.get("feature_caching", "NoCaching") in ("NoCaching", "None"): | |
| return None | |
| if self.config.get("feature_caching") == "Ada": | |
| return Flux2SchedulerCaching | |
| raise NotImplementedError(f"Unsupported feature_caching type: {self.config.get('feature_caching')}") | |
| def _get_scheduler_class(self): | |
| feature_caching = self.config.get("feature_caching", "NoCaching") | |
| if feature_caching in ("NoCaching", "None", None): | |
| return None | |
| if feature_caching == "Ada": | |
| return Flux2SchedulerCaching | |
| raise NotImplementedError(f"Unsupported feature_caching type: {feature_caching}") |
| @@ -0,0 +1,15 @@ | |||
| #!/bin/bash | |||
| lightx2v_path= | |||
There was a problem hiding this comment.
| @@ -0,0 +1,15 @@ | |||
| #!/bin/bash | |||
| lightx2v_path= | |||
There was a problem hiding this comment.
No description provided.