modified default arguments (ar temp = 0 and rep pen = 1.125 seems to be stable, at least given the few things i tested), do not pass top k/top p/min p to NAR even though technically none of those things should matter when greedy sampling

This commit is contained in:
mrq 2024-10-22 18:12:39 -05:00
parent 1a02cd5bce
commit 8eb9a4056b
9 changed files with 130 additions and 14 deletions

View File

@ -31,7 +31,7 @@ def main():
parser.add_argument("--top-p", type=float, default=1.0)
parser.add_argument("--top-k", type=int, default=0)
parser.add_argument("--min-p", type=float, default=0.0)
parser.add_argument("--repetition-penalty", type=float, default=1.0)
parser.add_argument("--repetition-penalty", type=float, default=1.125)
parser.add_argument("--repetition-penalty-decay", type=float, default=0.0)
parser.add_argument("--length-penalty", type=float, default=0.0)
parser.add_argument("--beam-width", type=int, default=0)

View File

@ -692,6 +692,7 @@ class Optimizations:
# | {"assign": [[ f'layers.{i}.' for i in range(0,6) ], [ f'layers.{i}.' for i in range(6,12) ]]} will assign layers 0-5 to device 1, and 6-12 to device 2
tensorrt: bool = False
unsloth: bool = False # unsloth gradient checkpointing (it just offloads tensors to the CPU during backwards, I don't think it's significant enough to bother with on small models)
@dataclass()
class Config(BaseConfig):

View File

@ -746,10 +746,13 @@ class Dataset(_Dataset):
flattened[bucket] = [*_interleaved_reorder(flattened[bucket], self.get_speaker)]
# flatten paths
self.paths = list(itertools.chain.from_iterable(flattened.values()))
elif self.sampler_order == "random":
random.shuffle( self.paths )
else:
# just interleave
self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)]
# dict of speakers keyed by speaker group
self.spkrs_by_spkr_group = {}

View File

