Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions configs/flux2/flux2_klein_i2i_inpaint_mask_cache.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"model_cls": "flux2_klein",
"task": "i2i",
"task_variant": "edit",
"infer_steps": 40,
"sample_guide_scale": 4.0,
"feature_caching": "Ada",
"vae_scale_factor": 16,
"enable_cfg": true,
"patch_size": 2,
"tokenizer_max_length": 512,
"rope_type": "flashinfer",
"max_image_area": 1048576,
"inpaint_mask_enabled": true
}
18 changes: 18 additions & 0 deletions configs/flux2/flux2_klein_i2i_inpaint_mask_cfg_parallel_cache.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{
"model_cls": "flux2_klein",
"task": "i2i",
"task_variant": "edit",
"infer_steps": 40,
"sample_guide_scale": 4.0,
"vae_scale_factor": 16,
"feature_caching": "Ada",
"enable_cfg": true,
"patch_size": 2,
"tokenizer_max_length": 512,
"rope_type": "flashinfer",
"max_image_area": 1048576,
"inpaint_mask_enabled": true,
"parallel": {
"cfg_p_size": 2
}
}
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import torch

from lightx2v.models.networks.flux2.infer.transformer_infer import Flux2TransformerInfer


class Flux2TransformerInferCaching(Flux2TransformerInfer):
def __init__(self, config):
super().__init__(config)
self.must_calc_steps = []
if self.config.get("changing_resolution", False):
self.must_calc_steps = self.config["changing_resolution_steps"]

def must_calc(self, step_index):
return step_index in self.must_calc_steps


class Flux2AdaArgs:
def __init__(self, config):
self.previous_residual_tiny = None
self.now_residual_tiny = None
self.norm_ord = 1
self.skipped_step_length = 1
self.previous_residual = None

self.previous_moreg = 1.0
self.moreg_strides = [1]
self.moreg_steps = [int(0.1 * config["infer_steps"]), int(0.9 * config["infer_steps"])]
self.moreg_hyp = [0.385, 8, 1, 2]
self.mograd_mul = 10
self.spatial_dim = config.get("adacache_spatial_dim", 0)


class Flux2TransformerInferAdaCaching(Flux2TransformerInferCaching):
def __init__(self, config):
super().__init__(config)
self.decisive_double_block_id = config.get("num_layers", 10) // 2
self.codebook = {0.03: 12, 0.05: 10, 0.07: 8, 0.09: 6, 0.11: 4, 1.00: 3}
self.args_even = Flux2AdaArgs(config)
self.args_odd = Flux2AdaArgs(config)

def infer(self, block_weights, pre_infer_out):
if self.scheduler.infer_condition:
index = self.scheduler.step_index
caching_records = self.scheduler.caching_records

if caching_records[index] or self.must_calc(index):
hidden_states = self.infer_calculating(block_weights, pre_infer_out)

if index <= self.scheduler.infer_steps - 2:
self.args_even.skipped_step_length = self.calculate_skip_step_length()
for i in range(1, self.args_even.skipped_step_length):
if (index + i) <= self.scheduler.infer_steps - 1:
self.scheduler.caching_records[index + i] = False
else:
hidden_states = self.infer_using_cache(pre_infer_out)
else:
index = self.scheduler.step_index
caching_records = self.scheduler.caching_records_2

if caching_records[index] or self.must_calc(index):
hidden_states = self.infer_calculating(block_weights, pre_infer_out)

if index <= self.scheduler.infer_steps - 2:
self.args_odd.skipped_step_length = self.calculate_skip_step_length()
for i in range(1, self.args_odd.skipped_step_length):
if (index + i) <= self.scheduler.infer_steps - 1:
self.scheduler.caching_records_2[index + i] = False
else:
hidden_states = self.infer_using_cache(pre_infer_out)

