diff --git a/vall_e/demo.py b/vall_e/demo.py index 38a17b7..9244958 100644 --- a/vall_e/demo.py +++ b/vall_e/demo.py @@ -67,7 +67,6 @@ def process_batch( tts, inputs, kwargs={} ): languages=[ x[2] for x in inputs ], out_paths=[ x[3] for x in inputs ], ) - safe_batched_inference( tts, **kwargs ) # Would be downright sugoi if I could incorporate this with into __main__ @@ -136,7 +135,7 @@ def main(): parser.add_argument("--comparison", type=str, default=None) parser.add_argument("--transcription-model", type=str, default="base") - parser.add_argument("--speaker-similarity-model", type=str, default="wavlm_base_plus") + parser.add_argument("--speaker-similarity-model", type=str, default="microsoft/wavlm-base-sv") args = parser.parse_args() @@ -397,7 +396,7 @@ def main(): total_metrics = (0, 0) for text, language, out_path, reference_path in tqdm(metrics_inputs, desc="Calculating metrics"): wer_score, cer_score = wer( out_path, text, language=language, device=tts.device, dtype=tts.dtype, model_name=args.transcription_model ) - sim_o_score = sim_o( out_path, reference_path, device=tts.device, dtype=tts.dtype, feat_type=args.speaker_similarity_model ) + sim_o_score = sim_o( out_path, reference_path, device=tts.device, dtype=tts.dtype, model_name=args.speaker_similarity_model ) metrics_map[out_path] = (wer_score, cer_score, sim_o_score) # collate entries into HTML diff --git a/vall_e/emb/similar.py b/vall_e/emb/similar.py index 8648770..fc921ae 100644 --- a/vall_e/emb/similar.py +++ b/vall_e/emb/similar.py @@ -46,29 +46,20 @@ tts = None # this is for computing SIM-O, but can probably technically be used for scoring similar utterances @cache -def _load_sim_model(device="cuda", dtype="float16", feat_type="wavlm_base_plus", feat_dim="auto"): +def _load_sim_model(device="cuda", dtype="float16", model_name='microsoft/wavlm-base-sv'): logging.getLogger('s3prl').setLevel(logging.DEBUG) logging.getLogger('speechbrain').setLevel(logging.DEBUG) - from ..utils.ext.ecapa_tdnn import ECAPA_TDNN_SMALL + #from ..utils.ext.ecapa_tdnn import ECAPA_TDNN_SMALL + from transformers import Wav2Vec2FeatureExtractor, WavLMForXVector - if feat_dim == "auto": - if feat_type == "fbank": - feat_dim = 40 - elif feat_type == "wavlm_base_plus": - feat_dim = 768 - elif feat_type == "wavlm_large": - feat_dim = 1024 - elif feat_type == "hubert_large_ll60k": - feat_dim = 1024 - elif feat_type == "wav2vec2_xlsr": - feat_dim = 1024 - - model = ECAPA_TDNN_SMALL(feat_dim=feat_dim, feat_type=feat_type, config_path=None) + model = WavLMForXVector.from_pretrained(model_name) model = model.to(device=device, dtype=coerce_dtype(dtype)) model = model.eval() + + feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name) - return model + return model, feature_extractor @torch.no_grad() def speaker_similarity_embedding( @@ -78,15 +69,19 @@ def speaker_similarity_embedding( device = model_kwargs.get("device", "cuda") dtype = model_kwargs.get("dtype", "float16") - model = _load_sim_model(**model_kwargs) + model, feature_extractor = _load_sim_model(**model_kwargs) + if isinstance(audio, str) or isinstance(audio, Path): audio = load_audio(audio, 16000) audio, sr = audio - audio = audio.to(device=device, dtype=coerce_dtype(dtype)) - - return model(audio) - + features = feature_extractor(audio, return_tensors="pt", sampling_rate=sr) + features = features.input_values.squeeze(0).to(dtype=coerce_dtype(dtype), device=device) + + output = model(input_values=features) + embeddings = output.embeddings + embeddings = torch.nn.functional.normalize(embeddings, dim=-1).cpu() + return embeddings def batch_similar_utterances( speaker_path, diff --git a/vall_e/inference.py b/vall_e/inference.py index cd10384..d8d9141 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -280,7 +280,7 @@ class TTS(): buffer[i].append(x) # flush - if len(buffer[0]) >= batch_size: + if buffer: batches.append(buffer) buffer = ([], [], [], []) diff --git a/vall_e/metrics.py b/vall_e/metrics.py index 032e38b..c38107d 100644 --- a/vall_e/metrics.py +++ b/vall_e/metrics.py @@ -10,7 +10,12 @@ import torch.nn.functional as F from pathlib import Path from torcheval.metrics.functional import word_error_rate -from torchmetrics import CharErrorRate + +# cringe warning message +try: + from torchmetrics.text import CharErrorRate +except Exception as e: + from torchmetrics import CharErrorRate def wer( audio, reference, language="auto", normalize=True, phonemize=True, **transcription_kwargs ): if language == "auto": @@ -45,4 +50,4 @@ def sim_o( audio, reference, **kwargs ): audio_emb = speaker_similarity_embedding( audio, **kwargs ) reference_emb = speaker_similarity_embedding( reference, **kwargs ) - return F.cosine_similarity( audio_emb, reference_emb ).item() \ No newline at end of file + return F.cosine_similarity( audio_emb, reference_emb, dim=-1 ).item() \ No newline at end of file diff --git a/vall_e/utils/ext/ecapa_tdnn.py b/vall_e/utils/ext/ecapa_tdnn.py deleted file mode 100644 index 8f28f69..0000000 --- a/vall_e/utils/ext/ecapa_tdnn.py +++ /dev/null @@ -1,468 +0,0 @@ -# borrowed with love from "https://github.com/keonlee9420/evaluate-zero-shot-tts/blob/master/src/evaluate_zero_shot_tts/utils/speaker_verification/models/ecapa_tdnn.py" -# (which was from https://github.com/microsoft/UniSpeech/blob/main/downstreams/speaker_verification/models/ecapa_tdnn.py) -# part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torchaudio.transforms as trans - -#from .utils import UpstreamExpert - -""" Res2Conv1d + BatchNorm1d + ReLU -""" - - -class Res2Conv1dReluBn(nn.Module): - """ - in_channels == out_channels == channels - """ - - def __init__( - self, - channels, - kernel_size=1, - stride=1, - padding=0, - dilation=1, - bias=True, - scale=4, - ): - super().__init__() - assert channels % scale == 0, "{} % {} != 0".format(channels, scale) - self.scale = scale - self.width = channels // scale - self.nums = scale if scale == 1 else scale - 1 - - self.convs = [] - self.bns = [] - for i in range(self.nums): - self.convs.append( - nn.Conv1d( - self.width, - self.width, - kernel_size, - stride, - padding, - dilation, - bias=bias, - ) - ) - self.bns.append(nn.BatchNorm1d(self.width)) - self.convs = nn.ModuleList(self.convs) - self.bns = nn.ModuleList(self.bns) - - def forward(self, x): - out = [] - spx = torch.split(x, self.width, 1) - for i in range(self.nums): - if i == 0: - sp = spx[i] - else: - sp = sp + spx[i] - # Order: conv -> relu -> bn - sp = self.convs[i](sp) - sp = self.bns[i](F.relu(sp)) - out.append(sp) - if self.scale != 1: - out.append(spx[self.nums]) - out = torch.cat(out, dim=1) - - return out - - -""" Conv1d + BatchNorm1d + ReLU -""" - - -class Conv1dReluBn(nn.Module): - def __init__( - self, - in_channels, - out_channels, - kernel_size=1, - stride=1, - padding=0, - dilation=1, - bias=True, - ): - super().__init__() - self.conv = nn.Conv1d( - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - bias=bias, - ) - self.bn = nn.BatchNorm1d(out_channels) - - def forward(self, x): - return self.bn(F.relu(self.conv(x))) - - -""" The SE connection of 1D case. -""" - - -class SE_Connect(nn.Module): - def __init__(self, channels, se_bottleneck_dim=128): - super().__init__() - self.linear1 = nn.Linear(channels, se_bottleneck_dim) - self.linear2 = nn.Linear(se_bottleneck_dim, channels) - - def forward(self, x): - out = x.mean(dim=2) - out = F.relu(self.linear1(out)) - out = torch.sigmoid(self.linear2(out)) - out = x * out.unsqueeze(2) - - return out - - -""" SE-Res2Block of the ECAPA-TDNN architecture. -""" - - -# def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale): -# return nn.Sequential( -# Conv1dReluBn(channels, 512, kernel_size=1, stride=1, padding=0), -# Res2Conv1dReluBn(512, kernel_size, stride, padding, dilation, scale=scale), -# Conv1dReluBn(512, channels, kernel_size=1, stride=1, padding=0), -# SE_Connect(channels) -# ) - - -class SE_Res2Block(nn.Module): - def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - scale, - se_bottleneck_dim, - ): - super().__init__() - self.Conv1dReluBn1 = Conv1dReluBn( - in_channels, out_channels, kernel_size=1, stride=1, padding=0 - ) - self.Res2Conv1dReluBn = Res2Conv1dReluBn( - out_channels, kernel_size, stride, padding, dilation, scale=scale - ) - self.Conv1dReluBn2 = Conv1dReluBn( - out_channels, out_channels, kernel_size=1, stride=1, padding=0 - ) - self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim) - - self.shortcut = None - if in_channels != out_channels: - self.shortcut = nn.Conv1d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=1, - ) - - def forward(self, x): - residual = x - if self.shortcut: - residual = self.shortcut(x) - - x = self.Conv1dReluBn1(x) - x = self.Res2Conv1dReluBn(x) - x = self.Conv1dReluBn2(x) - x = self.SE_Connect(x) - - return x + residual - - -""" Attentive weighted mean and standard deviation pooling. -""" - - -class AttentiveStatsPool(nn.Module): - def __init__( - self, in_dim, attention_channels=128, global_context_att=False - ): - super().__init__() - self.global_context_att = global_context_att - - # Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs. - if global_context_att: - self.linear1 = nn.Conv1d( - in_dim * 3, attention_channels, kernel_size=1 - ) # equals W and b in the paper - else: - self.linear1 = nn.Conv1d( - in_dim, attention_channels, kernel_size=1 - ) # equals W and b in the paper - self.linear2 = nn.Conv1d( - attention_channels, in_dim, kernel_size=1 - ) # equals V and k in the paper - - def forward(self, x): - if self.global_context_att: - context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x) - context_std = torch.sqrt( - torch.var(x, dim=-1, keepdim=True) + 1e-10 - ).expand_as(x) - x_in = torch.cat((x, context_mean, context_std), dim=1) - else: - x_in = x - - # DON'T use ReLU here! In experiments, I find ReLU hard to converge. - alpha = torch.tanh(self.linear1(x_in)) - # alpha = F.relu(self.linear1(x_in)) - alpha = torch.softmax(self.linear2(alpha), dim=2) - mean = torch.sum(alpha * x, dim=2) - residuals = torch.sum(alpha * (x**2), dim=2) - mean**2 - std = torch.sqrt(residuals.clamp(min=1e-9)) - return torch.cat([mean, std], dim=1) - - -class ECAPA_TDNN(nn.Module): - def __init__( - self, - feat_dim=80, - channels=512, - emb_dim=192, - global_context_att=False, - feat_type="fbank", - sr=16000, - feature_selection="hidden_states", - update_extract=False, - config_path=None, - ): - super().__init__() - - self.feat_type = feat_type - self.feature_selection = feature_selection - self.update_extract = update_extract - self.sr = sr - - if feat_type == "fbank" or feat_type == "mfcc": - self.update_extract = False - - win_len = int(sr * 0.025) - hop_len = int(sr * 0.01) - - if feat_type == "fbank": - self.feature_extract = trans.MelSpectrogram( - sample_rate=sr, - n_fft=512, - win_length=win_len, - hop_length=hop_len, - f_min=0.0, - f_max=sr // 2, - pad=0, - n_mels=feat_dim, - ) - elif feat_type == "mfcc": - melkwargs = { - "n_fft": 512, - "win_length": win_len, - "hop_length": hop_len, - "f_min": 0.0, - "f_max": sr // 2, - "pad": 0, - } - self.feature_extract = trans.MFCC( - sample_rate=sr, - n_mfcc=feat_dim, - log_mels=False, - melkwargs=melkwargs, - ) - else: - """ - if config_path is None: - self.feature_extract = torch.hub.load("s3prl/s3prl", feat_type) - else: - self.feature_extract = UpstreamExpert(config_path) - """ - self.feature_extract = torch.hub.load("s3prl/s3prl", feat_type) - if len(self.feature_extract.model.encoder.layers) == 24 and hasattr( - self.feature_extract.model.encoder.layers[23].self_attn, - "fp32_attention", - ): - self.feature_extract.model.encoder.layers[ - 23 - ].self_attn.fp32_attention = False - if len(self.feature_extract.model.encoder.layers) == 24 and hasattr( - self.feature_extract.model.encoder.layers[11].self_attn, - "fp32_attention", - ): - self.feature_extract.model.encoder.layers[ - 11 - ].self_attn.fp32_attention = False - - self.feat_num = self.get_feat_num() - self.feature_weight = nn.Parameter(torch.zeros(self.feat_num)) - - if feat_type != "fbank" and feat_type != "mfcc": - freeze_list = [ - "final_proj", - "label_embs_concat", - "mask_emb", - "project_q", - "quantizer", - ] - for name, param in self.feature_extract.named_parameters(): - for freeze_val in freeze_list: - if freeze_val in name: - param.requires_grad = False - break - - if not self.update_extract: - for param in self.feature_extract.parameters(): - param.requires_grad = False - - self.instance_norm = nn.InstanceNorm1d(feat_dim) - # self.channels = [channels] * 4 + [channels * 3] - self.channels = [channels] * 4 + [1536] - - self.layer1 = Conv1dReluBn( - feat_dim, self.channels[0], kernel_size=5, padding=2 - ) - self.layer2 = SE_Res2Block( - self.channels[0], - self.channels[1], - kernel_size=3, - stride=1, - padding=2, - dilation=2, - scale=8, - se_bottleneck_dim=128, - ) - self.layer3 = SE_Res2Block( - self.channels[1], - self.channels[2], - kernel_size=3, - stride=1, - padding=3, - dilation=3, - scale=8, - se_bottleneck_dim=128, - ) - self.layer4 = SE_Res2Block( - self.channels[2], - self.channels[3], - kernel_size=3, - stride=1, - padding=4, - dilation=4, - scale=8, - se_bottleneck_dim=128, - ) - - # self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1) - cat_channels = channels * 3 - self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1) - self.pooling = AttentiveStatsPool( - self.channels[-1], - attention_channels=128, - global_context_att=global_context_att, - ) - self.bn = nn.BatchNorm1d(self.channels[-1] * 2) - self.linear = nn.Linear(self.channels[-1] * 2, emb_dim) - - def get_feat_num(self): - self.feature_extract.eval() - wav = [ - torch.randn(self.sr).to( - next(self.feature_extract.parameters()).device - ) - ] - with torch.no_grad(): - features = self.feature_extract(wav) - select_feature = features[self.feature_selection] - if isinstance(select_feature, (list, tuple)): - return len(select_feature) - else: - return 1 - - def get_feat(self, x): - if self.update_extract: - x = self.feature_extract([sample for sample in x]) - else: - with torch.no_grad(): - if self.feat_type == "fbank" or self.feat_type == "mfcc": - x = ( - self.feature_extract(x) + 1e-6 - ) # B x feat_dim x time_len - else: - x = self.feature_extract([sample for sample in x]) - - if self.feat_type == "fbank": - x = x.log() - - if self.feat_type != "fbank" and self.feat_type != "mfcc": - x = x[self.feature_selection] - if isinstance(x, (list, tuple)): - x = torch.stack(x, dim=0) - else: - x = x.unsqueeze(0) - norm_weights = ( - F.softmax(self.feature_weight, dim=-1) - .unsqueeze(-1) - .unsqueeze(-1) - .unsqueeze(-1) - ) - x = (norm_weights * x).sum(dim=0) - x = torch.transpose(x, 1, 2) + 1e-6 - - x = self.instance_norm(x) - return x - - def forward(self, x): - x = self.get_feat(x) - - out1 = self.layer1(x) - out2 = self.layer2(out1) - out3 = self.layer3(out2) - out4 = self.layer4(out3) - - out = torch.cat([out2, out3, out4], dim=1) - out = F.relu(self.conv(out)) - out = self.bn(self.pooling(out)) - out = self.linear(out) - - return out - - -def ECAPA_TDNN_SMALL( - feat_dim, - emb_dim=256, - feat_type="fbank", - sr=16000, - feature_selection="hidden_states", - update_extract=False, - config_path=None, -): - return ECAPA_TDNN( - feat_dim=feat_dim, - channels=512, - emb_dim=emb_dim, - feat_type=feat_type, - sr=sr, - feature_selection=feature_selection, - update_extract=update_extract, - config_path=config_path, - ) - - -if __name__ == "__main__": - x = torch.zeros(2, 32000) - model = ECAPA_TDNN_SMALL( - feat_dim=768, - emb_dim=256, - feat_type="hubert_base", - feature_selection="hidden_states", - update_extract=False, - ) - - out = model(x) - # print(model) - print(out.shape) \ No newline at end of file