From 8eb9a4056b669b903982a8a0a709db3414dd6f29 Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 22 Oct 2024 18:12:39 -0500 Subject: [PATCH] modified default arguments (ar temp = 0 and rep pen = 1.125 seems to be stable, at least given the few things i tested), do not pass top k/top p/min p to NAR even though technically none of those things should matter when greedy sampling --- vall_e/__main__.py | 2 +- vall_e/config.py | 1 + vall_e/data.py | 3 ++ vall_e/demo.py | 12 +++-- vall_e/models/ar_nar.py | 14 +++--- vall_e/models/base.py | 4 +- vall_e/utils/unsloth.py | 98 +++++++++++++++++++++++++++++++++++++++++ vall_e/utils/wrapper.py | 8 ++++ vall_e/webui.py | 2 +- 9 files changed, 130 insertions(+), 14 deletions(-) create mode 100644 vall_e/utils/unsloth.py diff --git a/vall_e/__main__.py b/vall_e/__main__.py index 8f96d3e..362de39 100755 --- a/vall_e/__main__.py +++ b/vall_e/__main__.py @@ -31,7 +31,7 @@ def main(): parser.add_argument("--top-p", type=float, default=1.0) parser.add_argument("--top-k", type=int, default=0) parser.add_argument("--min-p", type=float, default=0.0) - parser.add_argument("--repetition-penalty", type=float, default=1.0) + parser.add_argument("--repetition-penalty", type=float, default=1.125) parser.add_argument("--repetition-penalty-decay", type=float, default=0.0) parser.add_argument("--length-penalty", type=float, default=0.0) parser.add_argument("--beam-width", type=int, default=0) diff --git a/vall_e/config.py b/vall_e/config.py index 91c77be..e78ed23 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -692,6 +692,7 @@ class Optimizations: # | {"assign": [[ f'layers.{i}.' for i in range(0,6) ], [ f'layers.{i}.' for i in range(6,12) ]]} will assign layers 0-5 to device 1, and 6-12 to device 2 tensorrt: bool = False + unsloth: bool = False # unsloth gradient checkpointing (it just offloads tensors to the CPU during backwards, I don't think it's significant enough to bother with on small models) @dataclass() class Config(BaseConfig): diff --git a/vall_e/data.py b/vall_e/data.py index 4ba7593..b3d67b8 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -746,10 +746,13 @@ class Dataset(_Dataset): flattened[bucket] = [*_interleaved_reorder(flattened[bucket], self.get_speaker)] # flatten paths self.paths = list(itertools.chain.from_iterable(flattened.values())) + elif self.sampler_order == "random": + random.shuffle( self.paths ) else: # just interleave self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)] + # dict of speakers keyed by speaker group self.spkrs_by_spkr_group = {} diff --git a/vall_e/demo.py b/vall_e/demo.py index 175e16f..1865cca 100644 --- a/vall_e/demo.py +++ b/vall_e/demo.py @@ -46,6 +46,7 @@ def main(): parser.add_argument("--demo-dir", type=Path, default=None) parser.add_argument("--skip-existing", action="store_true") parser.add_argument("--dataset-dir-name", type=str, default="dataset") + parser.add_argument("--dataset-dir-name-prefix", type=str, default=None) parser.add_argument("--sample-from-dataset", action="store_true") parser.add_argument("--skip-loading-dataloader", action="store_true") parser.add_argument("--dataset-samples", type=int, default=0) @@ -209,6 +210,7 @@ def main(): if args.sample_from_dataset: cfg.dataset.cache = False cfg.dataset.sample_type = "path" if len(cfg.dataset.training) < cfg.evaluation.batch_size else "speaker" + cfg.dataset.sample_order = "random" cfg.dataset.tasks_list = [ 'tts' ] samples_dirs["dataset"] = args.demo_dir / args.dataset_dir_name @@ -221,16 +223,20 @@ def main(): num = args.dataset_samples if args.dataset_samples else length for i in trange( num, desc="Sampling dataset for samples" ): - index = i if not cfg.dataset.sample_shuffle else random.randint( i, length ) + index = i if not cfg.dataset.sample_shuffle else random.randint( 0, len( dataloader.dataset ) ) batch = dataloader.dataset[i] - dir = args.demo_dir / args.dataset_dir_name / f'{i}' + if args.dataset_dir_name_prefix: + dir = args.demo_dir / args.dataset_dir_name / f'{args.dataset_dir_name_prefix}_{i}' + else: + dir = args.demo_dir / args.dataset_dir_name / f'{i}' (dir / "out").mkdir(parents=True, exist_ok=True) metadata = batch["metadata"] - text = get_random_prompt() if args.random_prompts else metadata["text"] + #text = get_random_prompt() if args.random_prompts else metadata["text"] + text = get_random_prompt() if i >= (num // 2) else metadata["text"] language = metadata["language"].lower() prompt = dir / "prompt.wav" diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 58d0c69..7342766 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -232,10 +232,10 @@ class AR_NAR(Base): quant_levels=quant_levels, temperature=sampling_temperature, - min_temperature=sampling_min_temperature, - top_p=sampling_top_p, - top_k=sampling_top_k, - min_p=sampling_min_p, + #min_temperature=sampling_min_temperature, + #top_p=sampling_top_p, + #top_k=sampling_top_k, + #min_p=sampling_min_p, #repetition_penalty=sampling_repetition_penalty, #repetition_penalty_decay=sampling_repetition_penalty_decay, #length_penalty=sampling_length_penalty, @@ -269,8 +269,8 @@ class AR_NAR(Base): entropies = [] # ick - low_temperature = sampling_repetition_penalty == 1.0 and sampling_temperature < 0.5 - low_temperature_range = cfg.dataset.frames_per_second * 3 + low_temperature = False # sampling_repetition_penalty == 1.0 and sampling_temperature == 0.0 # + low_temperature_range = cfg.dataset.frames_per_second * 5 original_sampling_temperature = sampling_temperature original_sampling_repetition_penalty = sampling_repetition_penalty @@ -302,7 +302,7 @@ class AR_NAR(Base): # to-do: tune these values, maybe have it factor based on confidence scores or something if low_temperature: enabled = n < low_temperature_range - sampling_repetition_penalty = 1.5 if enabled else original_sampling_repetition_penalty + sampling_repetition_penalty = 1.125 if enabled else original_sampling_repetition_penalty sampling_repetition_penalty_decay = 0.0 if enabled else original_sampling_repetition_penalty_decay sampling_temperature = original_sampling_temperature if enabled else 1.0 diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 03b1f43..5297d01 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -1560,8 +1560,8 @@ class Base(nn.Module): logits = [ logit / temperature for logit in logits ] # do DRY sampling - if dry_multiplier > 0.0: - logits = [ dry_sampling(logit, previous=resps[:, -1].tolist(), factor=dry_multiplier, base=dry_base, allowed_length=dry_allowed_length) for logit, resps in zip( logits, prev_list ) ] + if dry_multiplier > 0.0 and prev_list is not None: + logits = [ dry_sampling(logit, previous=prevs[:, -1].tolist(), factor=dry_multiplier, base=dry_base, allowed_length=dry_allowed_length) for logit, prevs in zip( logits, prev_list ) ] # do mirostat sampling # currently incompatible with beam searching with the way the two are implemented, perhaps a night of brain bashing can make the two work diff --git a/vall_e/utils/unsloth.py b/vall_e/utils/unsloth.py new file mode 100644 index 0000000..5bbd2c7 --- /dev/null +++ b/vall_e/utils/unsloth.py @@ -0,0 +1,98 @@ +# lifted from https://gist.github.com/pszemraj/e88ff24ab296b6d89057376b299b368a +# to-do: make this work with LoRAs, it complains + +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import transformers +import inspect + + +class Unsloth_Offloaded_Gradient_Checkpointer(torch.autograd.Function): + """ + Saves VRAM by smartly offloading to RAM. + Tiny hit to performance, since we mask the movement via non blocking calls. + """ + + @staticmethod + @torch.cuda.amp.custom_fwd + def forward(ctx, forward_function, hidden_states, *args): + saved_hidden_states = hidden_states.to("cpu", non_blocking=True) + with torch.no_grad(): + output = forward_function(hidden_states, *args) + ctx.save_for_backward(saved_hidden_states) + ctx.forward_function = forward_function + ctx.args = args + + return output + + pass + + @staticmethod + @torch.cuda.amp.custom_bwd + def backward(ctx, dY): + (hidden_states,) = ctx.saved_tensors + hidden_states = hidden_states.to("cuda", non_blocking=True).detach() + hidden_states.requires_grad = True + with torch.enable_grad(): + (output,) = ctx.forward_function(hidden_states, *ctx.args) + torch.autograd.backward(output, dY) + return ( + None, + hidden_states.grad, + ) + ( + None, + ) * len(ctx.args) + + pass + + +pass + + +def new_gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): + #assert gradient_checkpointing_kwargs == None + gradient_checkpointing_kwargs = None + if not self.supports_gradient_checkpointing: + raise ValueError( + f"{self.__class__.__name__} does not support gradient checkpointing." + ) + + gradient_checkpointing_func = Unsloth_Offloaded_Gradient_Checkpointer.apply + # For old GC format (transformers < 4.35.0) for models that live on the Hub + # we will fall back to the overwritten `_set_gradient_checkpointing` method + _is_using_old_format = ( + "value" in inspect.signature(self._set_gradient_checkpointing).parameters + ) + + if not _is_using_old_format: + self._set_gradient_checkpointing( + enable=True, gradient_checkpointing_func=gradient_checkpointing_func + ) + else: + raise NotImplementedError() + + if getattr(self, "_hf_peft_config_loaded", False): + # When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True + # we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334 + # When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate + # the gradients to make sure the gradient flows. + self.enable_input_require_grads() + + +def apply_unsloth_offloaded_gradient_checkpoint_monkey_patch(): + transformers.modeling_utils.PreTrainedModel.gradient_checkpointing_enable = ( + new_gradient_checkpointing_enable + ) \ No newline at end of file diff --git a/vall_e/utils/wrapper.py b/vall_e/utils/wrapper.py index 2cbd14b..c9aa087 100755 --- a/vall_e/utils/wrapper.py +++ b/vall_e/utils/wrapper.py @@ -101,6 +101,14 @@ if cfg.optimizations.tensorrt: _logger.warning(f'Error while importing TensorRT: {str(e)}') pass +if cfg.optimizations.unsloth: + try: + from .unsloth import apply_unsloth_offloaded_gradient_checkpoint_monkey_patch + #apply_unsloth_offloaded_gradient_checkpoint_monkey_patch() + except Exception as e: + _logger.warning(f'Error while importing Unsloth: {str(e)}') + pass + def compile_model(model, backend="auto"): if not backend or backend == "auto": backend = AVAILABLE_COMPILE_BACKENDS[0] diff --git a/vall_e/webui.py b/vall_e/webui.py index 2a9a33f..0bb36e0 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -367,7 +367,7 @@ with ui: layout["inference_tts"]["inputs"]["min-p"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Min P") layout["inference_tts"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.") with gr.Row(): - layout["inference_tts"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.") + layout["inference_tts"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.125, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.") layout["inference_tts"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.") layout["inference_tts"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.") with gr.Row():