diff --git a/api.py b/api.py index cc0896a20..092b5105e 100644 --- a/api.py +++ b/api.py @@ -599,13 +599,38 @@ def get_phones_and_bert(text, language, version, final=False): phones_list.append(phones) norm_text_list.append(norm_text) bert_list.append(bert) - bert = torch.cat(bert_list, dim=1) + # Fix: handle empty bert_list to avoid torch.cat() crash + if bert_list: + bert = torch.cat(bert_list, dim=1) + else: + phones_total = sum(phones_list, []) if phones_list else [] + bert = torch.zeros( + (1024, max(len(phones_total), 1)), + dtype=torch.float16 if is_half == True else torch.float32, + ).to(device) phones = sum(phones_list, []) norm_text = "".join(norm_text_list) if not final and len(phones) < 6: return get_phones_and_bert("." + text, language, version, final=True) + # Fix: align BERT feature length with phones length to avoid dimension mismatch + # e.g., RuntimeError: tensor a (44) must match tensor b (45) at non-singleton dimension 1 + phones_len = len(phones) + bert_len = bert.shape[1] + if phones_len != bert_len: + import torch.nn.functional as F + logger.warning(f"[TTS] BERT length mismatch: phones={phones_len}, bert={bert_len}, adjusting...") + bert = bert.transpose(1, 2) # (1024, seq) -> (seq, 1024) + if phones_len > bert_len: + # Interpolate to enlarge + bert = F.interpolate(bert.unsqueeze(0), size=phones_len, mode='linear', align_corners=False) + else: + # Truncate excess + bert = bert[:, :phones_len, :] + bert = bert.squeeze(0).transpose(0, 1) # (seq, 1024) -> (1024, seq) + bert = bert.to(device) + return phones, bert.to(torch.float16 if is_half == True else torch.float32), norm_text