From c2c6d912acf4e0ac76a0379ea2352a08f6824833 Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 17 Dec 2024 10:11:14 -0600 Subject: [PATCH] actually do speaker verification --- vall_e/demo.py | 2 +- vall_e/emb/similar.py | 33 ++- vall_e/models/__init__.py | 1 + vall_e/utils/ext/ecapa_tdnn.py | 467 +++++++++++++++++++++++++++++++++ 4 files changed, 498 insertions(+), 5 deletions(-) create mode 100644 vall_e/utils/ext/ecapa_tdnn.py diff --git a/vall_e/demo.py b/vall_e/demo.py index 782183b..8da990b 100644 --- a/vall_e/demo.py +++ b/vall_e/demo.py @@ -136,7 +136,7 @@ def main(): parser.add_argument("--comparison", type=str, default=None) parser.add_argument("--transcription-model", type=str, default="openai/whisper-base") - parser.add_argument("--speaker-similarity-model", type=str, default="microsoft/wavlm-base-sv") + parser.add_argument("--speaker-similarity-model", type=str, default="microsoft/wavlm-large") args = parser.parse_args() diff --git a/vall_e/emb/similar.py b/vall_e/emb/similar.py index fc921ae..d21d147 100644 --- a/vall_e/emb/similar.py +++ b/vall_e/emb/similar.py @@ -28,6 +28,8 @@ from ..utils.io import json_read, json_write from .g2p import encode as phonemize from .qnt import encode as quantize, trim, convert_audio +from ..models import download_model + from ..webui import init_tts def load_audio( path, target_sr=None ): @@ -46,20 +48,40 @@ tts = None # this is for computing SIM-O, but can probably technically be used for scoring similar utterances @cache -def _load_sim_model(device="cuda", dtype="float16", model_name='microsoft/wavlm-base-sv'): +def _load_sim_model(device="cuda", dtype="float16", model_name='microsoft/wavlm-large'): + from ..utils.ext.ecapa_tdnn import ECAPA_TDNN_SMALL + model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wavlm_large') + + finetune_path = Path("./data/models/wavlm_large_finetune.pth") + if not finetune_path.exists(): + download_model(finetune_path) + + state_dict = torch.load( finetune_path ) + state_dict = state_dict['model'] + del state_dict['loss_calculator.projection.weight'] + model.load_state_dict( state_dict ) + + model = model.to(device=device, dtype=coerce_dtype(dtype)) + model = model.eval() + + return model, None + + """ logging.getLogger('s3prl').setLevel(logging.DEBUG) logging.getLogger('speechbrain').setLevel(logging.DEBUG) - - #from ..utils.ext.ecapa_tdnn import ECAPA_TDNN_SMALL from transformers import Wav2Vec2FeatureExtractor, WavLMForXVector - model = WavLMForXVector.from_pretrained(model_name) + finetune_path = Path("./data/models/wavlm_large_finetune.pth") + if finetune_path.exists(): + state_dict = torch.load( finetune_path ) + model.load_state_dict( state_dict['model'] ) model = model.to(device=device, dtype=coerce_dtype(dtype)) model = model.eval() feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name) return model, feature_extractor + """ @torch.no_grad() def speaker_similarity_embedding( @@ -75,12 +97,15 @@ def speaker_similarity_embedding( audio = load_audio(audio, 16000) audio, sr = audio + embeddings = model(audio.to(device=device, dtype=coerce_dtype(dtype))) + """ features = feature_extractor(audio, return_tensors="pt", sampling_rate=sr) features = features.input_values.squeeze(0).to(dtype=coerce_dtype(dtype), device=device) output = model(input_values=features) embeddings = output.embeddings embeddings = torch.nn.functional.normalize(embeddings, dim=-1).cpu() + """ return embeddings def batch_similar_utterances( diff --git a/vall_e/models/__init__.py b/vall_e/models/__init__.py index 8b4c74f..b8a7f09 100755 --- a/vall_e/models/__init__.py +++ b/vall_e/models/__init__.py @@ -13,6 +13,7 @@ DEFAULT_MODEL_DIR = Path(__file__).parent.parent.parent / 'data/models' DEFAULT_MODEL_PATH = DEFAULT_MODEL_DIR / "ar+nar-len-llama-8.sft" DEFAULT_MODEL_URLS = { 'ar+nar-len-llama-8.sft': 'https://huggingface.co/ecker/vall-e/resolve/main/models/ckpt/ar%2Bnar-len-llama-8/ckpt/fp32.sft', + 'wavlm_large_finetune.pth': 'https://huggingface.co/Dongchao/UniAudio/resolve/main/wavlm_large_finetune.pth', } if not DEFAULT_MODEL_PATH.exists() and Path("./data/models/ar+nar-len-llama-8.sft").exists(): diff --git a/vall_e/utils/ext/ecapa_tdnn.py b/vall_e/utils/ext/ecapa_tdnn.py new file mode 100644 index 0000000..29a99fb --- /dev/null +++ b/vall_e/utils/ext/ecapa_tdnn.py @@ -0,0 +1,467 @@ +# borrowed with love from "https://github.com/keonlee9420/evaluate-zero-shot-tts/blob/master/src/evaluate_zero_shot_tts/utils/speaker_verification/models/ecapa_tdnn.py" +# part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchaudio.transforms as trans + +#from .utils import UpstreamExpert + +""" Res2Conv1d + BatchNorm1d + ReLU +""" + + +class Res2Conv1dReluBn(nn.Module): + """ + in_channels == out_channels == channels + """ + + def __init__( + self, + channels, + kernel_size=1, + stride=1, + padding=0, + dilation=1, + bias=True, + scale=4, + ): + super().__init__() + assert channels % scale == 0, "{} % {} != 0".format(channels, scale) + self.scale = scale + self.width = channels // scale + self.nums = scale if scale == 1 else scale - 1 + + self.convs = [] + self.bns = [] + for i in range(self.nums): + self.convs.append( + nn.Conv1d( + self.width, + self.width, + kernel_size, + stride, + padding, + dilation, + bias=bias, + ) + ) + self.bns.append(nn.BatchNorm1d(self.width)) + self.convs = nn.ModuleList(self.convs) + self.bns = nn.ModuleList(self.bns) + + def forward(self, x): + out = [] + spx = torch.split(x, self.width, 1) + for i in range(self.nums): + if i == 0: + sp = spx[i] + else: + sp = sp + spx[i] + # Order: conv -> relu -> bn + sp = self.convs[i](sp) + sp = self.bns[i](F.relu(sp)) + out.append(sp) + if self.scale != 1: + out.append(spx[self.nums]) + out = torch.cat(out, dim=1) + + return out + + +""" Conv1d + BatchNorm1d + ReLU +""" + + +class Conv1dReluBn(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + dilation=1, + bias=True, + ): + super().__init__() + self.conv = nn.Conv1d( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + bias=bias, + ) + self.bn = nn.BatchNorm1d(out_channels) + + def forward(self, x): + return self.bn(F.relu(self.conv(x))) + + +""" The SE connection of 1D case. +""" + + +class SE_Connect(nn.Module): + def __init__(self, channels, se_bottleneck_dim=128): + super().__init__() + self.linear1 = nn.Linear(channels, se_bottleneck_dim) + self.linear2 = nn.Linear(se_bottleneck_dim, channels) + + def forward(self, x): + out = x.mean(dim=2) + out = F.relu(self.linear1(out)) + out = torch.sigmoid(self.linear2(out)) + out = x * out.unsqueeze(2) + + return out + + +""" SE-Res2Block of the ECAPA-TDNN architecture. +""" + + +# def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale): +# return nn.Sequential( +# Conv1dReluBn(channels, 512, kernel_size=1, stride=1, padding=0), +# Res2Conv1dReluBn(512, kernel_size, stride, padding, dilation, scale=scale), +# Conv1dReluBn(512, channels, kernel_size=1, stride=1, padding=0), +# SE_Connect(channels) +# ) + + +class SE_Res2Block(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + scale, + se_bottleneck_dim, + ): + super().__init__() + self.Conv1dReluBn1 = Conv1dReluBn( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + self.Res2Conv1dReluBn = Res2Conv1dReluBn( + out_channels, kernel_size, stride, padding, dilation, scale=scale + ) + self.Conv1dReluBn2 = Conv1dReluBn( + out_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim) + + self.shortcut = None + if in_channels != out_channels: + self.shortcut = nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + ) + + def forward(self, x): + residual = x + if self.shortcut: + residual = self.shortcut(x) + + x = self.Conv1dReluBn1(x) + x = self.Res2Conv1dReluBn(x) + x = self.Conv1dReluBn2(x) + x = self.SE_Connect(x) + + return x + residual + + +""" Attentive weighted mean and standard deviation pooling. +""" + + +class AttentiveStatsPool(nn.Module): + def __init__( + self, in_dim, attention_channels=128, global_context_att=False + ): + super().__init__() + self.global_context_att = global_context_att + + # Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs. + if global_context_att: + self.linear1 = nn.Conv1d( + in_dim * 3, attention_channels, kernel_size=1 + ) # equals W and b in the paper + else: + self.linear1 = nn.Conv1d( + in_dim, attention_channels, kernel_size=1 + ) # equals W and b in the paper + self.linear2 = nn.Conv1d( + attention_channels, in_dim, kernel_size=1 + ) # equals V and k in the paper + + def forward(self, x): + if self.global_context_att: + context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x) + context_std = torch.sqrt( + torch.var(x, dim=-1, keepdim=True) + 1e-10 + ).expand_as(x) + x_in = torch.cat((x, context_mean, context_std), dim=1) + else: + x_in = x + + # DON'T use ReLU here! In experiments, I find ReLU hard to converge. + alpha = torch.tanh(self.linear1(x_in)) + # alpha = F.relu(self.linear1(x_in)) + alpha = torch.softmax(self.linear2(alpha), dim=2) + mean = torch.sum(alpha * x, dim=2) + residuals = torch.sum(alpha * (x**2), dim=2) - mean**2 + std = torch.sqrt(residuals.clamp(min=1e-9)) + return torch.cat([mean, std], dim=1) + + +class ECAPA_TDNN(nn.Module): + def __init__( + self, + feat_dim=80, + channels=512, + emb_dim=192, + global_context_att=False, + feat_type="fbank", + sr=16000, + feature_selection="hidden_states", + update_extract=False, + config_path=None, + ): + super().__init__() + + self.feat_type = feat_type + self.feature_selection = feature_selection + self.update_extract = update_extract + self.sr = sr + + if feat_type == "fbank" or feat_type == "mfcc": + self.update_extract = False + + win_len = int(sr * 0.025) + hop_len = int(sr * 0.01) + + if feat_type == "fbank": + self.feature_extract = trans.MelSpectrogram( + sample_rate=sr, + n_fft=512, + win_length=win_len, + hop_length=hop_len, + f_min=0.0, + f_max=sr // 2, + pad=0, + n_mels=feat_dim, + ) + elif feat_type == "mfcc": + melkwargs = { + "n_fft": 512, + "win_length": win_len, + "hop_length": hop_len, + "f_min": 0.0, + "f_max": sr // 2, + "pad": 0, + } + self.feature_extract = trans.MFCC( + sample_rate=sr, + n_mfcc=feat_dim, + log_mels=False, + melkwargs=melkwargs, + ) + else: + """ + if config_path is None: + self.feature_extract = torch.hub.load("s3prl/s3prl", feat_type) + else: + self.feature_extract = UpstreamExpert(config_path) + """ + self.feature_extract = torch.hub.load("s3prl/s3prl", feat_type) + if len(self.feature_extract.model.encoder.layers) == 24 and hasattr( + self.feature_extract.model.encoder.layers[23].self_attn, + "fp32_attention", + ): + self.feature_extract.model.encoder.layers[ + 23 + ].self_attn.fp32_attention = False + if len(self.feature_extract.model.encoder.layers) == 24 and hasattr( + self.feature_extract.model.encoder.layers[11].self_attn, + "fp32_attention", + ): + self.feature_extract.model.encoder.layers[ + 11 + ].self_attn.fp32_attention = False + + self.feat_num = self.get_feat_num() + self.feature_weight = nn.Parameter(torch.zeros(self.feat_num)) + + if feat_type != "fbank" and feat_type != "mfcc": + freeze_list = [ + "final_proj", + "label_embs_concat", + "mask_emb", + "project_q", + "quantizer", + ] + for name, param in self.feature_extract.named_parameters(): + for freeze_val in freeze_list: + if freeze_val in name: + param.requires_grad = False + break + + if not self.update_extract: + for param in self.feature_extract.parameters(): + param.requires_grad = False + + self.instance_norm = nn.InstanceNorm1d(feat_dim) + # self.channels = [channels] * 4 + [channels * 3] + self.channels = [channels] * 4 + [1536] + + self.layer1 = Conv1dReluBn( + feat_dim, self.channels[0], kernel_size=5, padding=2 + ) + self.layer2 = SE_Res2Block( + self.channels[0], + self.channels[1], + kernel_size=3, + stride=1, + padding=2, + dilation=2, + scale=8, + se_bottleneck_dim=128, + ) + self.layer3 = SE_Res2Block( + self.channels[1], + self.channels[2], + kernel_size=3, + stride=1, + padding=3, + dilation=3, + scale=8, + se_bottleneck_dim=128, + ) + self.layer4 = SE_Res2Block( + self.channels[2], + self.channels[3], + kernel_size=3, + stride=1, + padding=4, + dilation=4, + scale=8, + se_bottleneck_dim=128, + ) + + # self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1) + cat_channels = channels * 3 + self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1) + self.pooling = AttentiveStatsPool( + self.channels[-1], + attention_channels=128, + global_context_att=global_context_att, + ) + self.bn = nn.BatchNorm1d(self.channels[-1] * 2) + self.linear = nn.Linear(self.channels[-1] * 2, emb_dim) + + def get_feat_num(self): + self.feature_extract.eval() + wav = [ + torch.randn(self.sr).to( + next(self.feature_extract.parameters()).device + ) + ] + with torch.no_grad(): + features = self.feature_extract(wav) + select_feature = features[self.feature_selection] + if isinstance(select_feature, (list, tuple)): + return len(select_feature) + else: + return 1 + + def get_feat(self, x): + if self.update_extract: + x = self.feature_extract([sample for sample in x]) + else: + with torch.no_grad(): + if self.feat_type == "fbank" or self.feat_type == "mfcc": + x = ( + self.feature_extract(x) + 1e-6 + ) # B x feat_dim x time_len + else: + x = self.feature_extract([sample for sample in x]) + + if self.feat_type == "fbank": + x = x.log() + + if self.feat_type != "fbank" and self.feat_type != "mfcc": + x = x[self.feature_selection] + if isinstance(x, (list, tuple)): + x = torch.stack(x, dim=0) + else: + x = x.unsqueeze(0) + norm_weights = ( + F.softmax(self.feature_weight, dim=-1) + .unsqueeze(-1) + .unsqueeze(-1) + .unsqueeze(-1) + ) + x = (norm_weights * x).sum(dim=0) + x = torch.transpose(x, 1, 2) + 1e-6 + + x = self.instance_norm(x) + return x + + def forward(self, x): + x = self.get_feat(x) + + out1 = self.layer1(x) + out2 = self.layer2(out1) + out3 = self.layer3(out2) + out4 = self.layer4(out3) + + out = torch.cat([out2, out3, out4], dim=1) + out = F.relu(self.conv(out)) + out = self.bn(self.pooling(out)) + out = self.linear(out) + + return out + + +def ECAPA_TDNN_SMALL( + feat_dim, + emb_dim=256, + feat_type="fbank", + sr=16000, + feature_selection="hidden_states", + update_extract=False, + config_path=None, +): + return ECAPA_TDNN( + feat_dim=feat_dim, + channels=512, + emb_dim=emb_dim, + feat_type=feat_type, + sr=sr, + feature_selection=feature_selection, + update_extract=update_extract, + config_path=config_path, + ) + + +if __name__ == "__main__": + x = torch.zeros(2, 32000) + model = ECAPA_TDNN_SMALL( + feat_dim=768, + emb_dim=256, + feat_type="hubert_base", + feature_selection="hidden_states", + update_extract=False, + ) + + out = model(x) + # print(model) + print(out.shape) \ No newline at end of file