Faster validation
This commit is contained in:
parent
946c86e422
commit
85c6a11f26
@ -20,6 +20,9 @@ class Config(ConfigBase):
|
|||||||
p_additional_prompt: float = 0.8
|
p_additional_prompt: float = 0.8
|
||||||
max_prompts: int = 3
|
max_prompts: int = 3
|
||||||
|
|
||||||
|
max_num_val: int = 20
|
||||||
|
max_val_ar_steps: int = 300
|
||||||
|
|
||||||
token_dim: int = 256
|
token_dim: int = 256
|
||||||
num_tokens: int = 1024
|
num_tokens: int = 1024
|
||||||
|
|
||||||
|
@ -260,7 +260,7 @@ def create_datasets():
|
|||||||
)
|
)
|
||||||
|
|
||||||
val_dataset.interleaved_reorder_(_get_spkr_name)
|
val_dataset.interleaved_reorder_(_get_spkr_name)
|
||||||
val_dataset.head_(200)
|
val_dataset.head_(cfg.max_num_val)
|
||||||
|
|
||||||
test_dataset = VALLEDatset(
|
test_dataset = VALLEDatset(
|
||||||
test_paths,
|
test_paths,
|
||||||
@ -286,17 +286,17 @@ def create_train_val_dataloader():
|
|||||||
_logger.info(f"#samples (val): {len(val_dataset)}.")
|
_logger.info(f"#samples (val): {len(val_dataset)}.")
|
||||||
_logger.info(f"#samples (test): {len(test_dataset)}.")
|
_logger.info(f"#samples (test): {len(test_dataset)}.")
|
||||||
|
|
||||||
train200_dataset = copy.deepcopy(train_dataset)
|
train_for_val_dataset = copy.deepcopy(train_dataset)
|
||||||
train200_dataset.interleaved_reorder_(_get_spkr_name)
|
train_for_val_dataset.interleaved_reorder_(_get_spkr_name)
|
||||||
train200_dataset.head_(200)
|
train_for_val_dataset.head_(cfg.max_num_val)
|
||||||
train200_dataset.training_(False)
|
train_for_val_dataset.training_(False)
|
||||||
train200_dl = _create_dl(train200_dataset, training=False)
|
train_for_val_dl = _create_dl(train_for_val_dataset, training=False)
|
||||||
assert isinstance(train200_dl.dataset, VALLEDatset)
|
assert isinstance(train_for_val_dl.dataset, VALLEDatset)
|
||||||
|
|
||||||
return train_dl, train200_dl, val_dl, test_dl
|
return train_dl, train_for_val_dl, val_dl, test_dl
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
train_dl, train200_dl, val_dl, test_dl = create_train_val_dataloader()
|
train_dl, train_for_val_dl, val_dl, test_dl = create_train_val_dataloader()
|
||||||
sample = train_dl.dataset[0]
|
sample = train_dl.dataset[0]
|
||||||
print(sample)
|
print(sample)
|
||||||
|
@ -72,7 +72,11 @@ def main():
|
|||||||
batch = to_device(batch, cfg.device)
|
batch = to_device(batch, cfg.device)
|
||||||
|
|
||||||
if cfg.model.startswith("ar"):
|
if cfg.model.startswith("ar"):
|
||||||
resp_list = model(text_list=batch["text"], proms_list=batch["proms"])
|
resp_list = model(
|
||||||
|
text_list=batch["text"],
|
||||||
|
proms_list=batch["proms"],
|
||||||
|
max_steps=cfg.max_val_ar_steps,
|
||||||
|
)
|
||||||
resps_list = [r.unsqueeze(-1) for r in resp_list]
|
resps_list = [r.unsqueeze(-1) for r in resp_list]
|
||||||
elif cfg.model.startswith("nar"):
|
elif cfg.model.startswith("nar"):
|
||||||
resps_list = model(
|
resps_list = model(
|
||||||
@ -98,6 +102,8 @@ def main():
|
|||||||
if len(hyp) > 0:
|
if len(hyp) > 0:
|
||||||
qnt.decode_to_file(hyp, hyp_path)
|
qnt.decode_to_file(hyp, hyp_path)
|
||||||
|
|
||||||
|
qnt.unload_model()
|
||||||
|
|
||||||
stats = {k: sum(v) / len(v) for k, v in stats.items()}
|
stats = {k: sum(v) / len(v) for k, v in stats.items()}
|
||||||
stats["global_step"] = engines.global_step
|
stats["global_step"] = engines.global_step
|
||||||
stats["name"] = name
|
stats["name"] = name
|
||||||
|
Loading…
Reference in New Issue
Block a user