Skip to content

[feat] Add ada cache for flux2 ppt#1154

Merged
helloyongyang merged 2 commits into
mainfrom
dev/flux2_adacache
Jun 16, 2026
Merged

[feat] Add ada cache for flux2 ppt#1154
helloyongyang merged 2 commits into
mainfrom
dev/flux2_adacache

Conversation

@wangshankun

Copy link
Copy Markdown
Collaborator

No description provided.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces adaptive feature caching (AdaCache) support for Flux2 models, including new configuration files, shell scripts, and the implementation of caching-enabled transformer inference classes and schedulers. The review feedback highlights several critical and medium-severity issues: removing redundant and error-prone motion regulation (moreg) logic that causes division-by-zero errors in 2D image models, eliminating inefficient GPU-to-CPU tensor transfers before clearing references, reducing code duplication in the caching inference logic, handling null values for feature_caching to prevent unexpected NotImplementedError exceptions, and defaulting the path variable in shell scripts to prevent execution failures.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +98 to +145
def _update_spatial_dim(self, ada_args, residual):
if ada_args.spatial_dim <= 0:
ada_args.spatial_dim = residual.shape[0]

def _calculate_skip_step_length_for_args(self, ada_args):
if ada_args.previous_residual_tiny is None:
ada_args.previous_residual_tiny = ada_args.now_residual_tiny
return 1

cache = ada_args.previous_residual_tiny
res = ada_args.now_residual_tiny
self._update_spatial_dim(ada_args, res)
norm_ord = ada_args.norm_ord
cache_diff = (cache - res).norm(dim=(0, 1), p=norm_ord) / cache.norm(dim=(0, 1), p=norm_ord)
cache_diff = cache_diff / ada_args.skipped_step_length

if ada_args.moreg_steps[0] <= self.scheduler.step_index <= ada_args.moreg_steps[1]:
moreg = 0
for i in ada_args.moreg_strides:
moreg_i = (res[i * ada_args.spatial_dim :, :] - res[: -i * ada_args.spatial_dim, :]).norm(p=norm_ord)
moreg_i /= res[i * ada_args.spatial_dim :, :].norm(p=norm_ord) + res[: -i * ada_args.spatial_dim, :].norm(p=norm_ord)
moreg += moreg_i
moreg = moreg / len(ada_args.moreg_strides)
moreg = ((1 / ada_args.moreg_hyp[0] * moreg) ** ada_args.moreg_hyp[1]) / ada_args.moreg_hyp[2]
else:
moreg = 1.0

mograd = ada_args.mograd_mul * (moreg - ada_args.previous_moreg) / ada_args.skipped_step_length
ada_args.previous_moreg = moreg
moreg = moreg + abs(mograd)
cache_diff = cache_diff * moreg

metric_thres, cache_rates = list(self.codebook.keys()), list(self.codebook.values())
if cache_diff < metric_thres[0]:
new_rate = cache_rates[0]
elif cache_diff < metric_thres[1]:
new_rate = cache_rates[1]
elif cache_diff < metric_thres[2]:
new_rate = cache_rates[2]
elif cache_diff < metric_thres[3]:
new_rate = cache_rates[3]
elif cache_diff < metric_thres[4]:
new_rate = cache_rates[4]
else:
new_rate = cache_rates[-1]

ada_args.previous_residual_tiny = ada_args.now_residual_tiny
return new_rate

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The motion regulation (moreg) calculation is designed for video models to measure temporal motion across frames. Since Flux2 is a 2D image model, there is no temporal dimension, and moreg is completely redundant. Furthermore, because ada_args.spatial_dim defaults to 0 and is updated to residual.shape[0] (the full image sequence length), the slicing res[i * ada_args.spatial_dim :, :] results in empty tensors, leading to division-by-zero (NaN) errors during norm calculations. Removing the moreg logic entirely avoids these issues, simplifies the code, and improves performance.

    def _calculate_skip_step_length_for_args(self, ada_args):
        if ada_args.previous_residual_tiny is None:
            ada_args.previous_residual_tiny = ada_args.now_residual_tiny
            return 1

        cache = ada_args.previous_residual_tiny
        res = ada_args.now_residual_tiny
        norm_ord = ada_args.norm_ord
        cache_diff = (cache - res).norm(dim=(0, 1), p=norm_ord) / cache.norm(dim=(0, 1), p=norm_ord)
        cache_diff = cache_diff / ada_args.skipped_step_length

        new_rate = list(self.codebook.values())[-1]
        for thres, rate in self.codebook.items():
            if cache_diff < thres:
                new_rate = rate
                break

        ada_args.previous_residual_tiny = ada_args.now_residual_tiny
        return new_rate

Comment on lines +154 to +163
if ada_args.previous_residual is not None:
ada_args.previous_residual = ada_args.previous_residual.cpu()
if ada_args.previous_residual_tiny is not None:
ada_args.previous_residual_tiny = ada_args.previous_residual_tiny.cpu()
if ada_args.now_residual_tiny is not None:
ada_args.now_residual_tiny = ada_args.now_residual_tiny.cpu()

