Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions configs/flux2/flux2_klein_i2i_inpaint_mask_cache.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"model_cls": "flux2_klein",
"task": "i2i",
"task_variant": "edit",
"infer_steps": 40,
"sample_guide_scale": 4.0,
"feature_caching": "Ada",
"vae_scale_factor": 16,
"enable_cfg": true,
"patch_size": 2,
"tokenizer_max_length": 512,
"rope_type": "flashinfer",
"max_image_area": 1048576,
"inpaint_mask_enabled": true
}
18 changes: 18 additions & 0 deletions configs/flux2/flux2_klein_i2i_inpaint_mask_cfg_parallel_cache.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{
"model_cls": "flux2_klein",
"task": "i2i",
"task_variant": "edit",
"infer_steps": 40,
"sample_guide_scale": 4.0,
"vae_scale_factor": 16,
"feature_caching": "Ada",
"enable_cfg": true,
"patch_size": 2,
"tokenizer_max_length": 512,
"rope_type": "flashinfer",
"max_image_area": 1048576,
"inpaint_mask_enabled": true,
"parallel": {
"cfg_p_size": 2
}
}
3 changes: 1 addition & 2 deletions lightx2v/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
from lightx2v.common.ops import *
from lightx2v.models.runners.bagel.bagel_runner import BagelRunner # noqa: F401
from lightx2v.models.runners.ernie_image.ernie_image_runner import ErnieImageRunner # noqa: F401
from lightx2v.models.runners.flux2.flux2_runner import Flux2DevRunner, Flux2KleinRunner # noqa: F401
from lightx2v.models.runners.hidream_o1_image.hidream_o1_image_runner import HidreamO1ImageRunner # noqa: F401
from lightx2v.models.runners.hunyuan3d.hunyuan3d_shape_runner import Hunyuan3DShapeRunner # noqa: F401

# from lightx2v.models.runners.flux2.flux2_runner import Flux2DevRunner, Flux2KleinRunner # noqa: F401
from lightx2v.models.runners.hunyuan_video.hunyuan_video_15_distill_runner import HunyuanVideo15DistillRunner # noqa: F401
from lightx2v.models.runners.hunyuan_video.hunyuan_video_15_runner import HunyuanVideo15Runner # noqa: F401
from lightx2v.models.runners.longcat_image.longcat_image_runner import LongCatImageRunner # noqa: F401
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import torch

from lightx2v.models.networks.flux2.infer.transformer_infer import Flux2TransformerInfer


class Flux2TransformerInferCaching(Flux2TransformerInfer):
def __init__(self, config):
super().__init__(config)
self.must_calc_steps = []
if self.config.get("changing_resolution", False):
self.must_calc_steps = self.config["changing_resolution_steps"]

def must_calc(self, step_index):
return step_index in self.must_calc_steps


class Flux2AdaArgs:
def __init__(self, config):
self.previous_residual_tiny = None
self.now_residual_tiny = None
self.norm_ord = 1
self.skipped_step_length = 1
self.previous_residual = None

self.previous_moreg = 1.0
self.moreg_strides = [1]
self.moreg_steps = [int(0.1 * config["infer_steps"]), int(0.9 * config["infer_steps"])]
self.moreg_hyp = [0.385, 8, 1, 2]
self.mograd_mul = 10
self.spatial_dim = config.get("adacache_spatial_dim", 0)


class Flux2TransformerInferAdaCaching(Flux2TransformerInferCaching):
def __init__(self, config):
super().__init__(config)
self.decisive_double_block_id = config.get("num_layers", 10) // 2
self.codebook = {0.03: 12, 0.05: 10, 0.07: 8, 0.09: 6, 0.11: 4, 1.00: 3}
self.args_even = Flux2AdaArgs(config)
self.args_odd = Flux2AdaArgs(config)

def infer(self, block_weights, pre_infer_out):
if self.scheduler.infer_condition:
index = self.scheduler.step_index
caching_records = self.scheduler.caching_records

if caching_records[index] or self.must_calc(index):
hidden_states = self.infer_calculating(block_weights, pre_infer_out)

if index <= self.scheduler.infer_steps - 2:
self.args_even.skipped_step_length = self.calculate_skip_step_length()
for i in range(1, self.args_even.skipped_step_length):
if (index + i) <= self.scheduler.infer_steps - 1:
self.scheduler.caching_records[index + i] = False
else:
hidden_states = self.infer_using_cache(pre_infer_out)
else:
index = self.scheduler.step_index
caching_records = self.scheduler.caching_records_2

