added WER/SIM-O metrics, added APOLLO but I need to test it

This commit is contained in:
mrq 2024-12-10 20:13:21 -06:00
parent fc5e6d8599
commit 8568a93dad
18 changed files with 1216 additions and 56 deletions

View File

@ -10,6 +10,8 @@
<caption>LibriSpeech</caption>
<tr>
<th>Text</th>
<th>WER↓</th>
<th>SIM-O↑</th>
<th>Prompt</th>
<th>Our VALL-E</th>
<th>Original VALL-E</th>
@ -24,6 +26,8 @@
<caption>Sampled Dataset</caption>
<tr>
<th>Text</th>
<th>WER↓</th>
<th>SIM-O↑</th>
<th>Prompt</th>
<th>Our VALL-E</th>
<th>F5-TTS</th>

View File

@ -77,7 +77,7 @@ I'm uncertain on how to remedy this, as my options are:
## `transcribe.py`
This script handles taking raw input audio, and outputting adequate metadata containing transcriptions of said audio through `whisperX`.
This script primarily handles taking raw input audio, and outputting adequate metadata containing transcriptions of said audio through `whisperX`.
The process maintains slices `whisperX` thinks its best per the segments outputted, alongside the deduced language (if not specified).
@ -85,6 +85,18 @@ One limiting factor is that transcription transcribes into normal text, rather t
Refer to the `__main__`'s arguments for usage details.
### Metrics
This script also handles calculating `WER` simply by transcribing the given audio file (and reference, if requested), then comparing the word error rate.
This process *heavily* relies on text normalization, which currently is lacking, but transcribing the reference should keep things "normalized" per the transcriber.
### ROCm
Because life is pain, ROCm requires additional steps to ensure that `whisperX` works. A special fork of `CTranslate2` is required, but simplying following [these](https://github.com/arlo-phoenix/CTranslate2-rocm/blob/rocm/README_ROCM.md) steps should fix things.
In the future, I would love to replace WhisperX for something simple.
## `process.py`
This script handles taking raw input audio and its transcribed metadata, and outputs encoded audio (NumPy) files containing encoded audio and associated metadata.
@ -107,4 +119,8 @@ When processing a dataset, this requires already having accompanying metadata ge
Be *very* careful if you opt to output unsegmented and segmented utterances, as the sliced version may end up amongst the top-K similar candidates.
Refer to the `__main__`'s arguments for usage details.
Refer to the `__main__`'s arguments for usage details.
### Metrics
This script also handles calculating `SIM-O` per [keonlee9420/evaluate-zero-shot-tts](https://github.com/keonlee9420/evaluate-zero-shot-tts/blob/master/src/evaluate_zero_shot_tts/utils/speaker_verification/verification.py), by making use of a model to create an embedding of a speaker, then computing cosine similarities on those embeddings.

View File

@ -91,6 +91,9 @@ setup(
"causal-conv1d",
"mamba-ssm",
#
"torcheval",
# attention helpers
"xformers",
"sageattention==1.0.6",

View File

@ -22,7 +22,7 @@ from pathlib import Path
from .utils.distributed import world_size
from .utils.io import torch_load
from .utils import set_seed, prune_missing, md5_hash
from .utils import set_seed, prune_missing, md5_hash, coerce_dtype
@dataclass()
class BaseConfig:
@ -721,15 +721,7 @@ class Trainer:
@cached_property
def dtype(self):
if self.weight_dtype == "float16":
return torch.float16
if self.weight_dtype == "bfloat16":
return torch.bfloat16
if self.weight_dtype == "float8_e5m2":
return torch.float8_e5m2
if self.weight_dtype == "float8_e4m3fn":
return torch.float8_e4m3fn
return torch.float32
return coerce_dtype(self.weight_dtype)
@cached_property
def scale_loss(self):
@ -748,17 +740,7 @@ class Inference:
@property
def dtype(self):
if self.weight_dtype == "float16":
return torch.float16
if self.weight_dtype == "bfloat16":
return torch.bfloat16
if self.weight_dtype == "int8":
return torch.int8
if self.weight_dtype == "float8_e5m2":
return torch.float8_e5m2
if self.weight_dtype == "float8_e4m3fn":
return torch.float8_e4m3fn
return torch.float32
return coerce_dtype(self.weight_dtype)
@dataclass()
class Optimizations:

View File

@ -63,6 +63,13 @@ def sentence_split( s, split_by="sentences", quote_placeholder="<QUOTE>" ):
sentences = nltk.sent_tokenize(s)
return [ sentence.replace(quote_placeholder, '"') for sentence in sentences if sentence ]
# to-do: improve upon this since it's kind of ass
# this might be better to live in emb.g2p
def normalize_text( s ):
s = s.lower()
s = re.sub(r'[^\w\s]', '', s)
return s
@cache
def get_random_prompts( validation=False, min_length=0, tokenized=False ):
duration_range = [ 5.5, 12.0 ] # to-do: pull from cfg.dataset.duration_range
@ -1070,7 +1077,9 @@ class Dataset(_Dataset):
return root / name
def sample_prompts(self, spkr_name, reference, should_trim=True):
if not cfg.dataset.prompt_duration_range or cfg.dataset.prompt_duration_range[-1] == 0:
# return no prompt if explicitly requested for who knows why
# or if there's no other speakers to sample from (Emilia has a lot of singleton speakers, but I still want to make use of them)
if not cfg.dataset.prompt_duration_range or cfg.dataset.prompt_duration_range[-1] == 0 or len(self.paths_by_spkr_name[key]) <= 1:
return None
prom_list = []

View File

@ -20,6 +20,7 @@ import base64
import random
import logging
import time
import torch
_logger = logging.getLogger(__name__)
@ -29,6 +30,8 @@ from .inference import TTS
from .config import cfg
from .data import create_train_dataloader, create_val_dataloader, get_random_prompt
from .emb.qnt import decode_to_file
from .metrics import wer, sim_o
from .utils import setup_logging
from tqdm import tqdm, trange
@ -230,6 +233,8 @@ def main():
elif args.comparison:
raise Exception(f"Unrecognized comparison flag: {args.comparison}")
setup_logging()
# read html template
html = open(args.demo_dir / "index.template.html", "r", encoding="utf-8").read()
@ -318,6 +323,7 @@ def main():
inputs = []
outputs = []
metrics_inputs = []
comparison_inputs = []
for k, sample_dir in samples_dirs.items():
if not sample_dir.exists():
@ -359,9 +365,15 @@ def main():
# segregate comparisons into its own batch because they use different kwargs (and I do not support variadic-batched kwargs)
if args.comparison:
comparison_inputs.append((text, prompt, language, out_path_comparison))
if (args.skip_existing and not out_path_comparison.exists()) or not (args.skip_existing):
comparison_inputs.append((text, prompt, language, out_path_comparison))
metrics_inputs.append((text, language, out_path_comparison, reference))
inputs.append((text, prompt, language, out_path))
if (args.skip_existing and not out_path.exists()) or not (args.skip_existing):
inputs.append((text, prompt, language, out_path))
metrics_inputs.append((text, language, out_path, reference))
outputs.append((k, samples))
@ -371,10 +383,19 @@ def main():
if comparison_inputs:
process_batch( tts, comparison_inputs, sampling_kwargs | (comparison_kwargs["enabled"] if args.comparison else {}) )
metrics_map = {}
for text, language, out_path, reference_path in metrics_inputs:
wer_score = wer( out_path, text, language=language, device=tts.device, dtype=tts.dtype, model_name="base" )
sim_o_score = sim_o( out_path, reference_path, device=tts.device, dtype=tts.dtype )
metrics_map[out_path] = (wer_score, sim_o_score)
# collate entries into HTML
for k, samples in outputs:
samples = [
f'\n\t\t\t<tr>\n\t\t\t\t<td>{text}</td>'+
"".join([
f'\n\t\t\t\t<td>{metrics_map[audios[1]][0]:.3f}</td><td>{metrics_map[audios[1]][1]:.3f}</td>'
] ) +
"".join( [
f'\n\t\t\t\t<td><audio controls="controls" preload="none"><source src="{str(audio).replace(str(args.demo_dir), args.audio_path_root) if args.audio_path_root else encode(audio)}"/></audio></td>'
for audio in audios

View File

@ -16,12 +16,13 @@ _logger = logging.getLogger(__name__)
from tqdm.auto import tqdm
from pathlib import Path
from functools import cache
import torchaudio.functional as F
import torchaudio.transforms as T
from ..config import cfg
from ..utils import truncate_json
from ..utils import truncate_json, coerce_dtype
from ..utils.io import json_read, json_write
from .g2p import encode as phonemize
@ -29,19 +30,49 @@ from .qnt import encode as quantize, trim, convert_audio
from ..webui import init_tts
def load_audio( path ):
def load_audio( path, target_sr=None ):
waveform, sr = torchaudio.load( path )
# mix channels
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
if target_sr is None:
target_sr = cfg.sample_rate
# resample
waveform, sr = convert_audio(waveform, sr, cfg.sample_rate, 1), cfg.sample_rate
waveform, sr = convert_audio(waveform, sr, target_sr, 1), target_sr
return waveform, sr
tts = None
def process(
# this is for computing SIM-O, but can probably technically be used for scoring similar utterances
@cache
def _load_sim_model(device="cuda", dtype="float16", feat_type="wavlm_base_plus", feat_dim=768):
from ..utils.ext.ecapa_tdnn import ECAPA_TDNN_SMALL
model = ECAPA_TDNN_SMALL(feat_dim=feat_dim, feat_type=feat_type, config_path=None)
model = model.to(device=device, dtype=coerce_dtype(dtype))
model = model.eval()
return model
@torch.no_grad()
def speaker_similarity_embedding(
audio,
**model_kwargs,
):
device = model_kwargs.get("device", "cuda")
dtype = model_kwargs.get("dtype", "float16")
model = _load_sim_model(**model_kwargs)
if isinstance(audio, str) or isinstance(audio, Path):
audio = load_audio(audio, 16000)
audio, sr = audio
audio = audio.to(device=device, dtype=coerce_dtype(dtype))
return model(audio)
def batch_similar_utterances(
speaker_path,
yaml,
text=False,
@ -266,7 +297,7 @@ def main():
if args.skip_existing and metadata_keys and "similar" in metadata[metadata_keys[-1]]:
return
similarities = process(
similarities = batch_similar_utterances(
speaker_path=cfg.data_dir / speaker_name,
yaml=args.yaml,
text=args.text,
@ -314,7 +345,7 @@ def main():
add( data_dir, type="noise", texts=False )
elif args.input_speaker:
similarities = process(
similarities = batch_similar_utterances(
speaker_path=args.input_speaker,
yaml=args.yaml,
text=args.text,

View File

@ -11,9 +11,13 @@ import torchaudio
import whisperx
from functools import cache
from tqdm.auto import tqdm
from pathlib import Path
from ..utils import coerce_dtype
def pad(num, zeroes):
return str(num).zfill(zeroes+1)
@ -21,7 +25,132 @@ def process_items( items, stride=0, stride_offset=0 ):
items = sorted( items )
return items if stride == 0 else [ item for i, item in enumerate( items ) if (i+stride_offset) % stride == 0 ]
# major cringe but should automatically unload models when loading a different one
_cached_models = {
"model": (None, None),
"diarization": (None, None),
"align": (None, None),
}
# yes I can write a decorator to do this
def _load_model(model_name="large-v3", device="cuda", dtype="float16", language="auto"):
cache_key = f'{model_name}:{device}:{dtype}:{language}'
if _cached_models["model"][0] == cache_key:
return _cached_models["model"][1]
del _cached_models["model"]
if not isinstance( dtype, str ):
if dtype == torch.float32:
dtype = "float32"
elif dtype == torch.float16:
dtype = "float16"
elif dtype == torch.bfloat16:
dtype = "bfloat16"
# doesnt support it for some reason
if dtype == "bfloat16":
dtype = "float16"
kwargs = {}
kwargs["compute_type"] = dtype
kwargs["task"] = "transcribe"
kwargs["device"] = device
if language != "auto":
kwargs["language"] = language
model = whisperx.load_model(model_name, **kwargs)
_cached_models["model"] = (cache_key, model)
return model
def _load_diarization_model(device="cuda"):
cache_key = f'{device}'
if _cached_models["diarization"][0] == cache_key:
return _cached_models["diarization"][1]
del _cached_models["diarization"]
model = whisperx.DiarizationPipeline(device=device)
_cached_models["diarization"] = (cache_key, model)
return model
def _load_align_model(language, device="cuda"):
cache_key = f'{language}:{device}'
if _cached_models["align"][0] == cache_key:
return _cached_models["align"][1]
del _cached_models["align"]
model = whisperx.load_align_model(language_code=language, device=device)
_cached_models["align"] = (cache_key, model)
return model
# yes I can just do a for-loop
def unload_model():
del _cached_models["model"]
del _cached_models["diarization"]
del _cached_models["align"]
_cached_models["model"] = (None, None)
_cached_models["diarization"] = (None, None)
_cached_models["align"] = (None, None)
def transcribe(
audio,
language = "auto",
diarize = False,
batch_size = 16,
verbose=False,
align=True,
**model_kwargs,
):
metadata = {
"segments": [],
"language": "",
"text": "",
"start": 0,
"end": 0,
}
# load requested models
device = model_kwargs.get("device", "cuda")
model = _load_model(language=language, **model_kwargs)
diarize_model = _load_diarization_model(device=device) if diarize else None
# audio is a path, load it
if isinstance(audio, str) or isinstance(audio, Path):
#audio = load_audio(audio)
audio = whisperx.load_audio(audio)
result = model.transcribe(audio, batch_size=batch_size)
if language == "auto":
language = result["language"]
if align:
align_model, align_model_metadata = _load_align_model(language=language, device=device)
result = whisperx.align(result["segments"], align_model, align_model_metadata, audio, device, return_char_alignments=False)
if diarize_model is not None:
diarize_segments = diarize_model(audio)
result = whisperx.assign_word_speakers(diarize_segments, result)
text = []
start = 0
end = 0
for segment in result["segments"]:
text.append( segment["text"] )
start = min( start, segment["start"] )
end = max( end, segment["end"] )
metadata["language"] = language
metadata["segments"] = result["segments"]
metadata["text"] = " ".join(text).strip()
metadata["start"] = start
metadata["end"] = end
return metadata
def transcribe_batch(
input_audio = "voices",
input_voice = None,
output_metadata = "training/metadata",
@ -49,14 +178,11 @@ def transcribe(
if input_voice is not None:
only_speakers = [input_voice]
#
model = whisperx.load_model(model_name, device, compute_type=dtype)
"""
align_model, align_model_metadata, align_model_language = (None, None, None)
if diarize:
diarize_model = whisperx.DiarizationPipeline(device=device)
else:
diarize_model = None
model =_load_model(model_name, device, compute_type=dtype)
diarize_model = _load_diarization_model(device=device) if diarize else None
"""
for dataset_name in os.listdir(f'./{input_audio}/'):
if not os.path.isdir(f'./{input_audio}/{dataset_name}/'):
@ -95,7 +221,10 @@ def transcribe(
if os.path.isdir(inpath):
continue
metadata[filename] = transcribe( inpath, model_name=model_name, diarize=diarize, device=device, dtype=dtype )
"""
metadata[filename] = {
"segments": [],
"language": "",
@ -108,15 +237,10 @@ def transcribe(
result = model.transcribe(audio, batch_size=batch_size)
language = result["language"]
"""
if language[:2] not in ["ja"]:
language = "en"
"""
if align_model_language != language:
tqdm.write(f'Loading language: {language}')
align_model, align_model_metadata = whisperx.load_align_model(language_code=language, device=device)
align_model_language = language
align_model, align_model_metadata = _load_align_model(language=language, device=device)
result = whisperx.align(result["segments"], align_model, align_model_metadata, audio, device, return_char_alignments=False)
@ -138,6 +262,7 @@ def transcribe(
metadata[filename]["text"] = " ".join(text).strip()
metadata[filename]["start"] = start
metadata[filename]["end"] = end
"""
open(outpath, 'w', encoding='utf-8').write(json.dumps(metadata))
@ -169,7 +294,7 @@ def main():
args.stride_offset = int(args.device)
args.device = f'cuda:{args.device}'
transcribe(
transcribe_batch(
input_audio = args.input_audio,
input_voice = args.input_voice,
output_metadata = args.output_metadata,

View File

@ -110,8 +110,10 @@ def load_engines(training=True, **model_kwargs):
scheduler_class = None
params = {
"params": [ param for name, param in model.named_parameters() if name not in model.config.frozen_params ],
"lr": cfg.hyperparameters.learning_rate,
}
if cfg.hyperparameters.optimizer.lower() == "adamw":
params["betas"] = (0.9, 0.96)
params["eps"] = 1e-07
@ -129,17 +131,30 @@ def load_engines(training=True, **model_kwargs):
params['d_coef'] = params['lr']
params['lr'] = 1.0
elif cfg.hyperparameters.optimizer.lower() in ["apollo","apollo-mini"]:
optimizer_class = ml.Apollo
is_mini = cfg.hyperparameters.optimizer.lower() == "apollo-mini"
param_kwargs = {
"rank": 1 if is_mini else 256,
"proj": "random",
"scale_type": "tensor" if is_mini else "channel",
"scale": 128 if is_mini else 1,
"update_proj_gap": 200,
"proj_type": "std",
}
# grab any extra configs from the YAML
param_kwargs.update(cfg.hyperparameters.optimizer_params)
# and blank it so it doesn't update the main optimizer kwargs
cfg.hyperparameters.optimizer_params = {}
# settings are stored under params
params["params"] = [dict(params=params["params"], **param_kwargs)]
elif cfg.hyperparameters.optimizer.lower() == "adagrad":
optimizer_class = ml.Adagrad
else:
raise ValueError(f'Optimizer specified not implemented: {cfg.hyperparameters.optimizer}')
params.update(cfg.hyperparameters.optimizer_params)
optimizer = optimizer_class(
[ param for name, param in model.named_parameters() if name not in model.config.frozen_params ],
**params,
)
optimizer = optimizer_class(**params)
if cfg.hyperparameters.scheduler.lower() == "schedulefree":
if cfg.hyperparameters.optimizer.lower() == "adamw":

View File

@ -296,7 +296,7 @@ class TTS():
use_lora=use_lora,
)
with torch.autocast("cuda", dtype=dtype, enabled=amp):
with torch.autocast(self.device, dtype=dtype, enabled=amp):
if model_len is not None:
# extra kwargs
duration_padding = sampling_kwargs.pop("duration_padding", 1.05)
@ -384,7 +384,7 @@ class TTS():
resp = to_device(resp, device=self.device, dtype=torch.int16)
lang = to_device(lang, device=self.device, dtype=torch.uint8)
with torch.autocast("cuda", dtype=dtype, enabled=amp):
with torch.autocast(self.device, dtype=dtype, enabled=amp):
model = model_ar if model_ar is not None else model_nar
if model is not None:
text_list = model(
@ -430,7 +430,7 @@ class TTS():
phns = to_device(phns, device=self.device, dtype=torch.uint8 if len(self.symmap) < 256 else torch.int16)
lang = to_device(lang, device=self.device, dtype=torch.uint8)
with torch.autocast("cuda", dtype=dtype, enabled=amp):
with torch.autocast(self.device, dtype=dtype, enabled=amp):
input_kwargs = dict(
text_list=[phns],
proms_list=[prom],

33
vall_e/metrics.py Normal file
View File

@ -0,0 +1,33 @@
# handles objective metric calculations, such as WER and SIM-O
#from .emb.transcribe import transcribe
from .emb.similar import speaker_similarity_embedding
from .emb.transcribe import transcribe
from .emb.g2p import detect_language
from .data import normalize_text
import torch.nn.functional as F
from pathlib import Path
from torcheval.metrics.functional import word_error_rate
def wer( audio, reference, language="auto", **transcription_kwargs ):
if language == "auto":
language = detect_language( reference )
transcription = transcribe( audio, language=language, align=False, **transcription_kwargs )["text"]
# reference audio needs transcribing too
if isinstance( reference, Path ):
reference = transcribe( reference, language=language, align=False, **transcription_kwargs )["text"]
transcription = normalize_text( transcription )
reference = normalize_text( reference )
return word_error_rate([transcription], [reference]).item()
def sim_o( audio, reference, **kwargs ):
audio_emb = speaker_similarity_embedding( audio, **kwargs )
reference_emb = speaker_similarity_embedding( reference, **kwargs )
return F.cosine_similarity( audio_emb, reference_emb ).item()

View File

@ -15,5 +15,6 @@ from .utils import (
prune_missing,
clamp,
md5_hash,
convert_kwargs
convert_kwargs,
coerce_dtype
)

View File

433
vall_e/utils/ext/apollo.py Normal file
View File

@ -0,0 +1,433 @@
# "borrowed" with love from https://github.com/MadsToftrup/Apollo-dev/blob/main/galore_torch/apollo.py
# to be replaced with the official implementation (https://github.com/zhuhanqing/APOLLO) maybe
import torch
import math
import numpy as np
from torch import nn
from torch.optim import Optimizer
from typing import Any, Callable, Dict, Generator, Iterable, Optional, Sequence, Union, Tuple
from transformers.utils.versions import require_version
class GaLoreProjector:
def __init__(self, rank, verbose=False, update_proj_gap=200, scale=1.0, proj_type='std'):
self.rank = rank
self.verbose = verbose
self.update_proj_gap = update_proj_gap
self.scale = scale
self.ortho_matrix = None
self.proj_type = proj_type
self.svd_count = 0
def project(self, full_rank_grad, iter):
if self.ortho_matrix is not None and self.ortho_matrix.device != full_rank_grad.device:
self.ortho_matrix = self.ortho_matrix.to(full_rank_grad.device)
if self.proj_type == 'std':
if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right')
self.svd_count += 1
low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
else:
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left')
self.svd_count += 1
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
elif self.proj_type == 'reverse_std':
if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left')
self.svd_count += 1
low_rank_grad = torch.matmul(self.ortho_matrix.t(),full_rank_grad)
else:
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right')
self.svd_count += 1
low_rank_grad = torch.matmul(full_rank_grad,self.ortho_matrix.t())
elif self.proj_type == 'right':
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right')
self.svd_count += 1
low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
elif self.proj_type == 'left':
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left')
self.svd_count += 1
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
elif self.proj_type == 'full':
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='full')
self.svd_count += 1
low_rank_grad = torch.matmul(self.ortho_matrix[0].t(), full_rank_grad) @ self.ortho_matrix[1].t()
return low_rank_grad
def project_back(self, low_rank_grad):
if self.proj_type == 'std':
if low_rank_grad.shape[0] >= low_rank_grad.shape[1]:
full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
else:
full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
elif self.proj_type == 'reverse_std':
if low_rank_grad.shape[0] <= low_rank_grad.shape[1]: # note this is different from std
full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
else:
full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
elif self.proj_type == 'right':
full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
elif self.proj_type == 'left':
full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
elif self.proj_type == 'full':
full_rank_grad = torch.matmul(self.ortho_matrix[0], low_rank_grad) @ self.ortho_matrix[1]
return full_rank_grad * self.scale
# svd decomposition
def get_orthogonal_matrix(self, weights, rank, type):
module_params = weights
if module_params.data.dtype != torch.float:
float_data = False
original_type = module_params.data.dtype
original_device = module_params.data.device
matrix = module_params.data.float()
else:
float_data = True
matrix = module_params.data
U, s, Vh = torch.linalg.svd(matrix, full_matrices = False)
#make the smaller matrix always to be orthogonal matrix
if type=='right':
A = U[:, :rank] @ torch.diag(s[:rank])
B = Vh[:rank, :]
if not float_data:
B = B.to(original_device).type(original_type)
return B
elif type=='left':
A = U[:, :rank]
B = torch.diag(s[:rank]) @ Vh[:rank, :]
if not float_data:
A = A.to(original_device).type(original_type)
return A
elif type=='full':
A = U[:, :rank]
B = Vh[:rank, :]
if not float_data:
A = A.to(original_device).type(original_type)
B = B.to(original_device).type(original_type)
return [A, B]
else:
raise ValueError('type should be left, right or full')
def stable_randn(
shape: Union[int, Sequence[int]],
seed: int,
device: Optional[Union[str, torch.device]] = None,
dtype: Optional[torch.dtype] = torch.float32,
):
if device is None:
device = torch.device("cpu")
generator = torch.Generator(device=device).manual_seed(seed)
rn = torch.randn(shape, generator=generator, device=generator.device, dtype=dtype)
return rn
def next_seed(seed: int, adv: int = 0xF):
"""
This is a naive helper function to generate a new seed from the given seed.
"""
generator = torch.Generator().manual_seed(seed)
return torch.randint(
0, torch.iinfo(torch.int64).max, (adv,), generator=generator, device=generator.device
).tolist()[-1]
def split_seed(seed: int):
generator = torch.Generator().manual_seed(seed)
return tuple(
torch.randint(0, torch.iinfo(torch.int64).max, (2,), generator=generator, device=generator.device).tolist()
)
class GradientProjector:
def __init__(
self, rank, update_proj_gap=200, alpha=1.0, proj_type="std", seed=0
):
# This is a lazy implementation as we store the projection matrix instead of re-generation every iteration
self.rank = rank
self.update_proj_gap = update_proj_gap
self.alpha = alpha
self.proj_type = proj_type
self.ortho_matrix = None
self.seed = seed
def project(self, full_rank_grad, iter):
if self.proj_type == "std":
if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(
full_rank_grad, self.rank, type="right", seed=self.seed
)
self.seed = next_seed(self.seed)
low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
else:
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(
full_rank_grad, self.rank, type="left", seed=self.seed
)
self.seed = next_seed(self.seed)
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
elif self.proj_type == "reverse_std":
if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(
full_rank_grad, self.rank, type="left", seed=self.seed
)
self.seed = next_seed(self.seed)
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
else:
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(
full_rank_grad, self.rank, type="right", seed=self.seed
)
self.seed = next_seed(self.seed)
low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
elif self.proj_type == "right":
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(
full_rank_grad, self.rank, type="right", seed=self.seed
)
self.seed = next_seed(self.seed)
low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
elif self.proj_type == "left":
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(
full_rank_grad, self.rank, type="left", seed=self.seed
)
self.seed = next_seed(self.seed)
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
elif self.proj_type == "full":
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(
full_rank_grad, self.rank, type="full", seed=self.seed
)
self.seed = next_seed(self.seed)
low_rank_grad = (
torch.matmul(self.ortho_matrix[0].t(), full_rank_grad)
@ self.ortho_matrix[1].t()
)
return low_rank_grad
# random low rank projection
def get_orthogonal_matrix(self, weights, rank, type, seed):
module_params = weights
if module_params.data.dtype != torch.float:
float_data = False
original_type = module_params.data.dtype
original_device = module_params.data.device
matrix = module_params.data.float()
else:
float_data = True
matrix = module_params.data
if type == "left":
proj = stable_randn(
(matrix.shape[0], rank), seed=seed, device=matrix.device, dtype=matrix.dtype
) / math.sqrt(rank)
if not float_data:
proj = proj.to(original_device).type(original_type)
return proj
elif type == "right":
proj = stable_randn(
(rank, matrix.shape[1]), seed=seed, device=matrix.device, dtype=matrix.dtype
) / math.sqrt(rank)
if not float_data:
proj = proj.to(original_device).type(original_type)
return proj
elif type == "full":
raise NotImplementedError("full rank projection is not implemented yet")
else:
raise ValueError("type should be left, right or full")
class Apollo(Optimizer):
"""
Implements Adam algorithm with weight decay fix as introduced in [Decoupled Weight Decay
Regularization](https://arxiv.org/abs/1711.05101).
Parameters:
params (`Iterable[nn.parameter.Parameter]`):
Iterable of parameters to optimize or dictionaries defining parameter groups.
lr (`float`, *optional*, defaults to 0.001):
The learning rate to use.
betas (`Tuple[float,float]`, *optional*, defaults to `(0.9, 0.999)`):
Adam's betas parameters (b1, b2).
eps (`float`, *optional*, defaults to 1e-06):
Adam's epsilon for numerical stability.
weight_decay (`float`, *optional*, defaults to 0.0):
Decoupled weight decay to apply.
correct_bias (`bool`, *optional*, defaults to `True`):
Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`).
no_deprecation_warning (`bool`, *optional*, defaults to `False`):
A flag used to disable the deprecation warning (set to `True` to disable the warning).
"""
def __init__(
self,
params: Iterable[nn.parameter.Parameter],
lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-6,
weight_decay: float = 0.0,
correct_bias: bool = True,
scale_front: bool = False,
):
if lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)")
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0")
defaults = {"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "correct_bias": correct_bias}
super().__init__(params, defaults)
self.scale_front = scale_front
params_idx = 0
for group in self.param_groups:
for p in group["params"]:
params_idx += 1
if p.requires_grad:
self.state[p]["seed"] = params_idx
@torch.no_grad()
def step(self, closure: Callable = None):
"""
Performs a single optimization step.
Arguments:
closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad
if grad.is_sparse:
raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")
state = self.state[p]
if "step" not in state:
state["step"] = 0
# GaLore Projection
if "rank" in group:
if "projector" not in state:
if group["proj"] == "random":
state["projector"] = GradientProjector(group["rank"],
update_proj_gap=group["update_proj_gap"],
alpha=group["scale"],
proj_type=group["proj_type"],
seed=state["seed"])
elif group["proj"] == "svd":
state["projector"] = GaLoreProjector(group["rank"],
update_proj_gap=group["update_proj_gap"],
scale=group["scale"],
proj_type=group["proj_type"])
grad = state["projector"].project(grad, state["step"])
# State initialization
if "exp_avg" not in state:
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(grad)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(grad)
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group["betas"]
state["step"] += 1
# Decay the first and second moment running average coefficient
# In-place operations to update the averages at the same time
exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1))
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
denom = exp_avg_sq.sqrt().add_(group["eps"])
step_size = group["lr"]
if group["correct_bias"]: # No bias correction for Bert
bias_correction1 = 1.0 - beta1 ** state["step"]
bias_correction2 = 1.0 - beta2 ** state["step"]
step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
# compute norm gradient
norm_grad = exp_avg / denom
if "rank" in group:
if group['scale_type'] == 'channel':
norm_dim = 0 if norm_grad.shape[0] < norm_grad.shape[1] else 1
scaling_factor = (
torch.norm(norm_grad, dim=norm_dim) /
(torch.norm(grad, dim=norm_dim) + 1e-8)
)
if norm_dim == 1:
scaling_factor = scaling_factor.unsqueeze(1)
elif group['scale_type'] == 'tensor':
scaling_factor = (
torch.norm(norm_grad) /
(torch.norm(grad) + 1e-8)
)
scaling_grad = p.grad * scaling_factor
# Use Norm-Growth Limiter in Fira
if "scaling_grad" in state:
scaling_grad_norm = torch.norm(scaling_grad)
limiter = max(
scaling_grad_norm /
(state["scaling_grad"] + 1e-8),
1.01,
) / 1.01
scaling_grad = scaling_grad / limiter
state["scaling_grad"] = scaling_grad_norm / limiter
else:
state["scaling_grad"] = torch.norm(scaling_grad)
norm_grad = scaling_grad * np.sqrt(group["scale"])
p.add_(norm_grad, alpha=-step_size)
# Just adding the square of the weights to the loss function is *not*
# the correct way of using L2 regularization/weight decay with Adam,
# since that will interact with the m and v parameters in strange ways.
#
# Instead we want to decay the weights in a manner that doesn't interact
# with the m/v parameters. This is equivalent to adding the square
# of the weights to the loss with plain (non-momentum) SGD.
# Add weight decay at the end (fixed version)
if group["weight_decay"] > 0.0:
p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))
return loss

View File

@ -0,0 +1,467 @@
# borrowed with love from "https://github.com/keonlee9420/evaluate-zero-shot-tts/blob/master/src/evaluate_zero_shot_tts/utils/speaker_verification/models/ecapa_tdnn.py"
# part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio.transforms as trans
#from .utils import UpstreamExpert
""" Res2Conv1d + BatchNorm1d + ReLU
"""
class Res2Conv1dReluBn(nn.Module):
"""
in_channels == out_channels == channels
"""
def __init__(
self,
channels,
kernel_size=1,
stride=1,
padding=0,
dilation=1,
bias=True,
scale=4,
):
super().__init__()
assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
self.scale = scale
self.width = channels // scale
self.nums = scale if scale == 1 else scale - 1
self.convs = []
self.bns = []
for i in range(self.nums):
self.convs.append(
nn.Conv1d(
self.width,
self.width,
kernel_size,
stride,
padding,
dilation,
bias=bias,
)
)
self.bns.append(nn.BatchNorm1d(self.width))
self.convs = nn.ModuleList(self.convs)
self.bns = nn.ModuleList(self.bns)
def forward(self, x):
out = []
spx = torch.split(x, self.width, 1)
for i in range(self.nums):
if i == 0:
sp = spx[i]
else:
sp = sp + spx[i]
# Order: conv -> relu -> bn
sp = self.convs[i](sp)
sp = self.bns[i](F.relu(sp))
out.append(sp)
if self.scale != 1:
out.append(spx[self.nums])
out = torch.cat(out, dim=1)
return out
""" Conv1d + BatchNorm1d + ReLU
"""
class Conv1dReluBn(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
dilation=1,
bias=True,
):
super().__init__()
self.conv = nn.Conv1d(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
bias=bias,
)
self.bn = nn.BatchNorm1d(out_channels)
def forward(self, x):
return self.bn(F.relu(self.conv(x)))
""" The SE connection of 1D case.
"""
class SE_Connect(nn.Module):
def __init__(self, channels, se_bottleneck_dim=128):
super().__init__()
self.linear1 = nn.Linear(channels, se_bottleneck_dim)
self.linear2 = nn.Linear(se_bottleneck_dim, channels)
def forward(self, x):
out = x.mean(dim=2)
out = F.relu(self.linear1(out))
out = torch.sigmoid(self.linear2(out))
out = x * out.unsqueeze(2)
return out
""" SE-Res2Block of the ECAPA-TDNN architecture.
"""
# def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
# return nn.Sequential(
# Conv1dReluBn(channels, 512, kernel_size=1, stride=1, padding=0),
# Res2Conv1dReluBn(512, kernel_size, stride, padding, dilation, scale=scale),
# Conv1dReluBn(512, channels, kernel_size=1, stride=1, padding=0),
# SE_Connect(channels)
# )
class SE_Res2Block(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
scale,
se_bottleneck_dim,
):
super().__init__()
self.Conv1dReluBn1 = Conv1dReluBn(
in_channels, out_channels, kernel_size=1, stride=1, padding=0
)
self.Res2Conv1dReluBn = Res2Conv1dReluBn(
out_channels, kernel_size, stride, padding, dilation, scale=scale
)
self.Conv1dReluBn2 = Conv1dReluBn(
out_channels, out_channels, kernel_size=1, stride=1, padding=0
)
self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim)
self.shortcut = None
if in_channels != out_channels:
self.shortcut = nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
)
def forward(self, x):
residual = x
if self.shortcut:
residual = self.shortcut(x)
x = self.Conv1dReluBn1(x)
x = self.Res2Conv1dReluBn(x)
x = self.Conv1dReluBn2(x)
x = self.SE_Connect(x)
return x + residual
""" Attentive weighted mean and standard deviation pooling.
"""
class AttentiveStatsPool(nn.Module):
def __init__(
self, in_dim, attention_channels=128, global_context_att=False
):
super().__init__()
self.global_context_att = global_context_att
# Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
if global_context_att:
self.linear1 = nn.Conv1d(
in_dim * 3, attention_channels, kernel_size=1
) # equals W and b in the paper
else:
self.linear1 = nn.Conv1d(
in_dim, attention_channels, kernel_size=1
) # equals W and b in the paper
self.linear2 = nn.Conv1d(
attention_channels, in_dim, kernel_size=1
) # equals V and k in the paper
def forward(self, x):
if self.global_context_att:
context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
context_std = torch.sqrt(
torch.var(x, dim=-1, keepdim=True) + 1e-10
).expand_as(x)
x_in = torch.cat((x, context_mean, context_std), dim=1)
else:
x_in = x
# DON'T use ReLU here! In experiments, I find ReLU hard to converge.
alpha = torch.tanh(self.linear1(x_in))
# alpha = F.relu(self.linear1(x_in))
alpha = torch.softmax(self.linear2(alpha), dim=2)
mean = torch.sum(alpha * x, dim=2)
residuals = torch.sum(alpha * (x**2), dim=2) - mean**2
std = torch.sqrt(residuals.clamp(min=1e-9))
return torch.cat([mean, std], dim=1)
class ECAPA_TDNN(nn.Module):
def __init__(
self,
feat_dim=80,
channels=512,
emb_dim=192,
global_context_att=False,
feat_type="fbank",
sr=16000,
feature_selection="hidden_states",
update_extract=False,
config_path=None,
):
super().__init__()
self.feat_type = feat_type
self.feature_selection = feature_selection
self.update_extract = update_extract
self.sr = sr
if feat_type == "fbank" or feat_type == "mfcc":
self.update_extract = False
win_len = int(sr * 0.025)
hop_len = int(sr * 0.01)
if feat_type == "fbank":
self.feature_extract = trans.MelSpectrogram(
sample_rate=sr,
n_fft=512,
win_length=win_len,
hop_length=hop_len,
f_min=0.0,
f_max=sr // 2,
pad=0,
n_mels=feat_dim,
)
elif feat_type == "mfcc":
melkwargs = {
"n_fft": 512,
"win_length": win_len,
"hop_length": hop_len,
"f_min": 0.0,
"f_max": sr // 2,
"pad": 0,
}
self.feature_extract = trans.MFCC(
sample_rate=sr,
n_mfcc=feat_dim,
log_mels=False,
melkwargs=melkwargs,
)
else:
"""
if config_path is None:
self.feature_extract = torch.hub.load("s3prl/s3prl", feat_type)
else:
self.feature_extract = UpstreamExpert(config_path)
"""
self.feature_extract = torch.hub.load("s3prl/s3prl", feat_type)
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
self.feature_extract.model.encoder.layers[23].self_attn,
"fp32_attention",
):
self.feature_extract.model.encoder.layers[
23
].self_attn.fp32_attention = False
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
self.feature_extract.model.encoder.layers[11].self_attn,
"fp32_attention",
):
self.feature_extract.model.encoder.layers[
11
].self_attn.fp32_attention = False
self.feat_num = self.get_feat_num()
self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
if feat_type != "fbank" and feat_type != "mfcc":
freeze_list = [
"final_proj",
"label_embs_concat",
"mask_emb",
"project_q",
"quantizer",
]
for name, param in self.feature_extract.named_parameters():
for freeze_val in freeze_list:
if freeze_val in name:
param.requires_grad = False
break
if not self.update_extract:
for param in self.feature_extract.parameters():
param.requires_grad = False
self.instance_norm = nn.InstanceNorm1d(feat_dim)
# self.channels = [channels] * 4 + [channels * 3]
self.channels = [channels] * 4 + [1536]
self.layer1 = Conv1dReluBn(
feat_dim, self.channels[0], kernel_size=5, padding=2
)
self.layer2 = SE_Res2Block(
self.channels[0],
self.channels[1],
kernel_size=3,
stride=1,
padding=2,
dilation=2,
scale=8,
se_bottleneck_dim=128,
)
self.layer3 = SE_Res2Block(
self.channels[1],
self.channels[2],
kernel_size=3,
stride=1,
padding=3,
dilation=3,
scale=8,
se_bottleneck_dim=128,
)
self.layer4 = SE_Res2Block(
self.channels[2],
self.channels[3],
kernel_size=3,
stride=1,
padding=4,
dilation=4,
scale=8,
se_bottleneck_dim=128,
)
# self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
cat_channels = channels * 3
self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
self.pooling = AttentiveStatsPool(
self.channels[-1],
attention_channels=128,
global_context_att=global_context_att,
)
self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
def get_feat_num(self):
self.feature_extract.eval()
wav = [
torch.randn(self.sr).to(
next(self.feature_extract.parameters()).device
)
]
with torch.no_grad():
features = self.feature_extract(wav)
select_feature = features[self.feature_selection]
if isinstance(select_feature, (list, tuple)):
return len(select_feature)
else:
return 1
def get_feat(self, x):
if self.update_extract:
x = self.feature_extract([sample for sample in x])
else:
with torch.no_grad():
if self.feat_type == "fbank" or self.feat_type == "mfcc":
x = (
self.feature_extract(x) + 1e-6
) # B x feat_dim x time_len
else:
x = self.feature_extract([sample for sample in x])
if self.feat_type == "fbank":
x = x.log()
if self.feat_type != "fbank" and self.feat_type != "mfcc":
x = x[self.feature_selection]
if isinstance(x, (list, tuple)):
x = torch.stack(x, dim=0)
else:
x = x.unsqueeze(0)
norm_weights = (
F.softmax(self.feature_weight, dim=-1)
.unsqueeze(-1)
.unsqueeze(-1)
.unsqueeze(-1)
)
x = (norm_weights * x).sum(dim=0)
x = torch.transpose(x, 1, 2) + 1e-6
x = self.instance_norm(x)
return x
def forward(self, x):
x = self.get_feat(x)
out1 = self.layer1(x)
out2 = self.layer2(out1)
out3 = self.layer3(out2)
out4 = self.layer4(out3)
out = torch.cat([out2, out3, out4], dim=1)
out = F.relu(self.conv(out))
out = self.bn(self.pooling(out))
out = self.linear(out)
return out
def ECAPA_TDNN_SMALL(
feat_dim,
emb_dim=256,
feat_type="fbank",
sr=16000,
feature_selection="hidden_states",
update_extract=False,
config_path=None,
):
return ECAPA_TDNN(
feat_dim=feat_dim,
channels=512,
emb_dim=emb_dim,
feat_type=feat_type,
sr=sr,
feature_selection=feature_selection,
update_extract=update_extract,
config_path=config_path,
)
if __name__ == "__main__":
x = torch.zeros(2, 32000)
model = ECAPA_TDNN_SMALL(
feat_dim=768,
emb_dim=256,
feat_type="hubert_base",
feature_selection="hidden_states",
update_extract=False,
)
out = model(x)
# print(model)
print(out.shape)

View File

@ -124,6 +124,20 @@ def _get_named_modules(module, attrname):
if hasattr(module, attrname):
yield name, module
def coerce_dtype(s):
# not a string
if not isinstance(s, str):
return s
if s == "float16":
return torch.float16
if s == "bfloat16":
return torch.bfloat16
if s == "float8_e5m2":
return torch.float8_e5m2
if s == "float8_e4m3fn":
return torch.float8_e4m3fn
return torch.float32
def gather_attribute(module, attrname, delete=True, prefix=True):
ret = {}

View File

@ -103,12 +103,18 @@ if cfg.optimizations.tensorrt:
if cfg.optimizations.unsloth:
try:
from .unsloth import apply_unsloth_offloaded_gradient_checkpoint_monkey_patch
from .ext.unsloth import apply_unsloth_offloaded_gradient_checkpoint_monkey_patch
#apply_unsloth_offloaded_gradient_checkpoint_monkey_patch()
except Exception as e:
_logger.warning(f'Error while importing Unsloth: {str(e)}')
pass
try:
from .ext.apollo import Apollo
except Exception as e:
_logger.warning(f'Error while importing APOLLO: {str(e)}')
pass
def compile_model(model, backend="auto"):
if not backend or backend == "auto":
backend = AVAILABLE_COMPILE_BACKENDS[0]