validated that inferencing works, changed some defaults (NAR benefits from greedy sampling)

This commit is contained in:
mrq 2024-06-09 17:11:38 -05:00
parent 234f9efc6e
commit a7a6e0ac76
6 changed files with 57 additions and 34 deletions

View File

@ -14,18 +14,18 @@ def main():
parser.add_argument("--yaml", type=Path, default=None) parser.add_argument("--yaml", type=Path, default=None)
parser.add_argument("--max-ar-steps", type=int, default=6 * 75) parser.add_argument("--max-ar-steps", type=int, default=12 * cfg.dataset.frames_per_second)
parser.add_argument("--max-nar-levels", type=int, default=7) parser.add_argument("--max-nar-levels", type=int, default=7)
parser.add_argument("--max-ar-context", type=int, default=-1) parser.add_argument("--max-ar-context", type=int, default=-1)
parser.add_argument("--ar-temp", type=float, default=1.0) parser.add_argument("--ar-temp", type=float, default=1.0)
parser.add_argument("--nar-temp", type=float, default=1.0) parser.add_argument("--nar-temp", type=float, default=0.01)
parser.add_argument("--min-ar-temp", type=float, default=-1.0) parser.add_argument("--min-ar-temp", type=float, default=-1.0)
parser.add_argument("--min-nar-temp", type=float, default=-1.0) parser.add_argument("--min-nar-temp", type=float, default=-1.0)
parser.add_argument("--input-prompt-length", type=float, default=3.0) parser.add_argument("--input-prompt-length", type=float, default=3.0)
parser.add_argument("--top-p", type=float, default=1.0) parser.add_argument("--top-p", type=float, default=1.0)
parser.add_argument("--top-k", type=int, default=0) parser.add_argument("--top-k", type=int, default=16)
parser.add_argument("--repetition-penalty", type=float, default=1.0) parser.add_argument("--repetition-penalty", type=float, default=1.0)
parser.add_argument("--repetition-penalty-decay", type=float, default=0.0) parser.add_argument("--repetition-penalty-decay", type=float, default=0.0)
parser.add_argument("--length-penalty", type=float, default=0.0) parser.add_argument("--length-penalty", type=float, default=0.0)

View File

@ -87,7 +87,10 @@ class BaseConfig:
@classmethod @classmethod
def from_yaml( cls, yaml_path ): def from_yaml( cls, yaml_path ):
return cls.from_cli( [f'--yaml="{yaml_path}"'] ) state = {}
state = yaml.safe_load(open(yaml_path, "r", encoding="utf-8"))
state.setdefault("yaml_path", yaml_path)
return cls(**state)
@classmethod @classmethod
def from_cli(cls, args=sys.argv): def from_cli(cls, args=sys.argv):
@ -100,13 +103,10 @@ class BaseConfig:
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
args, unknown = parser.parse_known_args(args=args) args, unknown = parser.parse_known_args(args=args)
state = {}
if args.yaml: if args.yaml:
yaml_path = args.yaml return cls.from_yaml( args.yaml )
state = yaml.safe_load(open(yaml_path, "r", encoding="utf-8"))
state.setdefault("yaml_path", yaml_path)
return cls(**state) return cls(**{})
def __repr__(self): def __repr__(self):
return str(self) return str(self)
@ -678,7 +678,7 @@ class Config(BaseConfig):
print("Error while opening HDF5 file:", f'{self.rel_path}/{self.dataset.hdf5_name}', str(e)) print("Error while opening HDF5 file:", f'{self.rel_path}/{self.dataset.hdf5_name}', str(e))
self.dataset.use_hdf5 = False self.dataset.use_hdf5 = False
def format( self ): def format( self, training=True ):
if isinstance(self.dataset, type): if isinstance(self.dataset, type):
self.dataset = dict() self.dataset = dict()
@ -753,10 +753,24 @@ class Config(BaseConfig):
if self.trainer.activation_checkpointing is not None: if self.trainer.activation_checkpointing is not None:
self.trainer.gradient_checkpointing = self.trainer.activation_checkpointing self.trainer.gradient_checkpointing = self.trainer.activation_checkpointing
if not training:
self.dataset.use_hdf5 = False
# load our HDF5 file if requested here # load our HDF5 file if requested here
if self.dataset.use_hdf5: if self.dataset.use_hdf5:
self.load_hdf5() self.load_hdf5()
# load tokenizer
try:
from transformers import PreTrainedTokenizerFast
cfg.tokenizer = (cfg.rel_path if cfg.yaml_path is not None else Path("./data/")) / cfg.tokenizer
cfg.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(cfg.tokenizer))
except Exception as e:
cfg.tokenizer = NaiveTokenizer()
print("Error while parsing tokenizer:", e)
pass
# Preserves the old behavior # Preserves the old behavior
class NaiveTokenizer: class NaiveTokenizer:
def get_vocab( self ): def get_vocab( self ):
@ -792,14 +806,5 @@ except Exception as e:
print("Error while parsing config YAML:") print("Error while parsing config YAML:")
raise e # throw an error because I'm tired of silent errors messing things up for me raise e # throw an error because I'm tired of silent errors messing things up for me
try:
from transformers import PreTrainedTokenizerFast
cfg.tokenizer = (cfg.rel_path if cfg.yaml_path is not None else Path("./data/")) / cfg.tokenizer
cfg.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(cfg.tokenizer))
except Exception as e:
cfg.tokenizer = NaiveTokenizer()
print("Error while parsing tokenizer:", e)
pass
if __name__ == "__main__": if __name__ == "__main__":
print(cfg) print(cfg)