if caching_records[index] or self.must_calc(index):
hidden_states = self.infer_calculating(block_weights, pre_infer_out)

if index <= self.scheduler.infer_steps - 2:
self.args_odd.skipped_step_length = self.calculate_skip_step_length()
for i in range(1, self.args_odd.skipped_step_length):
if (index + i) <= self.scheduler.infer_steps - 1:
self.scheduler.caching_records_2[index + i] = False
else:
hidden_states = self.infer_using_cache(pre_infer_out)

return hidden_states
Comment on lines +41 to +71

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is significant code duplication between the if self.scheduler.infer_condition and else branches in the infer method. We can dynamically select the caching records and arguments to simplify the logic and improve maintainability.

    def infer(self, block_weights, pre_infer_out):
        index = self.scheduler.step_index
        if self.scheduler.infer_condition:
            caching_records = self.scheduler.caching_records
            ada_args = self.args_even
        else:
            caching_records = self.scheduler.caching_records_2
            ada_args = self.args_odd

        if caching_records[index] or self.must_calc(index):
            hidden_states = self.infer_calculating(block_weights, pre_infer_out)

            if index <= self.scheduler.infer_steps - 2:
                ada_args.skipped_step_length = self.calculate_skip_step_length()
                for i in range(1, ada_args.skipped_step_length):
                    if (index + i) <= self.scheduler.infer_steps - 1:
                        caching_records[index + i] = False
        else:
            hidden_states = self.infer_using_cache(pre_infer_out)

        return hidden_states


def infer_calculating(self, block_weights, pre_infer_out):
ori_hidden_states = pre_infer_out.hidden_states.clone()
ada_args = self.args_even if self.scheduler.infer_condition else self.args_odd

def on_decisive_block(gated_img_attn):
ada_args.now_residual_tiny = gated_img_attn.squeeze(0)

hidden_states = self._infer_forward(
block_weights,
pre_infer_out,
decisive_block_id=self.decisive_double_block_id,
on_decisive_block=on_decisive_block,
)

ada_args.previous_residual = hidden_states - ori_hidden_states
return hidden_states

def infer_using_cache(self, pre_infer_out):
hidden_states = pre_infer_out.hidden_states
if self.scheduler.infer_condition:
hidden_states = hidden_states + self.args_even.previous_residual
else:
hidden_states = hidden_states + self.args_odd.previous_residual
return hidden_states

def _update_spatial_dim(self, ada_args, residual):
if ada_args.spatial_dim <= 0:
ada_args.spatial_dim = residual.shape[0]

def _calculate_skip_step_length_for_args(self, ada_args):
if ada_args.previous_residual_tiny is None:
ada_args.previous_residual_tiny = ada_args.now_residual_tiny
return 1

cache = ada_args.previous_residual_tiny
res = ada_args.now_residual_tiny
self._update_spatial_dim(ada_args, res)
norm_ord = ada_args.norm_ord
cache_diff = (cache - res).norm(dim=(0, 1), p=norm_ord) / cache.norm(dim=(0, 1), p=norm_ord)
cache_diff = cache_diff / ada_args.skipped_step_length

if ada_args.moreg_steps[0] <= self.scheduler.step_index <= ada_args.moreg_steps[1]:
moreg = 0
for i in ada_args.moreg_strides:
moreg_i = (res[i * ada_args.spatial_dim :, :] - res[: -i * ada_args.spatial_dim, :]).norm(p=norm_ord)
moreg_i /= res[i * ada_args.spatial_dim :, :].norm(p=norm_ord) + res[: -i * ada_args.spatial_dim, :].norm(p=norm_ord)
moreg += moreg_i
moreg = moreg / len(ada_args.moreg_strides)
moreg = ((1 / ada_args.moreg_hyp[0] * moreg) ** ada_args.moreg_hyp[1]) / ada_args.moreg_hyp[2]
else:
moreg = 1.0

mograd = ada_args.mograd_mul * (moreg - ada_args.previous_moreg) / ada_args.skipped_step_length
ada_args.previous_moreg = moreg
moreg = moreg + abs(mograd)
cache_diff = cache_diff * moreg

metric_thres, cache_rates = list(self.codebook.keys()), list(self.codebook.values())
if cache_diff < metric_thres[0]:
new_rate = cache_rates[0]
elif cache_diff < metric_thres[1]:
new_rate = cache_rates[1]
elif cache_diff < metric_thres[2]:
new_rate = cache_rates[2]
elif cache_diff < metric_thres[3]:
new_rate = cache_rates[3]
elif cache_diff < metric_thres[4]:
new_rate = cache_rates[4]
else:
new_rate = cache_rates[-1]

