added prom-less training / inferencing, some other things

This commit is contained in:
mrq 2024-07-22 19:36:07 -05:00
parent 491ae2a684
commit 75b04686f8
13 changed files with 85 additions and 47 deletions

View File

@ -167,6 +167,7 @@ class Dataset:
reencode_on_concat: bool = False # whether to concat audio by decode => concat => encode, or naively concat codes
reencode_device: str = "cpu" # "cpu" is slower but saves memory, cuda throws [rank0]: RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method
noise_scale: float = 0.25 # scaling noise value
inject_noise_in_prom: bool = False
_frames_per_second: int = 0 # allows setting your own hint
@ -358,6 +359,7 @@ class LoRA:
rank: int = 128 # rank for the LoRA
alpha: int = 128 # rank for the LoRA
training: bool = True #
embeddings: bool = False # train the embedding too
parametrize: bool = False #
rvq_levels: list[int] = field(default_factory=lambda: []) # determines RVQ levels to activate the LoRA
@ -832,8 +834,7 @@ class Config(BaseConfig):
if self.hyperparameters.scheduler == "":
self.hyperparameters.torch_scheduler = True
if self.dataset.prompt_duration != 0:
self.dataset.prompt_duration_range = [self.dataset.prompt_duration, self.dataset.prompt_duration]
self.dataset.prompt_duration_range = [self.dataset.prompt_duration, self.dataset.prompt_duration]
if self.trainer.backend == "local" and self.distributed:
self.trainer.ddp = True

View File

@ -11,7 +11,7 @@ import torch
import itertools
from .config import cfg
from .emb.qnt import trim, trim_random, repeat_extend_audio, concat_audio, merge_audio, decode_to_file
from .emb.qnt import trim, trim_random, repeat_extend_audio, concat_audio, merge_audio, decode_to_file, decode as decode_qnt, encode as encode_qnt
from .utils.sampler import PoolSampler, OrderedSampler, BatchedOrderedSampler, RandomSampler
from .utils.distributed import global_rank, local_rank, world_size
@ -717,6 +717,9 @@ class Dataset(_Dataset):
def sample_prompts(self, spkr_name, ignore, should_trim=True):
if not cfg.dataset.prompt_duration_range or cfg.dataset.prompt_duration_range[-1] == 0:
return None
prom_list = []
choices = set(self.paths_by_spkr_name[spkr_name]) - {ignore}
@ -748,7 +751,7 @@ class Dataset(_Dataset):
qnt = _load_quants(path, return_metadata=False)
if 0 < trim_length and trim_length < qnt.shape[0]:
qnt = trim( qnt, trim_length, reencode=cfg.dataset.reencode_on_concat )
qnt = trim( qnt, trim_length, reencode=cfg.dataset.reencode_on_concat, device=cfg.dataset.reencode_device )
prom_list.append(qnt)
prom_length += qnt.shape[0]
@ -758,10 +761,10 @@ class Dataset(_Dataset):
# might be better to decode => concat waveforms with silence in between => reencode
# as you technically can't just append encodec sequences together like this without issues
prom = torch.cat(prom_list)
prom = concat_audio( *prom_list, reencode=cfg.dataset.reencode_on_concat, device=cfg.dataset.reencode_device )
if 0 < trim_length and trim_length < prom.shape[0]:
prom = trim( prom, trim_length, reencode=cfg.dataset.reencode_on_concat )
prom = trim( prom, trim_length, reencode=cfg.dataset.reencode_on_concat, device=cfg.dataset.reencode_device )
return prom
@ -855,6 +858,15 @@ class Dataset(_Dataset):
if task == "tts":
proms = self.sample_prompts(spkr_name, ignore=path)
if cfg.dataset.inject_noise_in_prom:
# sample random noise
noise = self.sample_noise()
# extend the noise to fill the target audio
noise = repeat_extend_audio(noise, proms.shape[0])
# create the input prompt by merging the target audio with the noise
proms = merge_audio( proms, noise, scale=[1, cfg.dataset.noise_scale], device=cfg.dataset.reencode_device )
# VALL-E Continuous (<text><partial resp> => <remaining resp> )
# (this could just be sampled as <text a><text b><audio a> => <audio b>, but I need to experiment with it)
elif task == "tts-c":

View File