View File

@ -371,7 +371,7 @@ class Engines(dict[str, Engine]):
def set_lr(self, lr): def set_lr(self, lr):
for engine in self.values(): for engine in self.values():
if not engine.training: if not engine._training:
continue continue
engine.set_lr(lr) engine.set_lr(lr)
@ -406,7 +406,7 @@ class Engines(dict[str, Engine]):
do_gc() do_gc()
for name, engine in self.items(): for name, engine in self.items():
if not engine.training: if not engine._training:
continue continue
device = engine.device device = engine.device

View File

@ -27,10 +27,10 @@ class TTS():
if config: if config:
cfg.load_yaml( config ) cfg.load_yaml( config )
cfg.dataset.use_hdf5 = False # could use cfg.load_hdf5(), but why would it ever need to be loaded for inferencing
try: try:
cfg.format() cfg.format( training=False )
cfg.dataset.use_hdf5 = False # could use cfg.load_hdf5(), but why would it ever need to be loaded for inferencing
except Exception as e: except Exception as e:
print("Error while parsing config YAML:") print("Error while parsing config YAML:")
raise e # throw an error because I'm tired of silent errors messing things up for me raise e # throw an error because I'm tired of silent errors messing things up for me
@ -161,7 +161,19 @@ class TTS():
phns = to_device(phns, self.device).to(torch.uint8 if len(self.symmap) < 256 else torch.int16) phns = to_device(phns, self.device).to(torch.uint8 if len(self.symmap) < 256 else torch.int16)
lang = to_device(lang, self.device).to(torch.uint8) lang = to_device(lang, self.device).to(torch.uint8)
text_list = [ phns ]
proms_list = [ prom ]
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp): with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
# AR temp: 1
# NAR temp: 0.05
# prom size: 3
"""
resps_list = engine(text_list=text_list, proms_list=proms_list, max_steps=max_ar_steps, sampling_temperature=ar_temp)
resps_list = engine(text_list=text_list, proms_list=proms_list, resps_list=resps_list, sampling_temperature=nar_temp)
"""
resps_list = model_ar( resps_list = model_ar(
text_list=[phns], proms_list=[prom], lang_list=[lang], max_steps=max_ar_steps, max_resp_context=max_ar_context, text_list=[phns], proms_list=[prom], lang_list=[lang], max_steps=max_ar_steps, max_resp_context=max_ar_context,
sampling_temperature=ar_temp, sampling_temperature=ar_temp,
@ -181,6 +193,8 @@ class TTS():
sampling_top_p=top_p, sampling_top_k=top_k, sampling_top_p=top_p, sampling_top_k=top_k,
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
) )
"""
"""
wav, sr = qnt.decode_to_file(resps_list[0], out_path, device=self.device) wav, sr = qnt.decode_to_file(resps_list[0], out_path, device=self.device)
wavs.append(wav) wavs.append(wav)