@ -46,6 +46,7 @@ def main():
parser.add_argument("--demo-dir", type=Path, default=None)
parser.add_argument("--skip-existing", action="store_true")
parser.add_argument("--dataset-dir-name", type=str, default="dataset")
parser.add_argument("--dataset-dir-name-prefix", type=str, default=None)
parser.add_argument("--sample-from-dataset", action="store_true")
parser.add_argument("--skip-loading-dataloader", action="store_true")
parser.add_argument("--dataset-samples", type=int, default=0)
@ -209,6 +210,7 @@ def main():
if args.sample_from_dataset:
cfg.dataset.cache = False
cfg.dataset.sample_type = "path" if len(cfg.dataset.training) < cfg.evaluation.batch_size else "speaker"
cfg.dataset.sample_order = "random"
cfg.dataset.tasks_list = [ 'tts' ]
samples_dirs["dataset"] = args.demo_dir / args.dataset_dir_name
@ -221,16 +223,20 @@ def main():
num = args.dataset_samples if args.dataset_samples else length
for i in trange( num, desc="Sampling dataset for samples" ):
index = i if not cfg.dataset.sample_shuffle else random.randint( i, length )
index = i if not cfg.dataset.sample_shuffle else random.randint( 0, len( dataloader.dataset ) )
batch = dataloader.dataset[i]
dir = args.demo_dir / args.dataset_dir_name / f'{i}'
if args.dataset_dir_name_prefix:
dir = args.demo_dir / args.dataset_dir_name / f'{args.dataset_dir_name_prefix}_{i}'
else:
dir = args.demo_dir / args.dataset_dir_name / f'{i}'
(dir / "out").mkdir(parents=True, exist_ok=True)
metadata = batch["metadata"]
text = get_random_prompt() if args.random_prompts else metadata["text"]
#text = get_random_prompt() if args.random_prompts else metadata["text"]
text = get_random_prompt() if i >= (num // 2) else metadata["text"]
language = metadata["language"].lower()
prompt = dir / "prompt.wav"

View File

@ -232,10 +232,10 @@ class AR_NAR(Base):
quant_levels=quant_levels,
temperature=sampling_temperature,
min_temperature=sampling_min_temperature,
top_p=sampling_top_p,
top_k=sampling_top_k,
min_p=sampling_min_p,
#min_temperature=sampling_min_temperature,
#top_p=sampling_top_p,
#top_k=sampling_top_k,
#min_p=sampling_min_p,
#repetition_penalty=sampling_repetition_penalty,
#repetition_penalty_decay=sampling_repetition_penalty_decay,
#length_penalty=sampling_length_penalty,
@ -269,8 +269,8 @@ class AR_NAR(Base):
entropies = []
# ick
low_temperature = sampling_repetition_penalty == 1.0 and sampling_temperature < 0.5
low_temperature_range = cfg.dataset.frames_per_second * 3
low_temperature = False # sampling_repetition_penalty == 1.0 and sampling_temperature == 0.0 #
low_temperature_range = cfg.dataset.frames_per_second * 5
original_sampling_temperature = sampling_temperature
original_sampling_repetition_penalty = sampling_repetition_penalty
@ -302,7 +302,7 @@ class AR_NAR(Base):
# to-do: tune these values, maybe have it factor based on confidence scores or something
if low_temperature:
enabled = n < low_temperature_range
sampling_repetition_penalty = 1.5 if enabled else original_sampling_repetition_penalty
sampling_repetition_penalty = 1.125 if enabled else original_sampling_repetition_penalty
sampling_repetition_penalty_decay = 0.0 if enabled else original_sampling_repetition_penalty_decay
sampling_temperature = original_sampling_temperature if enabled else 1.0

View File

@ -1560,8 +1560,8 @@ class Base(nn.Module):
logits = [ logit / temperature for logit in logits ]
# do DRY sampling
if dry_multiplier > 0.0:
logits = [ dry_sampling(logit, previous=resps[:, -1].tolist(), factor=dry_multiplier, base=dry_base, allowed_length=dry_allowed_length) for logit, resps in zip( logits, prev_list ) ]
if dry_multiplier > 0.0 and prev_list is not None:
logits = [ dry_sampling(logit, previous=prevs[:, -1].tolist(), factor=dry_multiplier, base=dry_base, allowed_length=dry_allowed_length) for logit, prevs in zip( logits, prev_list ) ]
# do mirostat sampling
# currently incompatible with beam searching with the way the two are implemented, perhaps a night of brain bashing can make the two work

98
vall_e/utils/unsloth.py Normal file
View File

@ -0,0 +1,98 @@
# lifted from https://gist.github.com/pszemraj/e88ff24ab296b6d89057376b299b368a
# to-do: make this work with LoRAs, it complains
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import transformers
import inspect
class Unsloth_Offloaded_Gradient_Checkpointer(torch.autograd.Function):
"""
Saves VRAM by smartly offloading to RAM.
Tiny hit to performance, since we mask the movement via non blocking calls.
"""
@staticmethod
@torch.cuda.amp.custom_fwd
def forward(ctx, forward_function, hidden_states, *args):
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
with torch.no_grad():
output = forward_function(hidden_states, *args)
ctx.save_for_backward(saved_hidden_states)
ctx.forward_function = forward_function
ctx.args = args
return output
pass
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, dY):
(hidden_states,) = ctx.saved_tensors
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
hidden_states.requires_grad = True
with torch.enable_grad():
(output,) = ctx.forward_function(hidden_states, *ctx.args)
torch.autograd.backward(output, dY)
return (
None,
hidden_states.grad,
) + (
None,
) * len(ctx.args)
pass
pass
def new_gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
#assert gradient_checkpointing_kwargs == None
gradient_checkpointing_kwargs = None
if not self.supports_gradient_checkpointing:
raise ValueError(
f"{self.__class__.__name__} does not support gradient checkpointing."
)
gradient_checkpointing_func = Unsloth_Offloaded_Gradient_Checkpointer.apply
# For old GC format (transformers < 4.35.0) for models that live on the Hub
# we will fall back to the overwritten `_set_gradient_checkpointing` method
_is_using_old_format = (
"value" in inspect.signature(self._set_gradient_checkpointing).parameters
)
if not _is_using_old_format:
self._set_gradient_checkpointing(
enable=True, gradient_checkpointing_func=gradient_checkpointing_func
)
else:
raise NotImplementedError()
if getattr(self, "_hf_peft_config_loaded", False):
# When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
# we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334
# When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate
# the gradients to make sure the gradient flows.
self.enable_input_require_grads()
def apply_unsloth_offloaded_gradient_checkpoint_monkey_patch():
transformers.modeling_utils.PreTrainedModel.gradient_checkpointing_enable = (
new_gradient_checkpointing_enable
)

View File

@ -101,6 +101,14 @@ if cfg.optimizations.tensorrt:
_logger.warning(f'Error while importing TensorRT: {str(e)}')
pass
if cfg.optimizations.unsloth:
try:
from .unsloth import apply_unsloth_offloaded_gradient_checkpoint_monkey_patch
#apply_unsloth_offloaded_gradient_checkpoint_monkey_patch()
except Exception as e:
_logger.warning(f'Error while importing Unsloth: {str(e)}')
pass
def compile_model(model, backend="auto"):
if not backend or backend == "auto":
backend = AVAILABLE_COMPILE_BACKENDS[0]

View File

@ -367,7 +367,7 @@ with ui:
layout["inference_tts"]["inputs"]["min-p"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Min P")
layout["inference_tts"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.")
with gr.Row():
layout["inference_tts"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
layout["inference_tts"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.125, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
layout["inference_tts"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.")
layout["inference_tts"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.")
with gr.Row():