ada_args.previous_residual = None
ada_args.previous_residual_tiny = None
ada_args.now_residual_tiny = None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Moving the residual tensors to CPU immediately before setting them to None is redundant and inefficient. Setting them to None releases the references, allowing PyTorch to free the GPU memory directly. The .cpu() calls waste GPU-to-CPU bandwidth and CPU memory allocation overhead.

            ada_args.previous_residual = None
            ada_args.previous_residual_tiny = None
            ada_args.now_residual_tiny = None

Comment on lines +41 to +71
def infer(self, block_weights, pre_infer_out):
if self.scheduler.infer_condition:
index = self.scheduler.step_index
caching_records = self.scheduler.caching_records

if caching_records[index] or self.must_calc(index):
hidden_states = self.infer_calculating(block_weights, pre_infer_out)

if index <= self.scheduler.infer_steps - 2:
self.args_even.skipped_step_length = self.calculate_skip_step_length()
for i in range(1, self.args_even.skipped_step_length):
if (index + i) <= self.scheduler.infer_steps - 1:
self.scheduler.caching_records[index + i] = False
else:
hidden_states = self.infer_using_cache(pre_infer_out)
else:
index = self.scheduler.step_index
caching_records = self.scheduler.caching_records_2

if caching_records[index] or self.must_calc(index):
hidden_states = self.infer_calculating(block_weights, pre_infer_out)

if index <= self.scheduler.infer_steps - 2:
self.args_odd.skipped_step_length = self.calculate_skip_step_length()
for i in range(1, self.args_odd.skipped_step_length):
if (index + i) <= self.scheduler.infer_steps - 1:
self.scheduler.caching_records_2[index + i] = False
else:
hidden_states = self.infer_using_cache(pre_infer_out)

return hidden_states

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is significant code duplication between the if self.scheduler.infer_condition and else branches in the infer method. We can dynamically select the caching records and arguments to simplify the logic and improve maintainability.

    def infer(self, block_weights, pre_infer_out):
        index = self.scheduler.step_index
        if self.scheduler.infer_condition:
            caching_records = self.scheduler.caching_records
            ada_args = self.args_even
        else:
            caching_records = self.scheduler.caching_records_2
            ada_args = self.args_odd

        if caching_records[index] or self.must_calc(index):
            hidden_states = self.infer_calculating(block_weights, pre_infer_out)

            if index <= self.scheduler.infer_steps - 2:
                ada_args.skipped_step_length = self.calculate_skip_step_length()
                for i in range(1, ada_args.skipped_step_length):
                    if (index + i) <= self.scheduler.infer_steps - 1:
                        caching_records[index + i] = False
        else:
            hidden_states = self.infer_using_cache(pre_infer_out)

        return hidden_states

Comment on lines +107 to +118
feature_caching = self.config.get("feature_caching", "NoCaching")
if feature_caching in ("NoCaching", "None"):
if self.cpu_offload and self.offload_granularity == "block":
self.transformer_infer_class = Flux2OffloadTransformerInfer
else:
self.transformer_infer_class = Flux2TransformerInfer
elif feature_caching == "Ada":
if self.cpu_offload and self.offload_granularity == "block":
raise NotImplementedError("Flux2 AdaCache does not support block-level cpu_offload yet")
self.transformer_infer_class = Flux2TransformerInferAdaCaching
else:
self.transformer_infer_class = Flux2TransformerInfer
raise NotImplementedError(f"Unsupported feature_caching type: {feature_caching}")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If feature_caching is set to null in the JSON configuration (which parses to Python None), the check feature_caching in ("NoCaching", "None") will evaluate to False, leading to an unexpected NotImplementedError. Adding None to the tuple ensures robust handling of null values.

Suggested change
feature_caching = self.config.get("feature_caching", "NoCaching")
if feature_caching in ("NoCaching", "None"):
if self.cpu_offload and self.offload_granularity == "block":
self.transformer_infer_class = Flux2OffloadTransformerInfer
else:
self.transformer_infer_class = Flux2TransformerInfer
elif feature_caching == "Ada":
if self.cpu_offload and self.offload_granularity == "block":
raise NotImplementedError("Flux2 AdaCache does not support block-level cpu_offload yet")
self.transformer_infer_class = Flux2TransformerInferAdaCaching
else:
self.transformer_infer_class = Flux2TransformerInfer
raise NotImplementedError(f"Unsupported feature_caching type: {feature_caching}")
feature_caching = self.config.get("feature_caching", "NoCaching")
if feature_caching in ("NoCaching", "None", None):
if self.cpu_offload and self.offload_granularity == "block":
self.transformer_infer_class = Flux2OffloadTransformerInfer
else:
self.transformer_infer_class = Flux2TransformerInfer
elif feature_caching == "Ada":
if self.cpu_offload and self.offload_granularity == "block":
raise NotImplementedError("Flux2 AdaCache does not support block-level cpu_offload yet")
self.transformer_infer_class = Flux2TransformerInferAdaCaching
else:
raise NotImplementedError(f"Unsupported feature_caching type: {feature_caching}")

