validated that inferencing works, changed some defaults (NAR benefits from greedy sampling)
This commit is contained in:
parent
234f9efc6e
commit
a7a6e0ac76
|
@ -14,18 +14,18 @@ def main():
|
|||
|
||||
parser.add_argument("--yaml", type=Path, default=None)
|
||||
|
||||
parser.add_argument("--max-ar-steps", type=int, default=6 * 75)
|
||||
parser.add_argument("--max-ar-steps", type=int, default=12 * cfg.dataset.frames_per_second)
|
||||
parser.add_argument("--max-nar-levels", type=int, default=7)
|
||||
parser.add_argument("--max-ar-context", type=int, default=-1)
|
||||
|
||||
parser.add_argument("--ar-temp", type=float, default=1.0)
|
||||
parser.add_argument("--nar-temp", type=float, default=1.0)
|
||||
parser.add_argument("--nar-temp", type=float, default=0.01)
|
||||
parser.add_argument("--min-ar-temp", type=float, default=-1.0)
|
||||
parser.add_argument("--min-nar-temp", type=float, default=-1.0)
|
||||
parser.add_argument("--input-prompt-length", type=float, default=3.0)
|
||||
|
||||
parser.add_argument("--top-p", type=float, default=1.0)
|
||||
parser.add_argument("--top-k", type=int, default=0)
|
||||
parser.add_argument("--top-k", type=int, default=16)
|
||||
parser.add_argument("--repetition-penalty", type=float, default=1.0)
|
||||
parser.add_argument("--repetition-penalty-decay", type=float, default=0.0)
|
||||
parser.add_argument("--length-penalty", type=float, default=0.0)
|
||||
|
|
|
@ -87,7 +87,10 @@ class BaseConfig:
|
|||
|
||||
@classmethod
|
||||
def from_yaml( cls, yaml_path ):
|
||||
return cls.from_cli( [f'--yaml="{yaml_path}"'] )
|
||||
state = {}
|
||||
state = yaml.safe_load(open(yaml_path, "r", encoding="utf-8"))
|
||||
state.setdefault("yaml_path", yaml_path)
|
||||
return cls(**state)
|
||||
|
||||
@classmethod
|
||||
def from_cli(cls, args=sys.argv):
|
||||
|
@ -100,13 +103,10 @@ class BaseConfig:
|
|||
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
|
||||
args, unknown = parser.parse_known_args(args=args)
|
||||
|
||||
state = {}
|
||||
if args.yaml:
|
||||
yaml_path = args.yaml
|
||||
state = yaml.safe_load(open(yaml_path, "r", encoding="utf-8"))
|
||||
state.setdefault("yaml_path", yaml_path)
|
||||
return cls.from_yaml( args.yaml )
|
||||
|
||||
return cls(**state)
|
||||
return cls(**{})
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
@ -678,7 +678,7 @@ class Config(BaseConfig):
|
|||
print("Error while opening HDF5 file:", f'{self.rel_path}/{self.dataset.hdf5_name}', str(e))
|
||||
self.dataset.use_hdf5 = False
|
||||
|
||||
def format( self ):
|
||||
def format( self, training=True ):
|
||||
if isinstance(self.dataset, type):
|
||||
self.dataset = dict()
|
||||
|
||||
|
@ -753,10 +753,24 @@ class Config(BaseConfig):
|
|||
if self.trainer.activation_checkpointing is not None:
|
||||
self.trainer.gradient_checkpointing = self.trainer.activation_checkpointing
|
||||
|
||||
if not training:
|
||||
self.dataset.use_hdf5 = False
|
||||
|
||||
# load our HDF5 file if requested here
|
||||
if self.dataset.use_hdf5:
|
||||
self.load_hdf5()
|
||||
|
||||
# load tokenizer
|
||||
try:
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
cfg.tokenizer = (cfg.rel_path if cfg.yaml_path is not None else Path("./data/")) / cfg.tokenizer
|
||||
cfg.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(cfg.tokenizer))
|
||||
except Exception as e:
|
||||
cfg.tokenizer = NaiveTokenizer()
|
||||
print("Error while parsing tokenizer:", e)
|
||||
pass
|
||||
|
||||
|
||||
# Preserves the old behavior
|
||||
class NaiveTokenizer:
|
||||
def get_vocab( self ):
|
||||
|
@ -792,14 +806,5 @@ except Exception as e:
|
|||
print("Error while parsing config YAML:")
|
||||
raise e # throw an error because I'm tired of silent errors messing things up for me
|
||||
|
||||
try:
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
cfg.tokenizer = (cfg.rel_path if cfg.yaml_path is not None else Path("./data/")) / cfg.tokenizer
|
||||
cfg.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(cfg.tokenizer))
|
||||
except Exception as e:
|
||||
cfg.tokenizer = NaiveTokenizer()
|
||||
print("Error while parsing tokenizer:", e)
|
||||
pass
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(cfg)
|
||||
|
|
|
@ -371,7 +371,7 @@ class Engines(dict[str, Engine]):
|
|||
|
||||
def set_lr(self, lr):
|
||||
for engine in self.values():
|
||||
if not engine.training:
|
||||
if not engine._training:
|
||||
continue
|
||||
engine.set_lr(lr)
|
||||
|
||||
|
@ -406,7 +406,7 @@ class Engines(dict[str, Engine]):
|
|||
do_gc()
|
||||
|
||||
for name, engine in self.items():
|
||||
if not engine.training:
|
||||
if not engine._training:
|
||||
continue
|
||||
|
||||
device = engine.device
|
||||
|
|
|
@ -27,10 +27,10 @@ class TTS():
|
|||
|
||||
if config:
|
||||
cfg.load_yaml( config )
|
||||
cfg.dataset.use_hdf5 = False # could use cfg.load_hdf5(), but why would it ever need to be loaded for inferencing
|
||||
|
||||
try:
|
||||
cfg.format()
|
||||
cfg.format( training=False )
|
||||
cfg.dataset.use_hdf5 = False # could use cfg.load_hdf5(), but why would it ever need to be loaded for inferencing
|
||||
except Exception as e:
|
||||
print("Error while parsing config YAML:")
|
||||
raise e # throw an error because I'm tired of silent errors messing things up for me
|
||||
|
@ -161,7 +161,19 @@ class TTS():
|
|||
phns = to_device(phns, self.device).to(torch.uint8 if len(self.symmap) < 256 else torch.int16)
|
||||
lang = to_device(lang, self.device).to(torch.uint8)
|
||||
|
||||
text_list = [ phns ]
|
||||
proms_list = [ prom ]
|
||||
|
||||
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
|
||||
# AR temp: 1
|
||||
# NAR temp: 0.05
|
||||
# prom size: 3
|
||||
|
||||
"""
|
||||
resps_list = engine(text_list=text_list, proms_list=proms_list, max_steps=max_ar_steps, sampling_temperature=ar_temp)
|
||||
resps_list = engine(text_list=text_list, proms_list=proms_list, resps_list=resps_list, sampling_temperature=nar_temp)
|
||||
"""
|
||||
|
||||
resps_list = model_ar(
|
||||
text_list=[phns], proms_list=[prom], lang_list=[lang], max_steps=max_ar_steps, max_resp_context=max_ar_context,
|
||||
sampling_temperature=ar_temp,
|
||||
|
@ -181,6 +193,8 @@ class TTS():
|
|||
sampling_top_p=top_p, sampling_top_k=top_k,
|
||||
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
|
||||
)
|
||||
"""
|
||||
"""
|
||||
|
||||
wav, sr = qnt.decode_to_file(resps_list[0], out_path, device=self.device)
|
||||
wavs.append(wav)
|
||||
|
|
|
@ -365,7 +365,7 @@ class Base(nn.Module):
|
|||
self.model = MistralModel(MistralConfig(
|
||||
vocab_size=n_resp_tokens,
|
||||
hidden_size=d_model,
|
||||
max_position_embeddings=75 * 60, # max-length of 60 seconds
|
||||
max_position_embeddings=75 * 60 * 5, # max-length of 60 seconds
|
||||
intermediate_size=d_model*4,
|
||||
num_hidden_layers=n_layers,
|
||||
num_attention_heads=n_heads,
|
||||
|
@ -381,7 +381,7 @@ class Base(nn.Module):
|
|||
self.model = MixtralModel(MixtralConfig(
|
||||
vocab_size =n_resp_tokens,
|
||||
hidden_size=d_model,
|
||||
max_position_embeddings=75 * 60, # max-length of 60 seconds
|
||||
max_position_embeddings=75 * 60 * 5, # max-length of 60 seconds
|
||||
intermediate_size=d_model*4,
|
||||
num_hidden_layers=n_layers,
|
||||
num_attention_heads=n_heads,
|
||||
|
@ -410,7 +410,7 @@ class Base(nn.Module):
|
|||
self.model = LlamaModel(LlamaConfig(
|
||||
vocab_size=n_resp_tokens,
|
||||
hidden_size=d_model,
|
||||
max_position_embeddings=75 * 60, # max-length of 60 seconds
|
||||
max_position_embeddings=75 * 60 * 5, # max-length of 60 seconds
|
||||
intermediate_size=d_model*4,
|
||||
num_hidden_layers=n_layers,
|
||||
num_attention_heads=n_heads,
|
||||
|
@ -427,7 +427,7 @@ class Base(nn.Module):
|
|||
self.model = MixtralModel(MixtralConfig(
|
||||
vocab_size =n_resp_tokens,
|
||||
hidden_size=d_model,
|
||||
max_position_embeddings=75 * 60, # max-length of 60 seconds
|
||||
max_position_embeddings=75 * 60 * 5, # max-length of 60 seconds
|
||||
intermediate_size=d_model*4,
|
||||
num_hidden_layers=n_layers,
|
||||
num_attention_heads=n_heads,
|
||||
|
@ -984,6 +984,10 @@ class Base(nn.Module):
|
|||
# perform repetition penalizing
|
||||
logits = [ reptition_penalize(logit, previous=resps[:, -1], factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ]
|
||||
|
||||
# argmax instead
|
||||
if temperature <= 0.0:
|
||||
return [ logit.argmax(dim=1) for logit in logits ]
|
||||
|
||||
# (AR) perform length penalizing
|
||||
if quant_levels is None and self.causal:
|
||||
logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, resps_list) ) ]
|
||||
|
|
|
@ -11,7 +11,7 @@ import gradio as gr
|
|||
from time import perf_counter
|
||||
from pathlib import Path
|
||||
|
||||
from .inference import TTS
|
||||
from .inference import TTS, cfg
|
||||
from .train import train
|
||||
|
||||
tts = None
|
||||
|
@ -66,7 +66,7 @@ def init_tts(restart=False):
|
|||
def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||
if kwargs.pop("dynamic-sampling", False):
|
||||
kwargs['min-ar-temp'] = 0.85 if kwargs['ar-temp'] > 0.85 else 0.0
|
||||
kwargs['min-nar-temp'] = 0.2 if kwargs['nar-temp'] > 0.2 else 0.0
|
||||
kwargs['min-nar-temp'] = 0.85 if kwargs['nar-temp'] > 0.85 else 0.0 # should probably disable it for the NAR
|
||||
else:
|
||||
kwargs['min-ar-temp'] = -1
|
||||
kwargs['min-nar-temp'] = -1
|
||||
|
@ -77,8 +77,8 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
parser.add_argument("--references", type=str, default=kwargs["reference"])
|
||||
parser.add_argument("--language", type=str, default="en")
|
||||
parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"])
|
||||
parser.add_argument("--max-ar-steps", type=int, default=int(kwargs["max-seconds"]*75))
|
||||
parser.add_argument("--max-ar-context", type=int, default=int(kwargs["max-seconds-context"]*75))
|
||||
parser.add_argument("--max-ar-steps", type=int, default=int(kwargs["max-seconds"]*cfg.dataset.frames_per_second))
|
||||
parser.add_argument("--max-ar-context", type=int, default=int(kwargs["max-seconds-context"]*cfg.dataset.frames_per_second))
|
||||
parser.add_argument("--max-nar-levels", type=int, default=kwargs["max-nar-levels"])
|
||||
parser.add_argument("--ar-temp", type=float, default=kwargs["ar-temp"])
|
||||
parser.add_argument("--nar-temp", type=float, default=kwargs["nar-temp"])
|
||||
|
@ -208,13 +208,13 @@ with ui:
|
|||
layout["inference"]["buttons"]["inference"] = gr.Button(value="Inference")
|
||||
with gr.Column(scale=7):
|
||||
with gr.Row():
|
||||
layout["inference"]["inputs"]["max-seconds"] = gr.Slider(value=6, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="Limits how many steps to perform in the AR pass.")
|
||||
layout["inference"]["inputs"]["max-seconds"] = gr.Slider(value=12, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="Limits how many steps to perform in the AR pass.")
|
||||
layout["inference"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
|
||||
layout["inference"]["inputs"]["input-prompt-length"] = gr.Slider(value=3.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Trim Length", info="Trims the input prompt down to X seconds. Set 0 to disable.")
|
||||
layout["inference"]["inputs"]["max-seconds-context"] = gr.Slider(value=0.0, minimum=0.0, maximum=12.0, step=0.05, label="Context Length", info="Amount of generated audio to keep in the context during inference, in seconds. Set 0 to disable.")
|
||||
with gr.Row():
|
||||
layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR.")
|
||||
layout["inference"]["inputs"]["nar-temp"] = gr.Slider(value=0.25, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR.")
|
||||
layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)")
|
||||
layout["inference"]["inputs"]["nar-temp"] = gr.Slider(value=0.01, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)")
|
||||
with gr.Row():
|
||||
layout["inference"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user