test trainer (vall_e.models.ar_nar) tests some SpeechX features

This commit is contained in:
mrq 2024-07-18 18:46:45 -05:00
parent 83a0954f85
commit 39f961abcd
6 changed files with 71 additions and 25 deletions

View File

@ -1,4 +1,4 @@
sample_rate: 24_000 # 44_000 for dac sample_rate: 24_000 # 44_000 / 44_100 for dac
audio_backend: "vocos" # or dac audio_backend: "vocos" # or dac
# model definitions to train # model definitions to train
@ -7,7 +7,7 @@ models:
size: "full" # model dimensionality size: "full" # model dimensionality
resp_levels: 8 # RVQ levels this model targets resp_levels: 8 # RVQ levels this model targets
prom_levels: 8 # should always be the above prom_levels: 8 # should always be the above
tasks: 8 # tasks this model can attend to, only tts is supported at the moment tasks: 8 # tasks this model can attend to, only tts is guaranteed results at the moment
langs: 2 # languages this model supports, semi-unused at the moment langs: 2 # languages this model supports, semi-unused at the moment
tones: 1 # tones this model supports, currently unused tones: 1 # tones this model supports, currently unused
arch_type: llama # underlying LLM arch to use, currently focusing on llama arch_type: llama # underlying LLM arch to use, currently focusing on llama
@ -19,7 +19,7 @@ models:
# factors for split loss values, remove to have a unified loss calculation # factors for split loss values, remove to have a unified loss calculation
loss_factors: loss_factors:
text: 0.1 # text phoneme portion of the sequence text: 0.1 # text phoneme portion of the sequence
prom: 0.0 # input prompt portion of the sequence prom: 0.5 # input prompt portion of the sequence
resp: 1.0 # output audio portin of the sequence resp: 1.0 # output audio portin of the sequence
# experimental settings # experimental settings
@ -28,7 +28,8 @@ models:
interleave: False # interleaves RVQ levels, only works with above for now interleave: False # interleaves RVQ levels, only works with above for now
audio_embedding_mode: "" # "" | "inclusive" | "exclusive", whether to utilize the audio backend's embeddings with the input embeddings audio_embedding_mode: "" # "" | "inclusive" | "exclusive", whether to utilize the audio backend's embeddings with the input embeddings
audio_embedding_sums: False # whether the input embeddings include all prior RVQ levels (sums) or only the current one, further experimentation is needed to see if this matters audio_embedding_sums: False # whether the input embeddings include all prior RVQ levels (sums) or only the current one, further experimentation is needed to see if this matters
p_rvq_levels: "equal" # "equal" | "auto", sets probabilities of which RVQ level to select during training, auto will have the next RVQ level half as likely as the previous one p_rvq_levels: "auto" # "equal" | "auto", sets probabilities of which RVQ level to select during training, auto will have the next RVQ level half as likely as the previous one
unified_position_ids: False # specifies whether or not position IDs should be continuous across the whole sequence (if True, naive behavior), or restart them at the next segment of the sequence (if False)
# hyperparameter settings (could be relegated to trainer settings) # hyperparameter settings (could be relegated to trainer settings)
hyperparameters: hyperparameters:

BIN
data/noise.dac Normal file

Binary file not shown.

BIN
data/noise.enc Normal file

Binary file not shown.

View File

@ -1063,11 +1063,8 @@ def create_datasets():
def create_train_val_dataloader(): def create_train_val_dataloader():
train_dataset, val_dataset = create_datasets() train_dataset, val_dataset = create_datasets()
# it'll cry about trying to pickle a torch._C_generator or something # deepcopy is slow
try: subtrain_dataset = Dataset( training=True )
subtrain_dataset = copy.deepcopy(train_dataset)
except Exception as e:
subtrain_dataset = Dataset( training=True )
if subtrain_dataset.sampler_type == "path": if subtrain_dataset.sampler_type == "path":
subtrain_dataset.head_(cfg.evaluation.size) subtrain_dataset.head_(cfg.evaluation.size)

View File

@ -361,7 +361,7 @@ def example_usage():
from einops import repeat from einops import repeat
from tqdm import tqdm from tqdm import tqdm
from ..emb.qnt import decode_to_file, unload_model from ..emb.qnt import decode_to_file, unload_model, trim_random, repeat_extend_audio, concat_audio, merge_audio
from ..engines import Engine from ..engines import Engine
from ..utils import wrapper as ml from ..utils import wrapper as ml
@ -385,7 +385,7 @@ def example_usage():
return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.resp_levels, :].t().to(torch.int16) return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.resp_levels, :].t().to(torch.int16)
qnt = _load_quants(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}") qnt = _load_quants(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")
noise = _load_quants(f"./data/noise.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")
text_list = [ text_list = [
tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device), tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device),
@ -404,6 +404,8 @@ def example_usage():
proms_list = proms_list[:1] proms_list = proms_list[:1]
resps_list = resps_list[:1] resps_list = resps_list[:1]
batch_size = len(text_list)
# rentet-full is the only configuration with BitNet's BitLinear that converges despite the grad_norm saying otherwise # rentet-full is the only configuration with BitNet's BitLinear that converges despite the grad_norm saying otherwise
kwargs = { kwargs = {
'n_text_tokens': 256, 'n_text_tokens': 256,
@ -428,8 +430,11 @@ def example_usage():
pass pass
""" """
bos_id, space_id, eos_id = cfg.tokenizer.encode( " " )
tasks = cfg.dataset.tasks_list
model = AR_NAR(**kwargs).to(device) model = AR_NAR(**kwargs).to(device)
steps = 150 steps = 150 * len(tasks)
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy" optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else "" scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
@ -497,22 +502,61 @@ def example_usage():
print(f"AR+NAR ({cfg.model.arch_type}, {cfg.audio_backend}) parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") print(f"AR+NAR ({cfg.model.arch_type}, {cfg.audio_backend}) parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
@torch.inference_mode() @torch.no_grad()
def sample( name, steps=1000 ): def sample_data(task=None):
if cfg.audio_backend == "dac" and name == "init": texts = []
return proms = []
resps = []
for i in range(batch_size):
if task is None:
task = random.choice(tasks)
text = text_list[i]
prom = proms_list[i]
resp = resps_list[i]
# do nothing
if task == "tts":
...
elif task == "tts-c":
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second)
prom = resp[:trim_length]
resp = resp[trim_length:]
elif task == "ns" or task == "sr":
# extend the noise to fill the target audio
noise_ext = repeat_extend_audio( noise, resp.shape[0] )
# create the input prompt by merging the target audio with the noise
prom = merge_audio( resp.cpu(), noise_ext, scale=[1, cfg.dataset.noise_scale], device=cfg.dataset.reencode_device )
# set the target to just be the noise if <sr>
if task == "sr":
resp = noise_ext
# set the text prompt to empty to train without a guided text prompt
if random.random() < 0.5:
text = torch.tensor([bos_id, eos_id]).to(device=device, dtype=torch.uint8)
texts.append( text.to(device) )
proms.append( prom.to(device) )
resps.append( resp.to(device) )
return texts, proms, resps
@torch.inference_mode()
def sample( name, steps=1000, task=None ):
engine.eval() engine.eval()
texts, proms, resps = sample_data( task )
if "ar" in cfg.model.capabilities: if "ar" in cfg.model.capabilities:
resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 ) resps = engine( texts, proms, max_steps=steps, sampling_temperature=0.95 )
else:
resps_list = [ qnt[:, 0].to( device ) ]
if "nar" in cfg.model.capabilities: if "nar" in cfg.model.capabilities:
resps_list = engine( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.2 ) resps = engine( texts, proms, resps, sampling_temperature=0.2 )
for i, o in enumerate(resps_list): for i, o in enumerate(resps):
_ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{name}.wav", device=device) _ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{task}.{name}.wav", device=device)
unload_model() unload_model()
@ -520,8 +564,10 @@ def example_usage():
engine.train() engine.train()
t = trange(steps) t = trange(steps)
for i in t: for i in t:
texts, proms, resps = sample_data()
stats = {"step": i} stats = {"step": i}
stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list) stats |= engine.traverse(text_list=texts, proms_list=proms, resps_list=resps)
stats |= {"grad_norm": engine.get_global_grad_norm()} stats |= {"grad_norm": engine.get_global_grad_norm()}
tqdm.write(f"{stats}") tqdm.write(f"{stats}")
@ -534,7 +580,9 @@ def example_usage():
#sample("init", 5) #sample("init", 5)
train() train()
sample("final")
for task in tasks:
sample("final", task=task)
if __name__ == "__main__": if __name__ == "__main__":
example_usage() example_usage()

View File

@ -880,6 +880,7 @@ class Base(nn.Module):
# Base-line TTS task # Base-line TTS task
# Sequence: <text><sep><rvq lvl><sep><prom><sep><resp> # Sequence: <text><sep><rvq lvl><sep><prom><sep><resp>
# prom /may/ include <task> tokens inside to help guide things, per SpeechX
if f'<{task_type}>' in get_task_symmap(): if f'<{task_type}>' in get_task_symmap():
# insert the text prompt # insert the text prompt
if text_list is not None: if text_list is not None:
@ -933,7 +934,6 @@ class Base(nn.Module):
# yes this could be encoded better # yes this could be encoded better
inputs[i].append( ( "len", torch.Tensor([ 0 ] + [ int(i) for i in str( resps_list[i].shape[0]) ] + [ 10 ]).to(device=device, dtype=torch.int16) ) ) inputs[i].append( ( "len", torch.Tensor([ 0 ] + [ int(i) for i in str( resps_list[i].shape[0]) ] + [ 10 ]).to(device=device, dtype=torch.int16) ) )
else: else:
raise Exception(f'Unrecognized task: {task_type}') raise Exception(f'Unrecognized task: {task_type}')
return inputs return inputs