From 524d7e2dd8ea8117b0230ddb6721d75f03fb0c2f Mon Sep 17 00:00:00 2001 From: wangshankun Date: Tue, 16 Jun 2026 08:38:40 +0000 Subject: [PATCH 1/2] [feat] Add ada cache for flux2 ppt --- .../flux2_klein_i2i_inpaint_mask_cache.json | 15 ++ ...n_i2i_inpaint_mask_cfg_parallel_cache.json | 18 ++ .../flux2/infer/feature_caching/__init__.py | 0 .../feature_caching/transformer_infer.py | 165 ++++++++++++++++++ .../networks/flux2/infer/transformer_infer.py | 91 +++++----- lightx2v/models/networks/flux2/model.py | 29 ++- lightx2v/models/runners/flux2/flux2_runner.py | 20 ++- .../flux2/feature_caching/__init__.py | 0 .../flux2/feature_caching/scheduler.py | 23 +++ ...nfer_flux2_klein_i2i_inpaint_mask_cache.sh | 15 ++ ...ein_i2i_inpaint_mask_cfg_parallel_cache.sh | 15 ++ 11 files changed, 343 insertions(+), 48 deletions(-) create mode 100644 configs/flux2/flux2_klein_i2i_inpaint_mask_cache.json create mode 100644 configs/flux2/flux2_klein_i2i_inpaint_mask_cfg_parallel_cache.json create mode 100644 lightx2v/models/networks/flux2/infer/feature_caching/__init__.py create mode 100644 lightx2v/models/networks/flux2/infer/feature_caching/transformer_infer.py create mode 100644 lightx2v/models/schedulers/flux2/feature_caching/__init__.py create mode 100644 lightx2v/models/schedulers/flux2/feature_caching/scheduler.py create mode 100644 scripts/flux2/infer_flux2_klein_i2i_inpaint_mask_cache.sh create mode 100644 scripts/flux2/infer_flux2_klein_i2i_inpaint_mask_cfg_parallel_cache.sh diff --git a/configs/flux2/flux2_klein_i2i_inpaint_mask_cache.json b/configs/flux2/flux2_klein_i2i_inpaint_mask_cache.json new file mode 100644 index 000000000..442173181 --- /dev/null +++ b/configs/flux2/flux2_klein_i2i_inpaint_mask_cache.json @@ -0,0 +1,15 @@ +{ + "model_cls": "flux2_klein", + "task": "i2i", + "task_variant": "edit", + "infer_steps": 40, + "sample_guide_scale": 4.0, + "feature_caching": "Ada", + "vae_scale_factor": 16, + "enable_cfg": true, + "patch_size": 2, + "tokenizer_max_length": 512, + "rope_type": "flashinfer", + "max_image_area": 1048576, + "inpaint_mask_enabled": true +} diff --git a/configs/flux2/flux2_klein_i2i_inpaint_mask_cfg_parallel_cache.json b/configs/flux2/flux2_klein_i2i_inpaint_mask_cfg_parallel_cache.json new file mode 100644 index 000000000..633a31d8a --- /dev/null +++ b/configs/flux2/flux2_klein_i2i_inpaint_mask_cfg_parallel_cache.json @@ -0,0 +1,18 @@ +{ + "model_cls": "flux2_klein", + "task": "i2i", + "task_variant": "edit", + "infer_steps": 40, + "sample_guide_scale": 4.0, + "vae_scale_factor": 16, + "feature_caching": "Ada", + "enable_cfg": true, + "patch_size": 2, + "tokenizer_max_length": 512, + "rope_type": "flashinfer", + "max_image_area": 1048576, + "inpaint_mask_enabled": true, + "parallel": { + "cfg_p_size": 2 + } +} diff --git a/lightx2v/models/networks/flux2/infer/feature_caching/__init__.py b/lightx2v/models/networks/flux2/infer/feature_caching/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightx2v/models/networks/flux2/infer/feature_caching/transformer_infer.py b/lightx2v/models/networks/flux2/infer/feature_caching/transformer_infer.py new file mode 100644 index 000000000..a1f11fda8 --- /dev/null +++ b/lightx2v/models/networks/flux2/infer/feature_caching/transformer_infer.py @@ -0,0 +1,165 @@ +import torch + +from lightx2v.models.networks.flux2.infer.transformer_infer import Flux2TransformerInfer + + +class Flux2TransformerInferCaching(Flux2TransformerInfer): + def __init__(self, config): + super().__init__(config) + self.must_calc_steps = [] + if self.config.get("changing_resolution", False): + self.must_calc_steps = self.config["changing_resolution_steps"] + + def must_calc(self, step_index): + return step_index in self.must_calc_steps + + +class Flux2AdaArgs: + def __init__(self, config): + self.previous_residual_tiny = None + self.now_residual_tiny = None + self.norm_ord = 1 + self.skipped_step_length = 1 + self.previous_residual = None + + self.previous_moreg = 1.0 + self.moreg_strides = [1] + self.moreg_steps = [int(0.1 * config["infer_steps"]), int(0.9 * config["infer_steps"])] + self.moreg_hyp = [0.385, 8, 1, 2] + self.mograd_mul = 10 + self.spatial_dim = config.get("adacache_spatial_dim", 0) + + +class Flux2TransformerInferAdaCaching(Flux2TransformerInferCaching): + def __init__(self, config): + super().__init__(config) + self.decisive_double_block_id = config.get("num_layers", 10) // 2 + self.codebook = {0.03: 12, 0.05: 10, 0.07: 8, 0.09: 6, 0.11: 4, 1.00: 3} + self.args_even = Flux2AdaArgs(config) + self.args_odd = Flux2AdaArgs(config) + + def infer(self, block_weights, pre_infer_out): + if self.scheduler.infer_condition: + index = self.scheduler.step_index + caching_records = self.scheduler.caching_records + + if caching_records[index] or self.must_calc(index): + hidden_states = self.infer_calculating(block_weights, pre_infer_out) + + if index <= self.scheduler.infer_steps - 2: + self.args_even.skipped_step_length = self.calculate_skip_step_length() + for i in range(1, self.args_even.skipped_step_length): + if (index + i) <= self.scheduler.infer_steps - 1: + self.scheduler.caching_records[index + i] = False + else: + hidden_states = self.infer_using_cache(pre_infer_out) + else: + index = self.scheduler.step_index + caching_records = self.scheduler.caching_records_2 + + if caching_records[index] or self.must_calc(index): + hidden_states = self.infer_calculating(block_weights, pre_infer_out) + + if index <= self.scheduler.infer_steps - 2: + self.args_odd.skipped_step_length = self.calculate_skip_step_length() + for i in range(1, self.args_odd.skipped_step_length): + if (index + i) <= self.scheduler.infer_steps - 1: + self.scheduler.caching_records_2[index + i] = False + else: + hidden_states = self.infer_using_cache(pre_infer_out) + + return hidden_states + + def infer_calculating(self, block_weights, pre_infer_out): + ori_hidden_states = pre_infer_out.hidden_states.clone() + ada_args = self.args_even if self.scheduler.infer_condition else self.args_odd + + def on_decisive_block(gated_img_attn): + ada_args.now_residual_tiny = gated_img_attn.squeeze(0) + + hidden_states = self._infer_forward( + block_weights, + pre_infer_out, + decisive_block_id=self.decisive_double_block_id, + on_decisive_block=on_decisive_block, + ) + + ada_args.previous_residual = hidden_states - ori_hidden_states + return hidden_states + + def infer_using_cache(self, pre_infer_out): + hidden_states = pre_infer_out.hidden_states + if self.scheduler.infer_condition: + hidden_states = hidden_states + self.args_even.previous_residual + else: + hidden_states = hidden_states + self.args_odd.previous_residual + return hidden_states + + def _update_spatial_dim(self, ada_args, residual): + if ada_args.spatial_dim <= 0: + ada_args.spatial_dim = residual.shape[0] + + def _calculate_skip_step_length_for_args(self, ada_args): + if ada_args.previous_residual_tiny is None: + ada_args.previous_residual_tiny = ada_args.now_residual_tiny + return 1 + + cache = ada_args.previous_residual_tiny + res = ada_args.now_residual_tiny + self._update_spatial_dim(ada_args, res) + norm_ord = ada_args.norm_ord + cache_diff = (cache - res).norm(dim=(0, 1), p=norm_ord) / cache.norm(dim=(0, 1), p=norm_ord) + cache_diff = cache_diff / ada_args.skipped_step_length + + if ada_args.moreg_steps[0] <= self.scheduler.step_index <= ada_args.moreg_steps[1]: + moreg = 0 + for i in ada_args.moreg_strides: + moreg_i = (res[i * ada_args.spatial_dim :, :] - res[: -i * ada_args.spatial_dim, :]).norm(p=norm_ord) + moreg_i /= res[i * ada_args.spatial_dim :, :].norm(p=norm_ord) + res[: -i * ada_args.spatial_dim, :].norm(p=norm_ord) + moreg += moreg_i + moreg = moreg / len(ada_args.moreg_strides) + moreg = ((1 / ada_args.moreg_hyp[0] * moreg) ** ada_args.moreg_hyp[1]) / ada_args.moreg_hyp[2] + else: + moreg = 1.0 + + mograd = ada_args.mograd_mul * (moreg - ada_args.previous_moreg) / ada_args.skipped_step_length + ada_args.previous_moreg = moreg + moreg = moreg + abs(mograd) + cache_diff = cache_diff * moreg + + metric_thres, cache_rates = list(self.codebook.keys()), list(self.codebook.values()) + if cache_diff < metric_thres[0]: + new_rate = cache_rates[0] + elif cache_diff < metric_thres[1]: + new_rate = cache_rates[1] + elif cache_diff < metric_thres[2]: + new_rate = cache_rates[2] + elif cache_diff < metric_thres[3]: + new_rate = cache_rates[3] + elif cache_diff < metric_thres[4]: + new_rate = cache_rates[4] + else: + new_rate = cache_rates[-1] + + ada_args.previous_residual_tiny = ada_args.now_residual_tiny + return new_rate + + def calculate_skip_step_length(self): + if self.scheduler.infer_condition: + return self._calculate_skip_step_length_for_args(self.args_even) + return self._calculate_skip_step_length_for_args(self.args_odd) + + def clear(self): + for ada_args in (self.args_even, self.args_odd): + if ada_args.previous_residual is not None: + ada_args.previous_residual = ada_args.previous_residual.cpu() + if ada_args.previous_residual_tiny is not None: + ada_args.previous_residual_tiny = ada_args.previous_residual_tiny.cpu() + if ada_args.now_residual_tiny is not None: + ada_args.now_residual_tiny = ada_args.now_residual_tiny.cpu() + + ada_args.previous_residual = None + ada_args.previous_residual_tiny = None + ada_args.now_residual_tiny = None + + torch.cuda.empty_cache() diff --git a/lightx2v/models/networks/flux2/infer/transformer_infer.py b/lightx2v/models/networks/flux2/infer/transformer_infer.py index b6205a041..b4d120990 100644 --- a/lightx2v/models/networks/flux2/infer/transformer_infer.py +++ b/lightx2v/models/networks/flux2/infer/transformer_infer.py @@ -56,6 +56,7 @@ def infer_double_stream_block( temb_mod_img, temb_mod_txt, image_rotary_emb, + img_attn_hook=None, ): heads = self.config["num_attention_heads"] head_dim = self.config["attention_head_dim"] @@ -134,7 +135,10 @@ def infer_double_stream_block( img_attn_output = block_weights.to_out.apply(img_attn_output) txt_attn_output = block_weights.to_add_out.apply(txt_attn_output) - hidden_states = hidden_states + gate_msa * img_attn_output + gated_img_attn = gate_msa * img_attn_output + if img_attn_hook is not None: + img_attn_hook(gated_img_attn) + hidden_states = hidden_states + gated_img_attn encoder_hidden_states = encoder_hidden_states + c_gate_msa * txt_attn_output norm_hidden_states2 = F.layer_norm(hidden_states, (hidden_states.shape[-1],)) norm_hidden_states2 = (norm_hidden_states2 * (1 + scale_mlp) + shift_mlp).squeeze(0) @@ -237,55 +241,59 @@ def infer_single_stream_block( return hidden_states - def infer(self, block_weights, pre_infer_out): + def _prepare_image_rotary_emb(self, image_rotary_emb, num_txt_tokens): + if self.seq_p_group is None or image_rotary_emb is None: + return image_rotary_emb + + world_size = dist.get_world_size(self.seq_p_group) + cur_rank = dist.get_rank(self.seq_p_group) + + if isinstance(image_rotary_emb, tuple): + freqs_cos, freqs_sin = image_rotary_emb + + txt_cos = freqs_cos[:num_txt_tokens] + img_cos = freqs_cos[num_txt_tokens:] + txt_sin = freqs_sin[:num_txt_tokens] + img_sin = freqs_sin[num_txt_tokens:] + + seqlen = img_cos.shape[0] + padding_size = (world_size - (seqlen % world_size)) % world_size + if padding_size > 0: + img_cos = F.pad(img_cos, (0, 0, 0, padding_size)) + img_sin = F.pad(img_sin, (0, 0, 0, padding_size)) + img_cos = torch.chunk(img_cos, world_size, dim=0)[cur_rank] + img_sin = torch.chunk(img_sin, world_size, dim=0)[cur_rank] + + freqs_cos = torch.cat([txt_cos, img_cos], dim=0) + freqs_sin = torch.cat([txt_sin, img_sin], dim=0) + return (freqs_cos, freqs_sin) + + txt_emb = image_rotary_emb[:num_txt_tokens] + img_emb = image_rotary_emb[num_txt_tokens:] + + seqlen = img_emb.shape[0] + padding_size = (world_size - (seqlen % world_size)) % world_size + if padding_size > 0: + img_emb = F.pad(img_emb, (0, 0, 0, padding_size)) + img_emb = torch.chunk(img_emb, world_size, dim=0)[cur_rank] + return torch.cat([txt_emb, img_emb], dim=0) + + def _infer_forward(self, block_weights, pre_infer_out, decisive_block_id=None, on_decisive_block=None): hidden_states = pre_infer_out.hidden_states encoder_hidden_states = pre_infer_out.encoder_hidden_states timestep = pre_infer_out.timestep image_rotary_emb = pre_infer_out.image_rotary_emb num_txt_tokens = encoder_hidden_states.shape[0] - - if self.seq_p_group is not None and image_rotary_emb is not None: - world_size = dist.get_world_size(self.seq_p_group) - cur_rank = dist.get_rank(self.seq_p_group) - - if isinstance(image_rotary_emb, tuple): - freqs_cos, freqs_sin = image_rotary_emb - - txt_cos = freqs_cos[:num_txt_tokens] - img_cos = freqs_cos[num_txt_tokens:] - txt_sin = freqs_sin[:num_txt_tokens] - img_sin = freqs_sin[num_txt_tokens:] - - seqlen = img_cos.shape[0] - padding_size = (world_size - (seqlen % world_size)) % world_size - if padding_size > 0: - img_cos = F.pad(img_cos, (0, 0, 0, padding_size)) - img_sin = F.pad(img_sin, (0, 0, 0, padding_size)) - img_cos = torch.chunk(img_cos, world_size, dim=0)[cur_rank] - img_sin = torch.chunk(img_sin, world_size, dim=0)[cur_rank] - - freqs_cos = torch.cat([txt_cos, img_cos], dim=0) - freqs_sin = torch.cat([txt_sin, img_sin], dim=0) - image_rotary_emb = (freqs_cos, freqs_sin) - else: - txt_emb = image_rotary_emb[:num_txt_tokens] - img_emb = image_rotary_emb[num_txt_tokens:] - - seqlen = img_emb.shape[0] - padding_size = (world_size - (seqlen % world_size)) % world_size - if padding_size > 0: - img_emb = F.pad(img_emb, (0, 0, 0, padding_size)) - img_emb = torch.chunk(img_emb, world_size, dim=0)[cur_rank] - - image_rotary_emb = torch.cat([txt_emb, img_emb], dim=0) + image_rotary_emb = self._prepare_image_rotary_emb(image_rotary_emb, num_txt_tokens) timestep_act = F.silu(timestep) double_stream_mod_img = block_weights.double_stream_modulation_img_linear.apply(timestep_act) double_stream_mod_txt = block_weights.double_stream_modulation_txt_linear.apply(timestep_act) single_stream_mod = block_weights.single_stream_modulation_linear.apply(timestep_act) - for block in block_weights.double_blocks: + for block_idx, block in enumerate(block_weights.double_blocks): + block_hook = on_decisive_block if block_idx == decisive_block_id else None encoder_hidden_states, hidden_states = self.infer_double_stream_block( block, hidden_states, @@ -293,6 +301,7 @@ def infer(self, block_weights, pre_infer_out): double_stream_mod_img, double_stream_mod_txt, image_rotary_emb, + img_attn_hook=block_hook, ) hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=0) @@ -306,8 +315,10 @@ def infer(self, block_weights, pre_infer_out): image_rotary_emb, num_txt_tokens=num_txt_tokens, ) - hidden_states = hidden_states[num_txt_tokens:, ...] - return hidden_states + return hidden_states[num_txt_tokens:, ...] + + def infer(self, block_weights, pre_infer_out): + return self._infer_forward(block_weights, pre_infer_out) # Backward-compatible alias diff --git a/lightx2v/models/networks/flux2/model.py b/lightx2v/models/networks/flux2/model.py index 998465a3a..d0c389194 100644 --- a/lightx2v/models/networks/flux2/model.py +++ b/lightx2v/models/networks/flux2/model.py @@ -3,6 +3,7 @@ from torch.nn import functional as F from lightx2v.models.networks.base_model import BaseTransformerModel +from lightx2v.models.networks.flux2.infer.feature_caching.transformer_infer import Flux2TransformerInferAdaCaching from lightx2v.models.networks.flux2.infer.offload.transformer_infer import Flux2OffloadTransformerInfer from lightx2v.models.networks.flux2.infer.post_infer import Flux2PostInfer from lightx2v.models.networks.flux2.infer.pre_infer import Flux2DevPreInfer, Flux2PreInfer @@ -103,10 +104,18 @@ class Flux2KleinTransformerModel(_Flux2TransformerModelBase): pre_weight_class = Flux2PreWeights def _init_infer_class(self): - if self.cpu_offload and self.offload_granularity == "block": - self.transformer_infer_class = Flux2OffloadTransformerInfer + feature_caching = self.config.get("feature_caching", "NoCaching") + if feature_caching in ("NoCaching", "None"): + if self.cpu_offload and self.offload_granularity == "block": + self.transformer_infer_class = Flux2OffloadTransformerInfer + else: + self.transformer_infer_class = Flux2TransformerInfer + elif feature_caching == "Ada": + if self.cpu_offload and self.offload_granularity == "block": + raise NotImplementedError("Flux2 AdaCache does not support block-level cpu_offload yet") + self.transformer_infer_class = Flux2TransformerInferAdaCaching else: - self.transformer_infer_class = Flux2TransformerInfer + raise NotImplementedError(f"Unsupported feature_caching type: {feature_caching}") self.pre_infer_class = Flux2PreInfer self.post_infer_class = Flux2PostInfer @@ -219,10 +228,18 @@ class Flux2DevTransformerModel(_Flux2TransformerModelBase): pre_weight_class = Flux2DevPreWeights def _init_infer_class(self): - if self.cpu_offload and self.offload_granularity == "block": - self.transformer_infer_class = Flux2OffloadTransformerInfer + feature_caching = self.config.get("feature_caching", "NoCaching") + if feature_caching in ("NoCaching", "None"): + if self.cpu_offload and self.offload_granularity == "block": + self.transformer_infer_class = Flux2OffloadTransformerInfer + else: + self.transformer_infer_class = Flux2TransformerInfer + elif feature_caching == "Ada": + if self.cpu_offload and self.offload_granularity == "block": + raise NotImplementedError("Flux2 AdaCache does not support block-level cpu_offload yet") + self.transformer_infer_class = Flux2TransformerInferAdaCaching else: - self.transformer_infer_class = Flux2TransformerInfer + raise NotImplementedError(f"Unsupported feature_caching type: {feature_caching}") self.pre_infer_class = Flux2DevPreInfer self.post_infer_class = Flux2PostInfer diff --git a/lightx2v/models/runners/flux2/flux2_runner.py b/lightx2v/models/runners/flux2/flux2_runner.py index ae0182fa4..3c9334295 100644 --- a/lightx2v/models/runners/flux2/flux2_runner.py +++ b/lightx2v/models/runners/flux2/flux2_runner.py @@ -8,6 +8,7 @@ from lightx2v.models.networks.flux2.model import Flux2DevTransformerModel, Flux2KleinTransformerModel from lightx2v.models.runners.default_runner import DefaultRunner +from lightx2v.models.schedulers.flux2.feature_caching.scheduler import Flux2DevSchedulerCaching, Flux2SchedulerCaching from lightx2v.models.schedulers.flux2.scheduler import Flux2DevScheduler, Flux2Scheduler from lightx2v.models.video_encoders.hf.flux2.vae import Flux2VAE from lightx2v.utils.profiler import ProfilingContext4DebugL1, ProfilingContext4DebugL2 @@ -38,6 +39,13 @@ def __init__(self, config): config["vae_scale_factor"] = config.get("vae_scale_factor", 16) super().__init__(config) + def _get_scheduler_class(self): + if self.config.get("feature_caching", "NoCaching") in ("NoCaching", "None"): + return None + if self.config.get("feature_caching") == "Ada": + return Flux2SchedulerCaching + raise NotImplementedError(f"Unsupported feature_caching type: {self.config.get('feature_caching')}") + @ProfilingContext4DebugL2("Load models") def load_model(self): self.text_encoders = self.load_text_encoder() @@ -391,7 +399,11 @@ def load_text_encoder(self): return [text_encoder] def init_scheduler(self): - self.scheduler = Flux2Scheduler(self.config) + caching_scheduler_class = self._get_scheduler_class() + if caching_scheduler_class is not None: + self.scheduler = caching_scheduler_class(self.config) + else: + self.scheduler = Flux2Scheduler(self.config) @ProfilingContext4DebugL1("Run Text Encoder") def run_text_encoder(self, text, image_list=None, neg_prompt=None): @@ -429,7 +441,11 @@ def load_text_encoder(self): return [text_encoder] def init_scheduler(self): - self.scheduler = Flux2DevScheduler(self.config) + caching_scheduler_class = self._get_scheduler_class() + if caching_scheduler_class is not None: + self.scheduler = Flux2DevSchedulerCaching(self.config) + else: + self.scheduler = Flux2DevScheduler(self.config) @ProfilingContext4DebugL1("Run Text Encoder") def run_text_encoder(self, text, image_list=None, neg_prompt=None): diff --git a/lightx2v/models/schedulers/flux2/feature_caching/__init__.py b/lightx2v/models/schedulers/flux2/feature_caching/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightx2v/models/schedulers/flux2/feature_caching/scheduler.py b/lightx2v/models/schedulers/flux2/feature_caching/scheduler.py new file mode 100644 index 000000000..38b8d0cbe --- /dev/null +++ b/lightx2v/models/schedulers/flux2/feature_caching/scheduler.py @@ -0,0 +1,23 @@ +from lightx2v.models.schedulers.flux2.scheduler import Flux2DevScheduler, Flux2Scheduler + + +class Flux2SchedulerCaching(Flux2Scheduler): + def __init__(self, config): + super().__init__(config) + self.caching_records_2 = [True] * self.infer_steps + + def _refresh_caching_records(self): + self.caching_records = [True] * self.infer_steps + self.caching_records_2 = [True] * self.infer_steps + + def set_timesteps(self): + super().set_timesteps() + self._refresh_caching_records() + + def clear(self): + if self.transformer_infer is not None: + self.transformer_infer.clear() + + +class Flux2DevSchedulerCaching(Flux2SchedulerCaching, Flux2DevScheduler): + pass diff --git a/scripts/flux2/infer_flux2_klein_i2i_inpaint_mask_cache.sh b/scripts/flux2/infer_flux2_klein_i2i_inpaint_mask_cache.sh new file mode 100644 index 000000000..7f6f98e25 --- /dev/null +++ b/scripts/flux2/infer_flux2_klein_i2i_inpaint_mask_cache.sh @@ -0,0 +1,15 @@ +#!/bin/bash +lightx2v_path= +model_path="/mnt/miaohua/wangshankun/HF/hub/models--black-forest-labs--FLUX.2-klein-4B/snapshots/ppt_260529_30e" +export CUDA_VISIBLE_DEVICES=5 + +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ + --model_cls flux2_klein \ + --task i2i \ + --model_path $model_path \ + --prompt "remove the masked foreground object and keep the background unchanged" \ + --image_path "${lightx2v_path}/assets/inputs/inpaint_mask" \ + --save_result_path "${lightx2v_path}/save_results/flux2_klein_i2i_inpaint_mask_cache.png" \ + --config_json "${lightx2v_path}/configs/flux2/flux2_klein_i2i_inpaint_mask_cache.json" diff --git a/scripts/flux2/infer_flux2_klein_i2i_inpaint_mask_cfg_parallel_cache.sh b/scripts/flux2/infer_flux2_klein_i2i_inpaint_mask_cfg_parallel_cache.sh new file mode 100644 index 000000000..0e7cab293 --- /dev/null +++ b/scripts/flux2/infer_flux2_klein_i2i_inpaint_mask_cfg_parallel_cache.sh @@ -0,0 +1,15 @@ +#!/bin/bash +lightx2v_path= +model_path="/mnt/miaohua/wangshankun/HF/hub/models--black-forest-labs--FLUX.2-klein-4B/snapshots/ppt_260529_30e" +export CUDA_VISIBLE_DEVICES=5,6 + +source ${lightx2v_path}/scripts/base/base.sh + +torchrun --nproc_per_node=2 -m lightx2v.infer \ + --model_cls flux2_klein \ + --task i2i \ + --model_path $model_path \ + --prompt "remove the masked foreground object and keep the background unchanged" \ + --image_path "${lightx2v_path}/assets/inputs/inpaint_mask" \ + --save_result_path "${lightx2v_path}/save_results/flux2_klein_i2i_inpaint_mask_cache.png" \ + --config_json "${lightx2v_path}/configs/flux2/flux2_klein_i2i_inpaint_mask_cfg_parallel_cache.json" From fc4d32b50646d56072e098fdfb3a28f4c66a97ca Mon Sep 17 00:00:00 2001 From: wangshankun Date: Tue, 16 Jun 2026 09:35:47 +0000 Subject: [PATCH 2/2] fix: code format --- lightx2v/infer.py | 3 +-- lightx2v/pipeline.py | 6 +----- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/lightx2v/infer.py b/lightx2v/infer.py index d5a24e8ab..4c820c97f 100755 --- a/lightx2v/infer.py +++ b/lightx2v/infer.py @@ -8,10 +8,9 @@ from lightx2v.common.ops import * from lightx2v.models.runners.bagel.bagel_runner import BagelRunner # noqa: F401 from lightx2v.models.runners.ernie_image.ernie_image_runner import ErnieImageRunner # noqa: F401 +from lightx2v.models.runners.flux2.flux2_runner import Flux2DevRunner, Flux2KleinRunner # noqa: F401 from lightx2v.models.runners.hidream_o1_image.hidream_o1_image_runner import HidreamO1ImageRunner # noqa: F401 from lightx2v.models.runners.hunyuan3d.hunyuan3d_shape_runner import Hunyuan3DShapeRunner # noqa: F401 - -# from lightx2v.models.runners.flux2.flux2_runner import Flux2DevRunner, Flux2KleinRunner # noqa: F401 from lightx2v.models.runners.hunyuan_video.hunyuan_video_15_distill_runner import HunyuanVideo15DistillRunner # noqa: F401 from lightx2v.models.runners.hunyuan_video.hunyuan_video_15_runner import HunyuanVideo15Runner # noqa: F401 from lightx2v.models.runners.longcat_image.longcat_image_runner import LongCatImageRunner # noqa: F401 diff --git a/lightx2v/pipeline.py b/lightx2v/pipeline.py index 197be6849..daf8ece78 100755 --- a/lightx2v/pipeline.py +++ b/lightx2v/pipeline.py @@ -11,12 +11,8 @@ from loguru import logger from lightx2v.common.ops import * - -try: - from lightx2v.models.runners.flux2.flux2_runner import Flux2DevRunner, Flux2KleinRunner # noqa: F401 -except (ImportError, ModuleNotFoundError) as e: - logger.warning(f"Flux2 runners not available: {e}") from lightx2v.models.runners.ernie_image.ernie_image_runner import ErnieImageRunner # noqa: F401 +from lightx2v.models.runners.flux2.flux2_runner import Flux2DevRunner, Flux2KleinRunner # noqa: F401 from lightx2v.models.runners.hunyuan_video.hunyuan_video_15_runner import HunyuanVideo15Runner # noqa: F401 from lightx2v.models.runners.longcat_image.longcat_image_runner import LongCatImageRunner # noqa: F401 from lightx2v.models.runners.ltx2.ltx2_runner import LTX2Runner # noqa: F401