ada_args.previous_residual_tiny = ada_args.now_residual_tiny
return new_rate
Comment on lines +98 to +145

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The motion regulation (moreg) calculation is designed for video models to measure temporal motion across frames. Since Flux2 is a 2D image model, there is no temporal dimension, and moreg is completely redundant. Furthermore, because ada_args.spatial_dim defaults to 0 and is updated to residual.shape[0] (the full image sequence length), the slicing res[i * ada_args.spatial_dim :, :] results in empty tensors, leading to division-by-zero (NaN) errors during norm calculations. Removing the moreg logic entirely avoids these issues, simplifies the code, and improves performance.

    def _calculate_skip_step_length_for_args(self, ada_args):
        if ada_args.previous_residual_tiny is None:
            ada_args.previous_residual_tiny = ada_args.now_residual_tiny
            return 1

        cache = ada_args.previous_residual_tiny
        res = ada_args.now_residual_tiny
        norm_ord = ada_args.norm_ord
        cache_diff = (cache - res).norm(dim=(0, 1), p=norm_ord) / cache.norm(dim=(0, 1), p=norm_ord)
        cache_diff = cache_diff / ada_args.skipped_step_length

        new_rate = list(self.codebook.values())[-1]
        for thres, rate in self.codebook.items():
            if cache_diff < thres:
                new_rate = rate
                break

        ada_args.previous_residual_tiny = ada_args.now_residual_tiny
        return new_rate


def calculate_skip_step_length(self):
if self.scheduler.infer_condition:
return self._calculate_skip_step_length_for_args(self.args_even)
return self._calculate_skip_step_length_for_args(self.args_odd)

def clear(self):
for ada_args in (self.args_even, self.args_odd):
if ada_args.previous_residual is not None:
ada_args.previous_residual = ada_args.previous_residual.cpu()
if ada_args.previous_residual_tiny is not None:
ada_args.previous_residual_tiny = ada_args.previous_residual_tiny.cpu()
if ada_args.now_residual_tiny is not None:
ada_args.now_residual_tiny = ada_args.now_residual_tiny.cpu()

ada_args.previous_residual = None
ada_args.previous_residual_tiny = None
ada_args.now_residual_tiny = None
Comment on lines +154 to +163

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Moving the residual tensors to CPU immediately before setting them to None is redundant and inefficient. Setting them to None releases the references, allowing PyTorch to free the GPU memory directly. The .cpu() calls waste GPU-to-CPU bandwidth and CPU memory allocation overhead.

            ada_args.previous_residual = None
            ada_args.previous_residual_tiny = None
            ada_args.now_residual_tiny = None


