added WER/SIM-O metrics, added APOLLO but I need to test it
This commit is contained in:
parent
fc5e6d8599
commit
8568a93dad
|
@ -10,6 +10,8 @@
|
|||
<caption>LibriSpeech</caption>
|
||||
<tr>
|
||||
<th>Text</th>
|
||||
<th>WER↓</th>
|
||||
<th>SIM-O↑</th>
|
||||
<th>Prompt</th>
|
||||
<th>Our VALL-E</th>
|
||||
<th>Original VALL-E</th>
|
||||
|
@ -24,6 +26,8 @@
|
|||
<caption>Sampled Dataset</caption>
|
||||
<tr>
|
||||
<th>Text</th>
|
||||
<th>WER↓</th>
|
||||
<th>SIM-O↑</th>
|
||||
<th>Prompt</th>
|
||||
<th>Our VALL-E</th>
|
||||
<th>F5-TTS</th>
|
||||
|
|
18
docs/emb.md
18
docs/emb.md
|
@ -77,7 +77,7 @@ I'm uncertain on how to remedy this, as my options are:
|
|||
|
||||
## `transcribe.py`
|
||||
|
||||
This script handles taking raw input audio, and outputting adequate metadata containing transcriptions of said audio through `whisperX`.
|
||||
This script primarily handles taking raw input audio, and outputting adequate metadata containing transcriptions of said audio through `whisperX`.
|
||||
|
||||
The process maintains slices `whisperX` thinks its best per the segments outputted, alongside the deduced language (if not specified).
|
||||
|
||||
|
@ -85,6 +85,18 @@ One limiting factor is that transcription transcribes into normal text, rather t
|
|||
|
||||
Refer to the `__main__`'s arguments for usage details.
|
||||
|
||||
### Metrics
|
||||
|
||||
This script also handles calculating `WER` simply by transcribing the given audio file (and reference, if requested), then comparing the word error rate.
|
||||
|
||||
This process *heavily* relies on text normalization, which currently is lacking, but transcribing the reference should keep things "normalized" per the transcriber.
|
||||
|
||||
### ROCm
|
||||
|
||||
Because life is pain, ROCm requires additional steps to ensure that `whisperX` works. A special fork of `CTranslate2` is required, but simplying following [these](https://github.com/arlo-phoenix/CTranslate2-rocm/blob/rocm/README_ROCM.md) steps should fix things.
|
||||
|
||||
In the future, I would love to replace WhisperX for something simple.
|
||||
|
||||
## `process.py`
|
||||
|
||||
This script handles taking raw input audio and its transcribed metadata, and outputs encoded audio (NumPy) files containing encoded audio and associated metadata.
|
||||
|
@ -108,3 +120,7 @@ When processing a dataset, this requires already having accompanying metadata ge
|
|||
Be *very* careful if you opt to output unsegmented and segmented utterances, as the sliced version may end up amongst the top-K similar candidates.
|
||||
|
||||
Refer to the `__main__`'s arguments for usage details.
|
||||
|
||||
### Metrics
|
||||
|
||||
This script also handles calculating `SIM-O` per [keonlee9420/evaluate-zero-shot-tts](https://github.com/keonlee9420/evaluate-zero-shot-tts/blob/master/src/evaluate_zero_shot_tts/utils/speaker_verification/verification.py), by making use of a model to create an embedding of a speaker, then computing cosine similarities on those embeddings.
|
3
setup.py
3
setup.py
|
@ -91,6 +91,9 @@ setup(
|
|||
"causal-conv1d",
|
||||
"mamba-ssm",
|
||||
|
||||
#
|
||||
"torcheval",
|
||||
|
||||
# attention helpers
|
||||
"xformers",
|
||||
"sageattention==1.0.6",
|
||||
|
|
|
@ -22,7 +22,7 @@ from pathlib import Path
|
|||
|
||||
from .utils.distributed import world_size
|
||||
from .utils.io import torch_load
|
||||
from .utils import set_seed, prune_missing, md5_hash
|
||||
from .utils import set_seed, prune_missing, md5_hash, coerce_dtype
|
||||
|
||||
@dataclass()
|
||||
class BaseConfig:
|
||||
|
@ -721,15 +721,7 @@ class Trainer:
|
|||
|
||||
@cached_property
|
||||
def dtype(self):
|
||||
if self.weight_dtype == "float16":
|
||||
return torch.float16
|
||||
if self.weight_dtype == "bfloat16":
|
||||
return torch.bfloat16
|
||||
if self.weight_dtype == "float8_e5m2":
|
||||
return torch.float8_e5m2
|
||||
if self.weight_dtype == "float8_e4m3fn":
|
||||
return torch.float8_e4m3fn
|
||||
return torch.float32
|
||||
return coerce_dtype(self.weight_dtype)
|
||||
|
||||
@cached_property
|
||||
def scale_loss(self):
|
||||
|
@ -748,17 +740,7 @@ class Inference:
|
|||
|
||||
@property
|
||||
def dtype(self):
|
||||
if self.weight_dtype == "float16":
|
||||
return torch.float16
|
||||
if self.weight_dtype == "bfloat16":
|
||||
return torch.bfloat16
|
||||
if self.weight_dtype == "int8":
|
||||
return torch.int8
|
||||
if self.weight_dtype == "float8_e5m2":
|
||||
return torch.float8_e5m2
|
||||
if self.weight_dtype == "float8_e4m3fn":
|
||||
return torch.float8_e4m3fn
|
||||
return torch.float32
|
||||
return coerce_dtype(self.weight_dtype)
|
||||
|
||||
@dataclass()
|
||||
class Optimizations:
|
||||
|
|
|
@ -63,6 +63,13 @@ def sentence_split( s, split_by="sentences", quote_placeholder="<QUOTE>" ):
|
|||
sentences = nltk.sent_tokenize(s)
|
||||
return [ sentence.replace(quote_placeholder, '"') for sentence in sentences if sentence ]
|
||||
|
||||
# to-do: improve upon this since it's kind of ass
|
||||
# this might be better to live in emb.g2p
|
||||
def normalize_text( s ):
|
||||
s = s.lower()
|
||||
s = re.sub(r'[^\w\s]', '', s)
|
||||
return s
|
||||
|
||||
@cache
|
||||
def get_random_prompts( validation=False, min_length=0, tokenized=False ):
|
||||
duration_range = [ 5.5, 12.0 ] # to-do: pull from cfg.dataset.duration_range
|
||||
|
@ -1070,7 +1077,9 @@ class Dataset(_Dataset):
|
|||
return root / name
|
||||
|
||||
def sample_prompts(self, spkr_name, reference, should_trim=True):
|
||||
if not cfg.dataset.prompt_duration_range or cfg.dataset.prompt_duration_range[-1] == 0:
|
||||
# return no prompt if explicitly requested for who knows why
|
||||
# or if there's no other speakers to sample from (Emilia has a lot of singleton speakers, but I still want to make use of them)
|
||||
if not cfg.dataset.prompt_duration_range or cfg.dataset.prompt_duration_range[-1] == 0 or len(self.paths_by_spkr_name[key]) <= 1:
|
||||
return None
|
||||
|
||||
prom_list = []
|
||||
|
|
|
@ -20,6 +20,7 @@ import base64
|
|||
import random
|
||||
import logging
|
||||
import time
|
||||
import torch
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -29,6 +30,8 @@ from .inference import TTS
|
|||
from .config import cfg
|
||||
from .data import create_train_dataloader, create_val_dataloader, get_random_prompt
|
||||
from .emb.qnt import decode_to_file
|
||||
from .metrics import wer, sim_o
|
||||
from .utils import setup_logging
|
||||
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
|
@ -230,6 +233,8 @@ def main():
|
|||
elif args.comparison:
|
||||
raise Exception(f"Unrecognized comparison flag: {args.comparison}")
|
||||
|
||||
setup_logging()
|
||||
|
||||
# read html template
|
||||
html = open(args.demo_dir / "index.template.html", "r", encoding="utf-8").read()
|
||||
|
||||
|
@ -318,6 +323,7 @@ def main():
|
|||
|
||||
inputs = []
|
||||
outputs = []
|
||||
metrics_inputs = []
|
||||
comparison_inputs = []
|
||||
for k, sample_dir in samples_dirs.items():
|
||||
if not sample_dir.exists():
|
||||
|
@ -359,9 +365,15 @@ def main():
|
|||
|
||||
# segregate comparisons into its own batch because they use different kwargs (and I do not support variadic-batched kwargs)
|
||||
if args.comparison:
|
||||
comparison_inputs.append((text, prompt, language, out_path_comparison))
|
||||
if (args.skip_existing and not out_path_comparison.exists()) or not (args.skip_existing):
|
||||
comparison_inputs.append((text, prompt, language, out_path_comparison))
|
||||
|
||||
inputs.append((text, prompt, language, out_path))
|
||||
metrics_inputs.append((text, language, out_path_comparison, reference))
|
||||
|
||||
if (args.skip_existing and not out_path.exists()) or not (args.skip_existing):
|
||||
inputs.append((text, prompt, language, out_path))
|
||||
|
||||
metrics_inputs.append((text, language, out_path, reference))
|
||||
|
||||
outputs.append((k, samples))
|
||||
|
||||
|
@ -371,10 +383,19 @@ def main():
|
|||
if comparison_inputs:
|
||||
process_batch( tts, comparison_inputs, sampling_kwargs | (comparison_kwargs["enabled"] if args.comparison else {}) )
|
||||
|
||||
metrics_map = {}
|
||||
for text, language, out_path, reference_path in metrics_inputs:
|
||||
wer_score = wer( out_path, text, language=language, device=tts.device, dtype=tts.dtype, model_name="base" )
|
||||
sim_o_score = sim_o( out_path, reference_path, device=tts.device, dtype=tts.dtype )
|
||||
metrics_map[out_path] = (wer_score, sim_o_score)
|
||||
|
||||
# collate entries into HTML
|
||||
for k, samples in outputs:
|
||||
samples = [
|
||||
f'\n\t\t\t<tr>\n\t\t\t\t<td>{text}</td>'+
|
||||
"".join([
|
||||
f'\n\t\t\t\t<td>{metrics_map[audios[1]][0]:.3f}</td><td>{metrics_map[audios[1]][1]:.3f}</td>'
|
||||
] ) +
|
||||
"".join( [
|
||||
f'\n\t\t\t\t<td><audio controls="controls" preload="none"><source src="{str(audio).replace(str(args.demo_dir), args.audio_path_root) if args.audio_path_root else encode(audio)}"/></audio></td>'
|
||||
for audio in audios
|
||||
|
|
|
@ -16,12 +16,13 @@ _logger = logging.getLogger(__name__)
|
|||
|
||||
from tqdm.auto import tqdm
|
||||
from pathlib import Path
|
||||
from functools import cache
|
||||
|
||||
import torchaudio.functional as F
|
||||
import torchaudio.transforms as T
|
||||
|
||||
from ..config import cfg
|
||||
from ..utils import truncate_json
|
||||
from ..utils import truncate_json, coerce_dtype
|
||||
from ..utils.io import json_read, json_write
|
||||
|
||||
from .g2p import encode as phonemize
|
||||
|
@ -29,19 +30,49 @@ from .qnt import encode as quantize, trim, convert_audio
|
|||
|
||||
from ..webui import init_tts
|
||||
|
||||
def load_audio( path ):
|
||||
def load_audio( path, target_sr=None ):
|
||||
waveform, sr = torchaudio.load( path )
|
||||
# mix channels
|
||||
if waveform.shape[0] > 1:
|
||||
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
||||
if target_sr is None:
|
||||
target_sr = cfg.sample_rate
|
||||
# resample
|
||||
waveform, sr = convert_audio(waveform, sr, cfg.sample_rate, 1), cfg.sample_rate
|
||||
waveform, sr = convert_audio(waveform, sr, target_sr, 1), target_sr
|
||||
|
||||
return waveform, sr
|
||||
|
||||
tts = None
|
||||
|
||||
def process(
|
||||
# this is for computing SIM-O, but can probably technically be used for scoring similar utterances
|
||||
@cache
|
||||
def _load_sim_model(device="cuda", dtype="float16", feat_type="wavlm_base_plus", feat_dim=768):
|
||||
from ..utils.ext.ecapa_tdnn import ECAPA_TDNN_SMALL
|
||||
model = ECAPA_TDNN_SMALL(feat_dim=feat_dim, feat_type=feat_type, config_path=None)
|
||||
model = model.to(device=device, dtype=coerce_dtype(dtype))
|
||||
model = model.eval()
|
||||
|
||||
return model
|
||||
|
||||
@torch.no_grad()
|
||||
def speaker_similarity_embedding(
|
||||
audio,
|
||||
**model_kwargs,
|
||||
):
|
||||
device = model_kwargs.get("device", "cuda")
|
||||
dtype = model_kwargs.get("dtype", "float16")
|
||||
|
||||
model = _load_sim_model(**model_kwargs)
|
||||
if isinstance(audio, str) or isinstance(audio, Path):
|
||||
audio = load_audio(audio, 16000)
|
||||
|
||||
audio, sr = audio
|
||||
audio = audio.to(device=device, dtype=coerce_dtype(dtype))
|
||||
|
||||
return model(audio)
|
||||
|
||||
|
||||
def batch_similar_utterances(
|
||||
speaker_path,
|
||||
yaml,
|
||||
text=False,
|
||||
|
@ -266,7 +297,7 @@ def main():
|
|||
if args.skip_existing and metadata_keys and "similar" in metadata[metadata_keys[-1]]:
|
||||
return
|
||||
|
||||
similarities = process(
|
||||
similarities = batch_similar_utterances(
|
||||
speaker_path=cfg.data_dir / speaker_name,
|
||||
yaml=args.yaml,
|
||||
text=args.text,
|
||||
|
@ -314,7 +345,7 @@ def main():
|
|||
add( data_dir, type="noise", texts=False )
|
||||
|
||||
elif args.input_speaker:
|
||||
similarities = process(
|
||||
similarities = batch_similar_utterances(
|
||||
speaker_path=args.input_speaker,
|
||||
yaml=args.yaml,
|
||||
text=args.text,
|
||||
|
|
|
@ -11,9 +11,13 @@ import torchaudio
|
|||
|
||||
import whisperx
|
||||
|
||||
from functools import cache
|
||||
from tqdm.auto import tqdm
|
||||
from pathlib import Path
|
||||
|
||||
from ..utils import coerce_dtype
|
||||
|
||||
|
||||
def pad(num, zeroes):
|
||||
return str(num).zfill(zeroes+1)
|
||||
|
||||
|
@ -21,7 +25,132 @@ def process_items( items, stride=0, stride_offset=0 ):
|
|||
items = sorted( items )
|
||||
return items if stride == 0 else [ item for i, item in enumerate( items ) if (i+stride_offset) % stride == 0 ]
|
||||
|
||||
# major cringe but should automatically unload models when loading a different one
|
||||
_cached_models = {
|
||||
"model": (None, None),
|
||||
"diarization": (None, None),
|
||||
"align": (None, None),
|
||||
}
|
||||
# yes I can write a decorator to do this
|
||||
def _load_model(model_name="large-v3", device="cuda", dtype="float16", language="auto"):
|
||||
cache_key = f'{model_name}:{device}:{dtype}:{language}'
|
||||
if _cached_models["model"][0] == cache_key:
|
||||
return _cached_models["model"][1]
|
||||
|
||||
del _cached_models["model"]
|
||||
|
||||
if not isinstance( dtype, str ):
|
||||
if dtype == torch.float32:
|
||||
dtype = "float32"
|
||||
elif dtype == torch.float16:
|
||||
dtype = "float16"
|
||||
elif dtype == torch.bfloat16:
|
||||
dtype = "bfloat16"
|
||||
|
||||
# doesnt support it for some reason
|
||||
if dtype == "bfloat16":
|
||||
dtype = "float16"
|
||||
|
||||
kwargs = {}
|
||||
kwargs["compute_type"] = dtype
|
||||
kwargs["task"] = "transcribe"
|
||||
kwargs["device"] = device
|
||||
|
||||
if language != "auto":
|
||||
kwargs["language"] = language
|
||||
|
||||
model = whisperx.load_model(model_name, **kwargs)
|
||||
|
||||
_cached_models["model"] = (cache_key, model)
|
||||
return model
|
||||
|
||||
def _load_diarization_model(device="cuda"):
|
||||
cache_key = f'{device}'
|
||||
|
||||
if _cached_models["diarization"][0] == cache_key:
|
||||
return _cached_models["diarization"][1]
|
||||
del _cached_models["diarization"]
|
||||
model = whisperx.DiarizationPipeline(device=device)
|
||||
_cached_models["diarization"] = (cache_key, model)
|
||||
return model
|
||||
|
||||
def _load_align_model(language, device="cuda"):
|
||||
cache_key = f'{language}:{device}'
|
||||
|
||||
if _cached_models["align"][0] == cache_key:
|
||||
return _cached_models["align"][1]
|
||||
del _cached_models["align"]
|
||||
model = whisperx.load_align_model(language_code=language, device=device)
|
||||
_cached_models["align"] = (cache_key, model)
|
||||
return model
|
||||
|
||||
# yes I can just do a for-loop
|
||||
def unload_model():
|
||||
del _cached_models["model"]
|
||||
del _cached_models["diarization"]
|
||||
del _cached_models["align"]
|
||||
|
||||
_cached_models["model"] = (None, None)
|
||||
_cached_models["diarization"] = (None, None)
|
||||
_cached_models["align"] = (None, None)
|
||||
|
||||
def transcribe(
|
||||
audio,
|
||||
language = "auto",
|
||||
diarize = False,
|
||||
batch_size = 16,
|
||||
verbose=False,
|
||||
align=True,
|
||||
**model_kwargs,
|
||||
):
|
||||
metadata = {
|
||||
"segments": [],
|
||||
"language": "",
|
||||
"text": "",
|
||||
"start": 0,
|
||||
"end": 0,
|
||||
}
|
||||
|
||||
# load requested models
|
||||
device = model_kwargs.get("device", "cuda")
|
||||
model = _load_model(language=language, **model_kwargs)
|
||||
diarize_model = _load_diarization_model(device=device) if diarize else None
|
||||
|
||||
# audio is a path, load it
|
||||
if isinstance(audio, str) or isinstance(audio, Path):
|
||||
#audio = load_audio(audio)
|
||||
audio = whisperx.load_audio(audio)
|
||||
|
||||
result = model.transcribe(audio, batch_size=batch_size)
|
||||
|
||||
if language == "auto":
|
||||
language = result["language"]
|
||||
|
||||
if align:
|
||||
align_model, align_model_metadata = _load_align_model(language=language, device=device)
|
||||
result = whisperx.align(result["segments"], align_model, align_model_metadata, audio, device, return_char_alignments=False)
|
||||
|
||||
if diarize_model is not None:
|
||||
diarize_segments = diarize_model(audio)
|
||||
result = whisperx.assign_word_speakers(diarize_segments, result)
|
||||
|
||||
text = []
|
||||
start = 0
|
||||
end = 0
|
||||
for segment in result["segments"]:
|
||||
text.append( segment["text"] )
|
||||
start = min( start, segment["start"] )
|
||||
end = max( end, segment["end"] )
|
||||
|
||||
metadata["language"] = language
|
||||
metadata["segments"] = result["segments"]
|
||||
metadata["text"] = " ".join(text).strip()
|
||||
metadata["start"] = start
|
||||
metadata["end"] = end
|
||||
|
||||
return metadata
|
||||
|
||||
def transcribe_batch(
|
||||
input_audio = "voices",
|
||||
input_voice = None,
|
||||
output_metadata = "training/metadata",
|
||||
|
@ -49,14 +178,11 @@ def transcribe(
|
|||
if input_voice is not None:
|
||||
only_speakers = [input_voice]
|
||||
|
||||
#
|
||||
model = whisperx.load_model(model_name, device, compute_type=dtype)
|
||||
"""
|
||||
align_model, align_model_metadata, align_model_language = (None, None, None)
|
||||
if diarize:
|
||||
diarize_model = whisperx.DiarizationPipeline(device=device)
|
||||
else:
|
||||
diarize_model = None
|
||||
|
||||
model =_load_model(model_name, device, compute_type=dtype)
|
||||
diarize_model = _load_diarization_model(device=device) if diarize else None
|
||||
"""
|
||||
|
||||
for dataset_name in os.listdir(f'./{input_audio}/'):
|
||||
if not os.path.isdir(f'./{input_audio}/{dataset_name}/'):
|
||||
|
@ -96,6 +222,9 @@ def transcribe(
|
|||
if os.path.isdir(inpath):
|
||||
continue
|
||||
|
||||
metadata[filename] = transcribe( inpath, model_name=model_name, diarize=diarize, device=device, dtype=dtype )
|
||||
|
||||
"""
|
||||
metadata[filename] = {
|
||||
"segments": [],
|
||||
"language": "",
|
||||
|
@ -108,15 +237,10 @@ def transcribe(
|
|||
result = model.transcribe(audio, batch_size=batch_size)
|
||||
language = result["language"]
|
||||
|
||||
"""
|
||||
if language[:2] not in ["ja"]:
|
||||
language = "en"
|
||||
"""
|
||||
|
||||
if align_model_language != language:
|
||||
tqdm.write(f'Loading language: {language}')
|
||||
align_model, align_model_metadata = whisperx.load_align_model(language_code=language, device=device)
|
||||
align_model_language = language
|
||||
align_model, align_model_metadata = _load_align_model(language=language, device=device)
|
||||
|
||||
result = whisperx.align(result["segments"], align_model, align_model_metadata, audio, device, return_char_alignments=False)
|
||||
|
||||
|
@ -138,6 +262,7 @@ def transcribe(
|
|||
metadata[filename]["text"] = " ".join(text).strip()
|
||||
metadata[filename]["start"] = start
|
||||
metadata[filename]["end"] = end
|
||||
"""
|
||||
|
||||
open(outpath, 'w', encoding='utf-8').write(json.dumps(metadata))
|
||||
|
||||
|
@ -169,7 +294,7 @@ def main():
|
|||
args.stride_offset = int(args.device)
|
||||
args.device = f'cuda:{args.device}'
|
||||
|
||||
transcribe(
|
||||
transcribe_batch(
|
||||
input_audio = args.input_audio,
|
||||
input_voice = args.input_voice,
|
||||
output_metadata = args.output_metadata,
|
||||
|
|
|
@ -110,8 +110,10 @@ def load_engines(training=True, **model_kwargs):
|
|||
scheduler_class = None
|
||||
|
||||
params = {
|
||||
"params": [ param for name, param in model.named_parameters() if name not in model.config.frozen_params ],
|
||||
"lr": cfg.hyperparameters.learning_rate,
|
||||
}
|
||||
|
||||
if cfg.hyperparameters.optimizer.lower() == "adamw":
|
||||
params["betas"] = (0.9, 0.96)
|
||||
params["eps"] = 1e-07
|
||||
|
@ -129,17 +131,30 @@ def load_engines(training=True, **model_kwargs):
|
|||
|
||||
params['d_coef'] = params['lr']
|
||||
params['lr'] = 1.0
|
||||
elif cfg.hyperparameters.optimizer.lower() in ["apollo","apollo-mini"]:
|
||||
optimizer_class = ml.Apollo
|
||||
is_mini = cfg.hyperparameters.optimizer.lower() == "apollo-mini"
|
||||
param_kwargs = {
|
||||
"rank": 1 if is_mini else 256,
|
||||
"proj": "random",
|
||||
"scale_type": "tensor" if is_mini else "channel",
|
||||
"scale": 128 if is_mini else 1,
|
||||
"update_proj_gap": 200,
|
||||
"proj_type": "std",
|
||||
}
|
||||
# grab any extra configs from the YAML
|
||||
param_kwargs.update(cfg.hyperparameters.optimizer_params)
|
||||
# and blank it so it doesn't update the main optimizer kwargs
|
||||
cfg.hyperparameters.optimizer_params = {}
|
||||
# settings are stored under params
|
||||
params["params"] = [dict(params=params["params"], **param_kwargs)]
|
||||
elif cfg.hyperparameters.optimizer.lower() == "adagrad":
|
||||
optimizer_class = ml.Adagrad
|
||||
else:
|
||||
raise ValueError(f'Optimizer specified not implemented: {cfg.hyperparameters.optimizer}')
|
||||
|
||||
params.update(cfg.hyperparameters.optimizer_params)
|
||||
|
||||
optimizer = optimizer_class(
|
||||
[ param for name, param in model.named_parameters() if name not in model.config.frozen_params ],
|
||||
**params,
|
||||
)
|
||||
optimizer = optimizer_class(**params)
|
||||
|
||||
if cfg.hyperparameters.scheduler.lower() == "schedulefree":
|
||||
if cfg.hyperparameters.optimizer.lower() == "adamw":
|
||||
|
|
|
@ -296,7 +296,7 @@ class TTS():
|
|||
use_lora=use_lora,
|
||||
)
|
||||
|
||||
with torch.autocast("cuda", dtype=dtype, enabled=amp):
|
||||
with torch.autocast(self.device, dtype=dtype, enabled=amp):
|
||||
if model_len is not None:
|
||||
# extra kwargs
|
||||
duration_padding = sampling_kwargs.pop("duration_padding", 1.05)
|
||||
|
@ -384,7 +384,7 @@ class TTS():
|
|||
resp = to_device(resp, device=self.device, dtype=torch.int16)
|
||||
lang = to_device(lang, device=self.device, dtype=torch.uint8)
|
||||
|
||||
with torch.autocast("cuda", dtype=dtype, enabled=amp):
|
||||
with torch.autocast(self.device, dtype=dtype, enabled=amp):
|
||||
model = model_ar if model_ar is not None else model_nar
|
||||
if model is not None:
|
||||
text_list = model(
|
||||
|
@ -430,7 +430,7 @@ class TTS():
|
|||
phns = to_device(phns, device=self.device, dtype=torch.uint8 if len(self.symmap) < 256 else torch.int16)
|
||||
lang = to_device(lang, device=self.device, dtype=torch.uint8)
|
||||
|
||||
with torch.autocast("cuda", dtype=dtype, enabled=amp):
|
||||
with torch.autocast(self.device, dtype=dtype, enabled=amp):
|
||||
input_kwargs = dict(
|
||||
text_list=[phns],
|
||||
proms_list=[prom],
|
||||
|
|
33
vall_e/metrics.py
Normal file
33
vall_e/metrics.py
Normal file
|
@ -0,0 +1,33 @@
|
|||
# handles objective metric calculations, such as WER and SIM-O
|
||||
|
||||
#from .emb.transcribe import transcribe
|
||||
from .emb.similar import speaker_similarity_embedding
|
||||
from .emb.transcribe import transcribe
|
||||
from .emb.g2p import detect_language
|
||||
from .data import normalize_text
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
from pathlib import Path
|
||||
from torcheval.metrics.functional import word_error_rate
|
||||
|
||||
def wer( audio, reference, language="auto", **transcription_kwargs ):
|
||||
if language == "auto":
|
||||
language = detect_language( reference )
|
||||
|
||||
transcription = transcribe( audio, language=language, align=False, **transcription_kwargs )["text"]
|
||||
|
||||
# reference audio needs transcribing too
|
||||
if isinstance( reference, Path ):
|
||||
reference = transcribe( reference, language=language, align=False, **transcription_kwargs )["text"]
|
||||
|
||||
transcription = normalize_text( transcription )
|
||||
reference = normalize_text( reference )
|
||||
|
||||
return word_error_rate([transcription], [reference]).item()
|
||||
|
||||
def sim_o( audio, reference, **kwargs ):
|
||||
audio_emb = speaker_similarity_embedding( audio, **kwargs )
|
||||
reference_emb = speaker_similarity_embedding( reference, **kwargs )
|
||||
|
||||
return F.cosine_similarity( audio_emb, reference_emb ).item()
|
|
@ -15,5 +15,6 @@ from .utils import (
|
|||
prune_missing,
|
||||
clamp,
|
||||
md5_hash,
|
||||
convert_kwargs
|
||||
convert_kwargs,
|
||||
coerce_dtype
|
||||
)
|
0
vall_e/utils/ext/__init__.py
Normal file
0
vall_e/utils/ext/__init__.py
Normal file
433
vall_e/utils/ext/apollo.py
Normal file
433
vall_e/utils/ext/apollo.py
Normal file
|
@ -0,0 +1,433 @@
|
|||
# "borrowed" with love from https://github.com/MadsToftrup/Apollo-dev/blob/main/galore_torch/apollo.py
|
||||
# to be replaced with the official implementation (https://github.com/zhuhanqing/APOLLO) maybe
|
||||
|
||||
import torch
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
from torch import nn
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from typing import Any, Callable, Dict, Generator, Iterable, Optional, Sequence, Union, Tuple
|
||||
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
class GaLoreProjector:
|
||||
def __init__(self, rank, verbose=False, update_proj_gap=200, scale=1.0, proj_type='std'):
|
||||
self.rank = rank
|
||||
self.verbose = verbose
|
||||
self.update_proj_gap = update_proj_gap
|
||||
self.scale = scale
|
||||
self.ortho_matrix = None
|
||||
self.proj_type = proj_type
|
||||
self.svd_count = 0
|
||||
|
||||
def project(self, full_rank_grad, iter):
|
||||
|
||||
if self.ortho_matrix is not None and self.ortho_matrix.device != full_rank_grad.device:
|
||||
self.ortho_matrix = self.ortho_matrix.to(full_rank_grad.device)
|
||||
|
||||
if self.proj_type == 'std':
|
||||
if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
|
||||
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
||||
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right')
|
||||
self.svd_count += 1
|
||||
low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
|
||||
else:
|
||||
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
||||
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left')
|
||||
self.svd_count += 1
|
||||
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
|
||||
elif self.proj_type == 'reverse_std':
|
||||
if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
|
||||
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
||||
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left')
|
||||
self.svd_count += 1
|
||||
low_rank_grad = torch.matmul(self.ortho_matrix.t(),full_rank_grad)
|
||||
else:
|
||||
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
||||
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right')
|
||||
self.svd_count += 1
|
||||
low_rank_grad = torch.matmul(full_rank_grad,self.ortho_matrix.t())
|
||||
elif self.proj_type == 'right':
|
||||
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
||||
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right')
|
||||
self.svd_count += 1
|
||||
low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
|
||||
elif self.proj_type == 'left':
|
||||
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
||||
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left')
|
||||
self.svd_count += 1
|
||||
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
|
||||
elif self.proj_type == 'full':
|
||||
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
||||
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='full')
|
||||
self.svd_count += 1
|
||||
low_rank_grad = torch.matmul(self.ortho_matrix[0].t(), full_rank_grad) @ self.ortho_matrix[1].t()
|
||||
|
||||
return low_rank_grad
|
||||
|
||||
def project_back(self, low_rank_grad):
|
||||
|
||||
if self.proj_type == 'std':
|
||||
if low_rank_grad.shape[0] >= low_rank_grad.shape[1]:
|
||||
full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
|
||||
else:
|
||||
full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
|
||||
elif self.proj_type == 'reverse_std':
|
||||
if low_rank_grad.shape[0] <= low_rank_grad.shape[1]: # note this is different from std
|
||||
full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
|
||||
else:
|
||||
full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
|
||||
elif self.proj_type == 'right':
|
||||
full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
|
||||
elif self.proj_type == 'left':
|
||||
full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
|
||||
elif self.proj_type == 'full':
|
||||
full_rank_grad = torch.matmul(self.ortho_matrix[0], low_rank_grad) @ self.ortho_matrix[1]
|
||||
return full_rank_grad * self.scale
|
||||
|
||||
# svd decomposition
|
||||
def get_orthogonal_matrix(self, weights, rank, type):
|
||||
module_params = weights
|
||||
|
||||
if module_params.data.dtype != torch.float:
|
||||
float_data = False
|
||||
original_type = module_params.data.dtype
|
||||
original_device = module_params.data.device
|
||||
matrix = module_params.data.float()
|
||||
else:
|
||||
float_data = True
|
||||
matrix = module_params.data
|
||||
|
||||
U, s, Vh = torch.linalg.svd(matrix, full_matrices = False)
|
||||
|
||||
#make the smaller matrix always to be orthogonal matrix
|
||||
if type=='right':
|
||||
A = U[:, :rank] @ torch.diag(s[:rank])
|
||||
B = Vh[:rank, :]
|
||||
|
||||
if not float_data:
|
||||
B = B.to(original_device).type(original_type)
|
||||
return B
|
||||
elif type=='left':
|
||||
A = U[:, :rank]
|
||||
B = torch.diag(s[:rank]) @ Vh[:rank, :]
|
||||
if not float_data:
|
||||
A = A.to(original_device).type(original_type)
|
||||
return A
|
||||
elif type=='full':
|
||||
A = U[:, :rank]
|
||||
B = Vh[:rank, :]
|
||||
if not float_data:
|
||||
A = A.to(original_device).type(original_type)
|
||||
B = B.to(original_device).type(original_type)
|
||||
return [A, B]
|
||||
else:
|
||||
raise ValueError('type should be left, right or full')
|
||||
|
||||
def stable_randn(
|
||||
shape: Union[int, Sequence[int]],
|
||||
seed: int,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
dtype: Optional[torch.dtype] = torch.float32,
|
||||
):
|
||||
if device is None:
|
||||
device = torch.device("cpu")
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
rn = torch.randn(shape, generator=generator, device=generator.device, dtype=dtype)
|
||||
return rn
|
||||
|
||||
|
||||
def next_seed(seed: int, adv: int = 0xF):
|
||||
"""
|
||||
This is a naive helper function to generate a new seed from the given seed.
|
||||
"""
|
||||
generator = torch.Generator().manual_seed(seed)
|
||||
return torch.randint(
|
||||
0, torch.iinfo(torch.int64).max, (adv,), generator=generator, device=generator.device
|
||||
).tolist()[-1]
|
||||
|
||||
|
||||
def split_seed(seed: int):
|
||||
generator = torch.Generator().manual_seed(seed)
|
||||
return tuple(
|
||||
torch.randint(0, torch.iinfo(torch.int64).max, (2,), generator=generator, device=generator.device).tolist()
|
||||
)
|
||||
|
||||
|
||||
class GradientProjector:
|
||||
def __init__(
|
||||
self, rank, update_proj_gap=200, alpha=1.0, proj_type="std", seed=0
|
||||
):
|
||||
# This is a lazy implementation as we store the projection matrix instead of re-generation every iteration
|
||||
self.rank = rank
|
||||
self.update_proj_gap = update_proj_gap
|
||||
self.alpha = alpha
|
||||
self.proj_type = proj_type
|
||||
|
||||
self.ortho_matrix = None
|
||||
self.seed = seed
|
||||
|
||||
def project(self, full_rank_grad, iter):
|
||||
|
||||
if self.proj_type == "std":
|
||||
if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
|
||||
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
||||
self.ortho_matrix = self.get_orthogonal_matrix(
|
||||
full_rank_grad, self.rank, type="right", seed=self.seed
|
||||
)
|
||||
self.seed = next_seed(self.seed)
|
||||
low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
|
||||
else:
|
||||
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
||||
self.ortho_matrix = self.get_orthogonal_matrix(
|
||||
full_rank_grad, self.rank, type="left", seed=self.seed
|
||||
)
|
||||
self.seed = next_seed(self.seed)
|
||||
|
||||
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
|
||||
elif self.proj_type == "reverse_std":
|
||||
if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
|
||||
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
||||
self.ortho_matrix = self.get_orthogonal_matrix(
|
||||
full_rank_grad, self.rank, type="left", seed=self.seed
|
||||
)
|
||||
self.seed = next_seed(self.seed)
|
||||
|
||||
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
|
||||
else:
|
||||
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
||||
self.ortho_matrix = self.get_orthogonal_matrix(
|
||||
full_rank_grad, self.rank, type="right", seed=self.seed
|
||||
)
|
||||
self.seed = next_seed(self.seed)
|
||||
low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
|
||||
elif self.proj_type == "right":
|
||||
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
||||
self.ortho_matrix = self.get_orthogonal_matrix(
|
||||
full_rank_grad, self.rank, type="right", seed=self.seed
|
||||
)
|
||||
self.seed = next_seed(self.seed)
|
||||
low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
|
||||
elif self.proj_type == "left":
|
||||
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
||||
self.ortho_matrix = self.get_orthogonal_matrix(
|
||||
full_rank_grad, self.rank, type="left", seed=self.seed
|
||||
)
|
||||
self.seed = next_seed(self.seed)
|
||||
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
|
||||
elif self.proj_type == "full":
|
||||
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
||||
self.ortho_matrix = self.get_orthogonal_matrix(
|
||||
full_rank_grad, self.rank, type="full", seed=self.seed
|
||||
)
|
||||
self.seed = next_seed(self.seed)
|
||||
low_rank_grad = (
|
||||
torch.matmul(self.ortho_matrix[0].t(), full_rank_grad)
|
||||
@ self.ortho_matrix[1].t()
|
||||
)
|
||||
|
||||
return low_rank_grad
|
||||
|
||||
# random low rank projection
|
||||
def get_orthogonal_matrix(self, weights, rank, type, seed):
|
||||
module_params = weights
|
||||
|
||||
if module_params.data.dtype != torch.float:
|
||||
float_data = False
|
||||
original_type = module_params.data.dtype
|
||||
original_device = module_params.data.device
|
||||
matrix = module_params.data.float()
|
||||
else:
|
||||
float_data = True
|
||||
matrix = module_params.data
|
||||
|
||||
if type == "left":
|
||||
proj = stable_randn(
|
||||
(matrix.shape[0], rank), seed=seed, device=matrix.device, dtype=matrix.dtype
|
||||
) / math.sqrt(rank)
|
||||
if not float_data:
|
||||
proj = proj.to(original_device).type(original_type)
|
||||
return proj
|
||||
elif type == "right":
|
||||
proj = stable_randn(
|
||||
(rank, matrix.shape[1]), seed=seed, device=matrix.device, dtype=matrix.dtype
|
||||
) / math.sqrt(rank)
|
||||
if not float_data:
|
||||
proj = proj.to(original_device).type(original_type)
|
||||
return proj
|
||||
elif type == "full":
|
||||
raise NotImplementedError("full rank projection is not implemented yet")
|
||||
else:
|
||||
raise ValueError("type should be left, right or full")
|
||||
|
||||
class Apollo(Optimizer):
|
||||
"""
|
||||
Implements Adam algorithm with weight decay fix as introduced in [Decoupled Weight Decay
|
||||
Regularization](https://arxiv.org/abs/1711.05101).
|
||||
|
||||
Parameters:
|
||||
params (`Iterable[nn.parameter.Parameter]`):
|
||||
Iterable of parameters to optimize or dictionaries defining parameter groups.
|
||||
lr (`float`, *optional*, defaults to 0.001):
|
||||
The learning rate to use.
|
||||
betas (`Tuple[float,float]`, *optional*, defaults to `(0.9, 0.999)`):
|
||||
Adam's betas parameters (b1, b2).
|
||||
eps (`float`, *optional*, defaults to 1e-06):
|
||||
Adam's epsilon for numerical stability.
|
||||
weight_decay (`float`, *optional*, defaults to 0.0):
|
||||
Decoupled weight decay to apply.
|
||||
correct_bias (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`).
|
||||
no_deprecation_warning (`bool`, *optional*, defaults to `False`):
|
||||
A flag used to disable the deprecation warning (set to `True` to disable the warning).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params: Iterable[nn.parameter.Parameter],
|
||||
lr: float = 1e-3,
|
||||
betas: Tuple[float, float] = (0.9, 0.999),
|
||||
eps: float = 1e-6,
|
||||
weight_decay: float = 0.0,
|
||||
correct_bias: bool = True,
|
||||
scale_front: bool = False,
|
||||
):
|
||||
if lr < 0.0:
|
||||
raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0")
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)")
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)")
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0")
|
||||
defaults = {"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "correct_bias": correct_bias}
|
||||
super().__init__(params, defaults)
|
||||
|
||||
self.scale_front = scale_front
|
||||
|
||||
params_idx = 0
|
||||
for group in self.param_groups:
|
||||
for p in group["params"]:
|
||||
params_idx += 1
|
||||
if p.requires_grad:
|
||||
self.state[p]["seed"] = params_idx
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure: Callable = None):
|
||||
"""
|
||||
Performs a single optimization step.
|
||||
|
||||
Arguments:
|
||||
closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
if "step" not in state:
|
||||
state["step"] = 0
|
||||
|
||||
# GaLore Projection
|
||||
if "rank" in group:
|
||||
if "projector" not in state:
|
||||
if group["proj"] == "random":
|
||||
state["projector"] = GradientProjector(group["rank"],
|
||||
update_proj_gap=group["update_proj_gap"],
|
||||
alpha=group["scale"],
|
||||
proj_type=group["proj_type"],
|
||||
seed=state["seed"])
|
||||
|
||||
elif group["proj"] == "svd":
|
||||
state["projector"] = GaLoreProjector(group["rank"],
|
||||
update_proj_gap=group["update_proj_gap"],
|
||||
scale=group["scale"],
|
||||
proj_type=group["proj_type"])
|
||||
|
||||
grad = state["projector"].project(grad, state["step"])
|
||||
|
||||
# State initialization
|
||||
if "exp_avg" not in state:
|
||||
# Exponential moving average of gradient values
|
||||
state["exp_avg"] = torch.zeros_like(grad)
|
||||
# Exponential moving average of squared gradient values
|
||||
state["exp_avg_sq"] = torch.zeros_like(grad)
|
||||
|
||||
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
||||
beta1, beta2 = group["betas"]
|
||||
|
||||
state["step"] += 1
|
||||
|
||||
# Decay the first and second moment running average coefficient
|
||||
# In-place operations to update the averages at the same time
|
||||
exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1))
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
|
||||
denom = exp_avg_sq.sqrt().add_(group["eps"])
|
||||
|
||||
step_size = group["lr"]
|
||||
if group["correct_bias"]: # No bias correction for Bert
|
||||
bias_correction1 = 1.0 - beta1 ** state["step"]
|
||||
bias_correction2 = 1.0 - beta2 ** state["step"]
|
||||
step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
|
||||
|
||||
# compute norm gradient
|
||||
norm_grad = exp_avg / denom
|
||||
|
||||
if "rank" in group:
|
||||
if group['scale_type'] == 'channel':
|
||||
norm_dim = 0 if norm_grad.shape[0] < norm_grad.shape[1] else 1
|
||||
scaling_factor = (
|
||||
torch.norm(norm_grad, dim=norm_dim) /
|
||||
(torch.norm(grad, dim=norm_dim) + 1e-8)
|
||||
)
|
||||
if norm_dim == 1:
|
||||
scaling_factor = scaling_factor.unsqueeze(1)
|
||||
|
||||
elif group['scale_type'] == 'tensor':
|
||||
scaling_factor = (
|
||||
torch.norm(norm_grad) /
|
||||
(torch.norm(grad) + 1e-8)
|
||||
)
|
||||
|
||||
scaling_grad = p.grad * scaling_factor
|
||||
|
||||
# Use Norm-Growth Limiter in Fira
|
||||
if "scaling_grad" in state:
|
||||
scaling_grad_norm = torch.norm(scaling_grad)
|
||||
limiter = max(
|
||||
scaling_grad_norm /
|
||||
(state["scaling_grad"] + 1e-8),
|
||||
1.01,
|
||||
) / 1.01
|
||||
scaling_grad = scaling_grad / limiter
|
||||
state["scaling_grad"] = scaling_grad_norm / limiter
|
||||
else:
|
||||
state["scaling_grad"] = torch.norm(scaling_grad)
|
||||
|
||||
norm_grad = scaling_grad * np.sqrt(group["scale"])
|
||||
|
||||
p.add_(norm_grad, alpha=-step_size)
|
||||
|
||||
# Just adding the square of the weights to the loss function is *not*
|
||||
# the correct way of using L2 regularization/weight decay with Adam,
|
||||
# since that will interact with the m and v parameters in strange ways.
|
||||
#
|
||||
# Instead we want to decay the weights in a manner that doesn't interact
|
||||
# with the m/v parameters. This is equivalent to adding the square
|
||||
# of the weights to the loss with plain (non-momentum) SGD.
|
||||
# Add weight decay at the end (fixed version)
|
||||
if group["weight_decay"] > 0.0:
|
||||
p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))
|
||||
|
||||
return loss
|
467
vall_e/utils/ext/ecapa_tdnn.py
Normal file
467
vall_e/utils/ext/ecapa_tdnn.py
Normal file
|
@ -0,0 +1,467 @@
|
|||
# borrowed with love from "https://github.com/keonlee9420/evaluate-zero-shot-tts/blob/master/src/evaluate_zero_shot_tts/utils/speaker_verification/models/ecapa_tdnn.py"
|
||||
# part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchaudio.transforms as trans
|
||||
|
||||
#from .utils import UpstreamExpert
|
||||
|
||||
""" Res2Conv1d + BatchNorm1d + ReLU
|
||||
"""
|
||||
|
||||
|
||||
class Res2Conv1dReluBn(nn.Module):
|
||||
"""
|
||||
in_channels == out_channels == channels
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
bias=True,
|
||||
scale=4,
|
||||
):
|
||||
super().__init__()
|
||||
assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
|
||||
self.scale = scale
|
||||
self.width = channels // scale
|
||||
self.nums = scale if scale == 1 else scale - 1
|
||||
|
||||
self.convs = []
|
||||
self.bns = []
|
||||
for i in range(self.nums):
|
||||
self.convs.append(
|
||||
nn.Conv1d(
|
||||
self.width,
|
||||
self.width,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
bias=bias,
|
||||
)
|
||||
)
|
||||
self.bns.append(nn.BatchNorm1d(self.width))
|
||||
self.convs = nn.ModuleList(self.convs)
|
||||
self.bns = nn.ModuleList(self.bns)
|
||||
|
||||
def forward(self, x):
|
||||
out = []
|
||||
spx = torch.split(x, self.width, 1)
|
||||
for i in range(self.nums):
|
||||
if i == 0:
|
||||
sp = spx[i]
|
||||
else:
|
||||
sp = sp + spx[i]
|
||||
# Order: conv -> relu -> bn
|
||||
sp = self.convs[i](sp)
|
||||
sp = self.bns[i](F.relu(sp))
|
||||
out.append(sp)
|
||||
if self.scale != 1:
|
||||
out.append(spx[self.nums])
|
||||
out = torch.cat(out, dim=1)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
""" Conv1d + BatchNorm1d + ReLU
|
||||
"""
|
||||
|
||||
|
||||
class Conv1dReluBn(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
bias=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv1d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
bias=bias,
|
||||
)
|
||||
self.bn = nn.BatchNorm1d(out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
return self.bn(F.relu(self.conv(x)))
|
||||
|
||||
|
||||
""" The SE connection of 1D case.
|
||||
"""
|
||||
|
||||
|
||||
class SE_Connect(nn.Module):
|
||||
def __init__(self, channels, se_bottleneck_dim=128):
|
||||
super().__init__()
|
||||
self.linear1 = nn.Linear(channels, se_bottleneck_dim)
|
||||
self.linear2 = nn.Linear(se_bottleneck_dim, channels)
|
||||
|
||||
def forward(self, x):
|
||||
out = x.mean(dim=2)
|
||||
out = F.relu(self.linear1(out))
|
||||
out = torch.sigmoid(self.linear2(out))
|
||||
out = x * out.unsqueeze(2)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
""" SE-Res2Block of the ECAPA-TDNN architecture.
|
||||
"""
|
||||
|
||||
|
||||
# def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
|
||||
# return nn.Sequential(
|
||||
# Conv1dReluBn(channels, 512, kernel_size=1, stride=1, padding=0),
|
||||
# Res2Conv1dReluBn(512, kernel_size, stride, padding, dilation, scale=scale),
|
||||
# Conv1dReluBn(512, channels, kernel_size=1, stride=1, padding=0),
|
||||
# SE_Connect(channels)
|
||||
# )
|
||||
|
||||
|
||||
class SE_Res2Block(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
scale,
|
||||
se_bottleneck_dim,
|
||||
):
|
||||
super().__init__()
|
||||
self.Conv1dReluBn1 = Conv1dReluBn(
|
||||
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.Res2Conv1dReluBn = Res2Conv1dReluBn(
|
||||
out_channels, kernel_size, stride, padding, dilation, scale=scale
|
||||
)
|
||||
self.Conv1dReluBn2 = Conv1dReluBn(
|
||||
out_channels, out_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim)
|
||||
|
||||
self.shortcut = None
|
||||
if in_channels != out_channels:
|
||||
self.shortcut = nn.Conv1d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
if self.shortcut:
|
||||
residual = self.shortcut(x)
|
||||
|
||||
x = self.Conv1dReluBn1(x)
|
||||
x = self.Res2Conv1dReluBn(x)
|
||||
x = self.Conv1dReluBn2(x)
|
||||
x = self.SE_Connect(x)
|
||||
|
||||
return x + residual
|
||||
|
||||
|
||||
""" Attentive weighted mean and standard deviation pooling.
|
||||
"""
|
||||
|
||||
|
||||
class AttentiveStatsPool(nn.Module):
|
||||
def __init__(
|
||||
self, in_dim, attention_channels=128, global_context_att=False
|
||||
):
|
||||
super().__init__()
|
||||
self.global_context_att = global_context_att
|
||||
|
||||
# Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
|
||||
if global_context_att:
|
||||
self.linear1 = nn.Conv1d(
|
||||
in_dim * 3, attention_channels, kernel_size=1
|
||||
) # equals W and b in the paper
|
||||
else:
|
||||
self.linear1 = nn.Conv1d(
|
||||
in_dim, attention_channels, kernel_size=1
|
||||
) # equals W and b in the paper
|
||||
self.linear2 = nn.Conv1d(
|
||||
attention_channels, in_dim, kernel_size=1
|
||||
) # equals V and k in the paper
|
||||
|
||||
def forward(self, x):
|
||||
if self.global_context_att:
|
||||
context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
|
||||
context_std = torch.sqrt(
|
||||
torch.var(x, dim=-1, keepdim=True) + 1e-10
|
||||
).expand_as(x)
|
||||
x_in = torch.cat((x, context_mean, context_std), dim=1)
|
||||
else:
|
||||
x_in = x
|
||||
|
||||
# DON'T use ReLU here! In experiments, I find ReLU hard to converge.
|
||||
alpha = torch.tanh(self.linear1(x_in))
|
||||
# alpha = F.relu(self.linear1(x_in))
|
||||
alpha = torch.softmax(self.linear2(alpha), dim=2)
|
||||
mean = torch.sum(alpha * x, dim=2)
|
||||
residuals = torch.sum(alpha * (x**2), dim=2) - mean**2
|
||||
std = torch.sqrt(residuals.clamp(min=1e-9))
|
||||
return torch.cat([mean, std], dim=1)
|
||||
|
||||
|
||||
class ECAPA_TDNN(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
feat_dim=80,
|
||||
channels=512,
|
||||
emb_dim=192,
|
||||
global_context_att=False,
|
||||
feat_type="fbank",
|
||||
sr=16000,
|
||||
feature_selection="hidden_states",
|
||||
update_extract=False,
|
||||
config_path=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.feat_type = feat_type
|
||||
self.feature_selection = feature_selection
|
||||
self.update_extract = update_extract
|
||||
self.sr = sr
|
||||
|
||||
if feat_type == "fbank" or feat_type == "mfcc":
|
||||
self.update_extract = False
|
||||
|
||||
win_len = int(sr * 0.025)
|
||||
hop_len = int(sr * 0.01)
|
||||
|
||||
if feat_type == "fbank":
|
||||
self.feature_extract = trans.MelSpectrogram(
|
||||
sample_rate=sr,
|
||||
n_fft=512,
|
||||
win_length=win_len,
|
||||
hop_length=hop_len,
|
||||
f_min=0.0,
|
||||
f_max=sr // 2,
|
||||
pad=0,
|
||||
n_mels=feat_dim,
|
||||
)
|
||||
elif feat_type == "mfcc":
|
||||
melkwargs = {
|
||||
"n_fft": 512,
|
||||
"win_length": win_len,
|
||||
"hop_length": hop_len,
|
||||
"f_min": 0.0,
|
||||
"f_max": sr // 2,
|
||||
"pad": 0,
|
||||
}
|
||||
self.feature_extract = trans.MFCC(
|
||||
sample_rate=sr,
|
||||
n_mfcc=feat_dim,
|
||||
log_mels=False,
|
||||
melkwargs=melkwargs,
|
||||
)
|
||||
else:
|
||||
"""
|
||||
if config_path is None:
|
||||
self.feature_extract = torch.hub.load("s3prl/s3prl", feat_type)
|
||||
else:
|
||||
self.feature_extract = UpstreamExpert(config_path)
|
||||
"""
|
||||
self.feature_extract = torch.hub.load("s3prl/s3prl", feat_type)
|
||||
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
|
||||
self.feature_extract.model.encoder.layers[23].self_attn,
|
||||
"fp32_attention",
|
||||
):
|
||||
self.feature_extract.model.encoder.layers[
|
||||
23
|
||||
].self_attn.fp32_attention = False
|
||||
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
|
||||
self.feature_extract.model.encoder.layers[11].self_attn,
|
||||
"fp32_attention",
|
||||
):
|
||||
self.feature_extract.model.encoder.layers[
|
||||
11
|
||||
].self_attn.fp32_attention = False
|
||||
|
||||
self.feat_num = self.get_feat_num()
|
||||
self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
|
||||
|
||||
if feat_type != "fbank" and feat_type != "mfcc":
|
||||
freeze_list = [
|
||||
"final_proj",
|
||||
"label_embs_concat",
|
||||
"mask_emb",
|
||||
"project_q",
|
||||
"quantizer",
|
||||
]
|
||||
for name, param in self.feature_extract.named_parameters():
|
||||
for freeze_val in freeze_list:
|
||||
if freeze_val in name:
|
||||
param.requires_grad = False
|
||||
break
|
||||
|
||||
if not self.update_extract:
|
||||
for param in self.feature_extract.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
self.instance_norm = nn.InstanceNorm1d(feat_dim)
|
||||
# self.channels = [channels] * 4 + [channels * 3]
|
||||
self.channels = [channels] * 4 + [1536]
|
||||
|
||||
self.layer1 = Conv1dReluBn(
|
||||
feat_dim, self.channels[0], kernel_size=5, padding=2
|
||||
)
|
||||
self.layer2 = SE_Res2Block(
|
||||
self.channels[0],
|
||||
self.channels[1],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=2,
|
||||
dilation=2,
|
||||
scale=8,
|
||||
se_bottleneck_dim=128,
|
||||
)
|
||||
self.layer3 = SE_Res2Block(
|
||||
self.channels[1],
|
||||
self.channels[2],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=3,
|
||||
dilation=3,
|
||||
scale=8,
|
||||
se_bottleneck_dim=128,
|
||||
)
|
||||
self.layer4 = SE_Res2Block(
|
||||
self.channels[2],
|
||||
self.channels[3],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=4,
|
||||
dilation=4,
|
||||
scale=8,
|
||||
se_bottleneck_dim=128,
|
||||
)
|
||||
|
||||
# self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
|
||||
cat_channels = channels * 3
|
||||
self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
|
||||
self.pooling = AttentiveStatsPool(
|
||||
self.channels[-1],
|
||||
attention_channels=128,
|
||||
global_context_att=global_context_att,
|
||||
)
|
||||
self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
|
||||
self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
|
||||
|
||||
def get_feat_num(self):
|
||||
self.feature_extract.eval()
|
||||
wav = [
|
||||
torch.randn(self.sr).to(
|
||||
next(self.feature_extract.parameters()).device
|
||||
)
|
||||
]
|
||||
with torch.no_grad():
|
||||
features = self.feature_extract(wav)
|
||||
select_feature = features[self.feature_selection]
|
||||
if isinstance(select_feature, (list, tuple)):
|
||||
return len(select_feature)
|
||||
else:
|
||||
return 1
|
||||
|
||||
def get_feat(self, x):
|
||||
if self.update_extract:
|
||||
x = self.feature_extract([sample for sample in x])
|
||||
else:
|
||||
with torch.no_grad():
|
||||
if self.feat_type == "fbank" or self.feat_type == "mfcc":
|
||||
x = (
|
||||
self.feature_extract(x) + 1e-6
|
||||
) # B x feat_dim x time_len
|
||||
else:
|
||||
x = self.feature_extract([sample for sample in x])
|
||||
|
||||
if self.feat_type == "fbank":
|
||||
x = x.log()
|
||||
|
||||
if self.feat_type != "fbank" and self.feat_type != "mfcc":
|
||||
x = x[self.feature_selection]
|
||||
if isinstance(x, (list, tuple)):
|
||||
x = torch.stack(x, dim=0)
|
||||
else:
|
||||
x = x.unsqueeze(0)
|
||||
norm_weights = (
|
||||
F.softmax(self.feature_weight, dim=-1)
|
||||
.unsqueeze(-1)
|
||||
.unsqueeze(-1)
|
||||
.unsqueeze(-1)
|
||||
)
|
||||
x = (norm_weights * x).sum(dim=0)
|
||||
x = torch.transpose(x, 1, 2) + 1e-6
|
||||
|
||||
x = self.instance_norm(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self.get_feat(x)
|
||||
|
||||
out1 = self.layer1(x)
|
||||
out2 = self.layer2(out1)
|
||||
out3 = self.layer3(out2)
|
||||
out4 = self.layer4(out3)
|
||||
|
||||
out = torch.cat([out2, out3, out4], dim=1)
|
||||
out = F.relu(self.conv(out))
|
||||
out = self.bn(self.pooling(out))
|
||||
out = self.linear(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def ECAPA_TDNN_SMALL(
|
||||
feat_dim,
|
||||
emb_dim=256,
|
||||
feat_type="fbank",
|
||||
sr=16000,
|
||||
feature_selection="hidden_states",
|
||||
update_extract=False,
|
||||
config_path=None,
|
||||
):
|
||||
return ECAPA_TDNN(
|
||||
feat_dim=feat_dim,
|
||||
channels=512,
|
||||
emb_dim=emb_dim,
|
||||
feat_type=feat_type,
|
||||
sr=sr,
|
||||
feature_selection=feature_selection,
|
||||
update_extract=update_extract,
|
||||
config_path=config_path,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
x = torch.zeros(2, 32000)
|
||||
model = ECAPA_TDNN_SMALL(
|
||||
feat_dim=768,
|
||||
emb_dim=256,
|
||||
feat_type="hubert_base",
|
||||
feature_selection="hidden_states",
|
||||
update_extract=False,
|
||||
)
|
||||
|
||||
out = model(x)
|
||||
# print(model)
|
||||
print(out.shape)
|
|
@ -124,6 +124,20 @@ def _get_named_modules(module, attrname):
|
|||
if hasattr(module, attrname):
|
||||
yield name, module
|
||||
|
||||
def coerce_dtype(s):
|
||||
# not a string
|
||||
if not isinstance(s, str):
|
||||
return s
|
||||
|
||||
if s == "float16":
|
||||
return torch.float16
|
||||
if s == "bfloat16":
|
||||
return torch.bfloat16
|
||||
if s == "float8_e5m2":
|
||||
return torch.float8_e5m2
|
||||
if s == "float8_e4m3fn":
|
||||
return torch.float8_e4m3fn
|
||||
return torch.float32
|
||||
|
||||
def gather_attribute(module, attrname, delete=True, prefix=True):
|
||||
ret = {}
|
||||
|
|
|
@ -103,12 +103,18 @@ if cfg.optimizations.tensorrt:
|
|||
|
||||
if cfg.optimizations.unsloth:
|
||||
try:
|
||||
from .unsloth import apply_unsloth_offloaded_gradient_checkpoint_monkey_patch
|
||||
from .ext.unsloth import apply_unsloth_offloaded_gradient_checkpoint_monkey_patch
|
||||
#apply_unsloth_offloaded_gradient_checkpoint_monkey_patch()
|
||||
except Exception as e:
|
||||
_logger.warning(f'Error while importing Unsloth: {str(e)}')
|
||||
pass
|
||||
|
||||
try:
|
||||
from .ext.apollo import Apollo
|
||||
except Exception as e:
|
||||
_logger.warning(f'Error while importing APOLLO: {str(e)}')
|
||||
pass
|
||||
|
||||
def compile_model(model, backend="auto"):
|
||||
if not backend or backend == "auto":
|
||||
backend = AVAILABLE_COMPILE_BACKENDS[0]
|
||||
|
|
Loading…
Reference in New Issue
Block a user