return hidden_states
Comment on lines +41 to +71

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is significant code duplication between the if self.scheduler.infer_condition and else branches in the infer method. We can dynamically select the caching records and arguments to simplify the logic and improve maintainability.

    def infer(self, block_weights, pre_infer_out):
        index = self.scheduler.step_index
        if self.scheduler.infer_condition:
            caching_records = self.scheduler.caching_records
            ada_args = self.args_even
        else:
            caching_records = self.scheduler.caching_records_2
            ada_args = self.args_odd

        if caching_records[index] or self.must_calc(index):
            hidden_states = self.infer_calculating(block_weights, pre_infer_out)

            if index <= self.scheduler.infer_steps - 2:
                ada_args.skipped_step_length = self.calculate_skip_step_length()
                for i in range(1, ada_args.skipped_step_length):
                    if (index + i) <= self.scheduler.infer_steps - 1:
                        caching_records[index + i] = False
        else:
            hidden_states = self.infer_using_cache(pre_infer_out)

        return hidden_states


def infer_calculating(self, block_weights, pre_infer_out):
ori_hidden_states = pre_infer_out.hidden_states.clone()
ada_args = self.args_even if self.scheduler.infer_condition else self.args_odd

def on_decisive_block(gated_img_attn):
ada_args.now_residual_tiny = gated_img_attn.squeeze(0)

hidden_states = self._infer_forward(
block_weights,
pre_infer_out,
decisive_block_id=self.decisive_double_block_id,
on_decisive_block=on_decisive_block,
)

ada_args.previous_residual = hidden_states - ori_hidden_states
return hidden_states

def infer_using_cache(self, pre_infer_out):
hidden_states = pre_infer_out.hidden_states
if self.scheduler.infer_condition:
hidden_states = hidden_states + self.args_even.previous_residual
else:
hidden_states = hidden_states + self.args_odd.previous_residual
return hidden_states

def _update_spatial_dim(self, ada_args, residual):
if ada_args.spatial_dim <= 0:
ada_args.spatial_dim = residual.shape[0]

def _calculate_skip_step_length_for_args(self, ada_args):
if ada_args.previous_residual_tiny is None:
ada_args.previous_residual_tiny = ada_args.now_residual_tiny
return 1

cache = ada_args.previous_residual_tiny
res = ada_args.now_residual_tiny
self._update_spatial_dim(ada_args, res)
norm_ord = ada_args.norm_ord
cache_diff = (cache - res).norm(dim=(0, 1), p=norm_ord) / cache.norm(dim=(0, 1), p=norm_ord)
cache_diff = cache_diff / ada_args.skipped_step_length

if ada_args.moreg_steps[0] <= self.scheduler.step_index <= ada_args.moreg_steps[1]:
moreg = 0
for i in ada_args.moreg_strides:
moreg_i = (res[i * ada_args.spatial_dim :, :] - res[: -i * ada_args.spatial_dim, :]).norm(p=norm_ord)
moreg_i /= res[i * ada_args.spatial_dim :, :].norm(p=norm_ord) + res[: -i * ada_args.spatial_dim, :].norm(p=norm_ord)
moreg += moreg_i
moreg = moreg / len(ada_args.moreg_strides)
moreg = ((1 / ada_args.moreg_hyp[0] * moreg) ** ada_args.moreg_hyp[1]) / ada_args.moreg_hyp[2]
else:
moreg = 1.0

mograd = ada_args.mograd_mul * (moreg - ada_args.previous_moreg) / ada_args.skipped_step_length
ada_args.previous_moreg = moreg
moreg = moreg + abs(mograd)
cache_diff = cache_diff * moreg

metric_thres, cache_rates = list(self.codebook.keys()), list(self.codebook.values())
if cache_diff < metric_thres[0]:
new_rate = cache_rates[0]
elif cache_diff < metric_thres[1]:
new_rate = cache_rates[1]
elif cache_diff < metric_thres[2]:
new_rate = cache_rates[2]
elif cache_diff < metric_thres[3]:
new_rate = cache_rates[3]
elif cache_diff < metric_thres[4]:
new_rate = cache_rates[4]
else:
new_rate = cache_rates[-1]

