Faster validation

This commit is contained in:
enhuiz 2023-01-12 20:26:49 +08:00
parent 946c86e422
commit 85c6a11f26
3 changed files with 19 additions and 10 deletions

View File

@ -20,6 +20,9 @@ class Config(ConfigBase):
p_additional_prompt: float = 0.8 p_additional_prompt: float = 0.8
max_prompts: int = 3 max_prompts: int = 3
max_num_val: int = 20
max_val_ar_steps: int = 300
token_dim: int = 256 token_dim: int = 256
num_tokens: int = 1024 num_tokens: int = 1024

View File

@ -260,7 +260,7 @@ def create_datasets():
) )
val_dataset.interleaved_reorder_(_get_spkr_name) val_dataset.interleaved_reorder_(_get_spkr_name)
val_dataset.head_(200) val_dataset.head_(cfg.max_num_val)
test_dataset = VALLEDatset( test_dataset = VALLEDatset(
test_paths, test_paths,
@ -286,17 +286,17 @@ def create_train_val_dataloader():
_logger.info(f"#samples (val): {len(val_dataset)}.") _logger.info(f"#samples (val): {len(val_dataset)}.")
_logger.info(f"#samples (test): {len(test_dataset)}.") _logger.info(f"#samples (test): {len(test_dataset)}.")
train200_dataset = copy.deepcopy(train_dataset) train_for_val_dataset = copy.deepcopy(train_dataset)
train200_dataset.interleaved_reorder_(_get_spkr_name) train_for_val_dataset.interleaved_reorder_(_get_spkr_name)
train200_dataset.head_(200) train_for_val_dataset.head_(cfg.max_num_val)
train200_dataset.training_(False) train_for_val_dataset.training_(False)
train200_dl = _create_dl(train200_dataset, training=False) train_for_val_dl = _create_dl(train_for_val_dataset, training=False)
assert isinstance(train200_dl.dataset, VALLEDatset) assert isinstance(train_for_val_dl.dataset, VALLEDatset)
return train_dl, train200_dl, val_dl, test_dl return train_dl, train_for_val_dl, val_dl, test_dl
if __name__ == "__main__": if __name__ == "__main__":
train_dl, train200_dl, val_dl, test_dl = create_train_val_dataloader() train_dl, train_for_val_dl, val_dl, test_dl = create_train_val_dataloader()
sample = train_dl.dataset[0] sample = train_dl.dataset[0]
print(sample) print(sample)

View File

@ -72,7 +72,11 @@ def main():
batch = to_device(batch, cfg.device) batch = to_device(batch, cfg.device)
if cfg.model.startswith("ar"): if cfg.model.startswith("ar"):
resp_list = model(text_list=batch["text"], proms_list=batch["proms"]) resp_list = model(
text_list=batch["text"],
proms_list=batch["proms"],
max_steps=cfg.max_val_ar_steps,
)
resps_list = [r.unsqueeze(-1) for r in resp_list] resps_list = [r.unsqueeze(-1) for r in resp_list]
elif cfg.model.startswith("nar"): elif cfg.model.startswith("nar"):
resps_list = model( resps_list = model(
@ -98,6 +102,8 @@ def main():
if len(hyp) > 0: if len(hyp) > 0:
qnt.decode_to_file(hyp, hyp_path) qnt.decode_to_file(hyp, hyp_path)
qnt.unload_model()
stats = {k: sum(v) / len(v) for k, v in stats.items()} stats = {k: sum(v) / len(v) for k, v in stats.items()}
stats["global_step"] = engines.global_step stats["global_step"] = engines.global_step
stats["name"] = name stats["name"] = name