diff --git a/tortoise_tts/config.py b/tortoise_tts/config.py index 8c3b937..8c718cc 100755 --- a/tortoise_tts/config.py +++ b/tortoise_tts/config.py @@ -175,6 +175,7 @@ class Dataset: class Model: name: str = "" # vanity name for the model training: bool = False + frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training def get(self, name=None): return [ self ] if not name or self.name == name else [] @@ -190,7 +191,7 @@ class Model: @property def lora_policy(self): - include = ["model"] # by default only adapt the main model (not embeddings nor classifier/output projection/LM head/whatever) + include = ["gpt"] # by default only adapt the main model (not embeddings nor classifier/output projection/LM head/whatever) exclude = [] return dict(include=include, exclude=exclude) @@ -202,6 +203,8 @@ class LoRA: rank: int = 8 # rank for the LoRA alpha: int = 16 # rank for the LoRA training: bool = True # + parametrize: bool = False # + module: str = "linear" # linear | conv1d @property def full_name(self): diff --git a/tortoise_tts/data.py b/tortoise_tts/data.py index 3a911a3..56227f4 100755 --- a/tortoise_tts/data.py +++ b/tortoise_tts/data.py @@ -497,9 +497,13 @@ class Dataset(_Dataset): if key not in cfg.hdf5: raise RuntimeError(f'Key of Path ({path}) not in HDF5: {key}') - text = cfg.hdf5[key]["text"][:] - mel = cfg.hdf5[key]["audio"][:] - latents = cfg.hdf5[key]["latents"][:] + try: + text = cfg.hdf5[key]["text"][:] + mel = cfg.hdf5[key]["audio"][:] + latents = cfg.hdf5[key]["latents"][:] + except Exception as e: + print( key, cfg.hdf5[key].keys() ) + raise e text = torch.from_numpy(text).to(self.text_dtype) mel = torch.from_numpy(mel).to(torch.int16) diff --git a/tortoise_tts/engines/__init__.py b/tortoise_tts/engines/__init__.py index 7cc973d..38b4c28 100755 --- a/tortoise_tts/engines/__init__.py +++ b/tortoise_tts/engines/__init__.py @@ -56,9 +56,11 @@ def load_engines(training=True): model.model = ml.replace_embedding( model.model ) for lora in cfg.loras: - model.model = apply_lora( model.model, rank = lora.rank, alpha = lora.alpha, policy = model.config.lora_policy ) + if hasattr(model, "gpt"): + #model.gpt = apply_lora( model.gpt, rank = lora.rank, alpha = lora.alpha, policy = model.config.lora_policy, parametrize = lora.parametrize ) + model = apply_lora( model, rank = lora.rank, alpha = lora.alpha, policy = model.config.lora_policy, use_parametrize = lora.parametrize ) - if backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer): + if not inferencing and (backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer)): optimizer_class = None scheduler_class = None @@ -124,7 +126,7 @@ def load_engines(training=True): loads_state_dict = True stats = None - if loads_state_dict: + if loads_state_dict and load_path.exists(): state = torch.load(load_path, map_location=torch.device(cfg.device)) # state dict is not just the module, extract the extra trainer details diff --git a/tortoise_tts/models/__init__.py b/tortoise_tts/models/__init__.py index 711b31c..ab3cb93 100755 --- a/tortoise_tts/models/__init__.py +++ b/tortoise_tts/models/__init__.py @@ -59,8 +59,8 @@ def get_model(config, training=True): name = config.name model = load_model(config.name) - - config.training = False + config.training = "autoregressive" in config.name + model.config = config print(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters") diff --git a/tortoise_tts/models/lora.py b/tortoise_tts/models/lora.py index a6f7a67..30ca4c3 100644 --- a/tortoise_tts/models/lora.py +++ b/tortoise_tts/models/lora.py @@ -4,18 +4,17 @@ import torch import torch.nn.functional as F import torch.nn.utils.parametrize as parametrize +from transformers.pytorch_utils import Conv1D + from torch import Tensor, nn import math from typing import Optional, List -# to-do: set cfg to decide -USE_PARAMETRIZATION = True - # LoRA Linear for replacement # Pros: simple, just needs to reuse the replace_linear and copy weights # Cons: does not work with other Linears (bnb, bitnet, te's fp8, etc), cannot apply multiple LoRAs (although for audio why would you) -class Linear(nn.Linear): +class LoRALinear(nn.Linear): def __init__( self, @@ -27,7 +26,7 @@ class Linear(nn.Linear): alpha: int = 1, dropout: float = 0.1, - merge_weights: bool = True, + merge_weights: bool = False, **kwargs, ): super().__init__(in_features=in_features, out_features=out_features, bias=bias, **kwargs) @@ -37,6 +36,7 @@ class Linear(nn.Linear): self.dropout = nn.Dropout(p=dropout) if dropout > 0 else lambda x: x self.merge_weights = merge_weights self.merged = False + self.enabled = True self.lora_B = nn.Parameter( self.weight.new_zeros( (out_features, rank) ) ) self.lora_A = nn.Parameter( self.weight.new_zeros( (rank, in_features) ) ) @@ -67,7 +67,7 @@ class Linear(nn.Linear): self.merged = True def forward(self, x: torch.Tensor): - if not self.merged: + if not self.merged and self.enabled: result = F.linear(x, self.weight, bias=self.bias) result += (self.dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling return result @@ -85,7 +85,7 @@ class Linear(nn.Linear): # Uses parametrization to inject LoRA weights # Pros: should work with any Linears # Cons: TBD -class ParameterizedLinear(nn.Module): +class ParameterizedLoRA(nn.Module): def __init__( self, @@ -120,7 +120,6 @@ class ParameterizedLinear(nn.Module): def forward(self, x: torch.Tensor): if self.enabled: return x + torch.matmul(self.lora_B, self.dropout(self.lora_A)).view(x.shape) * self.scaling - return x @classmethod @@ -133,10 +132,21 @@ class ParameterizedLinear(nn.Module): # M$'s LoRA class arranges things to where this isn't necessary return cls( in_features = layer.out_features, out_features = layer.in_features, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype) + @classmethod + def from_conv1d( cls, layer, device = None, dtype = None, **kwargs ): + if device is None: + device = layer.weight.device + if dtype is None: + dtype = layer.weight.dtype + + in_channels, out_channels = layer.weight.shape + # swap because we're feeding the output as our input + # M$'s LoRA class arranges things to where this isn't necessary + return cls( in_features = out_channels, out_features = in_channels, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype) + def passes_policy( policy, name ): if policy is None: return True - if "exclude" in policy: for term in policy["exclude"]: if term in name: @@ -149,30 +159,50 @@ def passes_policy( policy, name ): return False - -def apply_lora( model, register = True, merge = False, policy = None, **kwargs ): +def apply_lora( model, register = True, merge = False, policy = None, use_parametrize = False, **kwargs ): device = next(model.parameters()).device dtype = next(model.parameters()).dtype - klass = Linear - target = nn.Linear - - device = next(model.parameters()).device - dtype = next(model.parameters()).dtype - modules = [ k.split('.') for k, m in model.named_modules() if isinstance(m, target) and not isinstance(m, klass) and passes_policy( policy, k ) ] + modules = [ k.split('.') for k, m in model.named_modules() if passes_policy( policy, k ) ] for *parent, k in modules: name = '.'.join(parent) layer = getattr( model.get_submodule(name), k ) - if USE_PARAMETRIZATION: - parametrize.register_parametrization( layer, "weight", ParameterizedLinear.from_linear( layer, device=device, dtype=dtype, **kwargs ) ) - # parametrize.remove_parametrizations( layer, "weight", leave_parametrized=merge ) + if isinstance( layer, nn.Linear ): + target = nn.Linear + klass = ParameterizedLoRA if use_parametrize else LoRALinear + replacer = klass.from_linear + elif isinstance( layer, nn.Conv1d ): + target = nn.Conv1d + klass = ParameterizedLoRA if use_parametrize else LoRAConv1d + replacer = klass.from_conv1d + elif isinstance( layer, Conv1D ): + target = Conv1D + klass = ParameterizedLoRA if use_parametrize else LoRAConv1d + replacer = klass.from_conv1d else: - setattr( model.get_submodule(name), k, Linear.from_linear( layer, device=device, dtype=dtype, **kwargs ) ) + continue + + replacement = replacer( layer, device=device, dtype=dtype, **kwargs ) + + if use_parametrize: + parametrize.register_parametrization( layer, "weight", replacement ) + else: + setattr( model.get_submodule(name), k, replacement ) return model +def enable_lora( model, mode = True ): + for name, module in model.named_modules(): + if not isinstance( module, ParameterizedLoRA ) and not isinstance( module, LoRALinear ): + continue + module.enabled = mode + return model + +def disable_lora( model ): + return enable_lora( model, False ) + def freeze_non_lora_weights( model ): for name, param in model.named_parameters(): param.requires_grad_('lora_' in name) diff --git a/tortoise_tts/models/unified_voice.py b/tortoise_tts/models/unified_voice.py index 2461962..7b08a5b 100644 --- a/tortoise_tts/models/unified_voice.py +++ b/tortoise_tts/models/unified_voice.py @@ -273,9 +273,14 @@ def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text n_embd=model_dim, n_layer=layers, n_head=heads, - gradient_checkpointing=checkpointing, use_cache=not checkpointing) gpt = GPT2Model(gpt_config) + + if checkpointing: + gpt.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict( + use_reentrant=False + )) + # Override the built in positional embeddings del gpt.wpe gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim) diff --git a/tortoise_tts/train.py b/tortoise_tts/train.py index 62bc5c6..c0c6bdb 100755 --- a/tortoise_tts/train.py +++ b/tortoise_tts/train.py @@ -29,12 +29,14 @@ mel_stft_loss = auraloss.freq.MelSTFTLoss(cfg.sample_rate, device="cpu") def train_feeder(engine, batch): with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp): + device = batch["text"][0].device + batch_size = len(batch["text"]) + conditioning_latents = pad_sequence([ latents[0] for latents in batch["latents"] ], batch_first = True) text_inputs = pad_sequence([ text for text in batch["text"] ], batch_first = True) - text_lengths = pad_sequence([ text.shape[0] for text in batch["text"] ], batch_first = True) - mel_codes = pad_sequence([ code for codes in batch["mel"] ], batch_first = True) - wav_lengths = pad_sequence([ length for length in batch["wav_length"] ], batch_first = True) - + text_lengths = torch.Tensor([ text.shape[0] for text in batch["text"] ]).to(dtype=torch.int32) + mel_codes = pad_sequence([ codes[0] for codes in batch["mel"] ], batch_first = True) + wav_lengths = torch.Tensor([ x for x in batch["wav_length"] ]).to(dtype=torch.int32) engine.forward(conditioning_latents, text_inputs, text_lengths, mel_codes, wav_lengths) @@ -48,7 +50,7 @@ def train_feeder(engine, batch): stats |= {k: v.item() for k, v in stat.items()} engine.tokens_processed += sum([ text.shape[0] for text in batch["text"] ]) - engine.tokens_processed += sum([ resps.shape[0] for resps in batch["resps"] ]) + engine.tokens_processed += sum([ mel.shape[-1] for mel in batch["mel"] ]) return loss, stats @@ -76,9 +78,11 @@ def run_eval(engines, eval_name, dl): ref_path.parent.mkdir(parents=True, exist_ok=True) prom_path.parent.mkdir(parents=True, exist_ok=True) + """ ref_audio, sr = qnt.decode_to_file(ref, ref_path) hyp_audio, sr = qnt.decode_to_file(hyp, hyp_path) prom_audio, sr = qnt.decode_to_file(prom, prom_path) + """ # pseudo loss calculation since we don't get the logits during eval min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] ) @@ -137,10 +141,10 @@ def train(): print(traceback.format_exc()) engines.train() - qnt.unload_model() + #qnt.unload_model() do_gc() - qnt.unload_model() + #qnt.unload_model() if args.eval: return eval_fn(engines=trainer.load_engines())