ada_args.previous_residual_tiny = ada_args.now_residual_tiny
return new_rate
Comment on lines +98 to +145

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The motion regulation (moreg) calculation is designed for video models to measure temporal motion across frames. Since Flux2 is a 2D image model, there is no temporal dimension, and moreg is completely redundant. Furthermore, because ada_args.spatial_dim defaults to 0 and is updated to residual.shape[0] (the full image sequence length), the slicing res[i * ada_args.spatial_dim :, :] results in empty tensors, leading to division-by-zero (NaN) errors during norm calculations. Removing the moreg logic entirely avoids these issues, simplifies the code, and improves performance.

    def _calculate_skip_step_length_for_args(self, ada_args):
        if ada_args.previous_residual_tiny is None:
            ada_args.previous_residual_tiny = ada_args.now_residual_tiny
            return 1

        cache = ada_args.previous_residual_tiny
        res = ada_args.now_residual_tiny
        norm_ord = ada_args.norm_ord
        cache_diff = (cache - res).norm(dim=(0, 1), p=norm_ord) / cache.norm(dim=(0, 1), p=norm_ord)
        cache_diff = cache_diff / ada_args.skipped_step_length

        new_rate = list(self.codebook.values())[-1]
        for thres, rate in self.codebook.items():
            if cache_diff < thres:
                new_rate = rate
                break

        ada_args.previous_residual_tiny = ada_args.now_residual_tiny
        return new_rate


def calculate_skip_step_length(self):
if self.scheduler.infer_condition:
return self._calculate_skip_step_length_for_args(self.args_even)
return self._calculate_skip_step_length_for_args(self.args_odd)

def clear(self):
for ada_args in (self.args_even, self.args_odd):
if ada_args.previous_residual is not None:
ada_args.previous_residual = ada_args.previous_residual.cpu()
if ada_args.previous_residual_tiny is not None:
ada_args.previous_residual_tiny = ada_args.previous_residual_tiny.cpu()
if ada_args.now_residual_tiny is not None:
ada_args.now_residual_tiny = ada_args.now_residual_tiny.cpu()

ada_args.previous_residual = None
ada_args.previous_residual_tiny = None
ada_args.now_residual_tiny = None
Comment on lines +154 to +163

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Moving the residual tensors to CPU immediately before setting them to None is redundant and inefficient. Setting them to None releases the references, allowing PyTorch to free the GPU memory directly. The .cpu() calls waste GPU-to-CPU bandwidth and CPU memory allocation overhead.

            ada_args.previous_residual = None
            ada_args.previous_residual_tiny = None
            ada_args.now_residual_tiny = None


