added WER/SIM-O metrics, added APOLLO but I need to test it
This commit is contained in:
parent
fc5e6d8599
commit
8568a93dad
|
@ -10,6 +10,8 @@
|
||||||
<caption>LibriSpeech</caption>
|
<caption>LibriSpeech</caption>
|
||||||
<tr>
|
<tr>
|
||||||
<th>Text</th>
|
<th>Text</th>
|
||||||
|
<th>WER↓</th>
|
||||||
|
<th>SIM-O↑</th>
|
||||||
<th>Prompt</th>
|
<th>Prompt</th>
|
||||||
<th>Our VALL-E</th>
|
<th>Our VALL-E</th>
|
||||||
<th>Original VALL-E</th>
|
<th>Original VALL-E</th>
|
||||||
|
@ -24,6 +26,8 @@
|
||||||
<caption>Sampled Dataset</caption>
|
<caption>Sampled Dataset</caption>
|
||||||
<tr>
|
<tr>
|
||||||
<th>Text</th>
|
<th>Text</th>
|
||||||
|
<th>WER↓</th>
|
||||||
|
<th>SIM-O↑</th>
|
||||||
<th>Prompt</th>
|
<th>Prompt</th>
|
||||||
<th>Our VALL-E</th>
|
<th>Our VALL-E</th>
|
||||||
<th>F5-TTS</th>
|
<th>F5-TTS</th>
|
||||||
|
|
18
docs/emb.md
18
docs/emb.md
|
@ -77,7 +77,7 @@ I'm uncertain on how to remedy this, as my options are:
|
||||||
|
|
||||||
## `transcribe.py`
|
## `transcribe.py`
|
||||||
|
|
||||||
This script handles taking raw input audio, and outputting adequate metadata containing transcriptions of said audio through `whisperX`.
|
This script primarily handles taking raw input audio, and outputting adequate metadata containing transcriptions of said audio through `whisperX`.
|
||||||
|
|
||||||
The process maintains slices `whisperX` thinks its best per the segments outputted, alongside the deduced language (if not specified).
|
The process maintains slices `whisperX` thinks its best per the segments outputted, alongside the deduced language (if not specified).
|
||||||
|
|
||||||
|
@ -85,6 +85,18 @@ One limiting factor is that transcription transcribes into normal text, rather t
|
||||||
|
|
||||||
Refer to the `__main__`'s arguments for usage details.
|
Refer to the `__main__`'s arguments for usage details.
|
||||||
|
|
||||||
|
### Metrics
|
||||||
|
|
||||||
|
This script also handles calculating `WER` simply by transcribing the given audio file (and reference, if requested), then comparing the word error rate.
|
||||||
|
|
||||||
|
This process *heavily* relies on text normalization, which currently is lacking, but transcribing the reference should keep things "normalized" per the transcriber.
|
||||||
|
|
||||||
|
### ROCm
|
||||||
|
|
||||||
|
Because life is pain, ROCm requires additional steps to ensure that `whisperX` works. A special fork of `CTranslate2` is required, but simplying following [these](https://github.com/arlo-phoenix/CTranslate2-rocm/blob/rocm/README_ROCM.md) steps should fix things.
|
||||||
|
|
||||||
|
In the future, I would love to replace WhisperX for something simple.
|
||||||
|
|
||||||
## `process.py`
|
## `process.py`
|
||||||
|
|
||||||
This script handles taking raw input audio and its transcribed metadata, and outputs encoded audio (NumPy) files containing encoded audio and associated metadata.
|
This script handles taking raw input audio and its transcribed metadata, and outputs encoded audio (NumPy) files containing encoded audio and associated metadata.
|
||||||
|
@ -108,3 +120,7 @@ When processing a dataset, this requires already having accompanying metadata ge
|
||||||
Be *very* careful if you opt to output unsegmented and segmented utterances, as the sliced version may end up amongst the top-K similar candidates.
|
Be *very* careful if you opt to output unsegmented and segmented utterances, as the sliced version may end up amongst the top-K similar candidates.
|
||||||
|
|
||||||
Refer to the `__main__`'s arguments for usage details.
|
Refer to the `__main__`'s arguments for usage details.
|
||||||
|
|
||||||
|
### Metrics
|
||||||
|
|
||||||
|
This script also handles calculating `SIM-O` per [keonlee9420/evaluate-zero-shot-tts](https://github.com/keonlee9420/evaluate-zero-shot-tts/blob/master/src/evaluate_zero_shot_tts/utils/speaker_verification/verification.py), by making use of a model to create an embedding of a speaker, then computing cosine similarities on those embeddings.
|
3
setup.py
3
setup.py
|
@ -91,6 +91,9 @@ setup(
|
||||||
"causal-conv1d",
|
"causal-conv1d",
|
||||||
"mamba-ssm",
|
"mamba-ssm",
|
||||||
|
|
||||||
|
#
|
||||||
|
"torcheval",
|
||||||
|
|
||||||
# attention helpers
|
# attention helpers
|
||||||
"xformers",
|
"xformers",
|
||||||
"sageattention==1.0.6",
|
"sageattention==1.0.6",
|
||||||
|
|
|
@ -22,7 +22,7 @@ from pathlib import Path
|
||||||
|
|
||||||
from .utils.distributed import world_size
|
from .utils.distributed import world_size
|
||||||
from .utils.io import torch_load
|
from .utils.io import torch_load
|
||||||
from .utils import set_seed, prune_missing, md5_hash
|
from .utils import set_seed, prune_missing, md5_hash, coerce_dtype
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class BaseConfig:
|
class BaseConfig:
|
||||||
|
@ -721,15 +721,7 @@ class Trainer:
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def dtype(self):
|
def dtype(self):
|
||||||
if self.weight_dtype == "float16":
|
return coerce_dtype(self.weight_dtype)
|
||||||
return torch.float16
|
|
||||||
if self.weight_dtype == "bfloat16":
|
|
||||||
return torch.bfloat16
|
|
||||||
if self.weight_dtype == "float8_e5m2":
|
|
||||||
return torch.float8_e5m2
|
|
||||||
if self.weight_dtype == "float8_e4m3fn":
|
|
||||||
return torch.float8_e4m3fn
|
|
||||||
return torch.float32
|
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def scale_loss(self):
|
def scale_loss(self):
|
||||||
|
@ -748,17 +740,7 @@ class Inference:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self):
|
def dtype(self):
|
||||||
if self.weight_dtype == "float16":
|
return coerce_dtype(self.weight_dtype)
|
||||||
return torch.float16
|
|
||||||
if self.weight_dtype == "bfloat16":
|
|
||||||
return torch.bfloat16
|
|
||||||
if self.weight_dtype == "int8":
|
|
||||||
return torch.int8
|
|
||||||
if self.weight_dtype == "float8_e5m2":
|
|
||||||
return torch.float8_e5m2
|
|
||||||
if self.weight_dtype == "float8_e4m3fn":
|
|
||||||
return torch.float8_e4m3fn
|
|
||||||
return torch.float32
|
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class Optimizations:
|
class Optimizations:
|
||||||
|
|
|
@ -63,6 +63,13 @@ def sentence_split( s, split_by="sentences", quote_placeholder="<QUOTE>" ):
|
||||||
sentences = nltk.sent_tokenize(s)
|
sentences = nltk.sent_tokenize(s)
|
||||||
return [ sentence.replace(quote_placeholder, '"') for sentence in sentences if sentence ]
|
return [ sentence.replace(quote_placeholder, '"') for sentence in sentences if sentence ]
|
||||||
|
|
||||||
|
# to-do: improve upon this since it's kind of ass
|
||||||
|
# this might be better to live in emb.g2p
|
||||||
|
def normalize_text( s ):
|
||||||
|
s = s.lower()
|
||||||
|
s = re.sub(r'[^\w\s]', '', s)
|
||||||
|
return s
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
def get_random_prompts( validation=False, min_length=0, tokenized=False ):
|
def get_random_prompts( validation=False, min_length=0, tokenized=False ):
|
||||||
duration_range = [ 5.5, 12.0 ] # to-do: pull from cfg.dataset.duration_range
|
duration_range = [ 5.5, 12.0 ] # to-do: pull from cfg.dataset.duration_range
|
||||||
|
@ -1070,7 +1077,9 @@ class Dataset(_Dataset):
|
||||||
return root / name
|
return root / name
|
||||||
|
|
||||||
def sample_prompts(self, spkr_name, reference, should_trim=True):
|
def sample_prompts(self, spkr_name, reference, should_trim=True):
|
||||||
if not cfg.dataset.prompt_duration_range or cfg.dataset.prompt_duration_range[-1] == 0:
|
# return no prompt if explicitly requested for who knows why
|
||||||
|
# or if there's no other speakers to sample from (Emilia has a lot of singleton speakers, but I still want to make use of them)
|
||||||
|
if not cfg.dataset.prompt_duration_range or cfg.dataset.prompt_duration_range[-1] == 0 or len(self.paths_by_spkr_name[key]) <= 1:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
prom_list = []
|
prom_list = []
|
||||||
|
|
|
@ -20,6 +20,7 @@ import base64
|
||||||
import random
|
import random
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
import torch
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -29,6 +30,8 @@ from .inference import TTS
|
||||||
from .config import cfg
|
from .config import cfg
|
||||||
from .data import create_train_dataloader, create_val_dataloader, get_random_prompt
|
from .data import create_train_dataloader, create_val_dataloader, get_random_prompt
|
||||||
from .emb.qnt import decode_to_file
|
from .emb.qnt import decode_to_file
|
||||||
|
from .metrics import wer, sim_o
|
||||||
|
from .utils import setup_logging
|
||||||
|
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
|
@ -230,6 +233,8 @@ def main():
|
||||||
elif args.comparison:
|
elif args.comparison:
|
||||||
raise Exception(f"Unrecognized comparison flag: {args.comparison}")
|
raise Exception(f"Unrecognized comparison flag: {args.comparison}")
|
||||||
|
|
||||||
|
setup_logging()
|
||||||
|
|
||||||
# read html template
|
# read html template
|
||||||
html = open(args.demo_dir / "index.template.html", "r", encoding="utf-8").read()
|
html = open(args.demo_dir / "index.template.html", "r", encoding="utf-8").read()
|
||||||
|
|
||||||
|
@ -318,6 +323,7 @@ def main():
|
||||||
|
|
||||||
inputs = []
|
inputs = []
|
||||||
outputs = []
|
outputs = []
|
||||||
|
metrics_inputs = []
|
||||||
comparison_inputs = []
|
comparison_inputs = []
|
||||||
for k, sample_dir in samples_dirs.items():
|
for k, sample_dir in samples_dirs.items():
|
||||||
if not sample_dir.exists():
|
if not sample_dir.exists():
|
||||||
|
@ -359,9 +365,15 @@ def main():
|
||||||
|
|
||||||
# segregate comparisons into its own batch because they use different kwargs (and I do not support variadic-batched kwargs)
|
# segregate comparisons into its own batch because they use different kwargs (and I do not support variadic-batched kwargs)
|
||||||
if args.comparison:
|
if args.comparison:
|
||||||
comparison_inputs.append((text, prompt, language, out_path_comparison))
|
if (args.skip_existing and not out_path_comparison.exists()) or not (args.skip_existing):
|
||||||
|
comparison_inputs.append((text, prompt, language, out_path_comparison))
|
||||||
|
|
||||||
inputs.append((text, prompt, language, out_path))
|
metrics_inputs.append((text, language, out_path_comparison, reference))
|
||||||
|
|
||||||
|
if (args.skip_existing and not out_path.exists()) or not (args.skip_existing):
|
||||||
|
inputs.append((text, prompt, language, out_path))
|
||||||
|
|
||||||
|
metrics_inputs.append((text, language, out_path, reference))
|
||||||
|
|
||||||
outputs.append((k, samples))
|
outputs.append((k, samples))
|
||||||
|
|
||||||
|
@ -371,10 +383,19 @@ def main():
|
||||||
if comparison_inputs:
|
if comparison_inputs:
|
||||||
process_batch( tts, comparison_inputs, sampling_kwargs | (comparison_kwargs["enabled"] if args.comparison else {}) )
|
process_batch( tts, comparison_inputs, sampling_kwargs | (comparison_kwargs["enabled"] if args.comparison else {}) )
|
||||||
|
|
||||||
|
metrics_map = {}
|
||||||
|
for text, language, out_path, reference_path in metrics_inputs:
|
||||||
|
wer_score = wer( out_path, text, language=language, device=tts.device, dtype=tts.dtype, model_name="base" )
|
||||||
|
sim_o_score = sim_o( out_path, reference_path, device=tts.device, dtype=tts.dtype )
|
||||||
|
metrics_map[out_path] = (wer_score, sim_o_score)
|
||||||
|
|
||||||
# collate entries into HTML
|
# collate entries into HTML
|
||||||
for k, samples in outputs:
|
for k, samples in outputs:
|
||||||
samples = [
|
samples = [
|
||||||
f'\n\t\t\t<tr>\n\t\t\t\t<td>{text}</td>'+
|
f'\n\t\t\t<tr>\n\t\t\t\t<td>{text}</td>'+
|
||||||
|
"".join([
|
||||||
|
f'\n\t\t\t\t<td>{metrics_map[audios[1]][0]:.3f}</td><td>{metrics_map[audios[1]][1]:.3f}</td>'
|
||||||
|
] ) +
|
||||||
"".join( [
|
"".join( [
|
||||||
f'\n\t\t\t\t<td><audio controls="controls" preload="none"><source src="{str(audio).replace(str(args.demo_dir), args.audio_path_root) if args.audio_path_root else encode(audio)}"/></audio></td>'
|
f'\n\t\t\t\t<td><audio controls="controls" preload="none"><source src="{str(audio).replace(str(args.demo_dir), args.audio_path_root) if args.audio_path_root else encode(audio)}"/></audio></td>'
|
||||||
for audio in audios
|
for audio in audios
|
||||||
|
|
|
@ -16,12 +16,13 @@ _logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from functools import cache
|
||||||
|
|
||||||
import torchaudio.functional as F
|
import torchaudio.functional as F
|
||||||
import torchaudio.transforms as T
|
import torchaudio.transforms as T
|
||||||
|
|
||||||
from ..config import cfg
|
from ..config import cfg
|
||||||
from ..utils import truncate_json
|
from ..utils import truncate_json, coerce_dtype
|
||||||
from ..utils.io import json_read, json_write
|
from ..utils.io import json_read, json_write
|
||||||
|
|
||||||
from .g2p import encode as phonemize
|
from .g2p import encode as phonemize
|
||||||
|
@ -29,19 +30,49 @@ from .qnt import encode as quantize, trim, convert_audio
|
||||||
|
|
||||||
from ..webui import init_tts
|
from ..webui import init_tts
|
||||||
|
|
||||||
def load_audio( path ):
|
def load_audio( path, target_sr=None ):
|
||||||
waveform, sr = torchaudio.load( path )
|
waveform, sr = torchaudio.load( path )
|
||||||
# mix channels
|
# mix channels
|
||||||
if waveform.shape[0] > 1:
|
if waveform.shape[0] > 1:
|
||||||
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
||||||
|
if target_sr is None:
|
||||||
|
target_sr = cfg.sample_rate
|
||||||
# resample
|
# resample
|
||||||
waveform, sr = convert_audio(waveform, sr, cfg.sample_rate, 1), cfg.sample_rate
|
waveform, sr = convert_audio(waveform, sr, target_sr, 1), target_sr
|
||||||
|
|
||||||
return waveform, sr
|
return waveform, sr
|
||||||
|
|
||||||
tts = None
|
tts = None
|
||||||
|
|
||||||
def process(
|
# this is for computing SIM-O, but can probably technically be used for scoring similar utterances
|
||||||
|
@cache
|
||||||
|
def _load_sim_model(device="cuda", dtype="float16", feat_type="wavlm_base_plus", feat_dim=768):
|
||||||
|
from ..utils.ext.ecapa_tdnn import ECAPA_TDNN_SMALL
|
||||||
|
model = ECAPA_TDNN_SMALL(feat_dim=feat_dim, feat_type=feat_type, config_path=None)
|
||||||
|
model = model.to(device=device, dtype=coerce_dtype(dtype))
|
||||||
|
model = model.eval()
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def speaker_similarity_embedding(
|
||||||
|
audio,
|
||||||
|
**model_kwargs,
|
||||||
|
):
|
||||||
|
device = model_kwargs.get("device", "cuda")
|
||||||
|
dtype = model_kwargs.get("dtype", "float16")
|
||||||
|
|
||||||
|
model = _load_sim_model(**model_kwargs)
|
||||||
|
if isinstance(audio, str) or isinstance(audio, Path):
|
||||||
|
audio = load_audio(audio, 16000)
|
||||||
|
|
||||||
|
audio, sr = audio
|
||||||
|
audio = audio.to(device=device, dtype=coerce_dtype(dtype))
|
||||||
|
|
||||||
|
return model(audio)
|
||||||
|
|
||||||
|
|
||||||
|
def batch_similar_utterances(
|
||||||
speaker_path,
|
speaker_path,
|
||||||
yaml,
|
yaml,
|
||||||
text=False,
|
text=False,
|
||||||
|
@ -266,7 +297,7 @@ def main():
|
||||||
if args.skip_existing and metadata_keys and "similar" in metadata[metadata_keys[-1]]:
|
if args.skip_existing and metadata_keys and "similar" in metadata[metadata_keys[-1]]:
|
||||||
return
|
return
|
||||||
|
|
||||||
similarities = process(
|
similarities = batch_similar_utterances(
|
||||||
speaker_path=cfg.data_dir / speaker_name,
|
speaker_path=cfg.data_dir / speaker_name,
|
||||||
yaml=args.yaml,
|
yaml=args.yaml,
|
||||||
text=args.text,
|
text=args.text,
|
||||||
|
@ -314,7 +345,7 @@ def main():
|
||||||
add( data_dir, type="noise", texts=False )
|
add( data_dir, type="noise", texts=False )
|
||||||
|
|
||||||
elif args.input_speaker:
|
elif args.input_speaker:
|
||||||
similarities = process(
|
similarities = batch_similar_utterances(
|
||||||
speaker_path=args.input_speaker,
|
speaker_path=args.input_speaker,
|
||||||
yaml=args.yaml,
|
yaml=args.yaml,
|
||||||
text=args.text,
|
text=args.text,
|
||||||
|
|
|
@ -11,9 +11,13 @@ import torchaudio
|
||||||
|
|
||||||
import whisperx
|
import whisperx
|
||||||
|
|
||||||
|
from functools import cache
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from ..utils import coerce_dtype
|
||||||
|
|
||||||
|
|
||||||
def pad(num, zeroes):
|
def pad(num, zeroes):
|
||||||
return str(num).zfill(zeroes+1)
|
return str(num).zfill(zeroes+1)
|
||||||
|
|
||||||
|
@ -21,7 +25,132 @@ def process_items( items, stride=0, stride_offset=0 ):
|
||||||
items = sorted( items )
|
items = sorted( items )
|
||||||
return items if stride == 0 else [ item for i, item in enumerate( items ) if (i+stride_offset) % stride == 0 ]
|
return items if stride == 0 else [ item for i, item in enumerate( items ) if (i+stride_offset) % stride == 0 ]
|
||||||
|
|
||||||
|
# major cringe but should automatically unload models when loading a different one
|
||||||
|
_cached_models = {
|
||||||
|
"model": (None, None),
|
||||||
|
"diarization": (None, None),
|
||||||
|
"align": (None, None),
|
||||||
|
}
|
||||||
|
# yes I can write a decorator to do this
|
||||||
|
def _load_model(model_name="large-v3", device="cuda", dtype="float16", language="auto"):
|
||||||
|
cache_key = f'{model_name}:{device}:{dtype}:{language}'
|
||||||
|
if _cached_models["model"][0] == cache_key:
|
||||||
|
return _cached_models["model"][1]
|
||||||
|
|
||||||
|
del _cached_models["model"]
|
||||||
|
|
||||||
|
if not isinstance( dtype, str ):
|
||||||
|
if dtype == torch.float32:
|
||||||
|
dtype = "float32"
|
||||||
|
elif dtype == torch.float16:
|
||||||
|
dtype = "float16"
|
||||||
|
elif dtype == torch.bfloat16:
|
||||||
|
dtype = "bfloat16"
|
||||||
|
|
||||||
|
# doesnt support it for some reason
|
||||||
|
if dtype == "bfloat16":
|
||||||
|
dtype = "float16"
|
||||||
|
|
||||||
|
kwargs = {}
|
||||||
|
kwargs["compute_type"] = dtype
|
||||||
|
kwargs["task"] = "transcribe"
|
||||||
|
kwargs["device"] = device
|
||||||
|
|
||||||
|
if language != "auto":
|
||||||
|
kwargs["language"] = language
|
||||||
|
|
||||||
|
model = whisperx.load_model(model_name, **kwargs)
|
||||||
|
|
||||||
|
_cached_models["model"] = (cache_key, model)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def _load_diarization_model(device="cuda"):
|
||||||
|
cache_key = f'{device}'
|
||||||
|
|
||||||
|
if _cached_models["diarization"][0] == cache_key:
|
||||||
|
return _cached_models["diarization"][1]
|
||||||
|
del _cached_models["diarization"]
|
||||||
|
model = whisperx.DiarizationPipeline(device=device)
|
||||||
|
_cached_models["diarization"] = (cache_key, model)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def _load_align_model(language, device="cuda"):
|
||||||
|
cache_key = f'{language}:{device}'
|
||||||
|
|
||||||
|
if _cached_models["align"][0] == cache_key:
|
||||||
|
return _cached_models["align"][1]
|
||||||
|
del _cached_models["align"]
|
||||||
|
model = whisperx.load_align_model(language_code=language, device=device)
|
||||||
|
_cached_models["align"] = (cache_key, model)
|
||||||
|
return model
|
||||||
|
|
||||||
|
# yes I can just do a for-loop
|
||||||
|
def unload_model():
|
||||||
|
del _cached_models["model"]
|
||||||
|
del _cached_models["diarization"]
|
||||||
|
del _cached_models["align"]
|
||||||
|
|
||||||
|
_cached_models["model"] = (None, None)
|
||||||
|
_cached_models["diarization"] = (None, None)
|
||||||
|
_cached_models["align"] = (None, None)
|
||||||
|
|
||||||
def transcribe(
|
def transcribe(
|
||||||
|
audio,
|
||||||
|
language = "auto",
|
||||||
|
diarize = False,
|
||||||
|
batch_size = 16,
|
||||||
|
verbose=False,
|
||||||
|
align=True,
|
||||||
|
**model_kwargs,
|
||||||
|
):
|
||||||
|
metadata = {
|
||||||
|
"segments": [],
|
||||||
|
"language": "",
|
||||||
|
"text": "",
|
||||||
|
"start": 0,
|
||||||
|
"end": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
# load requested models
|
||||||
|
device = model_kwargs.get("device", "cuda")
|
||||||
|
model = _load_model(language=language, **model_kwargs)
|
||||||
|
diarize_model = _load_diarization_model(device=device) if diarize else None
|
||||||
|
|
||||||
|
# audio is a path, load it
|
||||||
|
if isinstance(audio, str) or isinstance(audio, Path):
|
||||||
|
#audio = load_audio(audio)
|
||||||
|
audio = whisperx.load_audio(audio)
|
||||||
|
|
||||||
|
result = model.transcribe(audio, batch_size=batch_size)
|
||||||
|
|
||||||
|
if language == "auto":
|
||||||
|
language = result["language"]
|
||||||
|
|
||||||
|
if align:
|
||||||
|
align_model, align_model_metadata = _load_align_model(language=language, device=device)
|
||||||
|
result = whisperx.align(result["segments"], align_model, align_model_metadata, audio, device, return_char_alignments=False)
|
||||||
|
|
||||||
|
if diarize_model is not None:
|
||||||
|
diarize_segments = diarize_model(audio)
|
||||||
|
result = whisperx.assign_word_speakers(diarize_segments, result)
|
||||||
|
|
||||||
|
text = []
|
||||||
|
start = 0
|
||||||
|
end = 0
|
||||||
|
for segment in result["segments"]:
|
||||||
|
text.append( segment["text"] )
|
||||||
|
start = min( start, segment["start"] )
|
||||||
|
end = max( end, segment["end"] )
|
||||||
|
|
||||||
|
metadata["language"] = language
|
||||||
|
metadata["segments"] = result["segments"]
|
||||||
|
metadata["text"] = " ".join(text).strip()
|
||||||
|
metadata["start"] = start
|
||||||
|
metadata["end"] = end
|
||||||
|
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
def transcribe_batch(
|
||||||
input_audio = "voices",
|
input_audio = "voices",
|
||||||
input_voice = None,
|
input_voice = None,
|
||||||
output_metadata = "training/metadata",
|
output_metadata = "training/metadata",
|
||||||
|
@ -49,14 +178,11 @@ def transcribe(
|
||||||
if input_voice is not None:
|
if input_voice is not None:
|
||||||
only_speakers = [input_voice]
|
only_speakers = [input_voice]
|
||||||
|
|
||||||
#
|
"""
|
||||||
model = whisperx.load_model(model_name, device, compute_type=dtype)
|
|
||||||
align_model, align_model_metadata, align_model_language = (None, None, None)
|
align_model, align_model_metadata, align_model_language = (None, None, None)
|
||||||
if diarize:
|
model =_load_model(model_name, device, compute_type=dtype)
|
||||||
diarize_model = whisperx.DiarizationPipeline(device=device)
|
diarize_model = _load_diarization_model(device=device) if diarize else None
|
||||||
else:
|
"""
|
||||||
diarize_model = None
|
|
||||||
|
|
||||||
|
|
||||||
for dataset_name in os.listdir(f'./{input_audio}/'):
|
for dataset_name in os.listdir(f'./{input_audio}/'):
|
||||||
if not os.path.isdir(f'./{input_audio}/{dataset_name}/'):
|
if not os.path.isdir(f'./{input_audio}/{dataset_name}/'):
|
||||||
|
@ -96,6 +222,9 @@ def transcribe(
|
||||||
if os.path.isdir(inpath):
|
if os.path.isdir(inpath):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
metadata[filename] = transcribe( inpath, model_name=model_name, diarize=diarize, device=device, dtype=dtype )
|
||||||
|
|
||||||
|
"""
|
||||||
metadata[filename] = {
|
metadata[filename] = {
|
||||||
"segments": [],
|
"segments": [],
|
||||||
"language": "",
|
"language": "",
|
||||||
|
@ -108,15 +237,10 @@ def transcribe(
|
||||||
result = model.transcribe(audio, batch_size=batch_size)
|
result = model.transcribe(audio, batch_size=batch_size)
|
||||||
language = result["language"]
|
language = result["language"]
|
||||||
|
|
||||||
"""
|
|
||||||
if language[:2] not in ["ja"]:
|
|
||||||
language = "en"
|
|
||||||
"""
|
|
||||||
|
|
||||||
if align_model_language != language:
|
if align_model_language != language:
|
||||||
tqdm.write(f'Loading language: {language}')
|
tqdm.write(f'Loading language: {language}')
|
||||||
align_model, align_model_metadata = whisperx.load_align_model(language_code=language, device=device)
|
|
||||||
align_model_language = language
|
align_model_language = language
|
||||||
|
align_model, align_model_metadata = _load_align_model(language=language, device=device)
|
||||||
|
|
||||||
result = whisperx.align(result["segments"], align_model, align_model_metadata, audio, device, return_char_alignments=False)
|
result = whisperx.align(result["segments"], align_model, align_model_metadata, audio, device, return_char_alignments=False)
|
||||||
|
|
||||||
|
@ -138,6 +262,7 @@ def transcribe(
|
||||||
metadata[filename]["text"] = " ".join(text).strip()
|
metadata[filename]["text"] = " ".join(text).strip()
|
||||||
metadata[filename]["start"] = start
|
metadata[filename]["start"] = start
|
||||||
metadata[filename]["end"] = end
|
metadata[filename]["end"] = end
|
||||||
|
"""
|
||||||
|
|
||||||
open(outpath, 'w', encoding='utf-8').write(json.dumps(metadata))
|
open(outpath, 'w', encoding='utf-8').write(json.dumps(metadata))
|
||||||
|
|
||||||
|
@ -169,7 +294,7 @@ def main():
|
||||||
args.stride_offset = int(args.device)
|
args.stride_offset = int(args.device)
|
||||||
args.device = f'cuda:{args.device}'
|
args.device = f'cuda:{args.device}'
|
||||||
|
|
||||||
transcribe(
|
transcribe_batch(
|
||||||
input_audio = args.input_audio,
|
input_audio = args.input_audio,
|
||||||
input_voice = args.input_voice,
|
input_voice = args.input_voice,
|
||||||
output_metadata = args.output_metadata,
|
output_metadata = args.output_metadata,
|
||||||
|
|
|
@ -110,8 +110,10 @@ def load_engines(training=True, **model_kwargs):
|
||||||
scheduler_class = None
|
scheduler_class = None
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
|
"params": [ param for name, param in model.named_parameters() if name not in model.config.frozen_params ],
|
||||||
"lr": cfg.hyperparameters.learning_rate,
|
"lr": cfg.hyperparameters.learning_rate,
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfg.hyperparameters.optimizer.lower() == "adamw":
|
if cfg.hyperparameters.optimizer.lower() == "adamw":
|
||||||
params["betas"] = (0.9, 0.96)
|
params["betas"] = (0.9, 0.96)
|
||||||
params["eps"] = 1e-07
|
params["eps"] = 1e-07
|
||||||
|
@ -129,17 +131,30 @@ def load_engines(training=True, **model_kwargs):
|
||||||
|
|
||||||
params['d_coef'] = params['lr']
|
params['d_coef'] = params['lr']
|
||||||
params['lr'] = 1.0
|
params['lr'] = 1.0
|
||||||
|
elif cfg.hyperparameters.optimizer.lower() in ["apollo","apollo-mini"]:
|
||||||
|
optimizer_class = ml.Apollo
|
||||||
|
is_mini = cfg.hyperparameters.optimizer.lower() == "apollo-mini"
|
||||||
|
param_kwargs = {
|
||||||
|
"rank": 1 if is_mini else 256,
|
||||||
|
"proj": "random",
|
||||||
|
"scale_type": "tensor" if is_mini else "channel",
|
||||||
|
"scale": 128 if is_mini else 1,
|
||||||
|
"update_proj_gap": 200,
|
||||||
|
"proj_type": "std",
|
||||||
|
}
|
||||||
|
# grab any extra configs from the YAML
|
||||||
|
param_kwargs.update(cfg.hyperparameters.optimizer_params)
|
||||||
|
# and blank it so it doesn't update the main optimizer kwargs
|
||||||
|
cfg.hyperparameters.optimizer_params = {}
|
||||||
|
# settings are stored under params
|
||||||
|
params["params"] = [dict(params=params["params"], **param_kwargs)]
|
||||||
elif cfg.hyperparameters.optimizer.lower() == "adagrad":
|
elif cfg.hyperparameters.optimizer.lower() == "adagrad":
|
||||||
optimizer_class = ml.Adagrad
|
optimizer_class = ml.Adagrad
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Optimizer specified not implemented: {cfg.hyperparameters.optimizer}')
|
raise ValueError(f'Optimizer specified not implemented: {cfg.hyperparameters.optimizer}')
|
||||||
|
|
||||||
params.update(cfg.hyperparameters.optimizer_params)
|
params.update(cfg.hyperparameters.optimizer_params)
|
||||||
|
optimizer = optimizer_class(**params)
|
||||||
optimizer = optimizer_class(
|
|
||||||
[ param for name, param in model.named_parameters() if name not in model.config.frozen_params ],
|
|
||||||
**params,
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.hyperparameters.scheduler.lower() == "schedulefree":
|
if cfg.hyperparameters.scheduler.lower() == "schedulefree":
|
||||||
if cfg.hyperparameters.optimizer.lower() == "adamw":
|
if cfg.hyperparameters.optimizer.lower() == "adamw":
|
||||||
|
|
|
@ -296,7 +296,7 @@ class TTS():
|
||||||
use_lora=use_lora,
|
use_lora=use_lora,
|
||||||
)
|
)
|
||||||
|
|
||||||
with torch.autocast("cuda", dtype=dtype, enabled=amp):
|
with torch.autocast(self.device, dtype=dtype, enabled=amp):
|
||||||
if model_len is not None:
|
if model_len is not None:
|
||||||
# extra kwargs
|
# extra kwargs
|
||||||
duration_padding = sampling_kwargs.pop("duration_padding", 1.05)
|
duration_padding = sampling_kwargs.pop("duration_padding", 1.05)
|
||||||
|
@ -384,7 +384,7 @@ class TTS():
|
||||||
resp = to_device(resp, device=self.device, dtype=torch.int16)
|
resp = to_device(resp, device=self.device, dtype=torch.int16)
|
||||||
lang = to_device(lang, device=self.device, dtype=torch.uint8)
|
lang = to_device(lang, device=self.device, dtype=torch.uint8)
|
||||||
|
|
||||||
with torch.autocast("cuda", dtype=dtype, enabled=amp):
|
with torch.autocast(self.device, dtype=dtype, enabled=amp):
|
||||||
model = model_ar if model_ar is not None else model_nar
|
model = model_ar if model_ar is not None else model_nar
|
||||||
if model is not None:
|
if model is not None:
|
||||||
text_list = model(
|
text_list = model(
|
||||||
|
@ -430,7 +430,7 @@ class TTS():
|
||||||
phns = to_device(phns, device=self.device, dtype=torch.uint8 if len(self.symmap) < 256 else torch.int16)
|
phns = to_device(phns, device=self.device, dtype=torch.uint8 if len(self.symmap) < 256 else torch.int16)
|
||||||
lang = to_device(lang, device=self.device, dtype=torch.uint8)
|
lang = to_device(lang, device=self.device, dtype=torch.uint8)
|
||||||
|
|
||||||
with torch.autocast("cuda", dtype=dtype, enabled=amp):
|
with torch.autocast(self.device, dtype=dtype, enabled=amp):
|
||||||
input_kwargs = dict(
|
input_kwargs = dict(
|
||||||
text_list=[phns],
|
text_list=[phns],
|
||||||
proms_list=[prom],
|
proms_list=[prom],
|
||||||
|
|
33
vall_e/metrics.py
Normal file
33
vall_e/metrics.py
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
# handles objective metric calculations, such as WER and SIM-O
|
||||||
|
|
||||||
|
#from .emb.transcribe import transcribe
|
||||||
|
from .emb.similar import speaker_similarity_embedding
|
||||||
|
from .emb.transcribe import transcribe
|
||||||
|
from .emb.g2p import detect_language
|
||||||
|
from .data import normalize_text
|
||||||
|
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from torcheval.metrics.functional import word_error_rate
|
||||||
|
|
||||||
|
def wer( audio, reference, language="auto", **transcription_kwargs ):
|
||||||
|
if language == "auto":
|
||||||
|
language = detect_language( reference )
|
||||||
|
|
||||||
|
transcription = transcribe( audio, language=language, align=False, **transcription_kwargs )["text"]
|
||||||
|
|
||||||
|
# reference audio needs transcribing too
|
||||||
|
if isinstance( reference, Path ):
|
||||||
|
reference = transcribe( reference, language=language, align=False, **transcription_kwargs )["text"]
|
||||||
|
|
||||||
|
transcription = normalize_text( transcription )
|
||||||
|
reference = normalize_text( reference )
|
||||||
|
|
||||||
|
return word_error_rate([transcription], [reference]).item()
|
||||||
|
|
||||||
|
def sim_o( audio, reference, **kwargs ):
|
||||||
|
audio_emb = speaker_similarity_embedding( audio, **kwargs )
|
||||||
|
reference_emb = speaker_similarity_embedding( reference, **kwargs )
|
||||||
|
|
||||||
|
return F.cosine_similarity( audio_emb, reference_emb ).item()
|
|
@ -15,5 +15,6 @@ from .utils import (
|
||||||
prune_missing,
|
prune_missing,
|
||||||
clamp,
|
clamp,
|
||||||
md5_hash,
|
md5_hash,
|
||||||
convert_kwargs
|
convert_kwargs,
|
||||||
|
coerce_dtype
|
||||||
)
|
)
|
0
vall_e/utils/ext/__init__.py
Normal file
0
vall_e/utils/ext/__init__.py
Normal file
433
vall_e/utils/ext/apollo.py
Normal file
433
vall_e/utils/ext/apollo.py
Normal file
|
@ -0,0 +1,433 @@
|
||||||
|
# "borrowed" with love from https://github.com/MadsToftrup/Apollo-dev/blob/main/galore_torch/apollo.py
|
||||||
|
# to be replaced with the official implementation (https://github.com/zhuhanqing/APOLLO) maybe
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from torch import nn
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
|
from typing import Any, Callable, Dict, Generator, Iterable, Optional, Sequence, Union, Tuple
|
||||||
|
|
||||||
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
|
class GaLoreProjector:
|
||||||
|
def __init__(self, rank, verbose=False, update_proj_gap=200, scale=1.0, proj_type='std'):
|
||||||
|
self.rank = rank
|
||||||
|
self.verbose = verbose
|
||||||
|
self.update_proj_gap = update_proj_gap
|
||||||
|
self.scale = scale
|
||||||
|
self.ortho_matrix = None
|
||||||
|
self.proj_type = proj_type
|
||||||
|
self.svd_count = 0
|
||||||
|
|
||||||
|
def project(self, full_rank_grad, iter):
|
||||||
|
|
||||||
|
if self.ortho_matrix is not None and self.ortho_matrix.device != full_rank_grad.device:
|
||||||
|
self.ortho_matrix = self.ortho_matrix.to(full_rank_grad.device)
|
||||||
|
|
||||||
|
if self.proj_type == 'std':
|
||||||
|
if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
|
||||||
|
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
||||||
|
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right')
|
||||||
|
self.svd_count += 1
|
||||||
|
low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
|
||||||
|
else:
|
||||||
|
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
||||||
|
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left')
|
||||||
|
self.svd_count += 1
|
||||||
|
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
|
||||||
|
elif self.proj_type == 'reverse_std':
|
||||||
|
if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
|
||||||
|
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
||||||
|
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left')
|
||||||
|
self.svd_count += 1
|
||||||
|
low_rank_grad = torch.matmul(self.ortho_matrix.t(),full_rank_grad)
|
||||||
|
else:
|
||||||
|
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
||||||
|
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right')
|
||||||
|
self.svd_count += 1
|
||||||
|
low_rank_grad = torch.matmul(full_rank_grad,self.ortho_matrix.t())
|
||||||
|
elif self.proj_type == 'right':
|
||||||
|
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
||||||
|
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right')
|
||||||
|
self.svd_count += 1
|
||||||
|
low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
|
||||||
|
elif self.proj_type == 'left':
|
||||||
|
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
||||||
|
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left')
|
||||||
|
self.svd_count += 1
|
||||||
|
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
|
||||||
|
elif self.proj_type == 'full':
|
||||||
|
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
||||||
|
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='full')
|
||||||
|
self.svd_count += 1
|
||||||
|
low_rank_grad = torch.matmul(self.ortho_matrix[0].t(), full_rank_grad) @ self.ortho_matrix[1].t()
|
||||||
|
|
||||||
|
return low_rank_grad
|
||||||
|
|
||||||
|
def project_back(self, low_rank_grad):
|
||||||
|
|
||||||
|
if self.proj_type == 'std':
|
||||||
|
if low_rank_grad.shape[0] >= low_rank_grad.shape[1]:
|
||||||
|
full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
|
||||||
|
else:
|
||||||
|
full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
|
||||||
|
elif self.proj_type == 'reverse_std':
|
||||||
|
if low_rank_grad.shape[0] <= low_rank_grad.shape[1]: # note this is different from std
|
||||||
|
full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
|
||||||
|
else:
|
||||||
|
full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
|
||||||
|
elif self.proj_type == 'right':
|
||||||
|
full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
|
||||||
|
elif self.proj_type == 'left':
|
||||||
|
full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
|
||||||
|
elif self.proj_type == 'full':
|
||||||
|
full_rank_grad = torch.matmul(self.ortho_matrix[0], low_rank_grad) @ self.ortho_matrix[1]
|
||||||
|
return full_rank_grad * self.scale
|
||||||
|
|
||||||
|
# svd decomposition
|
||||||
|
def get_orthogonal_matrix(self, weights, rank, type):
|
||||||
|
module_params = weights
|
||||||
|
|
||||||
|
if module_params.data.dtype != torch.float:
|
||||||
|
float_data = False
|
||||||
|
original_type = module_params.data.dtype
|
||||||
|
original_device = module_params.data.device
|
||||||
|
matrix = module_params.data.float()
|
||||||
|
else:
|
||||||
|
float_data = True
|
||||||
|
matrix = module_params.data
|
||||||
|
|
||||||
|
U, s, Vh = torch.linalg.svd(matrix, full_matrices = False)
|
||||||
|
|
||||||
|
#make the smaller matrix always to be orthogonal matrix
|
||||||
|
if type=='right':
|
||||||
|
A = U[:, :rank] @ torch.diag(s[:rank])
|
||||||
|
B = Vh[:rank, :]
|
||||||
|
|
||||||
|
if not float_data:
|
||||||
|
B = B.to(original_device).type(original_type)
|
||||||
|
return B
|
||||||
|
elif type=='left':
|
||||||
|
A = U[:, :rank]
|
||||||
|
B = torch.diag(s[:rank]) @ Vh[:rank, :]
|
||||||
|
if not float_data:
|
||||||
|
A = A.to(original_device).type(original_type)
|
||||||
|
return A
|
||||||
|
elif type=='full':
|
||||||
|
A = U[:, :rank]
|
||||||
|
B = Vh[:rank, :]
|
||||||
|
if not float_data:
|
||||||
|
A = A.to(original_device).type(original_type)
|
||||||
|
B = B.to(original_device).type(original_type)
|
||||||
|
return [A, B]
|
||||||
|
else:
|
||||||
|
raise ValueError('type should be left, right or full')
|
||||||
|
|
||||||
|
def stable_randn(
|
||||||
|
shape: Union[int, Sequence[int]],
|
||||||
|
seed: int,
|
||||||
|
device: Optional[Union[str, torch.device]] = None,
|
||||||
|
dtype: Optional[torch.dtype] = torch.float32,
|
||||||
|
):
|
||||||
|
if device is None:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
generator = torch.Generator(device=device).manual_seed(seed)
|
||||||
|
rn = torch.randn(shape, generator=generator, device=generator.device, dtype=dtype)
|
||||||
|
return rn
|
||||||
|
|
||||||
|
|
||||||
|
def next_seed(seed: int, adv: int = 0xF):
|
||||||
|
"""
|
||||||
|
This is a naive helper function to generate a new seed from the given seed.
|
||||||
|
"""
|
||||||
|
generator = torch.Generator().manual_seed(seed)
|
||||||
|
return torch.randint(
|
||||||
|
0, torch.iinfo(torch.int64).max, (adv,), generator=generator, device=generator.device
|
||||||
|
).tolist()[-1]
|
||||||
|
|
||||||
|
|
||||||
|
def split_seed(seed: int):
|
||||||
|
generator = torch.Generator().manual_seed(seed)
|
||||||
|
return tuple(
|
||||||
|
torch.randint(0, torch.iinfo(torch.int64).max, (2,), generator=generator, device=generator.device).tolist()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GradientProjector:
|
||||||
|
def __init__(
|
||||||
|
self, rank, update_proj_gap=200, alpha=1.0, proj_type="std", seed=0
|
||||||
|
):
|
||||||
|
# This is a lazy implementation as we store the projection matrix instead of re-generation every iteration
|
||||||
|
self.rank = rank
|
||||||
|
self.update_proj_gap = update_proj_gap
|
||||||
|
self.alpha = alpha
|
||||||
|
self.proj_type = proj_type
|
||||||
|
|
||||||
|
self.ortho_matrix = None
|
||||||
|
self.seed = seed
|
||||||
|
|
||||||
|
def project(self, full_rank_grad, iter):
|
||||||
|
|
||||||
|
if self.proj_type == "std":
|
||||||
|
if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
|
||||||
|
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
||||||
|
self.ortho_matrix = self.get_orthogonal_matrix(
|
||||||
|
full_rank_grad, self.rank, type="right", seed=self.seed
|
||||||
|
)
|
||||||
|
self.seed = next_seed(self.seed)
|
||||||
|
low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
|
||||||
|
else:
|
||||||
|
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
||||||
|
self.ortho_matrix = self.get_orthogonal_matrix(
|
||||||
|
full_rank_grad, self.rank, type="left", seed=self.seed
|
||||||
|
)
|
||||||
|
self.seed = next_seed(self.seed)
|
||||||
|
|
||||||
|
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
|
||||||
|
elif self.proj_type == "reverse_std":
|
||||||
|
if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
|
||||||
|
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
||||||
|
self.ortho_matrix = self.get_orthogonal_matrix(
|
||||||
|
full_rank_grad, self.rank, type="left", seed=self.seed
|
||||||
|
)
|
||||||
|
self.seed = next_seed(self.seed)
|
||||||
|
|
||||||
|
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
|
||||||
|
else:
|
||||||
|
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
||||||
|
self.ortho_matrix = self.get_orthogonal_matrix(
|
||||||
|
full_rank_grad, self.rank, type="right", seed=self.seed
|
||||||
|
)
|
||||||
|
self.seed = next_seed(self.seed)
|
||||||
|
low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
|
||||||
|
elif self.proj_type == "right":
|
||||||
|
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
||||||
|
self.ortho_matrix = self.get_orthogonal_matrix(
|
||||||
|
full_rank_grad, self.rank, type="right", seed=self.seed
|
||||||
|
)
|
||||||
|
self.seed = next_seed(self.seed)
|
||||||
|
low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
|
||||||
|
elif self.proj_type == "left":
|
||||||
|
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
||||||
|
self.ortho_matrix = self.get_orthogonal_matrix(
|
||||||
|
full_rank_grad, self.rank, type="left", seed=self.seed
|
||||||
|
)
|
||||||
|
self.seed = next_seed(self.seed)
|
||||||
|
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
|
||||||
|
elif self.proj_type == "full":
|
||||||
|
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
||||||
|
self.ortho_matrix = self.get_orthogonal_matrix(
|
||||||
|
full_rank_grad, self.rank, type="full", seed=self.seed
|
||||||
|
)
|
||||||
|
self.seed = next_seed(self.seed)
|
||||||
|
low_rank_grad = (
|
||||||
|
torch.matmul(self.ortho_matrix[0].t(), full_rank_grad)
|
||||||
|
@ self.ortho_matrix[1].t()
|
||||||
|
)
|
||||||
|
|
||||||
|
return low_rank_grad
|
||||||
|
|
||||||
|
# random low rank projection
|
||||||
|
def get_orthogonal_matrix(self, weights, rank, type, seed):
|
||||||
|
module_params = weights
|
||||||
|
|
||||||
|
if module_params.data.dtype != torch.float:
|
||||||
|
float_data = False
|
||||||
|
original_type = module_params.data.dtype
|
||||||
|
original_device = module_params.data.device
|
||||||
|
matrix = module_params.data.float()
|
||||||
|
else:
|
||||||
|
float_data = True
|
||||||
|
matrix = module_params.data
|
||||||
|
|
||||||
|
if type == "left":
|
||||||
|
proj = stable_randn(
|
||||||
|
(matrix.shape[0], rank), seed=seed, device=matrix.device, dtype=matrix.dtype
|
||||||
|
) / math.sqrt(rank)
|
||||||
|
if not float_data:
|
||||||
|
proj = proj.to(original_device).type(original_type)
|
||||||
|
return proj
|
||||||
|
elif type == "right":
|
||||||
|
proj = stable_randn(
|
||||||
|
(rank, matrix.shape[1]), seed=seed, device=matrix.device, dtype=matrix.dtype
|
||||||
|
) / math.sqrt(rank)
|
||||||
|
if not float_data:
|
||||||
|
proj = proj.to(original_device).type(original_type)
|
||||||
|
return proj
|
||||||
|
elif type == "full":
|
||||||
|
raise NotImplementedError("full rank projection is not implemented yet")
|
||||||
|
else:
|
||||||
|
raise ValueError("type should be left, right or full")
|
||||||
|
|
||||||
|
class Apollo(Optimizer):
|
||||||
|
"""
|
||||||
|
Implements Adam algorithm with weight decay fix as introduced in [Decoupled Weight Decay
|
||||||
|
Regularization](https://arxiv.org/abs/1711.05101).
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
params (`Iterable[nn.parameter.Parameter]`):
|
||||||
|
Iterable of parameters to optimize or dictionaries defining parameter groups.
|
||||||
|
lr (`float`, *optional*, defaults to 0.001):
|
||||||
|
The learning rate to use.
|
||||||
|
betas (`Tuple[float,float]`, *optional*, defaults to `(0.9, 0.999)`):
|
||||||
|
Adam's betas parameters (b1, b2).
|
||||||
|
eps (`float`, *optional*, defaults to 1e-06):
|
||||||
|
Adam's epsilon for numerical stability.
|
||||||
|
weight_decay (`float`, *optional*, defaults to 0.0):
|
||||||
|
Decoupled weight decay to apply.
|
||||||
|
correct_bias (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`).
|
||||||
|
no_deprecation_warning (`bool`, *optional*, defaults to `False`):
|
||||||
|
A flag used to disable the deprecation warning (set to `True` to disable the warning).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
params: Iterable[nn.parameter.Parameter],
|
||||||
|
lr: float = 1e-3,
|
||||||
|
betas: Tuple[float, float] = (0.9, 0.999),
|
||||||
|
eps: float = 1e-6,
|
||||||
|
weight_decay: float = 0.0,
|
||||||
|
correct_bias: bool = True,
|
||||||
|
scale_front: bool = False,
|
||||||
|
):
|
||||||
|
if lr < 0.0:
|
||||||
|
raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0")
|
||||||
|
if not 0.0 <= betas[0] < 1.0:
|
||||||
|
raise ValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)")
|
||||||
|
if not 0.0 <= betas[1] < 1.0:
|
||||||
|
raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)")
|
||||||
|
if not 0.0 <= eps:
|
||||||
|
raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0")
|
||||||
|
defaults = {"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "correct_bias": correct_bias}
|
||||||
|
super().__init__(params, defaults)
|
||||||
|
|
||||||
|
self.scale_front = scale_front
|
||||||
|
|
||||||
|
params_idx = 0
|
||||||
|
for group in self.param_groups:
|
||||||
|
for p in group["params"]:
|
||||||
|
params_idx += 1
|
||||||
|
if p.requires_grad:
|
||||||
|
self.state[p]["seed"] = params_idx
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def step(self, closure: Callable = None):
|
||||||
|
"""
|
||||||
|
Performs a single optimization step.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
|
||||||
|
"""
|
||||||
|
loss = None
|
||||||
|
if closure is not None:
|
||||||
|
loss = closure()
|
||||||
|
|
||||||
|
for group in self.param_groups:
|
||||||
|
for p in group["params"]:
|
||||||
|
if p.grad is None:
|
||||||
|
continue
|
||||||
|
grad = p.grad
|
||||||
|
if grad.is_sparse:
|
||||||
|
raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")
|
||||||
|
|
||||||
|
state = self.state[p]
|
||||||
|
|
||||||
|
if "step" not in state:
|
||||||
|
state["step"] = 0
|
||||||
|
|
||||||
|
# GaLore Projection
|
||||||
|
if "rank" in group:
|
||||||
|
if "projector" not in state:
|
||||||
|
if group["proj"] == "random":
|
||||||
|
state["projector"] = GradientProjector(group["rank"],
|
||||||
|
update_proj_gap=group["update_proj_gap"],
|
||||||
|
alpha=group["scale"],
|
||||||
|
proj_type=group["proj_type"],
|
||||||
|
seed=state["seed"])
|
||||||
|
|
||||||
|
elif group["proj"] == "svd":
|
||||||
|
state["projector"] = GaLoreProjector(group["rank"],
|
||||||
|
update_proj_gap=group["update_proj_gap"],
|
||||||
|
scale=group["scale"],
|
||||||
|
proj_type=group["proj_type"])
|
||||||
|
|
||||||
|
grad = state["projector"].project(grad, state["step"])
|
||||||
|
|
||||||
|
# State initialization
|
||||||
|
if "exp_avg" not in state:
|
||||||
|
# Exponential moving average of gradient values
|
||||||
|
state["exp_avg"] = torch.zeros_like(grad)
|
||||||
|
# Exponential moving average of squared gradient values
|
||||||
|
state["exp_avg_sq"] = torch.zeros_like(grad)
|
||||||
|
|
||||||
|
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
||||||
|
beta1, beta2 = group["betas"]
|
||||||
|
|
||||||
|
state["step"] += 1
|
||||||
|
|
||||||
|
# Decay the first and second moment running average coefficient
|
||||||
|
# In-place operations to update the averages at the same time
|
||||||
|
exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1))
|
||||||
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
|
||||||
|
denom = exp_avg_sq.sqrt().add_(group["eps"])
|
||||||
|
|
||||||
|
step_size = group["lr"]
|
||||||
|
if group["correct_bias"]: # No bias correction for Bert
|
||||||
|
bias_correction1 = 1.0 - beta1 ** state["step"]
|
||||||
|
bias_correction2 = 1.0 - beta2 ** state["step"]
|
||||||
|
step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
|
||||||
|
|
||||||
|
# compute norm gradient
|
||||||
|
norm_grad = exp_avg / denom
|
||||||
|
|
||||||
|
if "rank" in group:
|
||||||
|
if group['scale_type'] == 'channel':
|
||||||
|
norm_dim = 0 if norm_grad.shape[0] < norm_grad.shape[1] else 1
|
||||||
|
scaling_factor = (
|
||||||
|
torch.norm(norm_grad, dim=norm_dim) /
|
||||||
|
(torch.norm(grad, dim=norm_dim) + 1e-8)
|
||||||
|
)
|
||||||
|
if norm_dim == 1:
|
||||||
|
scaling_factor = scaling_factor.unsqueeze(1)
|
||||||
|
|
||||||
|
elif group['scale_type'] == 'tensor':
|
||||||
|
scaling_factor = (
|
||||||
|
torch.norm(norm_grad) /
|
||||||
|
(torch.norm(grad) + 1e-8)
|
||||||
|
)
|
||||||
|
|
||||||
|
scaling_grad = p.grad * scaling_factor
|
||||||
|
|
||||||
|
# Use Norm-Growth Limiter in Fira
|
||||||
|
if "scaling_grad" in state:
|
||||||
|
scaling_grad_norm = torch.norm(scaling_grad)
|
||||||
|
limiter = max(
|
||||||
|
scaling_grad_norm /
|
||||||
|
(state["scaling_grad"] + 1e-8),
|
||||||
|
1.01,
|
||||||
|
) / 1.01
|
||||||
|
scaling_grad = scaling_grad / limiter
|
||||||
|
state["scaling_grad"] = scaling_grad_norm / limiter
|
||||||
|
else:
|
||||||
|
state["scaling_grad"] = torch.norm(scaling_grad)
|
||||||
|
|
||||||
|
norm_grad = scaling_grad * np.sqrt(group["scale"])
|
||||||
|
|
||||||
|
p.add_(norm_grad, alpha=-step_size)
|
||||||
|
|
||||||
|
# Just adding the square of the weights to the loss function is *not*
|
||||||
|
# the correct way of using L2 regularization/weight decay with Adam,
|
||||||
|
# since that will interact with the m and v parameters in strange ways.
|
||||||
|
#
|
||||||
|
# Instead we want to decay the weights in a manner that doesn't interact
|
||||||
|
# with the m/v parameters. This is equivalent to adding the square
|
||||||
|
# of the weights to the loss with plain (non-momentum) SGD.
|
||||||
|
# Add weight decay at the end (fixed version)
|
||||||
|
if group["weight_decay"] > 0.0:
|
||||||
|
p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))
|
||||||
|
|
||||||
|
return loss
|
467
vall_e/utils/ext/ecapa_tdnn.py
Normal file
467
vall_e/utils/ext/ecapa_tdnn.py
Normal file
|
@ -0,0 +1,467 @@
|
||||||
|
# borrowed with love from "https://github.com/keonlee9420/evaluate-zero-shot-tts/blob/master/src/evaluate_zero_shot_tts/utils/speaker_verification/models/ecapa_tdnn.py"
|
||||||
|
# part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torchaudio.transforms as trans
|
||||||
|
|
||||||
|
#from .utils import UpstreamExpert
|
||||||
|
|
||||||
|
""" Res2Conv1d + BatchNorm1d + ReLU
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class Res2Conv1dReluBn(nn.Module):
|
||||||
|
"""
|
||||||
|
in_channels == out_channels == channels
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
dilation=1,
|
||||||
|
bias=True,
|
||||||
|
scale=4,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
|
||||||
|
self.scale = scale
|
||||||
|
self.width = channels // scale
|
||||||
|
self.nums = scale if scale == 1 else scale - 1
|
||||||
|
|
||||||
|
self.convs = []
|
||||||
|
self.bns = []
|
||||||
|
for i in range(self.nums):
|
||||||
|
self.convs.append(
|
||||||
|
nn.Conv1d(
|
||||||
|
self.width,
|
||||||
|
self.width,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
dilation,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.bns.append(nn.BatchNorm1d(self.width))
|
||||||
|
self.convs = nn.ModuleList(self.convs)
|
||||||
|
self.bns = nn.ModuleList(self.bns)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = []
|
||||||
|
spx = torch.split(x, self.width, 1)
|
||||||
|
for i in range(self.nums):
|
||||||
|
if i == 0:
|
||||||
|
sp = spx[i]
|
||||||
|
else:
|
||||||
|
sp = sp + spx[i]
|
||||||
|
# Order: conv -> relu -> bn
|
||||||
|
sp = self.convs[i](sp)
|
||||||
|
sp = self.bns[i](F.relu(sp))
|
||||||
|
out.append(sp)
|
||||||
|
if self.scale != 1:
|
||||||
|
out.append(spx[self.nums])
|
||||||
|
out = torch.cat(out, dim=1)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
""" Conv1d + BatchNorm1d + ReLU
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class Conv1dReluBn(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
dilation=1,
|
||||||
|
bias=True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv1d(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
dilation,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
self.bn = nn.BatchNorm1d(out_channels)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.bn(F.relu(self.conv(x)))
|
||||||
|
|
||||||
|
|
||||||
|
""" The SE connection of 1D case.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class SE_Connect(nn.Module):
|
||||||
|
def __init__(self, channels, se_bottleneck_dim=128):
|
||||||
|
super().__init__()
|
||||||
|
self.linear1 = nn.Linear(channels, se_bottleneck_dim)
|
||||||
|
self.linear2 = nn.Linear(se_bottleneck_dim, channels)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = x.mean(dim=2)
|
||||||
|
out = F.relu(self.linear1(out))
|
||||||
|
out = torch.sigmoid(self.linear2(out))
|
||||||
|
out = x * out.unsqueeze(2)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
""" SE-Res2Block of the ECAPA-TDNN architecture.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
|
||||||
|
# return nn.Sequential(
|
||||||
|
# Conv1dReluBn(channels, 512, kernel_size=1, stride=1, padding=0),
|
||||||
|
# Res2Conv1dReluBn(512, kernel_size, stride, padding, dilation, scale=scale),
|
||||||
|
# Conv1dReluBn(512, channels, kernel_size=1, stride=1, padding=0),
|
||||||
|
# SE_Connect(channels)
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
class SE_Res2Block(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
dilation,
|
||||||
|
scale,
|
||||||
|
se_bottleneck_dim,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.Conv1dReluBn1 = Conv1dReluBn(
|
||||||
|
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
||||||
|
)
|
||||||
|
self.Res2Conv1dReluBn = Res2Conv1dReluBn(
|
||||||
|
out_channels, kernel_size, stride, padding, dilation, scale=scale
|
||||||
|
)
|
||||||
|
self.Conv1dReluBn2 = Conv1dReluBn(
|
||||||
|
out_channels, out_channels, kernel_size=1, stride=1, padding=0
|
||||||
|
)
|
||||||
|
self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim)
|
||||||
|
|
||||||
|
self.shortcut = None
|
||||||
|
if in_channels != out_channels:
|
||||||
|
self.shortcut = nn.Conv1d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
residual = x
|
||||||
|
if self.shortcut:
|
||||||
|
residual = self.shortcut(x)
|
||||||
|
|
||||||
|
x = self.Conv1dReluBn1(x)
|
||||||
|
x = self.Res2Conv1dReluBn(x)
|
||||||
|
x = self.Conv1dReluBn2(x)
|
||||||
|
x = self.SE_Connect(x)
|
||||||
|
|
||||||
|
return x + residual
|
||||||
|
|
||||||
|
|
||||||
|
""" Attentive weighted mean and standard deviation pooling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class AttentiveStatsPool(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, in_dim, attention_channels=128, global_context_att=False
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.global_context_att = global_context_att
|
||||||
|
|
||||||
|
# Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
|
||||||
|
if global_context_att:
|
||||||
|
self.linear1 = nn.Conv1d(
|
||||||
|
in_dim * 3, attention_channels, kernel_size=1
|
||||||
|
) # equals W and b in the paper
|
||||||
|
else:
|
||||||
|
self.linear1 = nn.Conv1d(
|
||||||
|
in_dim, attention_channels, kernel_size=1
|
||||||
|
) # equals W and b in the paper
|
||||||
|
self.linear2 = nn.Conv1d(
|
||||||
|
attention_channels, in_dim, kernel_size=1
|
||||||
|
) # equals V and k in the paper
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.global_context_att:
|
||||||
|
context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
|
||||||
|
context_std = torch.sqrt(
|
||||||
|
torch.var(x, dim=-1, keepdim=True) + 1e-10
|
||||||
|
).expand_as(x)
|
||||||
|
x_in = torch.cat((x, context_mean, context_std), dim=1)
|
||||||
|
else:
|
||||||
|
x_in = x
|
||||||
|
|
||||||
|
# DON'T use ReLU here! In experiments, I find ReLU hard to converge.
|
||||||
|
alpha = torch.tanh(self.linear1(x_in))
|
||||||
|
# alpha = F.relu(self.linear1(x_in))
|
||||||
|
alpha = torch.softmax(self.linear2(alpha), dim=2)
|
||||||
|
mean = torch.sum(alpha * x, dim=2)
|
||||||
|
residuals = torch.sum(alpha * (x**2), dim=2) - mean**2
|
||||||
|
std = torch.sqrt(residuals.clamp(min=1e-9))
|
||||||
|
return torch.cat([mean, std], dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
class ECAPA_TDNN(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
feat_dim=80,
|
||||||
|
channels=512,
|
||||||
|
emb_dim=192,
|
||||||
|
global_context_att=False,
|
||||||
|
feat_type="fbank",
|
||||||
|
sr=16000,
|
||||||
|
feature_selection="hidden_states",
|
||||||
|
update_extract=False,
|
||||||
|
config_path=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.feat_type = feat_type
|
||||||
|
self.feature_selection = feature_selection
|
||||||
|
self.update_extract = update_extract
|
||||||
|
self.sr = sr
|
||||||
|
|
||||||
|
if feat_type == "fbank" or feat_type == "mfcc":
|
||||||
|
self.update_extract = False
|
||||||
|
|
||||||
|
win_len = int(sr * 0.025)
|
||||||
|
hop_len = int(sr * 0.01)
|
||||||
|
|
||||||
|
if feat_type == "fbank":
|
||||||
|
self.feature_extract = trans.MelSpectrogram(
|
||||||
|
sample_rate=sr,
|
||||||
|
n_fft=512,
|
||||||
|
win_length=win_len,
|
||||||
|
hop_length=hop_len,
|
||||||
|
f_min=0.0,
|
||||||
|
f_max=sr // 2,
|
||||||
|
pad=0,
|
||||||
|
n_mels=feat_dim,
|
||||||
|
)
|
||||||
|
elif feat_type == "mfcc":
|
||||||
|
melkwargs = {
|
||||||
|
"n_fft": 512,
|
||||||
|
"win_length": win_len,
|
||||||
|
"hop_length": hop_len,
|
||||||
|
"f_min": 0.0,
|
||||||
|
"f_max": sr // 2,
|
||||||
|
"pad": 0,
|
||||||
|
}
|
||||||
|
self.feature_extract = trans.MFCC(
|
||||||
|
sample_rate=sr,
|
||||||
|
n_mfcc=feat_dim,
|
||||||
|
log_mels=False,
|
||||||
|
melkwargs=melkwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
"""
|
||||||
|
if config_path is None:
|
||||||
|
self.feature_extract = torch.hub.load("s3prl/s3prl", feat_type)
|
||||||
|
else:
|
||||||
|
self.feature_extract = UpstreamExpert(config_path)
|
||||||
|
"""
|
||||||
|
self.feature_extract = torch.hub.load("s3prl/s3prl", feat_type)
|
||||||
|
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
|
||||||
|
self.feature_extract.model.encoder.layers[23].self_attn,
|
||||||
|
"fp32_attention",
|
||||||
|
):
|
||||||
|
self.feature_extract.model.encoder.layers[
|
||||||
|
23
|
||||||
|
].self_attn.fp32_attention = False
|
||||||
|
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
|
||||||
|
self.feature_extract.model.encoder.layers[11].self_attn,
|
||||||
|
"fp32_attention",
|
||||||
|
):
|
||||||
|
self.feature_extract.model.encoder.layers[
|
||||||
|
11
|
||||||
|
].self_attn.fp32_attention = False
|
||||||
|
|
||||||
|
self.feat_num = self.get_feat_num()
|
||||||
|
self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
|
||||||
|
|
||||||
|
if feat_type != "fbank" and feat_type != "mfcc":
|
||||||
|
freeze_list = [
|
||||||
|
"final_proj",
|
||||||
|
"label_embs_concat",
|
||||||
|
"mask_emb",
|
||||||
|
"project_q",
|
||||||
|
"quantizer",
|
||||||
|
]
|
||||||
|
for name, param in self.feature_extract.named_parameters():
|
||||||
|
for freeze_val in freeze_list:
|
||||||
|
if freeze_val in name:
|
||||||
|
param.requires_grad = False
|
||||||
|
break
|
||||||
|
|
||||||
|
if not self.update_extract:
|
||||||
|
for param in self.feature_extract.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
self.instance_norm = nn.InstanceNorm1d(feat_dim)
|
||||||
|
# self.channels = [channels] * 4 + [channels * 3]
|
||||||
|
self.channels = [channels] * 4 + [1536]
|
||||||
|
|
||||||
|
self.layer1 = Conv1dReluBn(
|
||||||
|
feat_dim, self.channels[0], kernel_size=5, padding=2
|
||||||
|
)
|
||||||
|
self.layer2 = SE_Res2Block(
|
||||||
|
self.channels[0],
|
||||||
|
self.channels[1],
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=2,
|
||||||
|
dilation=2,
|
||||||
|
scale=8,
|
||||||
|
se_bottleneck_dim=128,
|
||||||
|
)
|
||||||
|
self.layer3 = SE_Res2Block(
|
||||||
|
self.channels[1],
|
||||||
|
self.channels[2],
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=3,
|
||||||
|
dilation=3,
|
||||||
|
scale=8,
|
||||||
|
se_bottleneck_dim=128,
|
||||||
|
)
|
||||||
|
self.layer4 = SE_Res2Block(
|
||||||
|
self.channels[2],
|
||||||
|
self.channels[3],
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=4,
|
||||||
|
dilation=4,
|
||||||
|
scale=8,
|
||||||
|
se_bottleneck_dim=128,
|
||||||
|
)
|
||||||
|
|
||||||
|
# self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
|
||||||
|
cat_channels = channels * 3
|
||||||
|
self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
|
||||||
|
self.pooling = AttentiveStatsPool(
|
||||||
|
self.channels[-1],
|
||||||
|
attention_channels=128,
|
||||||
|
global_context_att=global_context_att,
|
||||||
|
)
|
||||||
|
self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
|
||||||
|
self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
|
||||||
|
|
||||||
|
def get_feat_num(self):
|
||||||
|
self.feature_extract.eval()
|
||||||
|
wav = [
|
||||||
|
torch.randn(self.sr).to(
|
||||||
|
next(self.feature_extract.parameters()).device
|
||||||
|
)
|
||||||
|
]
|
||||||
|
with torch.no_grad():
|
||||||
|
features = self.feature_extract(wav)
|
||||||
|
select_feature = features[self.feature_selection]
|
||||||
|
if isinstance(select_feature, (list, tuple)):
|
||||||
|
return len(select_feature)
|
||||||
|
else:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
def get_feat(self, x):
|
||||||
|
if self.update_extract:
|
||||||
|
x = self.feature_extract([sample for sample in x])
|
||||||
|
else:
|
||||||
|
with torch.no_grad():
|
||||||
|
if self.feat_type == "fbank" or self.feat_type == "mfcc":
|
||||||
|
x = (
|
||||||
|
self.feature_extract(x) + 1e-6
|
||||||
|
) # B x feat_dim x time_len
|
||||||
|
else:
|
||||||
|
x = self.feature_extract([sample for sample in x])
|
||||||
|
|
||||||
|
if self.feat_type == "fbank":
|
||||||
|
x = x.log()
|
||||||
|
|
||||||
|
if self.feat_type != "fbank" and self.feat_type != "mfcc":
|
||||||
|
x = x[self.feature_selection]
|
||||||
|
if isinstance(x, (list, tuple)):
|
||||||
|
x = torch.stack(x, dim=0)
|
||||||
|
else:
|
||||||
|
x = x.unsqueeze(0)
|
||||||
|
norm_weights = (
|
||||||
|
F.softmax(self.feature_weight, dim=-1)
|
||||||
|
.unsqueeze(-1)
|
||||||
|
.unsqueeze(-1)
|
||||||
|
.unsqueeze(-1)
|
||||||
|
)
|
||||||
|
x = (norm_weights * x).sum(dim=0)
|
||||||
|
x = torch.transpose(x, 1, 2) + 1e-6
|
||||||
|
|
||||||
|
x = self.instance_norm(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.get_feat(x)
|
||||||
|
|
||||||
|
out1 = self.layer1(x)
|
||||||
|
out2 = self.layer2(out1)
|
||||||
|
out3 = self.layer3(out2)
|
||||||
|
out4 = self.layer4(out3)
|
||||||
|
|
||||||
|
out = torch.cat([out2, out3, out4], dim=1)
|
||||||
|
out = F.relu(self.conv(out))
|
||||||
|
out = self.bn(self.pooling(out))
|
||||||
|
out = self.linear(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def ECAPA_TDNN_SMALL(
|
||||||
|
feat_dim,
|
||||||
|
emb_dim=256,
|
||||||
|
feat_type="fbank",
|
||||||
|
sr=16000,
|
||||||
|
feature_selection="hidden_states",
|
||||||
|
update_extract=False,
|
||||||
|
config_path=None,
|
||||||
|
):
|
||||||
|
return ECAPA_TDNN(
|
||||||
|
feat_dim=feat_dim,
|
||||||
|
channels=512,
|
||||||
|
emb_dim=emb_dim,
|
||||||
|
feat_type=feat_type,
|
||||||
|
sr=sr,
|
||||||
|
feature_selection=feature_selection,
|
||||||
|
update_extract=update_extract,
|
||||||
|
config_path=config_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
x = torch.zeros(2, 32000)
|
||||||
|
model = ECAPA_TDNN_SMALL(
|
||||||
|
feat_dim=768,
|
||||||
|
emb_dim=256,
|
||||||
|
feat_type="hubert_base",
|
||||||
|
feature_selection="hidden_states",
|
||||||
|
update_extract=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
out = model(x)
|
||||||
|
# print(model)
|
||||||
|
print(out.shape)
|
|
@ -124,6 +124,20 @@ def _get_named_modules(module, attrname):
|
||||||
if hasattr(module, attrname):
|
if hasattr(module, attrname):
|
||||||
yield name, module
|
yield name, module
|
||||||
|
|
||||||
|
def coerce_dtype(s):
|
||||||
|
# not a string
|
||||||
|
if not isinstance(s, str):
|
||||||
|
return s
|
||||||
|
|
||||||
|
if s == "float16":
|
||||||
|
return torch.float16
|
||||||
|
if s == "bfloat16":
|
||||||
|
return torch.bfloat16
|
||||||
|
if s == "float8_e5m2":
|
||||||
|
return torch.float8_e5m2
|
||||||
|
if s == "float8_e4m3fn":
|
||||||
|
return torch.float8_e4m3fn
|
||||||
|
return torch.float32
|
||||||
|
|
||||||
def gather_attribute(module, attrname, delete=True, prefix=True):
|
def gather_attribute(module, attrname, delete=True, prefix=True):
|
||||||
ret = {}
|
ret = {}
|
||||||
|
|
|
@ -103,12 +103,18 @@ if cfg.optimizations.tensorrt:
|
||||||
|
|
||||||
if cfg.optimizations.unsloth:
|
if cfg.optimizations.unsloth:
|
||||||
try:
|
try:
|
||||||
from .unsloth import apply_unsloth_offloaded_gradient_checkpoint_monkey_patch
|
from .ext.unsloth import apply_unsloth_offloaded_gradient_checkpoint_monkey_patch
|
||||||
#apply_unsloth_offloaded_gradient_checkpoint_monkey_patch()
|
#apply_unsloth_offloaded_gradient_checkpoint_monkey_patch()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
_logger.warning(f'Error while importing Unsloth: {str(e)}')
|
_logger.warning(f'Error while importing Unsloth: {str(e)}')
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .ext.apollo import Apollo
|
||||||
|
except Exception as e:
|
||||||
|
_logger.warning(f'Error while importing APOLLO: {str(e)}')
|
||||||
|
pass
|
||||||
|
|
||||||
def compile_model(model, backend="auto"):
|
def compile_model(model, backend="auto"):
|
||||||
if not backend or backend == "auto":
|
if not backend or backend == "auto":
|
||||||
backend = AVAILABLE_COMPILE_BACKENDS[0]
|
backend = AVAILABLE_COMPILE_BACKENDS[0]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user