actually make the evaluation dataset shuffled for sample_type=speaker

This commit is contained in:
mrq 2023-08-17 15:04:45 -05:00
parent 18403a3523
commit ee58db746f
3 changed files with 19 additions and 7 deletions

View File

@ -129,6 +129,7 @@ Some additional flags you can pass are:
## To-Do
* reduce load time for creating / preparing dataloaders.
* properly pass in `modules` names to `weight_quantization` and `activation_quantization`.
* train and release a model.
* extend to multiple languages (VALL-E X) and extend to SpeechX features.

View File

@ -467,7 +467,7 @@ try:
# cached_property stopped working...
if cfg.dataset.use_hdf5:
try:
cfg.hdf5 = h5py.File(f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', 'r' if cfg.distributed else 'a')
cfg.hdf5 = h5py.File(f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', 'r' if cfg.distributed else 'a') # to-do, have an easy to set flag that determines if training or creating the dataset
except Exception as e:
print("Error while opening HDF5 file:", f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', str(e))
cfg.dataset.use_hdf5 = False

View File

@ -466,9 +466,9 @@ def create_train_val_dataloader():
train_dataset.sample_type = cfg.dataset.sample_type #"speaker"
subtrain_dataset = copy.deepcopy(train_dataset)
subtrain_dataset.head_(cfg.evaluation.size)
subtrain_dataset.interleaved_reorder_(cfg.get_spkr)
#subtrain_dataset.training_(False)
if subtrain_dataset.sample_type == "path":
subtrain_dataset.head_(cfg.evaluation.size)
subtrain_dataset.interleaved_reorder_(cfg.get_spkr)
train_dl = _create_dataloader(train_dataset, training=True)
val_dl = _create_dataloader(val_dataset, training=False)
@ -564,8 +564,19 @@ def create_dataset_hdf5():
hf.close()
if __name__ == "__main__":
create_dataset_hdf5()
import argparse
parser = argparse.ArgumentParser("Save trained model to path.")
parser.add_argument("--create-hdf5", action="store_true")
args = parser.parse_args()
if args.create_hdf5:
create_dataset_hdf5()
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
sample = train_dl.dataset[0]
print(sample)
print("Training DL:", next(iter(train_dl)))
print("Training DL:", next(iter(train_dl)))
print("Evaluation DL:", next(iter(subtrain_dl)))
print("Evaluation DL:", next(iter(subtrain_dl)))
print("Validation DL:", next(iter(val_dl)))
print("Validation DL:", next(iter(val_dl)))