torch.cuda.empty_cache()
91 changes: 51 additions & 40 deletions lightx2v/models/networks/flux2/infer/transformer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def infer_double_stream_block(
temb_mod_img,
temb_mod_txt,
image_rotary_emb,
img_attn_hook=None,
):
heads = self.config["num_attention_heads"]
head_dim = self.config["attention_head_dim"]
Expand Down Expand Up @@ -134,7 +135,10 @@ def infer_double_stream_block(
img_attn_output = block_weights.to_out.apply(img_attn_output)
txt_attn_output = block_weights.to_add_out.apply(txt_attn_output)

hidden_states = hidden_states + gate_msa * img_attn_output
gated_img_attn = gate_msa * img_attn_output
if img_attn_hook is not None:
img_attn_hook(gated_img_attn)
hidden_states = hidden_states + gated_img_attn
encoder_hidden_states = encoder_hidden_states + c_gate_msa * txt_attn_output
norm_hidden_states2 = F.layer_norm(hidden_states, (hidden_states.shape[-1],))
norm_hidden_states2 = (norm_hidden_states2 * (1 + scale_mlp) + shift_mlp).squeeze(0)
Expand Down Expand Up @@ -237,62 +241,67 @@ def infer_single_stream_block(

return hidden_states

def infer(self, block_weights, pre_infer_out):
def _prepare_image_rotary_emb(self, image_rotary_emb, num_txt_tokens):
if self.seq_p_group is None or image_rotary_emb is None:
return image_rotary_emb

world_size = dist.get_world_size(self.seq_p_group)
cur_rank = dist.get_rank(self.seq_p_group)

if isinstance(image_rotary_emb, tuple):
freqs_cos, freqs_sin = image_rotary_emb

txt_cos = freqs_cos[:num_txt_tokens]
img_cos = freqs_cos[num_txt_tokens:]
txt_sin = freqs_sin[:num_txt_tokens]
img_sin = freqs_sin[num_txt_tokens:]

seqlen = img_cos.shape[0]
padding_size = (world_size - (seqlen % world_size)) % world_size
if padding_size > 0:
img_cos = F.pad(img_cos, (0, 0, 0, padding_size))
img_sin = F.pad(img_sin, (0, 0, 0, padding_size))
img_cos = torch.chunk(img_cos, world_size, dim=0)[cur_rank]
img_sin = torch.chunk(img_sin, world_size, dim=0)[cur_rank]

freqs_cos = torch.cat([txt_cos, img_cos], dim=0)
freqs_sin = torch.cat([txt_sin, img_sin], dim=0)
return (freqs_cos, freqs_sin)

txt_emb = image_rotary_emb[:num_txt_tokens]
img_emb = image_rotary_emb[num_txt_tokens:]

seqlen = img_emb.shape[0]
padding_size = (world_size - (seqlen % world_size)) % world_size
if padding_size > 0:
img_emb = F.pad(img_emb, (0, 0, 0, padding_size))
img_emb = torch.chunk(img_emb, world_size, dim=0)[cur_rank]
return torch.cat([txt_emb, img_emb], dim=0)

def _infer_forward(self, block_weights, pre_infer_out, decisive_block_id=None, on_decisive_block=None):
hidden_states = pre_infer_out.hidden_states
encoder_hidden_states = pre_infer_out.encoder_hidden_states
timestep = pre_infer_out.timestep
image_rotary_emb = pre_infer_out.image_rotary_emb

num_txt_tokens = encoder_hidden_states.shape[0]

if self.seq_p_group is not None and image_rotary_emb is not None:
world_size = dist.get_world_size(self.seq_p_group)
cur_rank = dist.get_rank(self.seq_p_group)

if isinstance(image_rotary_emb, tuple):
freqs_cos, freqs_sin = image_rotary_emb

txt_cos = freqs_cos[:num_txt_tokens]
img_cos = freqs_cos[num_txt_tokens:]
txt_sin = freqs_sin[:num_txt_tokens]
img_sin = freqs_sin[num_txt_tokens:]

seqlen = img_cos.shape[0]
padding_size = (world_size - (seqlen % world_size)) % world_size
if padding_size > 0:
img_cos = F.pad(img_cos, (0, 0, 0, padding_size))
img_sin = F.pad(img_sin, (0, 0, 0, padding_size))
img_cos = torch.chunk(img_cos, world_size, dim=0)[cur_rank]
img_sin = torch.chunk(img_sin, world_size, dim=0)[cur_rank]

freqs_cos = torch.cat([txt_cos, img_cos], dim=0)
freqs_sin = torch.cat([txt_sin, img_sin], dim=0)
image_rotary_emb = (freqs_cos, freqs_sin)
else:
txt_emb = image_rotary_emb[:num_txt_tokens]
img_emb = image_rotary_emb[num_txt_tokens:]

seqlen = img_emb.shape[0]
padding_size = (world_size - (seqlen % world_size)) % world_size
if padding_size > 0:
img_emb = F.pad(img_emb, (0, 0, 0, padding_size))
img_emb = torch.chunk(img_emb, world_size, dim=0)[cur_rank]

image_rotary_emb = torch.cat([txt_emb, img_emb], dim=0)
image_rotary_emb = self._prepare_image_rotary_emb(image_rotary_emb, num_txt_tokens)

timestep_act = F.silu(timestep)
double_stream_mod_img = block_weights.double_stream_modulation_img_linear.apply(timestep_act)
double_stream_mod_txt = block_weights.double_stream_modulation_txt_linear.apply(timestep_act)
single_stream_mod = block_weights.single_stream_modulation_linear.apply(timestep_act)

for block in block_weights.double_blocks:
for block_idx, block in enumerate(block_weights.double_blocks):
block_hook = on_decisive_block if block_idx == decisive_block_id else None
encoder_hidden_states, hidden_states = self.infer_double_stream_block(
block,
hidden_states,
encoder_hidden_states,
double_stream_mod_img,
double_stream_mod_txt,
image_rotary_emb,
img_attn_hook=block_hook,
)

hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=0)
Expand All @@ -306,8 +315,10 @@ def infer(self, block_weights, pre_infer_out):
image_rotary_emb,
num_txt_tokens=num_txt_tokens,
)
hidden_states = hidden_states[num_txt_tokens:, ...]
return hidden_states
return hidden_states[num_txt_tokens:, ...]

