modified default arguments (ar temp = 0 and rep pen = 1.125 seems to be stable, at least given the few things i tested), do not pass top k/top p/min p to NAR even though technically none of those things should matter when greedy sampling
This commit is contained in:
parent
1a02cd5bce
commit
8eb9a4056b
|
@ -31,7 +31,7 @@ def main():
|
|||
parser.add_argument("--top-p", type=float, default=1.0)
|
||||
parser.add_argument("--top-k", type=int, default=0)
|
||||
parser.add_argument("--min-p", type=float, default=0.0)
|
||||
parser.add_argument("--repetition-penalty", type=float, default=1.0)
|
||||
parser.add_argument("--repetition-penalty", type=float, default=1.125)
|
||||
parser.add_argument("--repetition-penalty-decay", type=float, default=0.0)
|
||||
parser.add_argument("--length-penalty", type=float, default=0.0)
|
||||
parser.add_argument("--beam-width", type=int, default=0)
|
||||
|
|
|
@ -692,6 +692,7 @@ class Optimizations:
|
|||
# | {"assign": [[ f'layers.{i}.' for i in range(0,6) ], [ f'layers.{i}.' for i in range(6,12) ]]} will assign layers 0-5 to device 1, and 6-12 to device 2
|
||||
|
||||
tensorrt: bool = False
|
||||
unsloth: bool = False # unsloth gradient checkpointing (it just offloads tensors to the CPU during backwards, I don't think it's significant enough to bother with on small models)
|
||||
|
||||
@dataclass()
|
||||
class Config(BaseConfig):
|
||||
|
|
|
@ -746,10 +746,13 @@ class Dataset(_Dataset):
|
|||
flattened[bucket] = [*_interleaved_reorder(flattened[bucket], self.get_speaker)]
|
||||
# flatten paths
|
||||
self.paths = list(itertools.chain.from_iterable(flattened.values()))
|
||||
elif self.sampler_order == "random":
|
||||
random.shuffle( self.paths )
|
||||
else:
|
||||
# just interleave
|
||||
self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)]
|
||||
|
||||
|
||||
|
||||
# dict of speakers keyed by speaker group
|
||||
self.spkrs_by_spkr_group = {}
|
||||
|
|
|
@ -46,6 +46,7 @@ def main():
|
|||
parser.add_argument("--demo-dir", type=Path, default=None)
|
||||
parser.add_argument("--skip-existing", action="store_true")
|
||||
parser.add_argument("--dataset-dir-name", type=str, default="dataset")
|
||||
parser.add_argument("--dataset-dir-name-prefix", type=str, default=None)
|
||||
parser.add_argument("--sample-from-dataset", action="store_true")
|
||||
parser.add_argument("--skip-loading-dataloader", action="store_true")
|
||||
parser.add_argument("--dataset-samples", type=int, default=0)
|
||||
|
@ -209,6 +210,7 @@ def main():
|
|||
if args.sample_from_dataset:
|
||||
cfg.dataset.cache = False
|
||||
cfg.dataset.sample_type = "path" if len(cfg.dataset.training) < cfg.evaluation.batch_size else "speaker"
|
||||
cfg.dataset.sample_order = "random"
|
||||
cfg.dataset.tasks_list = [ 'tts' ]
|
||||
|
||||
samples_dirs["dataset"] = args.demo_dir / args.dataset_dir_name
|
||||
|
@ -221,16 +223,20 @@ def main():
|
|||
num = args.dataset_samples if args.dataset_samples else length
|
||||
|
||||
for i in trange( num, desc="Sampling dataset for samples" ):
|
||||
index = i if not cfg.dataset.sample_shuffle else random.randint( i, length )
|
||||
index = i if not cfg.dataset.sample_shuffle else random.randint( 0, len( dataloader.dataset ) )
|
||||
batch = dataloader.dataset[i]
|
||||
|
||||
dir = args.demo_dir / args.dataset_dir_name / f'{i}'
|
||||
if args.dataset_dir_name_prefix:
|
||||
dir = args.demo_dir / args.dataset_dir_name / f'{args.dataset_dir_name_prefix}_{i}'
|
||||
else:
|
||||
dir = args.demo_dir / args.dataset_dir_name / f'{i}'
|
||||
|
||||
(dir / "out").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
metadata = batch["metadata"]
|
||||
|
||||
text = get_random_prompt() if args.random_prompts else metadata["text"]
|
||||
#text = get_random_prompt() if args.random_prompts else metadata["text"]
|
||||
text = get_random_prompt() if i >= (num // 2) else metadata["text"]
|
||||
language = metadata["language"].lower()
|
||||
|
||||
prompt = dir / "prompt.wav"
|
||||
|
|
|
@ -232,10 +232,10 @@ class AR_NAR(Base):
|
|||
quant_levels=quant_levels,
|
||||
|
||||
temperature=sampling_temperature,
|
||||
min_temperature=sampling_min_temperature,
|
||||
top_p=sampling_top_p,
|
||||
top_k=sampling_top_k,
|
||||
min_p=sampling_min_p,
|
||||
#min_temperature=sampling_min_temperature,
|
||||
#top_p=sampling_top_p,
|
||||
#top_k=sampling_top_k,
|
||||
#min_p=sampling_min_p,
|
||||
#repetition_penalty=sampling_repetition_penalty,
|
||||
#repetition_penalty_decay=sampling_repetition_penalty_decay,
|
||||
#length_penalty=sampling_length_penalty,
|
||||
|
@ -269,8 +269,8 @@ class AR_NAR(Base):
|
|||
entropies = []
|
||||
|
||||
# ick
|
||||
low_temperature = sampling_repetition_penalty == 1.0 and sampling_temperature < 0.5
|
||||
low_temperature_range = cfg.dataset.frames_per_second * 3
|
||||
low_temperature = False # sampling_repetition_penalty == 1.0 and sampling_temperature == 0.0 #
|
||||
low_temperature_range = cfg.dataset.frames_per_second * 5
|
||||
|
||||
original_sampling_temperature = sampling_temperature
|
||||
original_sampling_repetition_penalty = sampling_repetition_penalty
|
||||
|
@ -302,7 +302,7 @@ class AR_NAR(Base):
|
|||
# to-do: tune these values, maybe have it factor based on confidence scores or something
|
||||
if low_temperature:
|
||||
enabled = n < low_temperature_range
|
||||
sampling_repetition_penalty = 1.5 if enabled else original_sampling_repetition_penalty
|
||||
sampling_repetition_penalty = 1.125 if enabled else original_sampling_repetition_penalty
|
||||
sampling_repetition_penalty_decay = 0.0 if enabled else original_sampling_repetition_penalty_decay
|
||||
sampling_temperature = original_sampling_temperature if enabled else 1.0
|
||||
|
||||
|
|
|
@ -1560,8 +1560,8 @@ class Base(nn.Module):
|
|||
logits = [ logit / temperature for logit in logits ]
|
||||
|
||||
# do DRY sampling
|
||||
if dry_multiplier > 0.0:
|
||||
logits = [ dry_sampling(logit, previous=resps[:, -1].tolist(), factor=dry_multiplier, base=dry_base, allowed_length=dry_allowed_length) for logit, resps in zip( logits, prev_list ) ]
|
||||
if dry_multiplier > 0.0 and prev_list is not None:
|
||||
logits = [ dry_sampling(logit, previous=prevs[:, -1].tolist(), factor=dry_multiplier, base=dry_base, allowed_length=dry_allowed_length) for logit, prevs in zip( logits, prev_list ) ]
|
||||
|
||||
# do mirostat sampling
|
||||
# currently incompatible with beam searching with the way the two are implemented, perhaps a night of brain bashing can make the two work
|
||||
|
|
98
vall_e/utils/unsloth.py
Normal file
98
vall_e/utils/unsloth.py
Normal file
|
@ -0,0 +1,98 @@
|
|||
# lifted from https://gist.github.com/pszemraj/e88ff24ab296b6d89057376b299b368a
|
||||
# to-do: make this work with LoRAs, it complains
|
||||
|
||||
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
import inspect
|
||||
|
||||
|
||||
class Unsloth_Offloaded_Gradient_Checkpointer(torch.autograd.Function):
|
||||
"""
|
||||
Saves VRAM by smartly offloading to RAM.
|
||||
Tiny hit to performance, since we mask the movement via non blocking calls.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@torch.cuda.amp.custom_fwd
|
||||
def forward(ctx, forward_function, hidden_states, *args):
|
||||
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
|
||||
with torch.no_grad():
|
||||
output = forward_function(hidden_states, *args)
|
||||
ctx.save_for_backward(saved_hidden_states)
|
||||
ctx.forward_function = forward_function
|
||||
ctx.args = args
|
||||
|
||||
return output
|
||||
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@torch.cuda.amp.custom_bwd
|
||||
def backward(ctx, dY):
|
||||
(hidden_states,) = ctx.saved_tensors
|
||||
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
|
||||
hidden_states.requires_grad = True
|
||||
with torch.enable_grad():
|
||||
(output,) = ctx.forward_function(hidden_states, *ctx.args)
|
||||
torch.autograd.backward(output, dY)
|
||||
return (
|
||||
None,
|
||||
hidden_states.grad,
|
||||
) + (
|
||||
None,
|
||||
) * len(ctx.args)
|
||||
|
||||
pass
|
||||
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def new_gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
|
||||
#assert gradient_checkpointing_kwargs == None
|
||||
gradient_checkpointing_kwargs = None
|
||||
if not self.supports_gradient_checkpointing:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} does not support gradient checkpointing."
|
||||
)
|
||||
|
||||
gradient_checkpointing_func = Unsloth_Offloaded_Gradient_Checkpointer.apply
|
||||
# For old GC format (transformers < 4.35.0) for models that live on the Hub
|
||||
# we will fall back to the overwritten `_set_gradient_checkpointing` method
|
||||
_is_using_old_format = (
|
||||
"value" in inspect.signature(self._set_gradient_checkpointing).parameters
|
||||
)
|
||||
|
||||
if not _is_using_old_format:
|
||||
self._set_gradient_checkpointing(
|
||||
enable=True, gradient_checkpointing_func=gradient_checkpointing_func
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
if getattr(self, "_hf_peft_config_loaded", False):
|
||||
# When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
|
||||
# we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334
|
||||
# When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate
|
||||
# the gradients to make sure the gradient flows.
|
||||
self.enable_input_require_grads()
|
||||
|
||||
|
||||
def apply_unsloth_offloaded_gradient_checkpoint_monkey_patch():
|
||||
transformers.modeling_utils.PreTrainedModel.gradient_checkpointing_enable = (
|
||||
new_gradient_checkpointing_enable
|
||||
)
|
|
@ -101,6 +101,14 @@ if cfg.optimizations.tensorrt:
|
|||
_logger.warning(f'Error while importing TensorRT: {str(e)}')
|
||||
pass
|
||||
|
||||
if cfg.optimizations.unsloth:
|
||||
try:
|
||||
from .unsloth import apply_unsloth_offloaded_gradient_checkpoint_monkey_patch
|
||||
#apply_unsloth_offloaded_gradient_checkpoint_monkey_patch()
|
||||
except Exception as e:
|
||||
_logger.warning(f'Error while importing Unsloth: {str(e)}')
|
||||
pass
|
||||
|
||||
def compile_model(model, backend="auto"):
|
||||
if not backend or backend == "auto":
|
||||
backend = AVAILABLE_COMPILE_BACKENDS[0]
|
||||
|
|
|
@ -367,7 +367,7 @@ with ui:
|
|||
layout["inference_tts"]["inputs"]["min-p"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Min P")
|
||||
layout["inference_tts"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.")
|
||||
with gr.Row():
|
||||
layout["inference_tts"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
|
||||
layout["inference_tts"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.125, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
|
||||
layout["inference_tts"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.")
|
||||
layout["inference_tts"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.")
|
||||
with gr.Row():
|
||||
|
|
Loading…
Reference in New Issue
Block a user