From 8568a93dad417e118b09d3917f836d22a9f4a8b6 Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 10 Dec 2024 20:13:21 -0600 Subject: [PATCH] added WER/SIM-O metrics, added APOLLO but I need to test it --- data/demo/index.template.html | 4 + docs/emb.md | 20 +- setup.py | 3 + vall_e/config.py | 24 +- vall_e/data.py | 11 +- vall_e/demo.py | 25 +- vall_e/emb/similar.py | 43 ++- vall_e/emb/transcribe.py | 153 +++++++++- vall_e/engines/__init__.py | 25 +- vall_e/inference.py | 6 +- vall_e/metrics.py | 33 +++ vall_e/utils/__init__.py | 3 +- vall_e/utils/ext/__init__.py | 0 vall_e/utils/ext/apollo.py | 433 +++++++++++++++++++++++++++ vall_e/utils/ext/ecapa_tdnn.py | 467 ++++++++++++++++++++++++++++++ vall_e/utils/{ => ext}/unsloth.py | 0 vall_e/utils/utils.py | 14 + vall_e/utils/wrapper.py | 8 +- 18 files changed, 1216 insertions(+), 56 deletions(-) create mode 100644 vall_e/metrics.py create mode 100644 vall_e/utils/ext/__init__.py create mode 100644 vall_e/utils/ext/apollo.py create mode 100644 vall_e/utils/ext/ecapa_tdnn.py rename vall_e/utils/{ => ext}/unsloth.py (100%) diff --git a/data/demo/index.template.html b/data/demo/index.template.html index 207952a..dcb1e64 100644 --- a/data/demo/index.template.html +++ b/data/demo/index.template.html @@ -10,6 +10,8 @@ LibriSpeech Text + WER↓ + SIM-O↑ Prompt Our VALL-E Original VALL-E @@ -24,6 +26,8 @@ Sampled Dataset Text + WER↓ + SIM-O↑ Prompt Our VALL-E F5-TTS diff --git a/docs/emb.md b/docs/emb.md index a82ef25..03e2de9 100644 --- a/docs/emb.md +++ b/docs/emb.md @@ -77,7 +77,7 @@ I'm uncertain on how to remedy this, as my options are: ## `transcribe.py` -This script handles taking raw input audio, and outputting adequate metadata containing transcriptions of said audio through `whisperX`. +This script primarily handles taking raw input audio, and outputting adequate metadata containing transcriptions of said audio through `whisperX`. The process maintains slices `whisperX` thinks its best per the segments outputted, alongside the deduced language (if not specified). @@ -85,6 +85,18 @@ One limiting factor is that transcription transcribes into normal text, rather t Refer to the `__main__`'s arguments for usage details. +### Metrics + +This script also handles calculating `WER` simply by transcribing the given audio file (and reference, if requested), then comparing the word error rate. + +This process *heavily* relies on text normalization, which currently is lacking, but transcribing the reference should keep things "normalized" per the transcriber. + +### ROCm + +Because life is pain, ROCm requires additional steps to ensure that `whisperX` works. A special fork of `CTranslate2` is required, but simplying following [these](https://github.com/arlo-phoenix/CTranslate2-rocm/blob/rocm/README_ROCM.md) steps should fix things. + +In the future, I would love to replace WhisperX for something simple. + ## `process.py` This script handles taking raw input audio and its transcribed metadata, and outputs encoded audio (NumPy) files containing encoded audio and associated metadata. @@ -107,4 +119,8 @@ When processing a dataset, this requires already having accompanying metadata ge Be *very* careful if you opt to output unsegmented and segmented utterances, as the sliced version may end up amongst the top-K similar candidates. -Refer to the `__main__`'s arguments for usage details. \ No newline at end of file +Refer to the `__main__`'s arguments for usage details. + +### Metrics + +This script also handles calculating `SIM-O` per [keonlee9420/evaluate-zero-shot-tts](https://github.com/keonlee9420/evaluate-zero-shot-tts/blob/master/src/evaluate_zero_shot_tts/utils/speaker_verification/verification.py), by making use of a model to create an embedding of a speaker, then computing cosine similarities on those embeddings. \ No newline at end of file diff --git a/setup.py b/setup.py index 4d41fd2..33294a4 100755 --- a/setup.py +++ b/setup.py @@ -91,6 +91,9 @@ setup( "causal-conv1d", "mamba-ssm", + # + "torcheval", + # attention helpers "xformers", "sageattention==1.0.6", diff --git a/vall_e/config.py b/vall_e/config.py index 2164914..ca3cafa 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -22,7 +22,7 @@ from pathlib import Path from .utils.distributed import world_size from .utils.io import torch_load -from .utils import set_seed, prune_missing, md5_hash +from .utils import set_seed, prune_missing, md5_hash, coerce_dtype @dataclass() class BaseConfig: @@ -721,15 +721,7 @@ class Trainer: @cached_property def dtype(self): - if self.weight_dtype == "float16": - return torch.float16 - if self.weight_dtype == "bfloat16": - return torch.bfloat16 - if self.weight_dtype == "float8_e5m2": - return torch.float8_e5m2 - if self.weight_dtype == "float8_e4m3fn": - return torch.float8_e4m3fn - return torch.float32 + return coerce_dtype(self.weight_dtype) @cached_property def scale_loss(self): @@ -748,17 +740,7 @@ class Inference: @property def dtype(self): - if self.weight_dtype == "float16": - return torch.float16 - if self.weight_dtype == "bfloat16": - return torch.bfloat16 - if self.weight_dtype == "int8": - return torch.int8 - if self.weight_dtype == "float8_e5m2": - return torch.float8_e5m2 - if self.weight_dtype == "float8_e4m3fn": - return torch.float8_e4m3fn - return torch.float32 + return coerce_dtype(self.weight_dtype) @dataclass() class Optimizations: diff --git a/vall_e/data.py b/vall_e/data.py index 1ff2610..08bbf79 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -63,6 +63,13 @@ def sentence_split( s, split_by="sentences", quote_placeholder="" ): sentences = nltk.sent_tokenize(s) return [ sentence.replace(quote_placeholder, '"') for sentence in sentences if sentence ] +# to-do: improve upon this since it's kind of ass +# this might be better to live in emb.g2p +def normalize_text( s ): + s = s.lower() + s = re.sub(r'[^\w\s]', '', s) + return s + @cache def get_random_prompts( validation=False, min_length=0, tokenized=False ): duration_range = [ 5.5, 12.0 ] # to-do: pull from cfg.dataset.duration_range @@ -1070,7 +1077,9 @@ class Dataset(_Dataset): return root / name def sample_prompts(self, spkr_name, reference, should_trim=True): - if not cfg.dataset.prompt_duration_range or cfg.dataset.prompt_duration_range[-1] == 0: + # return no prompt if explicitly requested for who knows why + # or if there's no other speakers to sample from (Emilia has a lot of singleton speakers, but I still want to make use of them) + if not cfg.dataset.prompt_duration_range or cfg.dataset.prompt_duration_range[-1] == 0 or len(self.paths_by_spkr_name[key]) <= 1: return None prom_list = [] diff --git a/vall_e/demo.py b/vall_e/demo.py index 659238c..a050102 100644 --- a/vall_e/demo.py +++ b/vall_e/demo.py @@ -20,6 +20,7 @@ import base64 import random import logging import time +import torch _logger = logging.getLogger(__name__) @@ -29,6 +30,8 @@ from .inference import TTS from .config import cfg from .data import create_train_dataloader, create_val_dataloader, get_random_prompt from .emb.qnt import decode_to_file +from .metrics import wer, sim_o +from .utils import setup_logging from tqdm import tqdm, trange @@ -230,6 +233,8 @@ def main(): elif args.comparison: raise Exception(f"Unrecognized comparison flag: {args.comparison}") + setup_logging() + # read html template html = open(args.demo_dir / "index.template.html", "r", encoding="utf-8").read() @@ -318,6 +323,7 @@ def main(): inputs = [] outputs = [] + metrics_inputs = [] comparison_inputs = [] for k, sample_dir in samples_dirs.items(): if not sample_dir.exists(): @@ -359,9 +365,15 @@ def main(): # segregate comparisons into its own batch because they use different kwargs (and I do not support variadic-batched kwargs) if args.comparison: - comparison_inputs.append((text, prompt, language, out_path_comparison)) + if (args.skip_existing and not out_path_comparison.exists()) or not (args.skip_existing): + comparison_inputs.append((text, prompt, language, out_path_comparison)) + + metrics_inputs.append((text, language, out_path_comparison, reference)) - inputs.append((text, prompt, language, out_path)) + if (args.skip_existing and not out_path.exists()) or not (args.skip_existing): + inputs.append((text, prompt, language, out_path)) + + metrics_inputs.append((text, language, out_path, reference)) outputs.append((k, samples)) @@ -371,10 +383,19 @@ def main(): if comparison_inputs: process_batch( tts, comparison_inputs, sampling_kwargs | (comparison_kwargs["enabled"] if args.comparison else {}) ) + metrics_map = {} + for text, language, out_path, reference_path in metrics_inputs: + wer_score = wer( out_path, text, language=language, device=tts.device, dtype=tts.dtype, model_name="base" ) + sim_o_score = sim_o( out_path, reference_path, device=tts.device, dtype=tts.dtype ) + metrics_map[out_path] = (wer_score, sim_o_score) + # collate entries into HTML for k, samples in outputs: samples = [ f'\n\t\t\t\n\t\t\t\t{text}'+ + "".join([ + f'\n\t\t\t\t{metrics_map[audios[1]][0]:.3f}{metrics_map[audios[1]][1]:.3f}' + ] ) + "".join( [ f'\n\t\t\t\t' for audio in audios diff --git a/vall_e/emb/similar.py b/vall_e/emb/similar.py index 7b41468..42b16db 100644 --- a/vall_e/emb/similar.py +++ b/vall_e/emb/similar.py @@ -16,12 +16,13 @@ _logger = logging.getLogger(__name__) from tqdm.auto import tqdm from pathlib import Path +from functools import cache import torchaudio.functional as F import torchaudio.transforms as T from ..config import cfg -from ..utils import truncate_json +from ..utils import truncate_json, coerce_dtype from ..utils.io import json_read, json_write from .g2p import encode as phonemize @@ -29,19 +30,49 @@ from .qnt import encode as quantize, trim, convert_audio from ..webui import init_tts -def load_audio( path ): +def load_audio( path, target_sr=None ): waveform, sr = torchaudio.load( path ) # mix channels if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) + if target_sr is None: + target_sr = cfg.sample_rate # resample - waveform, sr = convert_audio(waveform, sr, cfg.sample_rate, 1), cfg.sample_rate + waveform, sr = convert_audio(waveform, sr, target_sr, 1), target_sr return waveform, sr tts = None -def process( +# this is for computing SIM-O, but can probably technically be used for scoring similar utterances +@cache +def _load_sim_model(device="cuda", dtype="float16", feat_type="wavlm_base_plus", feat_dim=768): + from ..utils.ext.ecapa_tdnn import ECAPA_TDNN_SMALL + model = ECAPA_TDNN_SMALL(feat_dim=feat_dim, feat_type=feat_type, config_path=None) + model = model.to(device=device, dtype=coerce_dtype(dtype)) + model = model.eval() + + return model + +@torch.no_grad() +def speaker_similarity_embedding( + audio, + **model_kwargs, +): + device = model_kwargs.get("device", "cuda") + dtype = model_kwargs.get("dtype", "float16") + + model = _load_sim_model(**model_kwargs) + if isinstance(audio, str) or isinstance(audio, Path): + audio = load_audio(audio, 16000) + + audio, sr = audio + audio = audio.to(device=device, dtype=coerce_dtype(dtype)) + + return model(audio) + + +def batch_similar_utterances( speaker_path, yaml, text=False, @@ -266,7 +297,7 @@ def main(): if args.skip_existing and metadata_keys and "similar" in metadata[metadata_keys[-1]]: return - similarities = process( + similarities = batch_similar_utterances( speaker_path=cfg.data_dir / speaker_name, yaml=args.yaml, text=args.text, @@ -314,7 +345,7 @@ def main(): add( data_dir, type="noise", texts=False ) elif args.input_speaker: - similarities = process( + similarities = batch_similar_utterances( speaker_path=args.input_speaker, yaml=args.yaml, text=args.text, diff --git a/vall_e/emb/transcribe.py b/vall_e/emb/transcribe.py index 6d66c86..5d9ec5f 100644 --- a/vall_e/emb/transcribe.py +++ b/vall_e/emb/transcribe.py @@ -11,9 +11,13 @@ import torchaudio import whisperx +from functools import cache from tqdm.auto import tqdm from pathlib import Path +from ..utils import coerce_dtype + + def pad(num, zeroes): return str(num).zfill(zeroes+1) @@ -21,7 +25,132 @@ def process_items( items, stride=0, stride_offset=0 ): items = sorted( items ) return items if stride == 0 else [ item for i, item in enumerate( items ) if (i+stride_offset) % stride == 0 ] +# major cringe but should automatically unload models when loading a different one +_cached_models = { + "model": (None, None), + "diarization": (None, None), + "align": (None, None), +} +# yes I can write a decorator to do this +def _load_model(model_name="large-v3", device="cuda", dtype="float16", language="auto"): + cache_key = f'{model_name}:{device}:{dtype}:{language}' + if _cached_models["model"][0] == cache_key: + return _cached_models["model"][1] + + del _cached_models["model"] + + if not isinstance( dtype, str ): + if dtype == torch.float32: + dtype = "float32" + elif dtype == torch.float16: + dtype = "float16" + elif dtype == torch.bfloat16: + dtype = "bfloat16" + + # doesnt support it for some reason + if dtype == "bfloat16": + dtype = "float16" + + kwargs = {} + kwargs["compute_type"] = dtype + kwargs["task"] = "transcribe" + kwargs["device"] = device + + if language != "auto": + kwargs["language"] = language + + model = whisperx.load_model(model_name, **kwargs) + + _cached_models["model"] = (cache_key, model) + return model + +def _load_diarization_model(device="cuda"): + cache_key = f'{device}' + + if _cached_models["diarization"][0] == cache_key: + return _cached_models["diarization"][1] + del _cached_models["diarization"] + model = whisperx.DiarizationPipeline(device=device) + _cached_models["diarization"] = (cache_key, model) + return model + +def _load_align_model(language, device="cuda"): + cache_key = f'{language}:{device}' + + if _cached_models["align"][0] == cache_key: + return _cached_models["align"][1] + del _cached_models["align"] + model = whisperx.load_align_model(language_code=language, device=device) + _cached_models["align"] = (cache_key, model) + return model + +# yes I can just do a for-loop +def unload_model(): + del _cached_models["model"] + del _cached_models["diarization"] + del _cached_models["align"] + + _cached_models["model"] = (None, None) + _cached_models["diarization"] = (None, None) + _cached_models["align"] = (None, None) + def transcribe( + audio, + language = "auto", + diarize = False, + batch_size = 16, + verbose=False, + align=True, + **model_kwargs, +): + metadata = { + "segments": [], + "language": "", + "text": "", + "start": 0, + "end": 0, + } + + # load requested models + device = model_kwargs.get("device", "cuda") + model = _load_model(language=language, **model_kwargs) + diarize_model = _load_diarization_model(device=device) if diarize else None + + # audio is a path, load it + if isinstance(audio, str) or isinstance(audio, Path): + #audio = load_audio(audio) + audio = whisperx.load_audio(audio) + + result = model.transcribe(audio, batch_size=batch_size) + + if language == "auto": + language = result["language"] + + if align: + align_model, align_model_metadata = _load_align_model(language=language, device=device) + result = whisperx.align(result["segments"], align_model, align_model_metadata, audio, device, return_char_alignments=False) + + if diarize_model is not None: + diarize_segments = diarize_model(audio) + result = whisperx.assign_word_speakers(diarize_segments, result) + + text = [] + start = 0 + end = 0 + for segment in result["segments"]: + text.append( segment["text"] ) + start = min( start, segment["start"] ) + end = max( end, segment["end"] ) + + metadata["language"] = language + metadata["segments"] = result["segments"] + metadata["text"] = " ".join(text).strip() + metadata["start"] = start + metadata["end"] = end + + return metadata + +def transcribe_batch( input_audio = "voices", input_voice = None, output_metadata = "training/metadata", @@ -49,14 +178,11 @@ def transcribe( if input_voice is not None: only_speakers = [input_voice] - # - model = whisperx.load_model(model_name, device, compute_type=dtype) + """ align_model, align_model_metadata, align_model_language = (None, None, None) - if diarize: - diarize_model = whisperx.DiarizationPipeline(device=device) - else: - diarize_model = None - + model =_load_model(model_name, device, compute_type=dtype) + diarize_model = _load_diarization_model(device=device) if diarize else None + """ for dataset_name in os.listdir(f'./{input_audio}/'): if not os.path.isdir(f'./{input_audio}/{dataset_name}/'): @@ -95,7 +221,10 @@ def transcribe( if os.path.isdir(inpath): continue + + metadata[filename] = transcribe( inpath, model_name=model_name, diarize=diarize, device=device, dtype=dtype ) + """ metadata[filename] = { "segments": [], "language": "", @@ -108,15 +237,10 @@ def transcribe( result = model.transcribe(audio, batch_size=batch_size) language = result["language"] - """ - if language[:2] not in ["ja"]: - language = "en" - """ - if align_model_language != language: tqdm.write(f'Loading language: {language}') - align_model, align_model_metadata = whisperx.load_align_model(language_code=language, device=device) align_model_language = language + align_model, align_model_metadata = _load_align_model(language=language, device=device) result = whisperx.align(result["segments"], align_model, align_model_metadata, audio, device, return_char_alignments=False) @@ -138,6 +262,7 @@ def transcribe( metadata[filename]["text"] = " ".join(text).strip() metadata[filename]["start"] = start metadata[filename]["end"] = end + """ open(outpath, 'w', encoding='utf-8').write(json.dumps(metadata)) @@ -169,7 +294,7 @@ def main(): args.stride_offset = int(args.device) args.device = f'cuda:{args.device}' - transcribe( + transcribe_batch( input_audio = args.input_audio, input_voice = args.input_voice, output_metadata = args.output_metadata, diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index ba62a32..ae9bcc0 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -110,8 +110,10 @@ def load_engines(training=True, **model_kwargs): scheduler_class = None params = { + "params": [ param for name, param in model.named_parameters() if name not in model.config.frozen_params ], "lr": cfg.hyperparameters.learning_rate, } + if cfg.hyperparameters.optimizer.lower() == "adamw": params["betas"] = (0.9, 0.96) params["eps"] = 1e-07 @@ -129,17 +131,30 @@ def load_engines(training=True, **model_kwargs): params['d_coef'] = params['lr'] params['lr'] = 1.0 + elif cfg.hyperparameters.optimizer.lower() in ["apollo","apollo-mini"]: + optimizer_class = ml.Apollo + is_mini = cfg.hyperparameters.optimizer.lower() == "apollo-mini" + param_kwargs = { + "rank": 1 if is_mini else 256, + "proj": "random", + "scale_type": "tensor" if is_mini else "channel", + "scale": 128 if is_mini else 1, + "update_proj_gap": 200, + "proj_type": "std", + } + # grab any extra configs from the YAML + param_kwargs.update(cfg.hyperparameters.optimizer_params) + # and blank it so it doesn't update the main optimizer kwargs + cfg.hyperparameters.optimizer_params = {} + # settings are stored under params + params["params"] = [dict(params=params["params"], **param_kwargs)] elif cfg.hyperparameters.optimizer.lower() == "adagrad": optimizer_class = ml.Adagrad else: raise ValueError(f'Optimizer specified not implemented: {cfg.hyperparameters.optimizer}') params.update(cfg.hyperparameters.optimizer_params) - - optimizer = optimizer_class( - [ param for name, param in model.named_parameters() if name not in model.config.frozen_params ], - **params, - ) + optimizer = optimizer_class(**params) if cfg.hyperparameters.scheduler.lower() == "schedulefree": if cfg.hyperparameters.optimizer.lower() == "adamw": diff --git a/vall_e/inference.py b/vall_e/inference.py index 5b92213..cd10384 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -296,7 +296,7 @@ class TTS(): use_lora=use_lora, ) - with torch.autocast("cuda", dtype=dtype, enabled=amp): + with torch.autocast(self.device, dtype=dtype, enabled=amp): if model_len is not None: # extra kwargs duration_padding = sampling_kwargs.pop("duration_padding", 1.05) @@ -384,7 +384,7 @@ class TTS(): resp = to_device(resp, device=self.device, dtype=torch.int16) lang = to_device(lang, device=self.device, dtype=torch.uint8) - with torch.autocast("cuda", dtype=dtype, enabled=amp): + with torch.autocast(self.device, dtype=dtype, enabled=amp): model = model_ar if model_ar is not None else model_nar if model is not None: text_list = model( @@ -430,7 +430,7 @@ class TTS(): phns = to_device(phns, device=self.device, dtype=torch.uint8 if len(self.symmap) < 256 else torch.int16) lang = to_device(lang, device=self.device, dtype=torch.uint8) - with torch.autocast("cuda", dtype=dtype, enabled=amp): + with torch.autocast(self.device, dtype=dtype, enabled=amp): input_kwargs = dict( text_list=[phns], proms_list=[prom], diff --git a/vall_e/metrics.py b/vall_e/metrics.py new file mode 100644 index 0000000..67388b8 --- /dev/null +++ b/vall_e/metrics.py @@ -0,0 +1,33 @@ +# handles objective metric calculations, such as WER and SIM-O + +#from .emb.transcribe import transcribe +from .emb.similar import speaker_similarity_embedding +from .emb.transcribe import transcribe +from .emb.g2p import detect_language +from .data import normalize_text + +import torch.nn.functional as F + +from pathlib import Path +from torcheval.metrics.functional import word_error_rate + +def wer( audio, reference, language="auto", **transcription_kwargs ): + if language == "auto": + language = detect_language( reference ) + + transcription = transcribe( audio, language=language, align=False, **transcription_kwargs )["text"] + + # reference audio needs transcribing too + if isinstance( reference, Path ): + reference = transcribe( reference, language=language, align=False, **transcription_kwargs )["text"] + + transcription = normalize_text( transcription ) + reference = normalize_text( reference ) + + return word_error_rate([transcription], [reference]).item() + +def sim_o( audio, reference, **kwargs ): + audio_emb = speaker_similarity_embedding( audio, **kwargs ) + reference_emb = speaker_similarity_embedding( reference, **kwargs ) + + return F.cosine_similarity( audio_emb, reference_emb ).item() \ No newline at end of file diff --git a/vall_e/utils/__init__.py b/vall_e/utils/__init__.py index 3dbe232..846d179 100755 --- a/vall_e/utils/__init__.py +++ b/vall_e/utils/__init__.py @@ -15,5 +15,6 @@ from .utils import ( prune_missing, clamp, md5_hash, - convert_kwargs + convert_kwargs, + coerce_dtype ) \ No newline at end of file diff --git a/vall_e/utils/ext/__init__.py b/vall_e/utils/ext/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vall_e/utils/ext/apollo.py b/vall_e/utils/ext/apollo.py new file mode 100644 index 0000000..51b4983 --- /dev/null +++ b/vall_e/utils/ext/apollo.py @@ -0,0 +1,433 @@ +# "borrowed" with love from https://github.com/MadsToftrup/Apollo-dev/blob/main/galore_torch/apollo.py +# to be replaced with the official implementation (https://github.com/zhuhanqing/APOLLO) maybe + +import torch +import math +import numpy as np + +from torch import nn +from torch.optim import Optimizer + +from typing import Any, Callable, Dict, Generator, Iterable, Optional, Sequence, Union, Tuple + +from transformers.utils.versions import require_version + +class GaLoreProjector: + def __init__(self, rank, verbose=False, update_proj_gap=200, scale=1.0, proj_type='std'): + self.rank = rank + self.verbose = verbose + self.update_proj_gap = update_proj_gap + self.scale = scale + self.ortho_matrix = None + self.proj_type = proj_type + self.svd_count = 0 + + def project(self, full_rank_grad, iter): + + if self.ortho_matrix is not None and self.ortho_matrix.device != full_rank_grad.device: + self.ortho_matrix = self.ortho_matrix.to(full_rank_grad.device) + + if self.proj_type == 'std': + if full_rank_grad.shape[0] >= full_rank_grad.shape[1]: + if self.ortho_matrix is None or iter % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right') + self.svd_count += 1 + low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t()) + else: + if self.ortho_matrix is None or iter % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left') + self.svd_count += 1 + low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad) + elif self.proj_type == 'reverse_std': + if full_rank_grad.shape[0] >= full_rank_grad.shape[1]: + if self.ortho_matrix is None or iter % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left') + self.svd_count += 1 + low_rank_grad = torch.matmul(self.ortho_matrix.t(),full_rank_grad) + else: + if self.ortho_matrix is None or iter % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right') + self.svd_count += 1 + low_rank_grad = torch.matmul(full_rank_grad,self.ortho_matrix.t()) + elif self.proj_type == 'right': + if self.ortho_matrix is None or iter % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right') + self.svd_count += 1 + low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t()) + elif self.proj_type == 'left': + if self.ortho_matrix is None or iter % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left') + self.svd_count += 1 + low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad) + elif self.proj_type == 'full': + if self.ortho_matrix is None or iter % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='full') + self.svd_count += 1 + low_rank_grad = torch.matmul(self.ortho_matrix[0].t(), full_rank_grad) @ self.ortho_matrix[1].t() + + return low_rank_grad + + def project_back(self, low_rank_grad): + + if self.proj_type == 'std': + if low_rank_grad.shape[0] >= low_rank_grad.shape[1]: + full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix) + else: + full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad) + elif self.proj_type == 'reverse_std': + if low_rank_grad.shape[0] <= low_rank_grad.shape[1]: # note this is different from std + full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad) + else: + full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix) + elif self.proj_type == 'right': + full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix) + elif self.proj_type == 'left': + full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad) + elif self.proj_type == 'full': + full_rank_grad = torch.matmul(self.ortho_matrix[0], low_rank_grad) @ self.ortho_matrix[1] + return full_rank_grad * self.scale + + # svd decomposition + def get_orthogonal_matrix(self, weights, rank, type): + module_params = weights + + if module_params.data.dtype != torch.float: + float_data = False + original_type = module_params.data.dtype + original_device = module_params.data.device + matrix = module_params.data.float() + else: + float_data = True + matrix = module_params.data + + U, s, Vh = torch.linalg.svd(matrix, full_matrices = False) + + #make the smaller matrix always to be orthogonal matrix + if type=='right': + A = U[:, :rank] @ torch.diag(s[:rank]) + B = Vh[:rank, :] + + if not float_data: + B = B.to(original_device).type(original_type) + return B + elif type=='left': + A = U[:, :rank] + B = torch.diag(s[:rank]) @ Vh[:rank, :] + if not float_data: + A = A.to(original_device).type(original_type) + return A + elif type=='full': + A = U[:, :rank] + B = Vh[:rank, :] + if not float_data: + A = A.to(original_device).type(original_type) + B = B.to(original_device).type(original_type) + return [A, B] + else: + raise ValueError('type should be left, right or full') + +def stable_randn( + shape: Union[int, Sequence[int]], + seed: int, + device: Optional[Union[str, torch.device]] = None, + dtype: Optional[torch.dtype] = torch.float32, +): + if device is None: + device = torch.device("cpu") + generator = torch.Generator(device=device).manual_seed(seed) + rn = torch.randn(shape, generator=generator, device=generator.device, dtype=dtype) + return rn + + +def next_seed(seed: int, adv: int = 0xF): + """ + This is a naive helper function to generate a new seed from the given seed. + """ + generator = torch.Generator().manual_seed(seed) + return torch.randint( + 0, torch.iinfo(torch.int64).max, (adv,), generator=generator, device=generator.device + ).tolist()[-1] + + +def split_seed(seed: int): + generator = torch.Generator().manual_seed(seed) + return tuple( + torch.randint(0, torch.iinfo(torch.int64).max, (2,), generator=generator, device=generator.device).tolist() + ) + + +class GradientProjector: + def __init__( + self, rank, update_proj_gap=200, alpha=1.0, proj_type="std", seed=0 + ): + # This is a lazy implementation as we store the projection matrix instead of re-generation every iteration + self.rank = rank + self.update_proj_gap = update_proj_gap + self.alpha = alpha + self.proj_type = proj_type + + self.ortho_matrix = None + self.seed = seed + + def project(self, full_rank_grad, iter): + + if self.proj_type == "std": + if full_rank_grad.shape[0] >= full_rank_grad.shape[1]: + if self.ortho_matrix is None or iter % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix( + full_rank_grad, self.rank, type="right", seed=self.seed + ) + self.seed = next_seed(self.seed) + low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t()) + else: + if self.ortho_matrix is None or iter % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix( + full_rank_grad, self.rank, type="left", seed=self.seed + ) + self.seed = next_seed(self.seed) + + low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad) + elif self.proj_type == "reverse_std": + if full_rank_grad.shape[0] >= full_rank_grad.shape[1]: + if self.ortho_matrix is None or iter % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix( + full_rank_grad, self.rank, type="left", seed=self.seed + ) + self.seed = next_seed(self.seed) + + low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad) + else: + if self.ortho_matrix is None or iter % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix( + full_rank_grad, self.rank, type="right", seed=self.seed + ) + self.seed = next_seed(self.seed) + low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t()) + elif self.proj_type == "right": + if self.ortho_matrix is None or iter % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix( + full_rank_grad, self.rank, type="right", seed=self.seed + ) + self.seed = next_seed(self.seed) + low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t()) + elif self.proj_type == "left": + if self.ortho_matrix is None or iter % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix( + full_rank_grad, self.rank, type="left", seed=self.seed + ) + self.seed = next_seed(self.seed) + low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad) + elif self.proj_type == "full": + if self.ortho_matrix is None or iter % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix( + full_rank_grad, self.rank, type="full", seed=self.seed + ) + self.seed = next_seed(self.seed) + low_rank_grad = ( + torch.matmul(self.ortho_matrix[0].t(), full_rank_grad) + @ self.ortho_matrix[1].t() + ) + + return low_rank_grad + + # random low rank projection + def get_orthogonal_matrix(self, weights, rank, type, seed): + module_params = weights + + if module_params.data.dtype != torch.float: + float_data = False + original_type = module_params.data.dtype + original_device = module_params.data.device + matrix = module_params.data.float() + else: + float_data = True + matrix = module_params.data + + if type == "left": + proj = stable_randn( + (matrix.shape[0], rank), seed=seed, device=matrix.device, dtype=matrix.dtype + ) / math.sqrt(rank) + if not float_data: + proj = proj.to(original_device).type(original_type) + return proj + elif type == "right": + proj = stable_randn( + (rank, matrix.shape[1]), seed=seed, device=matrix.device, dtype=matrix.dtype + ) / math.sqrt(rank) + if not float_data: + proj = proj.to(original_device).type(original_type) + return proj + elif type == "full": + raise NotImplementedError("full rank projection is not implemented yet") + else: + raise ValueError("type should be left, right or full") + +class Apollo(Optimizer): + """ + Implements Adam algorithm with weight decay fix as introduced in [Decoupled Weight Decay + Regularization](https://arxiv.org/abs/1711.05101). + + Parameters: + params (`Iterable[nn.parameter.Parameter]`): + Iterable of parameters to optimize or dictionaries defining parameter groups. + lr (`float`, *optional*, defaults to 0.001): + The learning rate to use. + betas (`Tuple[float,float]`, *optional*, defaults to `(0.9, 0.999)`): + Adam's betas parameters (b1, b2). + eps (`float`, *optional*, defaults to 1e-06): + Adam's epsilon for numerical stability. + weight_decay (`float`, *optional*, defaults to 0.0): + Decoupled weight decay to apply. + correct_bias (`bool`, *optional*, defaults to `True`): + Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`). + no_deprecation_warning (`bool`, *optional*, defaults to `False`): + A flag used to disable the deprecation warning (set to `True` to disable the warning). + """ + + def __init__( + self, + params: Iterable[nn.parameter.Parameter], + lr: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-6, + weight_decay: float = 0.0, + correct_bias: bool = True, + scale_front: bool = False, + ): + if lr < 0.0: + raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0") + defaults = {"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "correct_bias": correct_bias} + super().__init__(params, defaults) + + self.scale_front = scale_front + + params_idx = 0 + for group in self.param_groups: + for p in group["params"]: + params_idx += 1 + if p.requires_grad: + self.state[p]["seed"] = params_idx + + @torch.no_grad() + def step(self, closure: Callable = None): + """ + Performs a single optimization step. + + Arguments: + closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad + if grad.is_sparse: + raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead") + + state = self.state[p] + + if "step" not in state: + state["step"] = 0 + + # GaLore Projection + if "rank" in group: + if "projector" not in state: + if group["proj"] == "random": + state["projector"] = GradientProjector(group["rank"], + update_proj_gap=group["update_proj_gap"], + alpha=group["scale"], + proj_type=group["proj_type"], + seed=state["seed"]) + + elif group["proj"] == "svd": + state["projector"] = GaLoreProjector(group["rank"], + update_proj_gap=group["update_proj_gap"], + scale=group["scale"], + proj_type=group["proj_type"]) + + grad = state["projector"].project(grad, state["step"]) + + # State initialization + if "exp_avg" not in state: + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(grad) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like(grad) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + beta1, beta2 = group["betas"] + + state["step"] += 1 + + # Decay the first and second moment running average coefficient + # In-place operations to update the averages at the same time + exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1)) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) + denom = exp_avg_sq.sqrt().add_(group["eps"]) + + step_size = group["lr"] + if group["correct_bias"]: # No bias correction for Bert + bias_correction1 = 1.0 - beta1 ** state["step"] + bias_correction2 = 1.0 - beta2 ** state["step"] + step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 + + # compute norm gradient + norm_grad = exp_avg / denom + + if "rank" in group: + if group['scale_type'] == 'channel': + norm_dim = 0 if norm_grad.shape[0] < norm_grad.shape[1] else 1 + scaling_factor = ( + torch.norm(norm_grad, dim=norm_dim) / + (torch.norm(grad, dim=norm_dim) + 1e-8) + ) + if norm_dim == 1: + scaling_factor = scaling_factor.unsqueeze(1) + + elif group['scale_type'] == 'tensor': + scaling_factor = ( + torch.norm(norm_grad) / + (torch.norm(grad) + 1e-8) + ) + + scaling_grad = p.grad * scaling_factor + + # Use Norm-Growth Limiter in Fira + if "scaling_grad" in state: + scaling_grad_norm = torch.norm(scaling_grad) + limiter = max( + scaling_grad_norm / + (state["scaling_grad"] + 1e-8), + 1.01, + ) / 1.01 + scaling_grad = scaling_grad / limiter + state["scaling_grad"] = scaling_grad_norm / limiter + else: + state["scaling_grad"] = torch.norm(scaling_grad) + + norm_grad = scaling_grad * np.sqrt(group["scale"]) + + p.add_(norm_grad, alpha=-step_size) + + # Just adding the square of the weights to the loss function is *not* + # the correct way of using L2 regularization/weight decay with Adam, + # since that will interact with the m and v parameters in strange ways. + # + # Instead we want to decay the weights in a manner that doesn't interact + # with the m/v parameters. This is equivalent to adding the square + # of the weights to the loss with plain (non-momentum) SGD. + # Add weight decay at the end (fixed version) + if group["weight_decay"] > 0.0: + p.add_(p, alpha=(-group["lr"] * group["weight_decay"])) + + return loss \ No newline at end of file diff --git a/vall_e/utils/ext/ecapa_tdnn.py b/vall_e/utils/ext/ecapa_tdnn.py new file mode 100644 index 0000000..29a99fb --- /dev/null +++ b/vall_e/utils/ext/ecapa_tdnn.py @@ -0,0 +1,467 @@ +# borrowed with love from "https://github.com/keonlee9420/evaluate-zero-shot-tts/blob/master/src/evaluate_zero_shot_tts/utils/speaker_verification/models/ecapa_tdnn.py" +# part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchaudio.transforms as trans + +#from .utils import UpstreamExpert + +""" Res2Conv1d + BatchNorm1d + ReLU +""" + + +class Res2Conv1dReluBn(nn.Module): + """ + in_channels == out_channels == channels + """ + + def __init__( + self, + channels, + kernel_size=1, + stride=1, + padding=0, + dilation=1, + bias=True, + scale=4, + ): + super().__init__() + assert channels % scale == 0, "{} % {} != 0".format(channels, scale) + self.scale = scale + self.width = channels // scale + self.nums = scale if scale == 1 else scale - 1 + + self.convs = [] + self.bns = [] + for i in range(self.nums): + self.convs.append( + nn.Conv1d( + self.width, + self.width, + kernel_size, + stride, + padding, + dilation, + bias=bias, + ) + ) + self.bns.append(nn.BatchNorm1d(self.width)) + self.convs = nn.ModuleList(self.convs) + self.bns = nn.ModuleList(self.bns) + + def forward(self, x): + out = [] + spx = torch.split(x, self.width, 1) + for i in range(self.nums): + if i == 0: + sp = spx[i] + else: + sp = sp + spx[i] + # Order: conv -> relu -> bn + sp = self.convs[i](sp) + sp = self.bns[i](F.relu(sp)) + out.append(sp) + if self.scale != 1: + out.append(spx[self.nums]) + out = torch.cat(out, dim=1) + + return out + + +""" Conv1d + BatchNorm1d + ReLU +""" + + +class Conv1dReluBn(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + dilation=1, + bias=True, + ): + super().__init__() + self.conv = nn.Conv1d( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + bias=bias, + ) + self.bn = nn.BatchNorm1d(out_channels) + + def forward(self, x): + return self.bn(F.relu(self.conv(x))) + + +""" The SE connection of 1D case. +""" + + +class SE_Connect(nn.Module): + def __init__(self, channels, se_bottleneck_dim=128): + super().__init__() + self.linear1 = nn.Linear(channels, se_bottleneck_dim) + self.linear2 = nn.Linear(se_bottleneck_dim, channels) + + def forward(self, x): + out = x.mean(dim=2) + out = F.relu(self.linear1(out)) + out = torch.sigmoid(self.linear2(out)) + out = x * out.unsqueeze(2) + + return out + + +""" SE-Res2Block of the ECAPA-TDNN architecture. +""" + + +# def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale): +# return nn.Sequential( +# Conv1dReluBn(channels, 512, kernel_size=1, stride=1, padding=0), +# Res2Conv1dReluBn(512, kernel_size, stride, padding, dilation, scale=scale), +# Conv1dReluBn(512, channels, kernel_size=1, stride=1, padding=0), +# SE_Connect(channels) +# ) + + +class SE_Res2Block(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + scale, + se_bottleneck_dim, + ): + super().__init__() + self.Conv1dReluBn1 = Conv1dReluBn( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + self.Res2Conv1dReluBn = Res2Conv1dReluBn( + out_channels, kernel_size, stride, padding, dilation, scale=scale + ) + self.Conv1dReluBn2 = Conv1dReluBn( + out_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim) + + self.shortcut = None + if in_channels != out_channels: + self.shortcut = nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + ) + + def forward(self, x): + residual = x + if self.shortcut: + residual = self.shortcut(x) + + x = self.Conv1dReluBn1(x) + x = self.Res2Conv1dReluBn(x) + x = self.Conv1dReluBn2(x) + x = self.SE_Connect(x) + + return x + residual + + +""" Attentive weighted mean and standard deviation pooling. +""" + + +class AttentiveStatsPool(nn.Module): + def __init__( + self, in_dim, attention_channels=128, global_context_att=False + ): + super().__init__() + self.global_context_att = global_context_att + + # Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs. + if global_context_att: + self.linear1 = nn.Conv1d( + in_dim * 3, attention_channels, kernel_size=1 + ) # equals W and b in the paper + else: + self.linear1 = nn.Conv1d( + in_dim, attention_channels, kernel_size=1 + ) # equals W and b in the paper + self.linear2 = nn.Conv1d( + attention_channels, in_dim, kernel_size=1 + ) # equals V and k in the paper + + def forward(self, x): + if self.global_context_att: + context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x) + context_std = torch.sqrt( + torch.var(x, dim=-1, keepdim=True) + 1e-10 + ).expand_as(x) + x_in = torch.cat((x, context_mean, context_std), dim=1) + else: + x_in = x + + # DON'T use ReLU here! In experiments, I find ReLU hard to converge. + alpha = torch.tanh(self.linear1(x_in)) + # alpha = F.relu(self.linear1(x_in)) + alpha = torch.softmax(self.linear2(alpha), dim=2) + mean = torch.sum(alpha * x, dim=2) + residuals = torch.sum(alpha * (x**2), dim=2) - mean**2 + std = torch.sqrt(residuals.clamp(min=1e-9)) + return torch.cat([mean, std], dim=1) + + +class ECAPA_TDNN(nn.Module): + def __init__( + self, + feat_dim=80, + channels=512, + emb_dim=192, + global_context_att=False, + feat_type="fbank", + sr=16000, + feature_selection="hidden_states", + update_extract=False, + config_path=None, + ): + super().__init__() + + self.feat_type = feat_type + self.feature_selection = feature_selection + self.update_extract = update_extract + self.sr = sr + + if feat_type == "fbank" or feat_type == "mfcc": + self.update_extract = False + + win_len = int(sr * 0.025) + hop_len = int(sr * 0.01) + + if feat_type == "fbank": + self.feature_extract = trans.MelSpectrogram( + sample_rate=sr, + n_fft=512, + win_length=win_len, + hop_length=hop_len, + f_min=0.0, + f_max=sr // 2, + pad=0, + n_mels=feat_dim, + ) + elif feat_type == "mfcc": + melkwargs = { + "n_fft": 512, + "win_length": win_len, + "hop_length": hop_len, + "f_min": 0.0, + "f_max": sr // 2, + "pad": 0, + } + self.feature_extract = trans.MFCC( + sample_rate=sr, + n_mfcc=feat_dim, + log_mels=False, + melkwargs=melkwargs, + ) + else: + """ + if config_path is None: + self.feature_extract = torch.hub.load("s3prl/s3prl", feat_type) + else: + self.feature_extract = UpstreamExpert(config_path) + """ + self.feature_extract = torch.hub.load("s3prl/s3prl", feat_type) + if len(self.feature_extract.model.encoder.layers) == 24 and hasattr( + self.feature_extract.model.encoder.layers[23].self_attn, + "fp32_attention", + ): + self.feature_extract.model.encoder.layers[ + 23 + ].self_attn.fp32_attention = False + if len(self.feature_extract.model.encoder.layers) == 24 and hasattr( + self.feature_extract.model.encoder.layers[11].self_attn, + "fp32_attention", + ): + self.feature_extract.model.encoder.layers[ + 11 + ].self_attn.fp32_attention = False + + self.feat_num = self.get_feat_num() + self.feature_weight = nn.Parameter(torch.zeros(self.feat_num)) + + if feat_type != "fbank" and feat_type != "mfcc": + freeze_list = [ + "final_proj", + "label_embs_concat", + "mask_emb", + "project_q", + "quantizer", + ] + for name, param in self.feature_extract.named_parameters(): + for freeze_val in freeze_list: + if freeze_val in name: + param.requires_grad = False + break + + if not self.update_extract: + for param in self.feature_extract.parameters(): + param.requires_grad = False + + self.instance_norm = nn.InstanceNorm1d(feat_dim) + # self.channels = [channels] * 4 + [channels * 3] + self.channels = [channels] * 4 + [1536] + + self.layer1 = Conv1dReluBn( + feat_dim, self.channels[0], kernel_size=5, padding=2 + ) + self.layer2 = SE_Res2Block( + self.channels[0], + self.channels[1], + kernel_size=3, + stride=1, + padding=2, + dilation=2, + scale=8, + se_bottleneck_dim=128, + ) + self.layer3 = SE_Res2Block( + self.channels[1], + self.channels[2], + kernel_size=3, + stride=1, + padding=3, + dilation=3, + scale=8, + se_bottleneck_dim=128, + ) + self.layer4 = SE_Res2Block( + self.channels[2], + self.channels[3], + kernel_size=3, + stride=1, + padding=4, + dilation=4, + scale=8, + se_bottleneck_dim=128, + ) + + # self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1) + cat_channels = channels * 3 + self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1) + self.pooling = AttentiveStatsPool( + self.channels[-1], + attention_channels=128, + global_context_att=global_context_att, + ) + self.bn = nn.BatchNorm1d(self.channels[-1] * 2) + self.linear = nn.Linear(self.channels[-1] * 2, emb_dim) + + def get_feat_num(self): + self.feature_extract.eval() + wav = [ + torch.randn(self.sr).to( + next(self.feature_extract.parameters()).device + ) + ] + with torch.no_grad(): + features = self.feature_extract(wav) + select_feature = features[self.feature_selection] + if isinstance(select_feature, (list, tuple)): + return len(select_feature) + else: + return 1 + + def get_feat(self, x): + if self.update_extract: + x = self.feature_extract([sample for sample in x]) + else: + with torch.no_grad(): + if self.feat_type == "fbank" or self.feat_type == "mfcc": + x = ( + self.feature_extract(x) + 1e-6 + ) # B x feat_dim x time_len + else: + x = self.feature_extract([sample for sample in x]) + + if self.feat_type == "fbank": + x = x.log() + + if self.feat_type != "fbank" and self.feat_type != "mfcc": + x = x[self.feature_selection] + if isinstance(x, (list, tuple)): + x = torch.stack(x, dim=0) + else: + x = x.unsqueeze(0) + norm_weights = ( + F.softmax(self.feature_weight, dim=-1) + .unsqueeze(-1) + .unsqueeze(-1) + .unsqueeze(-1) + ) + x = (norm_weights * x).sum(dim=0) + x = torch.transpose(x, 1, 2) + 1e-6 + + x = self.instance_norm(x) + return x + + def forward(self, x): + x = self.get_feat(x) + + out1 = self.layer1(x) + out2 = self.layer2(out1) + out3 = self.layer3(out2) + out4 = self.layer4(out3) + + out = torch.cat([out2, out3, out4], dim=1) + out = F.relu(self.conv(out)) + out = self.bn(self.pooling(out)) + out = self.linear(out) + + return out + + +def ECAPA_TDNN_SMALL( + feat_dim, + emb_dim=256, + feat_type="fbank", + sr=16000, + feature_selection="hidden_states", + update_extract=False, + config_path=None, +): + return ECAPA_TDNN( + feat_dim=feat_dim, + channels=512, + emb_dim=emb_dim, + feat_type=feat_type, + sr=sr, + feature_selection=feature_selection, + update_extract=update_extract, + config_path=config_path, + ) + + +if __name__ == "__main__": + x = torch.zeros(2, 32000) + model = ECAPA_TDNN_SMALL( + feat_dim=768, + emb_dim=256, + feat_type="hubert_base", + feature_selection="hidden_states", + update_extract=False, + ) + + out = model(x) + # print(model) + print(out.shape) \ No newline at end of file diff --git a/vall_e/utils/unsloth.py b/vall_e/utils/ext/unsloth.py similarity index 100% rename from vall_e/utils/unsloth.py rename to vall_e/utils/ext/unsloth.py diff --git a/vall_e/utils/utils.py b/vall_e/utils/utils.py index 6e789ef..6c82ac2 100755 --- a/vall_e/utils/utils.py +++ b/vall_e/utils/utils.py @@ -124,6 +124,20 @@ def _get_named_modules(module, attrname): if hasattr(module, attrname): yield name, module +def coerce_dtype(s): + # not a string + if not isinstance(s, str): + return s + + if s == "float16": + return torch.float16 + if s == "bfloat16": + return torch.bfloat16 + if s == "float8_e5m2": + return torch.float8_e5m2 + if s == "float8_e4m3fn": + return torch.float8_e4m3fn + return torch.float32 def gather_attribute(module, attrname, delete=True, prefix=True): ret = {} diff --git a/vall_e/utils/wrapper.py b/vall_e/utils/wrapper.py index c9aa087..cb40e9b 100755 --- a/vall_e/utils/wrapper.py +++ b/vall_e/utils/wrapper.py @@ -103,12 +103,18 @@ if cfg.optimizations.tensorrt: if cfg.optimizations.unsloth: try: - from .unsloth import apply_unsloth_offloaded_gradient_checkpoint_monkey_patch + from .ext.unsloth import apply_unsloth_offloaded_gradient_checkpoint_monkey_patch #apply_unsloth_offloaded_gradient_checkpoint_monkey_patch() except Exception as e: _logger.warning(f'Error while importing Unsloth: {str(e)}') pass +try: + from .ext.apollo import Apollo +except Exception as e: + _logger.warning(f'Error while importing APOLLO: {str(e)}') + pass + def compile_model(model, backend="auto"): if not backend or backend == "auto": backend = AVAILABLE_COMPILE_BACKENDS[0]