def infer(self, block_weights, pre_infer_out):
return self._infer_forward(block_weights, pre_infer_out)


# Backward-compatible alias
Expand Down
29 changes: 23 additions & 6 deletions lightx2v/models/networks/flux2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from torch.nn import functional as F

from lightx2v.models.networks.base_model import BaseTransformerModel
from lightx2v.models.networks.flux2.infer.feature_caching.transformer_infer import Flux2TransformerInferAdaCaching
from lightx2v.models.networks.flux2.infer.offload.transformer_infer import Flux2OffloadTransformerInfer
from lightx2v.models.networks.flux2.infer.post_infer import Flux2PostInfer
from lightx2v.models.networks.flux2.infer.pre_infer import Flux2DevPreInfer, Flux2PreInfer
Expand Down Expand Up @@ -103,10 +104,18 @@ class Flux2KleinTransformerModel(_Flux2TransformerModelBase):
pre_weight_class = Flux2PreWeights

def _init_infer_class(self):
if self.cpu_offload and self.offload_granularity == "block":
self.transformer_infer_class = Flux2OffloadTransformerInfer
feature_caching = self.config.get("feature_caching", "NoCaching")
if feature_caching in ("NoCaching", "None"):
if self.cpu_offload and self.offload_granularity == "block":
self.transformer_infer_class = Flux2OffloadTransformerInfer
else:
self.transformer_infer_class = Flux2TransformerInfer
elif feature_caching == "Ada":
if self.cpu_offload and self.offload_granularity == "block":
raise NotImplementedError("Flux2 AdaCache does not support block-level cpu_offload yet")
self.transformer_infer_class = Flux2TransformerInferAdaCaching
else:
self.transformer_infer_class = Flux2TransformerInfer
raise NotImplementedError(f"Unsupported feature_caching type: {feature_caching}")
Comment on lines +107 to +118

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If feature_caching is set to null in the JSON configuration (which parses to Python None), the check feature_caching in ("NoCaching", "None") will evaluate to False, leading to an unexpected NotImplementedError. Adding None to the tuple ensures robust handling of null values.