torch.cuda.empty_cache()
91 changes: 51 additions & 40 deletions lightx2v/models/networks/flux2/infer/transformer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def infer_double_stream_block(
temb_mod_img,
temb_mod_txt,
image_rotary_emb,
img_attn_hook=None,
):
heads = self.config["num_attention_heads"]
head_dim = self.config["attention_head_dim"]
Expand Down Expand Up @@ -134,7 +135,10 @@ def infer_double_stream_block(
img_attn_output = block_weights.to_out.apply(img_attn_output)
txt_attn_output = block_weights.to_add_out.apply(txt_attn_output)

hidden_states = hidden_states + gate_msa * img_attn_output
gated_img_attn = gate_msa * img_attn_output
if img_attn_hook is not None:
img_attn_hook(gated_img_attn)
hidden_states = hidden_states + gated_img_attn
encoder_hidden_states = encoder_hidden_states + c_gate_msa * txt_attn_output
norm_hidden_states2 = F.layer_norm(hidden_states, (hidden_states.shape[-1],))
norm_hidden_states2 = (norm_hidden_states2 * (1 + scale_mlp) + shift_mlp).squeeze(0)
Expand Down Expand Up @@ -237,62 +241,67 @@ def infer_single_stream_block(

return hidden_states

def infer(self, block_weights, pre_infer_out):
def _prepare_image_rotary_emb(self, image_rotary_emb, num_txt_tokens):
if self.seq_p_group is None or image_rotary_emb is None:
return image_rotary_emb

world_size = dist.get_world_size(self.seq_p_group)
cur_rank = dist.get_rank(self.seq_p_group)

if isinstance(image_rotary_emb, tuple):
freqs_cos, freqs_sin = image_rotary_emb

txt_cos = freqs_cos[:num_txt_tokens]
img_cos = freqs_cos[num_txt_tokens:]
txt_sin = freqs_sin[:num_txt_tokens]
img_sin = freqs_sin[num_txt_tokens:]

seqlen = img_cos.shape[0]
padding_size = (world_size - (seqlen % world_size)) % world_size
if padding_size > 0:
img_cos = F.pad(img_cos, (0, 0, 0, padding_size))
img_sin = F.pad(img_sin, (0, 0, 0, padding_size))
img_cos = torch.chunk(img_cos, world_size, dim=0)[cur_rank]
img_sin = torch.chunk(img_sin, world_size, dim=0)[cur_rank]

freqs_cos = torch.cat([txt_cos, img_cos], dim=0)
freqs_sin = torch.cat([txt_sin, img_sin], dim=0)
return (freqs_cos, freqs_sin)

txt_emb = image_rotary_emb[:num_txt_tokens]
img_emb = image_rotary_emb[num_txt_tokens:]

seqlen = img_emb.shape[0]
padding_size = (world_size - (seqlen % world_size)) % world_size
if padding_size > 0:
img_emb = F.pad(img_emb, (0, 0, 0, padding_size))
img_emb = torch.chunk(img_emb, world_size, dim=0)[cur_rank]
return torch.cat([txt_emb, img_emb], dim=0)

def _infer_forward(self, block_weights, pre_infer_out, decisive_block_id=None, on_decisive_block=None):
hidden_states = pre_infer_out.hidden_states
encoder_hidden_states = pre_infer_out.encoder_hidden_states
timestep = pre_infer_out.timestep
image_rotary_emb = pre_infer_out.image_rotary_emb

num_txt_tokens = encoder_hidden_states.shape[0]

if self.seq_p_group is not None and image_rotary_emb is not None:
world_size = dist.get_world_size(self.seq_p_group)
cur_rank = dist.get_rank(self.seq_p_group)

if isinstance(image_rotary_emb, tuple):
freqs_cos, freqs_sin = image_rotary_emb

txt_cos = freqs_cos[:num_txt_tokens]
img_cos = freqs_cos[num_txt_tokens:]
txt_sin = freqs_sin[:num_txt_tokens]
img_sin = freqs_sin[num_txt_tokens:]

seqlen = img_cos.shape[0]
padding_size = (world_size - (seqlen % world_size)) % world_size
if padding_size > 0:
img_cos = F.pad(img_cos, (0, 0, 0, padding_size))
img_sin = F.pad(img_sin, (0, 0, 0, padding_size))
img_cos = torch.chunk(img_cos, world_size, dim=0)[cur_rank]
img_sin = torch.chunk(img_sin, world_size, dim=0)[cur_rank]

freqs_cos = torch.cat([txt_cos, img_cos], dim=0)
freqs_sin = torch.cat([txt_sin, img_sin], dim=0)
image_rotary_emb = (freqs_cos, freqs_sin)
else:
txt_emb = image_rotary_emb[:num_txt_tokens]
img_emb = image_rotary_emb[num_txt_tokens:]

seqlen = img_emb.shape[0]
padding_size = (world_size - (seqlen % world_size)) % world_size
if padding_size > 0:
img_emb = F.pad(img_emb, (0, 0, 0, padding_size))
img_emb = torch.chunk(img_emb, world_size, dim=0)[cur_rank]

image_rotary_emb = torch.cat([txt_emb, img_emb], dim=0)
image_rotary_emb = self._prepare_image_rotary_emb(image_rotary_emb, num_txt_tokens)

timestep_act = F.silu(timestep)
double_stream_mod_img = block_weights.double_stream_modulation_img_linear.apply(timestep_act)
double_stream_mod_txt = block_weights.double_stream_modulation_txt_linear.apply(timestep_act)
single_stream_mod = block_weights.single_stream_modulation_linear.apply(timestep_act)

for block in block_weights.double_blocks:
for block_idx, block in enumerate(block_weights.double_blocks):
block_hook = on_decisive_block if block_idx == decisive_block_id else None
encoder_hidden_states, hidden_states = self.infer_double_stream_block(
block,
hidden_states,
encoder_hidden_states,
double_stream_mod_img,
double_stream_mod_txt,
image_rotary_emb,
img_attn_hook=block_hook,
)

hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=0)
Expand All @@ -306,8 +315,10 @@ def infer(self, block_weights, pre_infer_out):
image_rotary_emb,
num_txt_tokens=num_txt_tokens,
)
hidden_states = hidden_states[num_txt_tokens:, ...]
return hidden_states
return hidden_states[num_txt_tokens:, ...]

def infer(self, block_weights, pre_infer_out):
return self._infer_forward(block_weights, pre_infer_out)


# Backward-compatible alias
Expand Down
Loading
Loading