Skip to content

Commit 74069ad

Browse files
committed
SD2 v autodetection fix
1 parent 477869c commit 74069ad

1 file changed

Lines changed: 3 additions & 2 deletions

File tree

modules/sd_models_config.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,13 @@ def is_using_v_parameterization_for_sd2(state_dict):
5858
with torch.no_grad():
5959
unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k}
6060
unet.load_state_dict(unet_sd, strict=True)
61-
unet.to(device=device, dtype=torch.float)
61+
unet.to(device=device, dtype=devices.dtype_unet)
6262

6363
test_cond = torch.ones((1, 2, 1024), device=device) * 0.5
6464
x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5
6565

66-
out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().item()
66+
with devices.autocast():
67+
out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().cpu().item()
6768

6869
return out < -1
6970

0 commit comments

Comments
 (0)