Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions GPT_SoVITS/AR/data/bucket_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ def __init__(
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size() if torch.cuda.is_available() else 1
num_replicas = dist.get_world_size() if dist.is_initialized() else 1
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank() if torch.cuda.is_available() else 0
if torch.cuda.is_available():
rank = dist.get_rank() if dist.is_initialized() else 0
if torch.cuda.is_available() and dist.is_initialized():
torch.cuda.set_device(rank)
if rank >= num_replicas or rank < 0:
raise ValueError("Invalid rank {}, rank should be in the interval [0, {}]".format(rank, num_replicas - 1))
Expand Down
2 changes: 1 addition & 1 deletion GPT_SoVITS/s1_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def main(args):
benchmark=False,
fast_dev_run=False,
strategy=DDPStrategy(process_group_backend="nccl" if platform.system() != "Windows" else "gloo")
if torch.cuda.is_available()
if torch.cuda.is_available() and torch.cuda.device_count() > 1
else "auto",
precision=config["train"]["precision"],
logger=logger,
Expand Down
27 changes: 19 additions & 8 deletions GPT_SoVITS/s2_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,13 @@ def run(rank, n_gpus, hps):
writer = SummaryWriter(log_dir=hps.s2_ckpt_dir)
writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))

dist.init_process_group(
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
init_method="env://?use_libuv=False",
world_size=n_gpus,
rank=rank,
)
if not (os.name == "nt" and n_gpus == 1):
dist.init_process_group(
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
init_method="env://?use_libuv=False",
world_size=n_gpus,
rank=rank,
)
torch.manual_seed(hps.train.seed)
if torch.cuda.is_available():
torch.cuda.set_device(rank)
Expand Down Expand Up @@ -197,8 +198,18 @@ def run(rank, n_gpus, hps):
eps=hps.train.eps,
)
if torch.cuda.is_available():
net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
if os.name == "nt" and n_gpus == 1:
class DummyDDP(torch.nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
net_g = DummyDDP(net_g)
net_d = DummyDDP(net_d)
else:
net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
else:
net_g = net_g.to(device)
net_d = net_d.to(device)
Expand Down
27 changes: 19 additions & 8 deletions GPT_SoVITS/s2_train_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,13 @@ def run(rank, n_gpus, hps):
writer = SummaryWriter(log_dir=hps.s2_ckpt_dir)
writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))

dist.init_process_group(
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
init_method="env://?use_libuv=False",
world_size=n_gpus,
rank=rank,
)
if not (os.name == "nt" and n_gpus == 1):
dist.init_process_group(
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
init_method="env://?use_libuv=False",
world_size=n_gpus,
rank=rank,
)
torch.manual_seed(hps.train.seed)
if torch.cuda.is_available():
torch.cuda.set_device(rank)
Expand Down Expand Up @@ -166,8 +167,18 @@ def run(rank, n_gpus, hps):
# eps=hps.train.eps,
# )
if torch.cuda.is_available():
net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
# net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
if os.name == "nt" and n_gpus == 1:
class DummyDDP(torch.nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
net_g = DummyDDP(net_g)
# net_d = DummyDDP(net_d)
else:
net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
# net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
else:
net_g = net_g.to(device)
# net_d = net_d.to(device)
Expand Down
24 changes: 17 additions & 7 deletions GPT_SoVITS/s2_train_v3_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,13 @@ def run(rank, n_gpus, hps):
writer = SummaryWriter(log_dir=hps.s2_ckpt_dir)
writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))

dist.init_process_group(
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
init_method="env://?use_libuv=False",
world_size=n_gpus,
rank=rank,
)
if not (os.name == "nt" and n_gpus == 1):
dist.init_process_group(
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
init_method="env://?use_libuv=False",
world_size=n_gpus,
rank=rank,
)
torch.manual_seed(hps.train.seed)
if torch.cuda.is_available():
torch.cuda.set_device(rank)
Expand Down Expand Up @@ -156,7 +157,16 @@ def get_optim(net_g):

def model2cuda(net_g, rank):
if torch.cuda.is_available():
net_g = DDP(net_g.cuda(rank), device_ids=[rank], find_unused_parameters=True)
if os.name == "nt" and n_gpus == 1:
class DummyDDP(torch.nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
net_g = DummyDDP(net_g.cuda(rank))
else:
net_g = DDP(net_g.cuda(rank), device_ids=[rank], find_unused_parameters=True)
else:
net_g = net_g.to(device)
return net_g
Expand Down
17 changes: 15 additions & 2 deletions GPT_SoVITS/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,25 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False
from time import time as ttime


import time

def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
dir = os.path.dirname(path)
name = os.path.basename(path)
tmp_path = "%s.pth" % (ttime())
tmp_path = "%s.pth" % (time.time())
torch.save(fea, tmp_path)
shutil.move(tmp_path, "%s/%s" % (dir, name))
target_path = "%s/%s" % (dir, name)
try:
shutil.move(tmp_path, target_path)
except Exception as e:
print(f"Move failed with error {e}, retrying via copy and delete...")
if os.path.exists(target_path):
try:
os.remove(target_path)
except:
pass
shutil.copyfile(tmp_path, target_path)
os.remove(tmp_path)


def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
Expand Down