@ -431,7 +431,7 @@ def encode_from_file(path, device="cuda"):
Helper Functions
"""
# trims from the start, up to `target`
def trim( qnt, target, reencode=False ):
def trim( qnt, target, reencode=False, device="cuda" ):
length = max( qnt.shape[0], qnt.shape[1] )
if target > 0:
start = 0
@ -454,8 +454,8 @@ def trim( qnt, target, reencode=False ):
start = start / cfg.dataset.frames_per_second * cfg.sample_rate
end = end / cfg.dataset.frames_per_second * cfg.sample_rate
wav = decode(qnt)[0]
return encode(wav[start:end], cfg.sample_rate)[0].t()
wav = decode(qnt, device=device)[0]
return encode(wav[start:end], cfg.sample_rate, device=device)[0].t()
# trims a random piece of audio, up to `target`
# to-do: try and align to EnCodec window

View File

@ -81,7 +81,7 @@ class Engine():
# freeze non-LoRA params if requested
if not self.hyper_config.frozen_params and not freeze_all and cfg.lora is not None:
return freeze_non_lora_weights( self.module )
return freeze_non_lora_weights( self.module, embeddings=cfg.lora.embeddings )
for name, param in self.module.named_parameters():
if (freeze_all and param.requires_grad) or (not freeze_all and name in self.hyper_config.frozen_params):

View File

@ -27,6 +27,8 @@ from deepspeed.accelerator import get_accelerator
from ..utils.distributed import init_distributed, distributed_initialized
from ..utils import wrapper as ml
from ..models.lora import freeze_non_lora_weights
if not distributed_initialized() and cfg.trainer.backend == "deepspeed":
init_distributed(init_deepspeed_dist)
@ -66,11 +68,10 @@ class Engine(DeepSpeedEngine):
def freeze(self, freeze_all=True):
# freeze non-LoRA params if requested
if not self.hyper_config.frozen_params and not freeze_all and cfg.lora is not None:
for name, param in self.module.named_parameters():
should = 'lora_' in name
param.requires_grad_(should)
if not should:
self._frozen_params.add(param)
frozen_params = freeze_non_lora_weights( self.module, embeddings=cfg.lora.embeddings )
for param in frozen_params:
self._frozen_params.add( param )
return
if self.hyper_config is None or not hasattr(self.hyper_config, "frozen_params"):

View File

@ -170,21 +170,18 @@ class TTS():
output_dir.mkdir(parents=True, exist_ok=True)
out_path = output_dir / f"{time.time()}.wav"
prom = self.encode_audio( references, trim_length=input_prompt_length )
prom = self.encode_audio( references, trim_length=input_prompt_length ) if references else None
phns = self.encode_text( line, language=language )
lang = self.encode_lang( language )
prom = to_device(prom, self.device).to(torch.int16)
phns = to_device(phns, self.device).to(torch.uint8 if len(self.symmap) < 256 else torch.int16)
lang = to_device(lang, self.device).to(torch.uint8)
prom = to_device(prom, device=self.device, dtype=torch.int16)
phns = to_device(phns, device=self.device, dtype=torch.uint8 if len(self.symmap) < 256 else torch.int16)
lang = to_device(lang, device=self.device, dtype=torch.uint8)
text_list = [ phns ]
proms_list = [ prom ]
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
# AR temp: 1
# NAR temp: 0.05
# prom size: 3
if model_ar is not None:
resps_list = model_ar(
text_list=[phns], proms_list=[prom], lang_list=[lang], max_steps=max_ar_steps,

View File

@ -174,15 +174,23 @@ class AR_NAR(Base):
"""
for i in range(batch_size):
# other tasks might have the prom be a list and this is just the easiest way to acknowledge that
if task_list[i] == "tts":
# cap quant_level if it exceeds its corresponding resp/prom
if quant_levels[i] >= resps_list[i].shape[-1]:
quant_levels[i] = resps_list[i].shape[-1] - 1
# cap quant_level if it exceeds its corresponding resp/prom
if quant_levels[i] >= resps_list[i].shape[-1]:
quant_levels[i] = resps_list[i].shape[-1] - 1
# proms_list[i] could be a Tensor, list[Tensor], or None
if isinstance( proms_list[i], torch.Tensor ):
if quant_levels[i] >= proms_list[i].shape[-1]:
quant_levels[i] = proms_list[i].shape[-1] - 1
elif isinstance( proms_list[i], list ):
for j, prom in enumerate( proms_list[i] ):
if not isinstance( prom, torch.Tensor ):
continue
if quant_levels[i] >= prom.shape[-1]:
quant_levels[i] = prom.shape[-1] - 1
# only apply stop token for RVQ level 0
if quant_levels[i] > 0:
continue

View File

