Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion api.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,13 +599,38 @@ def get_phones_and_bert(text, language, version, final=False):
phones_list.append(phones)
norm_text_list.append(norm_text)
bert_list.append(bert)
bert = torch.cat(bert_list, dim=1)
# Fix: handle empty bert_list to avoid torch.cat() crash
if bert_list:
bert = torch.cat(bert_list, dim=1)
else:
phones_total = sum(phones_list, []) if phones_list else []
bert = torch.zeros(
(1024, max(len(phones_total), 1)),
dtype=torch.float16 if is_half == True else torch.float32,
).to(device)
phones = sum(phones_list, [])
norm_text = "".join(norm_text_list)

if not final and len(phones) < 6:
return get_phones_and_bert("." + text, language, version, final=True)

# Fix: align BERT feature length with phones length to avoid dimension mismatch
# e.g., RuntimeError: tensor a (44) must match tensor b (45) at non-singleton dimension 1
phones_len = len(phones)
bert_len = bert.shape[1]
if phones_len != bert_len:
import torch.nn.functional as F
logger.warning(f"[TTS] BERT length mismatch: phones={phones_len}, bert={bert_len}, adjusting...")
bert = bert.transpose(1, 2) # (1024, seq) -> (seq, 1024)
if phones_len > bert_len:
# Interpolate to enlarge
bert = F.interpolate(bert.unsqueeze(0), size=phones_len, mode='linear', align_corners=False)
else:
# Truncate excess
bert = bert[:, :phones_len, :]
bert = bert.squeeze(0).transpose(0, 1) # (seq, 1024) -> (1024, seq)
bert = bert.to(device)

return phones, bert.to(torch.float16 if is_half == True else torch.float32), norm_text


Expand Down