Faster validation
This commit is contained in:
parent
946c86e422
commit
85c6a11f26
@ -20,6 +20,9 @@ class Config(ConfigBase):
|
||||
p_additional_prompt: float = 0.8
|
||||
max_prompts: int = 3
|
||||
|
||||
max_num_val: int = 20
|
||||
max_val_ar_steps: int = 300
|
||||
|
||||
token_dim: int = 256
|
||||
num_tokens: int = 1024
|
||||
|
||||
|
@ -260,7 +260,7 @@ def create_datasets():
|
||||
)
|
||||
|
||||
val_dataset.interleaved_reorder_(_get_spkr_name)
|
||||
val_dataset.head_(200)
|
||||
val_dataset.head_(cfg.max_num_val)
|
||||
|
||||
test_dataset = VALLEDatset(
|
||||
test_paths,
|
||||
@ -286,17 +286,17 @@ def create_train_val_dataloader():
|
||||
_logger.info(f"#samples (val): {len(val_dataset)}.")
|
||||
_logger.info(f"#samples (test): {len(test_dataset)}.")
|
||||
|
||||
train200_dataset = copy.deepcopy(train_dataset)
|
||||
train200_dataset.interleaved_reorder_(_get_spkr_name)
|
||||
train200_dataset.head_(200)
|
||||
train200_dataset.training_(False)
|
||||
train200_dl = _create_dl(train200_dataset, training=False)
|
||||
assert isinstance(train200_dl.dataset, VALLEDatset)
|
||||
train_for_val_dataset = copy.deepcopy(train_dataset)
|
||||
train_for_val_dataset.interleaved_reorder_(_get_spkr_name)
|
||||
train_for_val_dataset.head_(cfg.max_num_val)
|
||||
train_for_val_dataset.training_(False)
|
||||
train_for_val_dl = _create_dl(train_for_val_dataset, training=False)
|
||||
assert isinstance(train_for_val_dl.dataset, VALLEDatset)
|
||||
|
||||
return train_dl, train200_dl, val_dl, test_dl
|
||||
return train_dl, train_for_val_dl, val_dl, test_dl
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_dl, train200_dl, val_dl, test_dl = create_train_val_dataloader()
|
||||
train_dl, train_for_val_dl, val_dl, test_dl = create_train_val_dataloader()
|
||||
sample = train_dl.dataset[0]
|
||||
print(sample)
|
||||
|
@ -72,7 +72,11 @@ def main():
|
||||
batch = to_device(batch, cfg.device)
|
||||
|
||||
if cfg.model.startswith("ar"):
|
||||
resp_list = model(text_list=batch["text"], proms_list=batch["proms"])
|
||||
resp_list = model(
|
||||
text_list=batch["text"],
|
||||
proms_list=batch["proms"],
|
||||
max_steps=cfg.max_val_ar_steps,
|
||||
)
|
||||
resps_list = [r.unsqueeze(-1) for r in resp_list]
|
||||
elif cfg.model.startswith("nar"):
|
||||
resps_list = model(
|
||||
@ -98,6 +102,8 @@ def main():
|
||||
if len(hyp) > 0:
|
||||
qnt.decode_to_file(hyp, hyp_path)
|
||||
|
||||
qnt.unload_model()
|
||||
|
||||
stats = {k: sum(v) / len(v) for k, v in stats.items()}
|
||||
stats["global_step"] = engines.global_step
|
||||
stats["name"] = name
|
||||
|
Loading…
Reference in New Issue
Block a user