@ -891,27 +891,30 @@ class Base(nn.Module):
# insert task type as a string
inputs[i].append( ( "task", task_type ) )
# to-do: maybe not split the below blocks up
# might be beneficial in the event I need to use a difference sequence, such as STT tasks
# Base-line TTS task
# Sequence: <text><sep><rvq lvl><sep><prom><sep><resp>
# prom /may/ include <task> tokens inside to help guide things, per SpeechX
if f'<{task_type}>' in get_task_symmap():
# insert the text prompt
if text_list is not None:
if text_list is not None and text_list[i] is not None:
inputs[i].append( ( "text", text_list[i] ) )
# insert lang token if we're trained for it
if "lang" in self.capabilities and lang_list is not None:
if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None:
inputs[i].append( ( "lang", lang_list[i] ) )
# insert RVQ level guidance token if the model is versioned for it
if self.rvq_l_emb is not None:
inputs[i].append( ( "quant_level", torch.Tensor([ quant_level ]).to(device=device, dtype=torch.int16) ) )
# insert input audio prompt
if proms_list is not None:
if proms_list is not None and proms_list[i] is not None:
inputs[i].append( ( "prom", proms_list[i] ) )
# insert tone token if we're trained for it
if "tone" in self.capabilities and tone_list is not None:
if "tone" in self.capabilities and tone_list is not None and tone_list[i] is not None:
inputs[i].append( ( "tone", tone_list[i] ) )
# insert the current output response
if resps_list is not None:
if resps_list is not None and resps_list[i] is not None:
inputs[i].append( ( "resp", resps_list[i] ) )
# Audio length prediction task
@ -922,10 +925,10 @@ class Base(nn.Module):
raise Exception(f"Requesting task `{task_type}` but corresponding embedding is not defined.")
# insert the text prompt
if text_list is not None:
if text_list is not None and text_list[i] is not None:
inputs[i].append( ( "text", text_list[i] ) )
# insert lang token if we're trained for it
if "lang" in self.capabilities and lang_list is not None:
if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None:
inputs[i].append( ( "lang", lang_list[i] ) )
# technically will always be level 0 but for the sake of keeing the input formatting coherent...
if self.rvq_l_emb is not None:
@ -933,17 +936,17 @@ class Base(nn.Module):
quant_levels[i] = 0
inputs[i].append( ( "quant_level", torch.Tensor([ 0 ]).to(device=device, dtype=torch.int16) ) )
# insert input audio prompt
if proms_list is not None:
if proms_list is not None and proms_list[i] is not None:
inputs[i].append( ( "prom", proms_list[i] ) )
# insert tone token if we're trained for it
if "tone" in self.capabilities and tone_list is not None:
if "tone" in self.capabilities and tone_list is not None and tone_list[i] is not None:
inputs[i].append( ( "tone", tone_list[i] ) )
# insert output length tokens (if it exists)
if len_list is not None:
if len_list is not None and len_list[i] is not None:
inputs[i].append( ( "len", len_list[i] ) )
# "encode" length to tokens for 0-9 + stop
elif resps_list is not None:
elif resps_list is not None and resps_list[i] is not None:
# yes this could be encoded better
inputs[i].append( ( "len", torch.Tensor([ 0 ] + [ int(i) for i in str( resps_list[i].shape[0]) ] + [ 10 ]).to(device=device, dtype=torch.int16) ) )
else:

View File

@ -205,10 +205,18 @@ def enable_lora( model, mode = True ):
def disable_lora( model ):
return enable_lora( model, False )
def freeze_non_lora_weights( model ):
def freeze_non_lora_weights( model, embeddings = False ):
frozen_params = []
for name, param in model.named_parameters():
param.requires_grad_('lora_' in name)
return model
should = 'lora_' in name or (embeddings and "_emb" in name)
param.requires_grad_(should)
if not should:
frozen_params.append( param )
return frozen_params
def lora_get_state_dict( state_dict, split = True ):
lora = { name: param for name, param in state_dict.items() if "lora_" in name }

View File

@ -100,7 +100,7 @@ def run_eval(engines, eval_name, dl):
filename = f"{filename}_{task}"
# flatten prom
if not isinstance(prom, torch.Tensor):
if not isinstance(prom, torch.Tensor) and prom is not None:
prom = torch.concat([ p for p in prom if isinstance(p, torch.Tensor) ])
# to-do, refine the output dir to be sane-er
@ -114,7 +114,8 @@ def run_eval(engines, eval_name, dl):
ref_audio, sr = qnt.decode_to_file(ref, ref_path)
hyp_audio, sr = qnt.decode_to_file(hyp, hyp_path)
prom_audio, sr = qnt.decode_to_file(prom, prom_path)
if prom is not None:
prom_audio, sr = qnt.decode_to_file(prom, prom_path)
# pseudo loss calculation since we don't get the logits during eval
min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] )

View File

@ -19,7 +19,7 @@ class PoolSampler():
def reset(self):
self.current_pool = [ i for i in self.global_indices ]
if self.shuffle:
random(self.current_pool)
random.shuffle(self.current_pool)
def sample(self, pool = None):
if pool is None:

View File

@ -173,5 +173,8 @@ def tree_map(fn: Callable, x):
return x
def to_device(x: T, device) -> T:
return tree_map(lambda t: t.to(device), x)
def to_device(x: T | None, **kwargs) -> T:
if x is None:
return
return tree_map(lambda t: t.to(**kwargs), x)

View File

@ -79,7 +79,9 @@ def init_tts(yaml=None, restart=False):
if tts is not None:
if not restart:
return tts
del tts
tts = None
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', yaml)) # os environ so it can be specified in a HuggingFace Space too
@ -124,8 +126,10 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
tmp = tempfile.NamedTemporaryFile(suffix='.wav')
"""
if not args.references:
raise ValueError("No reference audio provided.")
"""
tts = init_tts()
@ -134,7 +138,7 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
wav, sr = tts.inference(
text=args.text,
language=args.language,
references=[args.references.split(";")],
references=[args.references.split(";")] if args.references is not None else [],
out_path=tmp.name,
max_ar_steps=args.max_ar_steps,
max_nar_levels=args.max_nar_levels,