diff --git a/data/config.yaml b/data/config.yaml index 0adaffa..d841c00 100644 --- a/data/config.yaml +++ b/data/config.yaml @@ -1,4 +1,4 @@ -sample_rate: 24_000 # 44_000 for dac +sample_rate: 24_000 # 44_000 / 44_100 for dac audio_backend: "vocos" # or dac # model definitions to train @@ -7,7 +7,7 @@ models: size: "full" # model dimensionality resp_levels: 8 # RVQ levels this model targets prom_levels: 8 # should always be the above - tasks: 8 # tasks this model can attend to, only tts is supported at the moment + tasks: 8 # tasks this model can attend to, only tts is guaranteed results at the moment langs: 2 # languages this model supports, semi-unused at the moment tones: 1 # tones this model supports, currently unused arch_type: llama # underlying LLM arch to use, currently focusing on llama @@ -19,7 +19,7 @@ models: # factors for split loss values, remove to have a unified loss calculation loss_factors: text: 0.1 # text phoneme portion of the sequence - prom: 0.0 # input prompt portion of the sequence + prom: 0.5 # input prompt portion of the sequence resp: 1.0 # output audio portin of the sequence # experimental settings @@ -28,7 +28,8 @@ models: interleave: False # interleaves RVQ levels, only works with above for now audio_embedding_mode: "" # "" | "inclusive" | "exclusive", whether to utilize the audio backend's embeddings with the input embeddings audio_embedding_sums: False # whether the input embeddings include all prior RVQ levels (sums) or only the current one, further experimentation is needed to see if this matters - p_rvq_levels: "equal" # "equal" | "auto", sets probabilities of which RVQ level to select during training, auto will have the next RVQ level half as likely as the previous one + p_rvq_levels: "auto" # "equal" | "auto", sets probabilities of which RVQ level to select during training, auto will have the next RVQ level half as likely as the previous one + unified_position_ids: False # specifies whether or not position IDs should be continuous across the whole sequence (if True, naive behavior), or restart them at the next segment of the sequence (if False) # hyperparameter settings (could be relegated to trainer settings) hyperparameters: diff --git a/data/noise.dac b/data/noise.dac new file mode 100644 index 0000000..1a4d02e Binary files /dev/null and b/data/noise.dac differ diff --git a/data/noise.enc b/data/noise.enc new file mode 100644 index 0000000..40636d2 Binary files /dev/null and b/data/noise.enc differ diff --git a/vall_e/data.py b/vall_e/data.py index d22a597..d9d5c2a 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -1063,11 +1063,8 @@ def create_datasets(): def create_train_val_dataloader(): train_dataset, val_dataset = create_datasets() - # it'll cry about trying to pickle a torch._C_generator or something - try: - subtrain_dataset = copy.deepcopy(train_dataset) - except Exception as e: - subtrain_dataset = Dataset( training=True ) + # deepcopy is slow + subtrain_dataset = Dataset( training=True ) if subtrain_dataset.sampler_type == "path": subtrain_dataset.head_(cfg.evaluation.size) diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 2cdef3b..220e045 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -361,7 +361,7 @@ def example_usage(): from einops import repeat from tqdm import tqdm - from ..emb.qnt import decode_to_file, unload_model + from ..emb.qnt import decode_to_file, unload_model, trim_random, repeat_extend_audio, concat_audio, merge_audio from ..engines import Engine from ..utils import wrapper as ml @@ -385,7 +385,7 @@ def example_usage(): return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.resp_levels, :].t().to(torch.int16) qnt = _load_quants(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}") - + noise = _load_quants(f"./data/noise.{'dac' if cfg.audio_backend == 'dac' else 'enc'}") text_list = [ tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device), @@ -404,6 +404,8 @@ def example_usage(): proms_list = proms_list[:1] resps_list = resps_list[:1] + batch_size = len(text_list) + # rentet-full is the only configuration with BitNet's BitLinear that converges despite the grad_norm saying otherwise kwargs = { 'n_text_tokens': 256, @@ -428,8 +430,11 @@ def example_usage(): pass """ + bos_id, space_id, eos_id = cfg.tokenizer.encode( " " ) + tasks = cfg.dataset.tasks_list + model = AR_NAR(**kwargs).to(device) - steps = 150 + steps = 150 * len(tasks) optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy" scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else "" @@ -497,22 +502,61 @@ def example_usage(): print(f"AR+NAR ({cfg.model.arch_type}, {cfg.audio_backend}) parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") - @torch.inference_mode() - def sample( name, steps=1000 ): - if cfg.audio_backend == "dac" and name == "init": - return + @torch.no_grad() + def sample_data(task=None): + texts = [] + proms = [] + resps = [] + for i in range(batch_size): + if task is None: + task = random.choice(tasks) + + text = text_list[i] + prom = proms_list[i] + resp = resps_list[i] + + # do nothing + if task == "tts": + ... + elif task == "tts-c": + trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second) + + prom = resp[:trim_length] + resp = resp[trim_length:] + elif task == "ns" or task == "sr": + # extend the noise to fill the target audio + noise_ext = repeat_extend_audio( noise, resp.shape[0] ) + # create the input prompt by merging the target audio with the noise + prom = merge_audio( resp.cpu(), noise_ext, scale=[1, cfg.dataset.noise_scale], device=cfg.dataset.reencode_device ) + # set the target to just be the noise if + if task == "sr": + resp = noise_ext + + # set the text prompt to empty to train without a guided text prompt + if random.random() < 0.5: + text = torch.tensor([bos_id, eos_id]).to(device=device, dtype=torch.uint8) + + texts.append( text.to(device) ) + proms.append( prom.to(device) ) + resps.append( resp.to(device) ) + + return texts, proms, resps + + @torch.inference_mode() + def sample( name, steps=1000, task=None ): engine.eval() + + texts, proms, resps = sample_data( task ) + if "ar" in cfg.model.capabilities: - resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 ) - else: - resps_list = [ qnt[:, 0].to( device ) ] + resps = engine( texts, proms, max_steps=steps, sampling_temperature=0.95 ) if "nar" in cfg.model.capabilities: - resps_list = engine( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.2 ) + resps = engine( texts, proms, resps, sampling_temperature=0.2 ) - for i, o in enumerate(resps_list): - _ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{name}.wav", device=device) + for i, o in enumerate(resps): + _ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{task}.{name}.wav", device=device) unload_model() @@ -520,8 +564,10 @@ def example_usage(): engine.train() t = trange(steps) for i in t: + texts, proms, resps = sample_data() + stats = {"step": i} - stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list) + stats |= engine.traverse(text_list=texts, proms_list=proms, resps_list=resps) stats |= {"grad_norm": engine.get_global_grad_norm()} tqdm.write(f"{stats}") @@ -534,7 +580,9 @@ def example_usage(): #sample("init", 5) train() - sample("final") + + for task in tasks: + sample("final", task=task) if __name__ == "__main__": example_usage() \ No newline at end of file diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 97d27bb..8d0e3b0 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -880,6 +880,7 @@ class Base(nn.Module): # Base-line TTS task # Sequence: + # prom /may/ include tokens inside to help guide things, per SpeechX if f'<{task_type}>' in get_task_symmap(): # insert the text prompt if text_list is not None: @@ -933,7 +934,6 @@ class Base(nn.Module): # yes this could be encoded better inputs[i].append( ( "len", torch.Tensor([ 0 ] + [ int(i) for i in str( resps_list[i].shape[0]) ] + [ 10 ]).to(device=device, dtype=torch.int16) ) ) else: - raise Exception(f'Unrecognized task: {task_type}') return inputs