Suggested change
feature_caching = self.config.get("feature_caching", "NoCaching")
if feature_caching in ("NoCaching", "None"):
if self.cpu_offload and self.offload_granularity == "block":
self.transformer_infer_class = Flux2OffloadTransformerInfer
else:
self.transformer_infer_class = Flux2TransformerInfer
elif feature_caching == "Ada":
if self.cpu_offload and self.offload_granularity == "block":
raise NotImplementedError("Flux2 AdaCache does not support block-level cpu_offload yet")
self.transformer_infer_class = Flux2TransformerInferAdaCaching
else:
self.transformer_infer_class = Flux2TransformerInfer
raise NotImplementedError(f"Unsupported feature_caching type: {feature_caching}")
feature_caching = self.config.get("feature_caching", "NoCaching")
if feature_caching in ("NoCaching", "None", None):
if self.cpu_offload and self.offload_granularity == "block":
self.transformer_infer_class = Flux2OffloadTransformerInfer
else:
self.transformer_infer_class = Flux2TransformerInfer
elif feature_caching == "Ada":
if self.cpu_offload and self.offload_granularity == "block":
raise NotImplementedError("Flux2 AdaCache does not support block-level cpu_offload yet")
self.transformer_infer_class = Flux2TransformerInferAdaCaching
else:
raise NotImplementedError(f"Unsupported feature_caching type: {feature_caching}")

self.pre_infer_class = Flux2PreInfer
self.post_infer_class = Flux2PostInfer

Expand Down Expand Up @@ -219,10 +228,18 @@ class Flux2DevTransformerModel(_Flux2TransformerModelBase):
pre_weight_class = Flux2DevPreWeights

def _init_infer_class(self):
if self.cpu_offload and self.offload_granularity == "block":
self.transformer_infer_class = Flux2OffloadTransformerInfer
feature_caching = self.config.get("feature_caching", "NoCaching")
if feature_caching in ("NoCaching", "None"):
if self.cpu_offload and self.offload_granularity == "block":
self.transformer_infer_class = Flux2OffloadTransformerInfer
else:
self.transformer_infer_class = Flux2TransformerInfer
elif feature_caching == "Ada":
if self.cpu_offload and self.offload_granularity == "block":
raise NotImplementedError("Flux2 AdaCache does not support block-level cpu_offload yet")
self.transformer_infer_class = Flux2TransformerInferAdaCaching
else:
self.transformer_infer_class = Flux2TransformerInfer
raise NotImplementedError(f"Unsupported feature_caching type: {feature_caching}")
Comment on lines +231 to +242

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If feature_caching is set to null in the JSON configuration (which parses to Python None), the check feature_caching in ("NoCaching", "None") will evaluate to False, leading to an unexpected NotImplementedError. Adding None to the tuple ensures robust handling of null values.

Suggested change
feature_caching = self.config.get("feature_caching", "NoCaching")
if feature_caching in ("NoCaching", "None"):
if self.cpu_offload and self.offload_granularity == "block":
self.transformer_infer_class = Flux2OffloadTransformerInfer
else:
self.transformer_infer_class = Flux2TransformerInfer
elif feature_caching == "Ada":
if self.cpu_offload and self.offload_granularity == "block":
raise NotImplementedError("Flux2 AdaCache does not support block-level cpu_offload yet")
self.transformer_infer_class = Flux2TransformerInferAdaCaching
else:
self.transformer_infer_class = Flux2TransformerInfer
raise NotImplementedError(f"Unsupported feature_caching type: {feature_caching}")
feature_caching = self.config.get("feature_caching", "NoCaching")
if feature_caching in ("NoCaching", "None", None):
if self.cpu_offload and self.offload_granularity == "block":
self.transformer_infer_class = Flux2OffloadTransformerInfer
else:
self.transformer_infer_class = Flux2TransformerInfer
elif feature_caching == "Ada":
if self.cpu_offload and self.offload_granularity == "block":
raise NotImplementedError("Flux2 AdaCache does not support block-level cpu_offload yet")
self.transformer_infer_class = Flux2TransformerInferAdaCaching
else:
raise NotImplementedError(f"Unsupported feature_caching type: {feature_caching}")

self.pre_infer_class = Flux2DevPreInfer
self.post_infer_class = Flux2PostInfer

Expand Down
Loading
Loading