View File

@ -365,7 +365,7 @@ class Base(nn.Module):
self.model = MistralModel(MistralConfig( self.model = MistralModel(MistralConfig(
vocab_size=n_resp_tokens, vocab_size=n_resp_tokens,
hidden_size=d_model, hidden_size=d_model,
max_position_embeddings=75 * 60, # max-length of 60 seconds max_position_embeddings=75 * 60 * 5, # max-length of 60 seconds
intermediate_size=d_model*4, intermediate_size=d_model*4,
num_hidden_layers=n_layers, num_hidden_layers=n_layers,
num_attention_heads=n_heads, num_attention_heads=n_heads,
@ -381,7 +381,7 @@ class Base(nn.Module):
self.model = MixtralModel(MixtralConfig( self.model = MixtralModel(MixtralConfig(
vocab_size =n_resp_tokens, vocab_size =n_resp_tokens,
hidden_size=d_model, hidden_size=d_model,
max_position_embeddings=75 * 60, # max-length of 60 seconds max_position_embeddings=75 * 60 * 5, # max-length of 60 seconds
intermediate_size=d_model*4, intermediate_size=d_model*4,
num_hidden_layers=n_layers, num_hidden_layers=n_layers,
num_attention_heads=n_heads, num_attention_heads=n_heads,
@ -410,7 +410,7 @@ class Base(nn.Module):
self.model = LlamaModel(LlamaConfig( self.model = LlamaModel(LlamaConfig(
vocab_size=n_resp_tokens, vocab_size=n_resp_tokens,
hidden_size=d_model, hidden_size=d_model,
max_position_embeddings=75 * 60, # max-length of 60 seconds max_position_embeddings=75 * 60 * 5, # max-length of 60 seconds
intermediate_size=d_model*4, intermediate_size=d_model*4,
num_hidden_layers=n_layers, num_hidden_layers=n_layers,
num_attention_heads=n_heads, num_attention_heads=n_heads,
@ -427,7 +427,7 @@ class Base(nn.Module):
self.model = MixtralModel(MixtralConfig( self.model = MixtralModel(MixtralConfig(
vocab_size =n_resp_tokens, vocab_size =n_resp_tokens,
hidden_size=d_model, hidden_size=d_model,
max_position_embeddings=75 * 60, # max-length of 60 seconds max_position_embeddings=75 * 60 * 5, # max-length of 60 seconds
intermediate_size=d_model*4, intermediate_size=d_model*4,
num_hidden_layers=n_layers, num_hidden_layers=n_layers,
num_attention_heads=n_heads, num_attention_heads=n_heads,
@ -984,6 +984,10 @@ class Base(nn.Module):
# perform repetition penalizing # perform repetition penalizing
logits = [ reptition_penalize(logit, previous=resps[:, -1], factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ] logits = [ reptition_penalize(logit, previous=resps[:, -1], factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ]
# argmax instead
if temperature <= 0.0:
return [ logit.argmax(dim=1) for logit in logits ]
# (AR) perform length penalizing # (AR) perform length penalizing
if quant_levels is None and self.causal: if quant_levels is None and self.causal:
logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, resps_list) ) ] logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, resps_list) ) ]

View File

@ -11,7 +11,7 @@ import gradio as gr
from time import perf_counter from time import perf_counter
from pathlib import Path from pathlib import Path
from .inference import TTS from .inference import TTS, cfg
from .train import train from .train import train
tts = None tts = None
@ -66,7 +66,7 @@ def init_tts(restart=False):
def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
if kwargs.pop("dynamic-sampling", False): if kwargs.pop("dynamic-sampling", False):
kwargs['min-ar-temp'] = 0.85 if kwargs['ar-temp'] > 0.85 else 0.0 kwargs['min-ar-temp'] = 0.85 if kwargs['ar-temp'] > 0.85 else 0.0
kwargs['min-nar-temp'] = 0.2 if kwargs['nar-temp'] > 0.2 else 0.0 kwargs['min-nar-temp'] = 0.85 if kwargs['nar-temp'] > 0.85 else 0.0 # should probably disable it for the NAR
else: else:
kwargs['min-ar-temp'] = -1 kwargs['min-ar-temp'] = -1
kwargs['min-nar-temp'] = -1 kwargs['min-nar-temp'] = -1
@ -77,8 +77,8 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
parser.add_argument("--references", type=str, default=kwargs["reference"]) parser.add_argument("--references", type=str, default=kwargs["reference"])
parser.add_argument("--language", type=str, default="en") parser.add_argument("--language", type=str, default="en")
parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"]) parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"])
parser.add_argument("--max-ar-steps", type=int, default=int(kwargs["max-seconds"]*75)) parser.add_argument("--max-ar-steps", type=int, default=int(kwargs["max-seconds"]*cfg.dataset.frames_per_second))
parser.add_argument("--max-ar-context", type=int, default=int(kwargs["max-seconds-context"]*75)) parser.add_argument("--max-ar-context", type=int, default=int(kwargs["max-seconds-context"]*cfg.dataset.frames_per_second))
parser.add_argument("--max-nar-levels", type=int, default=kwargs["max-nar-levels"]) parser.add_argument("--max-nar-levels", type=int, default=kwargs["max-nar-levels"])
parser.add_argument("--ar-temp", type=float, default=kwargs["ar-temp"]) parser.add_argument("--ar-temp", type=float, default=kwargs["ar-temp"])
parser.add_argument("--nar-temp", type=float, default=kwargs["nar-temp"]) parser.add_argument("--nar-temp", type=float, default=kwargs["nar-temp"])
@ -208,13 +208,13 @@ with ui:
layout["inference"]["buttons"]["inference"] = gr.Button(value="Inference") layout["inference"]["buttons"]["inference"] = gr.Button(value="Inference")
with gr.Column(scale=7): with gr.Column(scale=7):
with gr.Row(): with gr.Row():
layout["inference"]["inputs"]["max-seconds"] = gr.Slider(value=6, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="Limits how many steps to perform in the AR pass.") layout["inference"]["inputs"]["max-seconds"] = gr.Slider(value=12, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="Limits how many steps to perform in the AR pass.")
layout["inference"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.") layout["inference"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
layout["inference"]["inputs"]["input-prompt-length"] = gr.Slider(value=3.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Trim Length", info="Trims the input prompt down to X seconds. Set 0 to disable.") layout["inference"]["inputs"]["input-prompt-length"] = gr.Slider(value=3.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Trim Length", info="Trims the input prompt down to X seconds. Set 0 to disable.")
layout["inference"]["inputs"]["max-seconds-context"] = gr.Slider(value=0.0, minimum=0.0, maximum=12.0, step=0.05, label="Context Length", info="Amount of generated audio to keep in the context during inference, in seconds. Set 0 to disable.") layout["inference"]["inputs"]["max-seconds-context"] = gr.Slider(value=0.0, minimum=0.0, maximum=12.0, step=0.05, label="Context Length", info="Amount of generated audio to keep in the context during inference, in seconds. Set 0 to disable.")
with gr.Row(): with gr.Row():
layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR.") layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)")
layout["inference"]["inputs"]["nar-temp"] = gr.Slider(value=0.25, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR.") layout["inference"]["inputs"]["nar-temp"] = gr.Slider(value=0.01, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)")
with gr.Row(): with gr.Row():
layout["inference"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.") layout["inference"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")