test trainer (vall_e.models.ar_nar) tests some SpeechX features
This commit is contained in:
parent
83a0954f85
commit
39f961abcd
|
@ -1,4 +1,4 @@
|
|||
sample_rate: 24_000 # 44_000 for dac
|
||||
sample_rate: 24_000 # 44_000 / 44_100 for dac
|
||||
audio_backend: "vocos" # or dac
|
||||
|
||||
# model definitions to train
|
||||
|
@ -7,7 +7,7 @@ models:
|
|||
size: "full" # model dimensionality
|
||||
resp_levels: 8 # RVQ levels this model targets
|
||||
prom_levels: 8 # should always be the above
|
||||
tasks: 8 # tasks this model can attend to, only tts is supported at the moment
|
||||
tasks: 8 # tasks this model can attend to, only tts is guaranteed results at the moment
|
||||
langs: 2 # languages this model supports, semi-unused at the moment
|
||||
tones: 1 # tones this model supports, currently unused
|
||||
arch_type: llama # underlying LLM arch to use, currently focusing on llama
|
||||
|
@ -19,7 +19,7 @@ models:
|
|||
# factors for split loss values, remove to have a unified loss calculation
|
||||
loss_factors:
|
||||
text: 0.1 # text phoneme portion of the sequence
|
||||
prom: 0.0 # input prompt portion of the sequence
|
||||
prom: 0.5 # input prompt portion of the sequence
|
||||
resp: 1.0 # output audio portin of the sequence
|
||||
|
||||
# experimental settings
|
||||
|
@ -28,7 +28,8 @@ models:
|
|||
interleave: False # interleaves RVQ levels, only works with above for now
|
||||
audio_embedding_mode: "" # "" | "inclusive" | "exclusive", whether to utilize the audio backend's embeddings with the input embeddings
|
||||
audio_embedding_sums: False # whether the input embeddings include all prior RVQ levels (sums) or only the current one, further experimentation is needed to see if this matters
|
||||
p_rvq_levels: "equal" # "equal" | "auto", sets probabilities of which RVQ level to select during training, auto will have the next RVQ level half as likely as the previous one
|
||||
p_rvq_levels: "auto" # "equal" | "auto", sets probabilities of which RVQ level to select during training, auto will have the next RVQ level half as likely as the previous one
|
||||
unified_position_ids: False # specifies whether or not position IDs should be continuous across the whole sequence (if True, naive behavior), or restart them at the next segment of the sequence (if False)
|
||||
|
||||
# hyperparameter settings (could be relegated to trainer settings)
|
||||
hyperparameters:
|
||||
|
|
BIN
data/noise.dac
Normal file
BIN
data/noise.dac
Normal file
Binary file not shown.
BIN
data/noise.enc
Normal file
BIN
data/noise.enc
Normal file
Binary file not shown.
|
@ -1063,10 +1063,7 @@ def create_datasets():
|
|||
def create_train_val_dataloader():
|
||||
train_dataset, val_dataset = create_datasets()
|
||||
|
||||
# it'll cry about trying to pickle a torch._C_generator or something
|
||||
try:
|
||||
subtrain_dataset = copy.deepcopy(train_dataset)
|
||||
except Exception as e:
|
||||
# deepcopy is slow
|
||||
subtrain_dataset = Dataset( training=True )
|
||||
|
||||
if subtrain_dataset.sampler_type == "path":
|
||||
|
|
|
@ -361,7 +361,7 @@ def example_usage():
|
|||
from einops import repeat
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..emb.qnt import decode_to_file, unload_model
|
||||
from ..emb.qnt import decode_to_file, unload_model, trim_random, repeat_extend_audio, concat_audio, merge_audio
|
||||
from ..engines import Engine
|
||||
from ..utils import wrapper as ml
|
||||
|
||||
|
@ -385,7 +385,7 @@ def example_usage():
|
|||
return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.resp_levels, :].t().to(torch.int16)
|
||||
|
||||
qnt = _load_quants(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")
|
||||
|
||||
noise = _load_quants(f"./data/noise.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")
|
||||
|
||||
text_list = [
|
||||
tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device),
|
||||
|
@ -404,6 +404,8 @@ def example_usage():
|
|||
proms_list = proms_list[:1]
|
||||
resps_list = resps_list[:1]
|
||||
|
||||
batch_size = len(text_list)
|
||||
|
||||
# rentet-full is the only configuration with BitNet's BitLinear that converges despite the grad_norm saying otherwise
|
||||
kwargs = {
|
||||
'n_text_tokens': 256,
|
||||
|
@ -428,8 +430,11 @@ def example_usage():
|
|||
pass
|
||||
"""
|
||||
|
||||
bos_id, space_id, eos_id = cfg.tokenizer.encode( " " )
|
||||
tasks = cfg.dataset.tasks_list
|
||||
|
||||
model = AR_NAR(**kwargs).to(device)
|
||||
steps = 150
|
||||
steps = 150 * len(tasks)
|
||||
|
||||
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
|
||||
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
||||
|
@ -497,22 +502,61 @@ def example_usage():
|
|||
|
||||
print(f"AR+NAR ({cfg.model.arch_type}, {cfg.audio_backend}) parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||
|
||||
@torch.inference_mode()
|
||||
def sample( name, steps=1000 ):
|
||||
if cfg.audio_backend == "dac" and name == "init":
|
||||
return
|
||||
@torch.no_grad()
|
||||
def sample_data(task=None):
|
||||
texts = []
|
||||
proms = []
|
||||
resps = []
|
||||
|
||||
for i in range(batch_size):
|
||||
if task is None:
|
||||
task = random.choice(tasks)
|
||||
|
||||
text = text_list[i]
|
||||
prom = proms_list[i]
|
||||
resp = resps_list[i]
|
||||
|
||||
# do nothing
|
||||
if task == "tts":
|
||||
...
|
||||
elif task == "tts-c":
|
||||
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second)
|
||||
|
||||
prom = resp[:trim_length]
|
||||
resp = resp[trim_length:]
|
||||
elif task == "ns" or task == "sr":
|
||||
# extend the noise to fill the target audio
|
||||
noise_ext = repeat_extend_audio( noise, resp.shape[0] )
|
||||
# create the input prompt by merging the target audio with the noise
|
||||
prom = merge_audio( resp.cpu(), noise_ext, scale=[1, cfg.dataset.noise_scale], device=cfg.dataset.reencode_device )
|
||||
# set the target to just be the noise if <sr>
|
||||
if task == "sr":
|
||||
resp = noise_ext
|
||||
|
||||
# set the text prompt to empty to train without a guided text prompt
|
||||
if random.random() < 0.5:
|
||||
text = torch.tensor([bos_id, eos_id]).to(device=device, dtype=torch.uint8)
|
||||
|
||||
texts.append( text.to(device) )
|
||||
proms.append( prom.to(device) )
|
||||
resps.append( resp.to(device) )
|
||||
|
||||
return texts, proms, resps
|
||||
|
||||
@torch.inference_mode()
|
||||
def sample( name, steps=1000, task=None ):
|
||||
engine.eval()
|
||||
|
||||
texts, proms, resps = sample_data( task )
|
||||
|
||||
if "ar" in cfg.model.capabilities:
|
||||
resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 )
|
||||
else:
|
||||
resps_list = [ qnt[:, 0].to( device ) ]
|
||||
resps = engine( texts, proms, max_steps=steps, sampling_temperature=0.95 )
|
||||
|
||||
if "nar" in cfg.model.capabilities:
|
||||
resps_list = engine( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.2 )
|
||||
resps = engine( texts, proms, resps, sampling_temperature=0.2 )
|
||||
|
||||
for i, o in enumerate(resps_list):
|
||||
_ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{name}.wav", device=device)
|
||||
for i, o in enumerate(resps):
|
||||
_ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{task}.{name}.wav", device=device)
|
||||
|
||||
unload_model()
|
||||
|
||||
|
@ -520,8 +564,10 @@ def example_usage():
|
|||
engine.train()
|
||||
t = trange(steps)
|
||||
for i in t:
|
||||
texts, proms, resps = sample_data()
|
||||
|
||||
stats = {"step": i}
|
||||
stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list)
|
||||
stats |= engine.traverse(text_list=texts, proms_list=proms, resps_list=resps)
|
||||
stats |= {"grad_norm": engine.get_global_grad_norm()}
|
||||
|
||||
tqdm.write(f"{stats}")
|
||||
|
@ -534,7 +580,9 @@ def example_usage():
|
|||
|
||||
#sample("init", 5)
|
||||
train()
|
||||
sample("final")
|
||||
|
||||
for task in tasks:
|
||||
sample("final", task=task)
|
||||
|
||||
if __name__ == "__main__":
|
||||
example_usage()
|
|
@ -880,6 +880,7 @@ class Base(nn.Module):
|
|||
|
||||
# Base-line TTS task
|
||||
# Sequence: <text><sep><rvq lvl><sep><prom><sep><resp>
|
||||
# prom /may/ include <task> tokens inside to help guide things, per SpeechX
|
||||
if f'<{task_type}>' in get_task_symmap():
|
||||
# insert the text prompt
|
||||
if text_list is not None:
|
||||
|
@ -933,7 +934,6 @@ class Base(nn.Module):
|
|||
# yes this could be encoded better
|
||||
inputs[i].append( ( "len", torch.Tensor([ 0 ] + [ int(i) for i in str( resps_list[i].shape[0]) ] + [ 10 ]).to(device=device, dtype=torch.int16) ) )
|
||||
else:
|
||||
|
||||
raise Exception(f'Unrecognized task: {task_type}')
|
||||
|
||||
return inputs
|
||||
|
|
Loading…
Reference in New Issue
Block a user