请问该方法是否适配SDXL,我自己实验了一下,在图像编辑的时候出现了问题,不知道是否是我的代码有误?
这是把a dog 改成 a cat时的效果:
以下是我修改的代码:
`def intermediate_to_latent_sdxl(sd_pipe, sd_params, intermediate=None, intermediate_second = None, freeze_step = 0):
prompt = sd_params['prompt']
negative_prompt = sd_params['negative_prompt']
seed = sd_params['seed']
guidance_scale = sd_params['guidance_scale']
num_inference_steps = sd_params['num_inference_steps']
width = sd_params['width']
height = sd_params['height']
dtype = torch.float32
( prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = sd_pipe.encode_prompt(
prompt=prompt,
prompt_2=prompt,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt,
device=sd_pipe.device,
num_images_per_prompt=1,
do_classifier_free_guidance= guidance_scale > 1.0,
)
torch.manual_seed(seed)
sd_pipe.scheduler.set_timesteps(num_inference_steps, device='cuda')
timesteps = sd_pipe.scheduler.timesteps
xis = []
do_classifier_free_guidance = guidance_scale > 1.0
if intermediate is None:
shape = (1, 4, 64, 64)
intermediate = torch.randn(shape, generator=None, device='cuda', dtype=dtype)
print('latents are None')
add_time_ids = get_add_time_ids(original_size=(height, width), crops_coords_top_left=(0, 0), target_size=(height, width), device=sd_pipe.device, batch_size=prompt_embeds.shape[0])
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds])
add_time_ids_cfg = torch.cat([add_time_ids, add_time_ids])
unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids_cfg}
else:
prompt_embeds = prompt_embeds
add_text_embeds = pooled_prompt_embeds
add_time_ids_cfg = add_time_ids
unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids_cfg}
xis.append(intermediate)
with torch.no_grad():
for i, t in enumerate(timesteps):
if i < freeze_step:
continue
latent_model_input = torch.cat([intermediate] * 2) if do_classifier_free_guidance else intermediate
noise_pred = sd_pipe.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
added_cond_kwargs=unet_added_cond_kwargs,
return_dict=False,
)[0]
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if i < num_inference_steps - 1:
alpha_s = sd_pipe.scheduler.alphas_cumprod[timesteps[i + 1]].to(torch.float32)
alpha_t = sd_pipe.scheduler.alphas_cumprod[t].to(torch.float32)
else:
alpha_s = 1
alpha_t = sd_pipe.scheduler.alphas_cumprod[t].to(torch.float32)
sigma_s = (1 - alpha_s)**0.5
sigma_t = (1 - alpha_t)**0.5
alpha_s = alpha_s**0.5
alpha_t = alpha_t**0.5
coef_xt = alpha_s / alpha_t
coef_eps = sigma_s - sigma_t * coef_xt
if i == freeze_step:
if intermediate_second is not None:
print('have intermediate_second')
intermediate = intermediate_second.clone()
else:
print('dont have intermediate_second')
intermediate = coef_xt * intermediate + coef_eps * noise_pred
else:
# calculate i-1
alpha_p = sd_pipe.scheduler.alphas_cumprod[timesteps[i - 1]].to(torch.float32)
sigma_p = (1 - alpha_p) ** 0.5
alpha_p = alpha_p ** 0.5
# calculate t
t_p, t_t, t_s = sigma_p / alpha_p, sigma_t / alpha_t, sigma_s / alpha_s
# calculate delta
delta_1 = t_t - t_p
delta_2 = t_s - t_t
delta_3 = t_s - t_p
# calculate coef
coef_1 = delta_2 * delta_3 * alpha_s / delta_1
coef_2 = (delta_2/delta_1)**2*(alpha_s/alpha_p)
coef_3 = (delta_1 - delta_2)*delta_3/(delta_1**2)*(alpha_s / alpha_t)
# iterate
intermediate = coef_1 * noise_pred + coef_2 * xis[-2] + coef_3 * xis[-1]
xis.append(intermediate)
return xis[-1]
def latent_to_intermediate_sdxl(sd_pipe, sd_params, latent=None, freeze_step = 0):
prompt = sd_params['prompt']
negative_prompt = sd_params['negative_prompt']
seed = sd_params['seed']
guidance_scale = sd_params['guidance_scale']
num_inference_steps = sd_params['num_inference_steps']
width = sd_params['width']
height = sd_params['height']
dtype = torch.float32
( prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = sd_pipe.encode_prompt(
prompt=prompt,
prompt_2=prompt,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt,
device=sd_pipe.device,
num_images_per_prompt=1,
do_classifier_free_guidance= guidance_scale > 1.0,
)
torch.manual_seed(seed)
sd_pipe.scheduler.set_timesteps(num_inference_steps, device='cuda')
timesteps = sd_pipe.scheduler.timesteps
xis = []
do_classifier_free_guidance = guidance_scale > 1.0
if latent is None:
shape = (1, 4, 64, 64)
latent = torch.randn(shape, generator=None, device='cuda', dtype=dtype)
print('latents are None')
add_time_ids = get_add_time_ids(original_size=(height, width), crops_coords_top_left=(0, 0), target_size=(height, width), device=sd_pipe.device, batch_size=prompt_embeds.shape[0])
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds])
add_time_ids_cfg = torch.cat([add_time_ids, add_time_ids])
else:
prompt_embeds = prompt_embeds
add_text_embeds = pooled_prompt_embeds
add_time_ids_cfg = add_time_ids
unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids_cfg}
xis.append(latent)
with torch.no_grad():
for i, t in enumerate(timesteps):
if i >= num_inference_steps - freeze_step:
continue
# print('###', i)
index = num_inference_steps - i - 1
time = timesteps[index + 1] if index < num_inference_steps - 1 else 1
latent_model_input = torch.cat([latent] * 2) if do_classifier_free_guidance else latent
noise_pred = sd_pipe.unet(
latent_model_input,
time,
encoder_hidden_states=prompt_embeds,
added_cond_kwargs=unet_added_cond_kwargs,
return_dict=False,
)[0]
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if index < num_inference_steps - 1:
alpha_i = sd_pipe.scheduler.alphas_cumprod[timesteps[index]].to(torch.float32)
alpha_i_minus_1 = sd_pipe.scheduler.alphas_cumprod[timesteps[index + 1]].to(torch.float32)
else:
alpha_i = sd_pipe.scheduler.alphas_cumprod[timesteps[index]].to(torch.float32)
alpha_i_minus_1 = 1
sigma_i = (1 - alpha_i)**0.5
sigma_i_minus_1 = (1 - alpha_i_minus_1)**0.5
alpha_i = alpha_i**0.5
alpha_i_minus_1 = alpha_i_minus_1**0.5
if i == 0:
latent = (alpha_i/alpha_i_minus_1)*latent+(sigma_i-(alpha_i/alpha_i_minus_1)*sigma_i_minus_1) * noise_pred
else:
alpha_i_minus_2 = 1 if i == 1 else sd_pipe.scheduler.alphas_cumprod[timesteps[index + 2]].to(torch.float32)
sigma_i_minus_2 = (1 - alpha_i_minus_2) ** 0.5
alpha_i_minus_2 = alpha_i_minus_2 ** 0.5
h_i = sigma_i/alpha_i - sigma_i_minus_1/alpha_i_minus_1
h_i_minus_1 = sigma_i_minus_1/alpha_i_minus_1 - sigma_i_minus_2/alpha_i_minus_2
coef_x_i_minus_2 = (alpha_i/alpha_i_minus_2)*(h_i**2)/(h_i_minus_1**2)
coef_x_i_minus_1 = (alpha_i/alpha_i_minus_1)*(h_i_minus_1**2 - h_i**2)/(h_i_minus_1**2)
coef_eps = alpha_i*(h_i_minus_1 + h_i)*h_i/h_i_minus_1
latent = coef_x_i_minus_2 * xis[-2] + coef_x_i_minus_1 * xis[-1] + coef_eps * noise_pred
xis.append(latent)
return xis[-1], xis[-2]`
请问该方法是否适配SDXL,我自己实验了一下,在图像编辑的时候出现了问题,不知道是否是我的代码有误?
这是把a dog 改成 a cat时的效果:
以下是我修改的代码:
`def intermediate_to_latent_sdxl(sd_pipe, sd_params, intermediate=None, intermediate_second = None, freeze_step = 0):
def latent_to_intermediate_sdxl(sd_pipe, sd_params, latent=None, freeze_step = 0):