Skip to content

关于SDXL的适配问题 #14

@JamieCR1999

Description

@JamieCR1999

请问该方法是否适配SDXL,我自己实验了一下,在图像编辑的时候出现了问题,不知道是否是我的代码有误?
这是把a dog 改成 a cat时的效果:

Image

以下是我修改的代码:
`def intermediate_to_latent_sdxl(sd_pipe, sd_params, intermediate=None, intermediate_second = None, freeze_step = 0):

prompt = sd_params['prompt']
negative_prompt = sd_params['negative_prompt']
seed = sd_params['seed']
guidance_scale = sd_params['guidance_scale']
num_inference_steps = sd_params['num_inference_steps']
width = sd_params['width']
height = sd_params['height']
dtype = torch.float32

(   prompt_embeds,
    negative_prompt_embeds,
    pooled_prompt_embeds,
    negative_pooled_prompt_embeds,
) = sd_pipe.encode_prompt(
    prompt=prompt,
    prompt_2=prompt,
    negative_prompt=negative_prompt,
    negative_prompt_2=negative_prompt,
    device=sd_pipe.device,
    num_images_per_prompt=1,
    do_classifier_free_guidance= guidance_scale > 1.0,
)

torch.manual_seed(seed)
sd_pipe.scheduler.set_timesteps(num_inference_steps, device='cuda')
timesteps = sd_pipe.scheduler.timesteps

xis = []
do_classifier_free_guidance = guidance_scale > 1.0
if intermediate is None:
    shape = (1, 4, 64, 64)
    intermediate = torch.randn(shape, generator=None, device='cuda', dtype=dtype)
    print('latents are None')

add_time_ids = get_add_time_ids(original_size=(height, width), crops_coords_top_left=(0, 0), target_size=(height, width), device=sd_pipe.device, batch_size=prompt_embeds.shape[0])

if do_classifier_free_guidance:
    prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
    add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds])
    add_time_ids_cfg = torch.cat([add_time_ids, add_time_ids])
    unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids_cfg}
else:
    prompt_embeds = prompt_embeds
    add_text_embeds = pooled_prompt_embeds
    add_time_ids_cfg = add_time_ids
    unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids_cfg}

xis.append(intermediate)
with torch.no_grad():
    for i, t in enumerate(timesteps):
        if i < freeze_step:
            continue
        latent_model_input = torch.cat([intermediate] * 2) if do_classifier_free_guidance else intermediate
        noise_pred = sd_pipe.unet(
            latent_model_input,
            t,
            encoder_hidden_states=prompt_embeds,
            added_cond_kwargs=unet_added_cond_kwargs,
            return_dict=False,
        )[0]
        if do_classifier_free_guidance:
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        if i < num_inference_steps - 1:
            alpha_s = sd_pipe.scheduler.alphas_cumprod[timesteps[i + 1]].to(torch.float32)
            alpha_t = sd_pipe.scheduler.alphas_cumprod[t].to(torch.float32)
        else:
            alpha_s = 1
            alpha_t = sd_pipe.scheduler.alphas_cumprod[t].to(torch.float32)

        sigma_s = (1 - alpha_s)**0.5
        sigma_t = (1 - alpha_t)**0.5
        alpha_s = alpha_s**0.5
        alpha_t = alpha_t**0.5


        coef_xt = alpha_s / alpha_t
        coef_eps = sigma_s - sigma_t * coef_xt
        if i == freeze_step:
            if intermediate_second is not None:
                print('have intermediate_second')
                intermediate = intermediate_second.clone()
            else:
                print('dont have intermediate_second')
                intermediate = coef_xt * intermediate + coef_eps * noise_pred
        else:
            # calculate i-1
            alpha_p = sd_pipe.scheduler.alphas_cumprod[timesteps[i - 1]].to(torch.float32)
            sigma_p = (1 - alpha_p) ** 0.5
            alpha_p = alpha_p ** 0.5

            # calculate t
            t_p, t_t, t_s = sigma_p / alpha_p, sigma_t / alpha_t, sigma_s / alpha_s

            # calculate delta
            delta_1 = t_t - t_p
            delta_2 = t_s - t_t
            delta_3 = t_s - t_p

            # calculate coef
            coef_1 = delta_2 * delta_3 * alpha_s / delta_1
            coef_2 = (delta_2/delta_1)**2*(alpha_s/alpha_p)
            coef_3 = (delta_1 - delta_2)*delta_3/(delta_1**2)*(alpha_s / alpha_t)

            # iterate
            intermediate = coef_1 * noise_pred + coef_2 * xis[-2] + coef_3 * xis[-1]
        xis.append(intermediate)
return xis[-1]

def latent_to_intermediate_sdxl(sd_pipe, sd_params, latent=None, freeze_step = 0):

prompt = sd_params['prompt']
negative_prompt = sd_params['negative_prompt']
seed = sd_params['seed']
guidance_scale = sd_params['guidance_scale']
num_inference_steps = sd_params['num_inference_steps']
width = sd_params['width']
height = sd_params['height']
dtype = torch.float32

(   prompt_embeds,
    negative_prompt_embeds,
    pooled_prompt_embeds,
    negative_pooled_prompt_embeds,
) = sd_pipe.encode_prompt(
    prompt=prompt,
    prompt_2=prompt,
    negative_prompt=negative_prompt,
    negative_prompt_2=negative_prompt,
    device=sd_pipe.device,
    num_images_per_prompt=1,
    do_classifier_free_guidance= guidance_scale > 1.0,
)

torch.manual_seed(seed)
sd_pipe.scheduler.set_timesteps(num_inference_steps, device='cuda')
timesteps = sd_pipe.scheduler.timesteps

xis = []
do_classifier_free_guidance = guidance_scale > 1.0
if latent is None:
    shape = (1, 4, 64, 64)
    latent = torch.randn(shape, generator=None, device='cuda', dtype=dtype)
    print('latents are None')

add_time_ids = get_add_time_ids(original_size=(height, width), crops_coords_top_left=(0, 0), target_size=(height, width), device=sd_pipe.device, batch_size=prompt_embeds.shape[0])

if do_classifier_free_guidance:
    prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
    add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds])
    add_time_ids_cfg = torch.cat([add_time_ids, add_time_ids])
else:
    prompt_embeds = prompt_embeds
    add_text_embeds = pooled_prompt_embeds
    add_time_ids_cfg = add_time_ids
unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids_cfg}

xis.append(latent)
with torch.no_grad():
    for i, t in enumerate(timesteps):
        if i >= num_inference_steps - freeze_step:
            continue
        # print('###', i)
        index = num_inference_steps - i - 1
        time = timesteps[index + 1] if index < num_inference_steps - 1 else 1
        latent_model_input = torch.cat([latent] * 2) if do_classifier_free_guidance else latent
        noise_pred = sd_pipe.unet(
            latent_model_input,
            time,
            encoder_hidden_states=prompt_embeds,
            added_cond_kwargs=unet_added_cond_kwargs,
            return_dict=False,
        )[0]
        if do_classifier_free_guidance:
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        if index < num_inference_steps - 1:
            alpha_i = sd_pipe.scheduler.alphas_cumprod[timesteps[index]].to(torch.float32)
            alpha_i_minus_1 = sd_pipe.scheduler.alphas_cumprod[timesteps[index + 1]].to(torch.float32)
        else:
            alpha_i = sd_pipe.scheduler.alphas_cumprod[timesteps[index]].to(torch.float32)
            alpha_i_minus_1 = 1

        sigma_i = (1 - alpha_i)**0.5
        sigma_i_minus_1 = (1 - alpha_i_minus_1)**0.5
        alpha_i = alpha_i**0.5
        alpha_i_minus_1 = alpha_i_minus_1**0.5


        if i == 0:
            latent = (alpha_i/alpha_i_minus_1)*latent+(sigma_i-(alpha_i/alpha_i_minus_1)*sigma_i_minus_1) * noise_pred
        else:
            alpha_i_minus_2 = 1 if i == 1 else sd_pipe.scheduler.alphas_cumprod[timesteps[index + 2]].to(torch.float32)
            sigma_i_minus_2 = (1 - alpha_i_minus_2) ** 0.5
            alpha_i_minus_2 = alpha_i_minus_2 ** 0.5

            h_i = sigma_i/alpha_i - sigma_i_minus_1/alpha_i_minus_1
            h_i_minus_1 = sigma_i_minus_1/alpha_i_minus_1 - sigma_i_minus_2/alpha_i_minus_2

            coef_x_i_minus_2 = (alpha_i/alpha_i_minus_2)*(h_i**2)/(h_i_minus_1**2)
            coef_x_i_minus_1 = (alpha_i/alpha_i_minus_1)*(h_i_minus_1**2 - h_i**2)/(h_i_minus_1**2)
            coef_eps = alpha_i*(h_i_minus_1 + h_i)*h_i/h_i_minus_1
            latent = coef_x_i_minus_2 * xis[-2] + coef_x_i_minus_1 * xis[-1] + coef_eps * noise_pred
        xis.append(latent)
return xis[-1], xis[-2]`

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions