diff --git a/vall_e/config.py b/vall_e/config.py index b75f3cd..76bc03e 100644 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -20,6 +20,9 @@ class Config(ConfigBase): p_additional_prompt: float = 0.8 max_prompts: int = 3 + max_num_val: int = 20 + max_val_ar_steps: int = 300 + token_dim: int = 256 num_tokens: int = 1024 diff --git a/vall_e/data.py b/vall_e/data.py index db492bc..e25d4e6 100644 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -260,7 +260,7 @@ def create_datasets(): ) val_dataset.interleaved_reorder_(_get_spkr_name) - val_dataset.head_(200) + val_dataset.head_(cfg.max_num_val) test_dataset = VALLEDatset( test_paths, @@ -286,17 +286,17 @@ def create_train_val_dataloader(): _logger.info(f"#samples (val): {len(val_dataset)}.") _logger.info(f"#samples (test): {len(test_dataset)}.") - train200_dataset = copy.deepcopy(train_dataset) - train200_dataset.interleaved_reorder_(_get_spkr_name) - train200_dataset.head_(200) - train200_dataset.training_(False) - train200_dl = _create_dl(train200_dataset, training=False) - assert isinstance(train200_dl.dataset, VALLEDatset) + train_for_val_dataset = copy.deepcopy(train_dataset) + train_for_val_dataset.interleaved_reorder_(_get_spkr_name) + train_for_val_dataset.head_(cfg.max_num_val) + train_for_val_dataset.training_(False) + train_for_val_dl = _create_dl(train_for_val_dataset, training=False) + assert isinstance(train_for_val_dl.dataset, VALLEDatset) - return train_dl, train200_dl, val_dl, test_dl + return train_dl, train_for_val_dl, val_dl, test_dl if __name__ == "__main__": - train_dl, train200_dl, val_dl, test_dl = create_train_val_dataloader() + train_dl, train_for_val_dl, val_dl, test_dl = create_train_val_dataloader() sample = train_dl.dataset[0] print(sample) diff --git a/vall_e/train.py b/vall_e/train.py index c28c0d8..5dbcd06 100644 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -72,7 +72,11 @@ def main(): batch = to_device(batch, cfg.device) if cfg.model.startswith("ar"): - resp_list = model(text_list=batch["text"], proms_list=batch["proms"]) + resp_list = model( + text_list=batch["text"], + proms_list=batch["proms"], + max_steps=cfg.max_val_ar_steps, + ) resps_list = [r.unsqueeze(-1) for r in resp_list] elif cfg.model.startswith("nar"): resps_list = model( @@ -98,6 +102,8 @@ def main(): if len(hyp) > 0: qnt.decode_to_file(hyp, hyp_path) + qnt.unload_model() + stats = {k: sum(v) / len(v) for k, v in stats.items()} stats["global_step"] = engines.global_step stats["name"] = name