added prom-less training / inferencing, some other things
This commit is contained in:
parent
491ae2a684
commit
75b04686f8
|
@ -167,6 +167,7 @@ class Dataset:
|
|||
reencode_on_concat: bool = False # whether to concat audio by decode => concat => encode, or naively concat codes
|
||||
reencode_device: str = "cpu" # "cpu" is slower but saves memory, cuda throws [rank0]: RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method
|
||||
noise_scale: float = 0.25 # scaling noise value
|
||||
inject_noise_in_prom: bool = False
|
||||
|
||||
_frames_per_second: int = 0 # allows setting your own hint
|
||||
|
||||
|
@ -358,6 +359,7 @@ class LoRA:
|
|||
rank: int = 128 # rank for the LoRA
|
||||
alpha: int = 128 # rank for the LoRA
|
||||
training: bool = True #
|
||||
embeddings: bool = False # train the embedding too
|
||||
parametrize: bool = False #
|
||||
rvq_levels: list[int] = field(default_factory=lambda: []) # determines RVQ levels to activate the LoRA
|
||||
|
||||
|
@ -832,8 +834,7 @@ class Config(BaseConfig):
|
|||
if self.hyperparameters.scheduler == "":
|
||||
self.hyperparameters.torch_scheduler = True
|
||||
|
||||
if self.dataset.prompt_duration != 0:
|
||||
self.dataset.prompt_duration_range = [self.dataset.prompt_duration, self.dataset.prompt_duration]
|
||||
self.dataset.prompt_duration_range = [self.dataset.prompt_duration, self.dataset.prompt_duration]
|
||||
|
||||
if self.trainer.backend == "local" and self.distributed:
|
||||
self.trainer.ddp = True
|
||||
|
|
|
@ -11,7 +11,7 @@ import torch
|
|||
import itertools
|
||||
|
||||
from .config import cfg
|
||||
from .emb.qnt import trim, trim_random, repeat_extend_audio, concat_audio, merge_audio, decode_to_file
|
||||
from .emb.qnt import trim, trim_random, repeat_extend_audio, concat_audio, merge_audio, decode_to_file, decode as decode_qnt, encode as encode_qnt
|
||||
from .utils.sampler import PoolSampler, OrderedSampler, BatchedOrderedSampler, RandomSampler
|
||||
from .utils.distributed import global_rank, local_rank, world_size
|
||||
|
||||
|
@ -717,6 +717,9 @@ class Dataset(_Dataset):
|
|||
|
||||
|
||||
def sample_prompts(self, spkr_name, ignore, should_trim=True):
|
||||
if not cfg.dataset.prompt_duration_range or cfg.dataset.prompt_duration_range[-1] == 0:
|
||||
return None
|
||||
|
||||
prom_list = []
|
||||
|
||||
choices = set(self.paths_by_spkr_name[spkr_name]) - {ignore}
|
||||
|
@ -748,7 +751,7 @@ class Dataset(_Dataset):
|
|||
qnt = _load_quants(path, return_metadata=False)
|
||||
|
||||
if 0 < trim_length and trim_length < qnt.shape[0]:
|
||||
qnt = trim( qnt, trim_length, reencode=cfg.dataset.reencode_on_concat )
|
||||
qnt = trim( qnt, trim_length, reencode=cfg.dataset.reencode_on_concat, device=cfg.dataset.reencode_device )
|
||||
|
||||
prom_list.append(qnt)
|
||||
prom_length += qnt.shape[0]
|
||||
|
@ -758,10 +761,10 @@ class Dataset(_Dataset):
|
|||
|
||||
# might be better to decode => concat waveforms with silence in between => reencode
|
||||
# as you technically can't just append encodec sequences together like this without issues
|
||||
prom = torch.cat(prom_list)
|
||||
prom = concat_audio( *prom_list, reencode=cfg.dataset.reencode_on_concat, device=cfg.dataset.reencode_device )
|
||||
|
||||
if 0 < trim_length and trim_length < prom.shape[0]:
|
||||
prom = trim( prom, trim_length, reencode=cfg.dataset.reencode_on_concat )
|
||||
prom = trim( prom, trim_length, reencode=cfg.dataset.reencode_on_concat, device=cfg.dataset.reencode_device )
|
||||
|
||||
return prom
|
||||
|
||||
|
@ -855,6 +858,15 @@ class Dataset(_Dataset):
|
|||
if task == "tts":
|
||||
proms = self.sample_prompts(spkr_name, ignore=path)
|
||||
|
||||
if cfg.dataset.inject_noise_in_prom:
|
||||
# sample random noise
|
||||
noise = self.sample_noise()
|
||||
# extend the noise to fill the target audio
|
||||
noise = repeat_extend_audio(noise, proms.shape[0])
|
||||
# create the input prompt by merging the target audio with the noise
|
||||
proms = merge_audio( proms, noise, scale=[1, cfg.dataset.noise_scale], device=cfg.dataset.reencode_device )
|
||||
|
||||
|
||||
# VALL-E Continuous (<text><partial resp> => <remaining resp> )
|
||||
# (this could just be sampled as <text a><text b><audio a> => <audio b>, but I need to experiment with it)
|
||||
elif task == "tts-c":
|
||||
|
|
|
@ -431,7 +431,7 @@ def encode_from_file(path, device="cuda"):
|
|||
Helper Functions
|
||||
"""
|
||||
# trims from the start, up to `target`
|
||||
def trim( qnt, target, reencode=False ):
|
||||
def trim( qnt, target, reencode=False, device="cuda" ):
|
||||
length = max( qnt.shape[0], qnt.shape[1] )
|
||||
if target > 0:
|
||||
start = 0
|
||||
|
@ -454,8 +454,8 @@ def trim( qnt, target, reencode=False ):
|
|||
start = start / cfg.dataset.frames_per_second * cfg.sample_rate
|
||||
end = end / cfg.dataset.frames_per_second * cfg.sample_rate
|
||||
|
||||
wav = decode(qnt)[0]
|
||||
return encode(wav[start:end], cfg.sample_rate)[0].t()
|
||||
wav = decode(qnt, device=device)[0]
|
||||
return encode(wav[start:end], cfg.sample_rate, device=device)[0].t()
|
||||
|
||||
# trims a random piece of audio, up to `target`
|
||||
# to-do: try and align to EnCodec window
|
||||
|
|
|
@ -81,7 +81,7 @@ class Engine():
|
|||
|
||||
# freeze non-LoRA params if requested
|
||||
if not self.hyper_config.frozen_params and not freeze_all and cfg.lora is not None:
|
||||
return freeze_non_lora_weights( self.module )
|
||||
return freeze_non_lora_weights( self.module, embeddings=cfg.lora.embeddings )
|
||||
|
||||
for name, param in self.module.named_parameters():
|
||||
if (freeze_all and param.requires_grad) or (not freeze_all and name in self.hyper_config.frozen_params):
|
||||
|
|
|
@ -27,6 +27,8 @@ from deepspeed.accelerator import get_accelerator
|
|||
from ..utils.distributed import init_distributed, distributed_initialized
|
||||
from ..utils import wrapper as ml
|
||||
|
||||
from ..models.lora import freeze_non_lora_weights
|
||||
|
||||
if not distributed_initialized() and cfg.trainer.backend == "deepspeed":
|
||||
init_distributed(init_deepspeed_dist)
|
||||
|
||||
|
@ -66,11 +68,10 @@ class Engine(DeepSpeedEngine):
|
|||
def freeze(self, freeze_all=True):
|
||||
# freeze non-LoRA params if requested
|
||||
if not self.hyper_config.frozen_params and not freeze_all and cfg.lora is not None:
|
||||
for name, param in self.module.named_parameters():
|
||||
should = 'lora_' in name
|
||||
param.requires_grad_(should)
|
||||
if not should:
|
||||
self._frozen_params.add(param)
|
||||
frozen_params = freeze_non_lora_weights( self.module, embeddings=cfg.lora.embeddings )
|
||||
for param in frozen_params:
|
||||
self._frozen_params.add( param )
|
||||
|
||||
return
|
||||
|
||||
if self.hyper_config is None or not hasattr(self.hyper_config, "frozen_params"):
|
||||
|
|
|
@ -170,21 +170,18 @@ class TTS():
|
|||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
out_path = output_dir / f"{time.time()}.wav"
|
||||
|
||||
prom = self.encode_audio( references, trim_length=input_prompt_length )
|
||||
prom = self.encode_audio( references, trim_length=input_prompt_length ) if references else None
|
||||
phns = self.encode_text( line, language=language )
|
||||
lang = self.encode_lang( language )
|
||||
|
||||
prom = to_device(prom, self.device).to(torch.int16)
|
||||
phns = to_device(phns, self.device).to(torch.uint8 if len(self.symmap) < 256 else torch.int16)
|
||||
lang = to_device(lang, self.device).to(torch.uint8)
|
||||
prom = to_device(prom, device=self.device, dtype=torch.int16)
|
||||
phns = to_device(phns, device=self.device, dtype=torch.uint8 if len(self.symmap) < 256 else torch.int16)
|
||||
lang = to_device(lang, device=self.device, dtype=torch.uint8)
|
||||
|
||||
text_list = [ phns ]
|
||||
proms_list = [ prom ]
|
||||
|
||||
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
|
||||
# AR temp: 1
|
||||
# NAR temp: 0.05
|
||||
# prom size: 3
|
||||
if model_ar is not None:
|
||||
resps_list = model_ar(
|
||||
text_list=[phns], proms_list=[prom], lang_list=[lang], max_steps=max_ar_steps,
|
||||
|
|
|
@ -174,15 +174,23 @@ class AR_NAR(Base):
|
|||
"""
|
||||
|
||||
for i in range(batch_size):
|
||||
# other tasks might have the prom be a list and this is just the easiest way to acknowledge that
|
||||
if task_list[i] == "tts":
|
||||
# cap quant_level if it exceeds its corresponding resp/prom
|
||||
if quant_levels[i] >= resps_list[i].shape[-1]:
|
||||
quant_levels[i] = resps_list[i].shape[-1] - 1
|
||||
# cap quant_level if it exceeds its corresponding resp/prom
|
||||
if quant_levels[i] >= resps_list[i].shape[-1]:
|
||||
quant_levels[i] = resps_list[i].shape[-1] - 1
|
||||
|
||||
# proms_list[i] could be a Tensor, list[Tensor], or None
|
||||
if isinstance( proms_list[i], torch.Tensor ):
|
||||
if quant_levels[i] >= proms_list[i].shape[-1]:
|
||||
quant_levels[i] = proms_list[i].shape[-1] - 1
|
||||
|
||||
elif isinstance( proms_list[i], list ):
|
||||
for j, prom in enumerate( proms_list[i] ):
|
||||
if not isinstance( prom, torch.Tensor ):
|
||||
continue
|
||||
|
||||
if quant_levels[i] >= prom.shape[-1]:
|
||||
quant_levels[i] = prom.shape[-1] - 1
|
||||
|
||||
# only apply stop token for RVQ level 0
|
||||
if quant_levels[i] > 0:
|
||||
continue
|
||||
|
|
|
@ -891,27 +891,30 @@ class Base(nn.Module):
|
|||
# insert task type as a string
|
||||
inputs[i].append( ( "task", task_type ) )
|
||||
|
||||
# to-do: maybe not split the below blocks up
|
||||
# might be beneficial in the event I need to use a difference sequence, such as STT tasks
|
||||
|
||||
# Base-line TTS task
|
||||
# Sequence: <text><sep><rvq lvl><sep><prom><sep><resp>
|
||||
# prom /may/ include <task> tokens inside to help guide things, per SpeechX
|
||||
if f'<{task_type}>' in get_task_symmap():
|
||||
# insert the text prompt
|
||||
if text_list is not None:
|
||||
if text_list is not None and text_list[i] is not None:
|
||||
inputs[i].append( ( "text", text_list[i] ) )
|
||||
# insert lang token if we're trained for it
|
||||
if "lang" in self.capabilities and lang_list is not None:
|
||||
if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None:
|
||||
inputs[i].append( ( "lang", lang_list[i] ) )
|
||||
# insert RVQ level guidance token if the model is versioned for it
|
||||
if self.rvq_l_emb is not None:
|
||||
inputs[i].append( ( "quant_level", torch.Tensor([ quant_level ]).to(device=device, dtype=torch.int16) ) )
|
||||
# insert input audio prompt
|
||||
if proms_list is not None:
|
||||
if proms_list is not None and proms_list[i] is not None:
|
||||
inputs[i].append( ( "prom", proms_list[i] ) )
|
||||
# insert tone token if we're trained for it
|
||||
if "tone" in self.capabilities and tone_list is not None:
|
||||
if "tone" in self.capabilities and tone_list is not None and tone_list[i] is not None:
|
||||
inputs[i].append( ( "tone", tone_list[i] ) )
|
||||
# insert the current output response
|
||||
if resps_list is not None:
|
||||
if resps_list is not None and resps_list[i] is not None:
|
||||
inputs[i].append( ( "resp", resps_list[i] ) )
|
||||
|
||||
# Audio length prediction task
|
||||
|
@ -922,10 +925,10 @@ class Base(nn.Module):
|
|||
raise Exception(f"Requesting task `{task_type}` but corresponding embedding is not defined.")
|
||||
|
||||
# insert the text prompt
|
||||
if text_list is not None:
|
||||
if text_list is not None and text_list[i] is not None:
|
||||
inputs[i].append( ( "text", text_list[i] ) )
|
||||
# insert lang token if we're trained for it
|
||||
if "lang" in self.capabilities and lang_list is not None:
|
||||
if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None:
|
||||
inputs[i].append( ( "lang", lang_list[i] ) )
|
||||
# technically will always be level 0 but for the sake of keeing the input formatting coherent...
|
||||
if self.rvq_l_emb is not None:
|
||||
|
@ -933,17 +936,17 @@ class Base(nn.Module):
|
|||
quant_levels[i] = 0
|
||||
inputs[i].append( ( "quant_level", torch.Tensor([ 0 ]).to(device=device, dtype=torch.int16) ) )
|
||||
# insert input audio prompt
|
||||
if proms_list is not None:
|
||||
if proms_list is not None and proms_list[i] is not None:
|
||||
inputs[i].append( ( "prom", proms_list[i] ) )
|
||||
# insert tone token if we're trained for it
|
||||
if "tone" in self.capabilities and tone_list is not None:
|
||||
if "tone" in self.capabilities and tone_list is not None and tone_list[i] is not None:
|
||||
inputs[i].append( ( "tone", tone_list[i] ) )
|
||||
|
||||
# insert output length tokens (if it exists)
|
||||
if len_list is not None:
|
||||
if len_list is not None and len_list[i] is not None:
|
||||
inputs[i].append( ( "len", len_list[i] ) )
|
||||
# "encode" length to tokens for 0-9 + stop
|
||||
elif resps_list is not None:
|
||||
elif resps_list is not None and resps_list[i] is not None:
|
||||
# yes this could be encoded better
|
||||
inputs[i].append( ( "len", torch.Tensor([ 0 ] + [ int(i) for i in str( resps_list[i].shape[0]) ] + [ 10 ]).to(device=device, dtype=torch.int16) ) )
|
||||
else:
|
||||
|
|
|
@ -205,10 +205,18 @@ def enable_lora( model, mode = True ):
|
|||
def disable_lora( model ):
|
||||
return enable_lora( model, False )
|
||||
|
||||
def freeze_non_lora_weights( model ):
|
||||
def freeze_non_lora_weights( model, embeddings = False ):
|
||||
frozen_params = []
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
param.requires_grad_('lora_' in name)
|
||||
return model
|
||||
should = 'lora_' in name or (embeddings and "_emb" in name)
|
||||
|
||||
param.requires_grad_(should)
|
||||
|
||||
if not should:
|
||||
frozen_params.append( param )
|
||||
|
||||
return frozen_params
|
||||
|
||||
def lora_get_state_dict( state_dict, split = True ):
|
||||
lora = { name: param for name, param in state_dict.items() if "lora_" in name }
|
||||
|
|
|
@ -100,7 +100,7 @@ def run_eval(engines, eval_name, dl):
|
|||
filename = f"{filename}_{task}"
|
||||
|
||||
# flatten prom
|
||||
if not isinstance(prom, torch.Tensor):
|
||||
if not isinstance(prom, torch.Tensor) and prom is not None:
|
||||
prom = torch.concat([ p for p in prom if isinstance(p, torch.Tensor) ])
|
||||
|
||||
# to-do, refine the output dir to be sane-er
|
||||
|
@ -114,7 +114,8 @@ def run_eval(engines, eval_name, dl):
|
|||
|
||||
ref_audio, sr = qnt.decode_to_file(ref, ref_path)
|
||||
hyp_audio, sr = qnt.decode_to_file(hyp, hyp_path)
|
||||
prom_audio, sr = qnt.decode_to_file(prom, prom_path)
|
||||
if prom is not None:
|
||||
prom_audio, sr = qnt.decode_to_file(prom, prom_path)
|
||||
|
||||
# pseudo loss calculation since we don't get the logits during eval
|
||||
min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] )
|
||||
|
|
|
@ -19,7 +19,7 @@ class PoolSampler():
|
|||
def reset(self):
|
||||
self.current_pool = [ i for i in self.global_indices ]
|
||||
if self.shuffle:
|
||||
random(self.current_pool)
|
||||
random.shuffle(self.current_pool)
|
||||
|
||||
def sample(self, pool = None):
|
||||
if pool is None:
|
||||
|
|
|
@ -173,5 +173,8 @@ def tree_map(fn: Callable, x):
|
|||
return x
|
||||
|
||||
|
||||
def to_device(x: T, device) -> T:
|
||||
return tree_map(lambda t: t.to(device), x)
|
||||
def to_device(x: T | None, **kwargs) -> T:
|
||||
if x is None:
|
||||
return
|
||||
|
||||
return tree_map(lambda t: t.to(**kwargs), x)
|
||||
|
|
|
@ -79,7 +79,9 @@ def init_tts(yaml=None, restart=False):
|
|||
if tts is not None:
|
||||
if not restart:
|
||||
return tts
|
||||
|
||||
del tts
|
||||
tts = None
|
||||
|
||||
parser = argparse.ArgumentParser(allow_abbrev=False)
|
||||
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', yaml)) # os environ so it can be specified in a HuggingFace Space too
|
||||
|
@ -124,8 +126,10 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
|
||||
tmp = tempfile.NamedTemporaryFile(suffix='.wav')
|
||||
|
||||
"""
|
||||
if not args.references:
|
||||
raise ValueError("No reference audio provided.")
|
||||
"""
|
||||
|
||||
tts = init_tts()
|
||||
|
||||
|
@ -134,7 +138,7 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
wav, sr = tts.inference(
|
||||
text=args.text,
|
||||
language=args.language,
|
||||
references=[args.references.split(";")],
|
||||
references=[args.references.split(";")] if args.references is not None else [],
|
||||
out_path=tmp.name,
|
||||
max_ar_steps=args.max_ar_steps,
|
||||
max_nar_levels=args.max_nar_levels,
|
||||
|
|
Loading…
Reference in New Issue
Block a user