Comment on lines +231 to +242
feature_caching = self.config.get("feature_caching", "NoCaching")
if feature_caching in ("NoCaching", "None"):
if self.cpu_offload and self.offload_granularity == "block":
self.transformer_infer_class = Flux2OffloadTransformerInfer
else:
self.transformer_infer_class = Flux2TransformerInfer
elif feature_caching == "Ada":
if self.cpu_offload and self.offload_granularity == "block":
raise NotImplementedError("Flux2 AdaCache does not support block-level cpu_offload yet")
self.transformer_infer_class = Flux2TransformerInferAdaCaching
else:
self.transformer_infer_class = Flux2TransformerInfer
raise NotImplementedError(f"Unsupported feature_caching type: {feature_caching}")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If feature_caching is set to null in the JSON configuration (which parses to Python None), the check feature_caching in ("NoCaching", "None") will evaluate to False, leading to an unexpected NotImplementedError. Adding None to the tuple ensures robust handling of null values.

Suggested change
feature_caching = self.config.get("feature_caching", "NoCaching")
if feature_caching in ("NoCaching", "None"):
if self.cpu_offload and self.offload_granularity == "block":
self.transformer_infer_class = Flux2OffloadTransformerInfer
else:
self.transformer_infer_class = Flux2TransformerInfer
elif feature_caching == "Ada":
if self.cpu_offload and self.offload_granularity == "block":
raise NotImplementedError("Flux2 AdaCache does not support block-level cpu_offload yet")
self.transformer_infer_class = Flux2TransformerInferAdaCaching
else:
self.transformer_infer_class = Flux2TransformerInfer
raise NotImplementedError(f"Unsupported feature_caching type: {feature_caching}")
feature_caching = self.config.get("feature_caching", "NoCaching")
if feature_caching in ("NoCaching", "None", None):
if self.cpu_offload and self.offload_granularity == "block":
self.transformer_infer_class = Flux2OffloadTransformerInfer
else:
self.transformer_infer_class = Flux2TransformerInfer
elif feature_caching == "Ada":
if self.cpu_offload and self.offload_granularity == "block":
raise NotImplementedError("Flux2 AdaCache does not support block-level cpu_offload yet")
self.transformer_infer_class = Flux2TransformerInferAdaCaching
else:
raise NotImplementedError(f"Unsupported feature_caching type: {feature_caching}")

Comment on lines +42 to +47
def _get_scheduler_class(self):
if self.config.get("feature_caching", "NoCaching") in ("NoCaching", "None"):
return None
if self.config.get("feature_caching") == "Ada":
return Flux2SchedulerCaching
raise NotImplementedError(f"Unsupported feature_caching type: {self.config.get('feature_caching')}")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If feature_caching is set to null in the JSON configuration (which parses to Python None), the check self.config.get("feature_caching", "NoCaching") in ("NoCaching", "None") will evaluate to False, leading to an unexpected NotImplementedError. Adding None to the tuple ensures robust handling of null values.

Suggested change
def _get_scheduler_class(self):
if self.config.get("feature_caching", "NoCaching") in ("NoCaching", "None"):
return None
if self.config.get("feature_caching") == "Ada":
return Flux2SchedulerCaching
raise NotImplementedError(f"Unsupported feature_caching type: {self.config.get('feature_caching')}")
def _get_scheduler_class(self):
feature_caching = self.config.get("feature_caching", "NoCaching")
if feature_caching in ("NoCaching", "None", None):
return None
if feature_caching == "Ada":
return Flux2SchedulerCaching
raise NotImplementedError(f"Unsupported feature_caching type: {feature_caching}")

@@ -0,0 +1,15 @@
#!/bin/bash
lightx2v_path=

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Leaving lightx2v_path empty will cause the script to fail when trying to source /scripts/base/base.sh or resolve asset paths. Defaulting it to . (current directory) makes the script runnable out-of-the-box from the repository root.

Suggested change
lightx2v_path=
lightx2v_path="."

@@ -0,0 +1,15 @@
#!/bin/bash
lightx2v_path=

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Leaving lightx2v_path empty will cause the script to fail when trying to source /scripts/base/base.sh or resolve asset paths. Defaulting it to . (current directory) makes the script runnable out-of-the-box from the repository root.

Suggested change
lightx2v_path=
lightx2v_path="."

@helloyongyang helloyongyang merged commit 9c14628 into main Jun 16, 2026
2 checks passed
@helloyongyang helloyongyang deleted the dev/flux2_adacache branch June 16, 2026 10:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants