test trainer (vall_e.models.ar_nar) tests some SpeechX features
This commit is contained in:
parent
83a0954f85
commit
39f961abcd
|
@ -1,4 +1,4 @@
|
||||||
sample_rate: 24_000 # 44_000 for dac
|
sample_rate: 24_000 # 44_000 / 44_100 for dac
|
||||||
audio_backend: "vocos" # or dac
|
audio_backend: "vocos" # or dac
|
||||||
|
|
||||||
# model definitions to train
|
# model definitions to train
|
||||||
|
@ -7,7 +7,7 @@ models:
|
||||||
size: "full" # model dimensionality
|
size: "full" # model dimensionality
|
||||||
resp_levels: 8 # RVQ levels this model targets
|
resp_levels: 8 # RVQ levels this model targets
|
||||||
prom_levels: 8 # should always be the above
|
prom_levels: 8 # should always be the above
|
||||||
tasks: 8 # tasks this model can attend to, only tts is supported at the moment
|
tasks: 8 # tasks this model can attend to, only tts is guaranteed results at the moment
|
||||||
langs: 2 # languages this model supports, semi-unused at the moment
|
langs: 2 # languages this model supports, semi-unused at the moment
|
||||||
tones: 1 # tones this model supports, currently unused
|
tones: 1 # tones this model supports, currently unused
|
||||||
arch_type: llama # underlying LLM arch to use, currently focusing on llama
|
arch_type: llama # underlying LLM arch to use, currently focusing on llama
|
||||||
|
@ -19,7 +19,7 @@ models:
|
||||||
# factors for split loss values, remove to have a unified loss calculation
|
# factors for split loss values, remove to have a unified loss calculation
|
||||||
loss_factors:
|
loss_factors:
|
||||||
text: 0.1 # text phoneme portion of the sequence
|
text: 0.1 # text phoneme portion of the sequence
|
||||||
prom: 0.0 # input prompt portion of the sequence
|
prom: 0.5 # input prompt portion of the sequence
|
||||||
resp: 1.0 # output audio portin of the sequence
|
resp: 1.0 # output audio portin of the sequence
|
||||||
|
|
||||||
# experimental settings
|
# experimental settings
|
||||||
|
@ -28,7 +28,8 @@ models:
|
||||||
interleave: False # interleaves RVQ levels, only works with above for now
|
interleave: False # interleaves RVQ levels, only works with above for now
|
||||||
audio_embedding_mode: "" # "" | "inclusive" | "exclusive", whether to utilize the audio backend's embeddings with the input embeddings
|
audio_embedding_mode: "" # "" | "inclusive" | "exclusive", whether to utilize the audio backend's embeddings with the input embeddings
|
||||||
audio_embedding_sums: False # whether the input embeddings include all prior RVQ levels (sums) or only the current one, further experimentation is needed to see if this matters
|
audio_embedding_sums: False # whether the input embeddings include all prior RVQ levels (sums) or only the current one, further experimentation is needed to see if this matters
|
||||||
p_rvq_levels: "equal" # "equal" | "auto", sets probabilities of which RVQ level to select during training, auto will have the next RVQ level half as likely as the previous one
|
p_rvq_levels: "auto" # "equal" | "auto", sets probabilities of which RVQ level to select during training, auto will have the next RVQ level half as likely as the previous one
|
||||||
|
unified_position_ids: False # specifies whether or not position IDs should be continuous across the whole sequence (if True, naive behavior), or restart them at the next segment of the sequence (if False)
|
||||||
|
|
||||||
# hyperparameter settings (could be relegated to trainer settings)
|
# hyperparameter settings (could be relegated to trainer settings)
|
||||||
hyperparameters:
|
hyperparameters:
|
||||||
|
|
BIN
data/noise.dac
Normal file
BIN
data/noise.dac
Normal file
Binary file not shown.
BIN
data/noise.enc
Normal file
BIN
data/noise.enc
Normal file
Binary file not shown.
|
@ -1063,11 +1063,8 @@ def create_datasets():
|
||||||
def create_train_val_dataloader():
|
def create_train_val_dataloader():
|
||||||
train_dataset, val_dataset = create_datasets()
|
train_dataset, val_dataset = create_datasets()
|
||||||
|
|
||||||
# it'll cry about trying to pickle a torch._C_generator or something
|
# deepcopy is slow
|
||||||
try:
|
subtrain_dataset = Dataset( training=True )
|
||||||
subtrain_dataset = copy.deepcopy(train_dataset)
|
|
||||||
except Exception as e:
|
|
||||||
subtrain_dataset = Dataset( training=True )
|
|
||||||
|
|
||||||
if subtrain_dataset.sampler_type == "path":
|
if subtrain_dataset.sampler_type == "path":
|
||||||
subtrain_dataset.head_(cfg.evaluation.size)
|
subtrain_dataset.head_(cfg.evaluation.size)
|
||||||
|
|
|
@ -361,7 +361,7 @@ def example_usage():
|
||||||
from einops import repeat
|
from einops import repeat
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from ..emb.qnt import decode_to_file, unload_model
|
from ..emb.qnt import decode_to_file, unload_model, trim_random, repeat_extend_audio, concat_audio, merge_audio
|
||||||
from ..engines import Engine
|
from ..engines import Engine
|
||||||
from ..utils import wrapper as ml
|
from ..utils import wrapper as ml
|
||||||
|
|
||||||
|
@ -385,7 +385,7 @@ def example_usage():
|
||||||
return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.resp_levels, :].t().to(torch.int16)
|
return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.resp_levels, :].t().to(torch.int16)
|
||||||
|
|
||||||
qnt = _load_quants(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")
|
qnt = _load_quants(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")
|
||||||
|
noise = _load_quants(f"./data/noise.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")
|
||||||
|
|
||||||
text_list = [
|
text_list = [
|
||||||
tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device),
|
tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device),
|
||||||
|
@ -404,6 +404,8 @@ def example_usage():
|
||||||
proms_list = proms_list[:1]
|
proms_list = proms_list[:1]
|
||||||
resps_list = resps_list[:1]
|
resps_list = resps_list[:1]
|
||||||
|
|
||||||
|
batch_size = len(text_list)
|
||||||
|
|
||||||
# rentet-full is the only configuration with BitNet's BitLinear that converges despite the grad_norm saying otherwise
|
# rentet-full is the only configuration with BitNet's BitLinear that converges despite the grad_norm saying otherwise
|
||||||
kwargs = {
|
kwargs = {
|
||||||
'n_text_tokens': 256,
|
'n_text_tokens': 256,
|
||||||
|
@ -428,8 +430,11 @@ def example_usage():
|
||||||
pass
|
pass
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
bos_id, space_id, eos_id = cfg.tokenizer.encode( " " )
|
||||||
|
tasks = cfg.dataset.tasks_list
|
||||||
|
|
||||||
model = AR_NAR(**kwargs).to(device)
|
model = AR_NAR(**kwargs).to(device)
|
||||||
steps = 150
|
steps = 150 * len(tasks)
|
||||||
|
|
||||||
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
|
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
|
||||||
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
||||||
|
@ -497,22 +502,61 @@ def example_usage():
|
||||||
|
|
||||||
print(f"AR+NAR ({cfg.model.arch_type}, {cfg.audio_backend}) parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
print(f"AR+NAR ({cfg.model.arch_type}, {cfg.audio_backend}) parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.no_grad()
|
||||||
def sample( name, steps=1000 ):
|
def sample_data(task=None):
|
||||||
if cfg.audio_backend == "dac" and name == "init":
|
texts = []
|
||||||
return
|
proms = []
|
||||||
|
resps = []
|
||||||
|
|
||||||
|
for i in range(batch_size):
|
||||||
|
if task is None:
|
||||||
|
task = random.choice(tasks)
|
||||||
|
|
||||||
|
text = text_list[i]
|
||||||
|
prom = proms_list[i]
|
||||||
|
resp = resps_list[i]
|
||||||
|
|
||||||
|
# do nothing
|
||||||
|
if task == "tts":
|
||||||
|
...
|
||||||
|
elif task == "tts-c":
|
||||||
|
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second)
|
||||||
|
|
||||||
|
prom = resp[:trim_length]
|
||||||
|
resp = resp[trim_length:]
|
||||||
|
elif task == "ns" or task == "sr":
|
||||||
|
# extend the noise to fill the target audio
|
||||||
|
noise_ext = repeat_extend_audio( noise, resp.shape[0] )
|
||||||
|
# create the input prompt by merging the target audio with the noise
|
||||||
|
prom = merge_audio( resp.cpu(), noise_ext, scale=[1, cfg.dataset.noise_scale], device=cfg.dataset.reencode_device )
|
||||||
|
# set the target to just be the noise if <sr>
|
||||||
|
if task == "sr":
|
||||||
|
resp = noise_ext
|
||||||
|
|
||||||
|
# set the text prompt to empty to train without a guided text prompt
|
||||||
|
if random.random() < 0.5:
|
||||||
|
text = torch.tensor([bos_id, eos_id]).to(device=device, dtype=torch.uint8)
|
||||||
|
|
||||||
|
texts.append( text.to(device) )
|
||||||
|
proms.append( prom.to(device) )
|
||||||
|
resps.append( resp.to(device) )
|
||||||
|
|
||||||
|
return texts, proms, resps
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def sample( name, steps=1000, task=None ):
|
||||||
engine.eval()
|
engine.eval()
|
||||||
|
|
||||||
|
texts, proms, resps = sample_data( task )
|
||||||
|
|
||||||
if "ar" in cfg.model.capabilities:
|
if "ar" in cfg.model.capabilities:
|
||||||
resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 )
|
resps = engine( texts, proms, max_steps=steps, sampling_temperature=0.95 )
|
||||||
else:
|
|
||||||
resps_list = [ qnt[:, 0].to( device ) ]
|
|
||||||
|
|
||||||
if "nar" in cfg.model.capabilities:
|
if "nar" in cfg.model.capabilities:
|
||||||
resps_list = engine( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.2 )
|
resps = engine( texts, proms, resps, sampling_temperature=0.2 )
|
||||||
|
|
||||||
for i, o in enumerate(resps_list):
|
for i, o in enumerate(resps):
|
||||||
_ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{name}.wav", device=device)
|
_ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{task}.{name}.wav", device=device)
|
||||||
|
|
||||||
unload_model()
|
unload_model()
|
||||||
|
|
||||||
|
@ -520,8 +564,10 @@ def example_usage():
|
||||||
engine.train()
|
engine.train()
|
||||||
t = trange(steps)
|
t = trange(steps)
|
||||||
for i in t:
|
for i in t:
|
||||||
|
texts, proms, resps = sample_data()
|
||||||
|
|
||||||
stats = {"step": i}
|
stats = {"step": i}
|
||||||
stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list)
|
stats |= engine.traverse(text_list=texts, proms_list=proms, resps_list=resps)
|
||||||
stats |= {"grad_norm": engine.get_global_grad_norm()}
|
stats |= {"grad_norm": engine.get_global_grad_norm()}
|
||||||
|
|
||||||
tqdm.write(f"{stats}")
|
tqdm.write(f"{stats}")
|
||||||
|
@ -534,7 +580,9 @@ def example_usage():
|
||||||
|
|
||||||
#sample("init", 5)
|
#sample("init", 5)
|
||||||
train()
|
train()
|
||||||
sample("final")
|
|
||||||
|
for task in tasks:
|
||||||
|
sample("final", task=task)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
example_usage()
|
example_usage()
|
|
@ -880,6 +880,7 @@ class Base(nn.Module):
|
||||||
|
|
||||||
# Base-line TTS task
|
# Base-line TTS task
|
||||||
# Sequence: <text><sep><rvq lvl><sep><prom><sep><resp>
|
# Sequence: <text><sep><rvq lvl><sep><prom><sep><resp>
|
||||||
|
# prom /may/ include <task> tokens inside to help guide things, per SpeechX
|
||||||
if f'<{task_type}>' in get_task_symmap():
|
if f'<{task_type}>' in get_task_symmap():
|
||||||
# insert the text prompt
|
# insert the text prompt
|
||||||
if text_list is not None:
|
if text_list is not None:
|
||||||
|
@ -933,7 +934,6 @@ class Base(nn.Module):
|
||||||
# yes this could be encoded better
|
# yes this could be encoded better
|
||||||
inputs[i].append( ( "len", torch.Tensor([ 0 ] + [ int(i) for i in str( resps_list[i].shape[0]) ] + [ 10 ]).to(device=device, dtype=torch.int16) ) )
|
inputs[i].append( ( "len", torch.Tensor([ 0 ] + [ int(i) for i in str( resps_list[i].shape[0]) ] + [ 10 ]).to(device=device, dtype=torch.int16) ) )
|
||||||
else:
|
else:
|
||||||
|
|
||||||
raise Exception(f'Unrecognized task: {task_type}')
|
raise Exception(f'Unrecognized task: {task_type}')
|
||||||
|
|
||||||
return inputs
|
return inputs
|
||||||
|
|
Loading…
Reference in New Issue
Block a user