diff --git a/vall_e/config.py b/vall_e/config.py index f9ced76..4bbc306 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -343,7 +343,6 @@ class DeepSpeed: inferencing: bool = False amp: bool = False - fp16: bool = False config: dict = field(default_factory=lambda: {}) # to pass through deepspeed config @@ -373,7 +372,7 @@ class DeepSpeed: autotune_params['exps_dir'] = str( cfg.relpath / "autotune" / "exps_" ) # DeepSpeed fp16 is incompatible with its AMP - if cfg.trainer.weight_dtype.lower() == "float16" and self.fp16: + if cfg.trainer.weight_dtype.lower() == "float16": self.amp = False # disable local AMP @@ -393,7 +392,7 @@ class DeepSpeed: } if not cfg.hyperparameters.torch_scheduler else None, "gradient_clipping": cfg.hyperparameters.gradient_clipping, "fp16": { - "enabled": cfg.trainer.weight_dtype.lower() == "float16" and self.fp16, + "enabled": cfg.trainer.weight_dtype.lower() == "float16", "auto_cast": True, # ??? }, "bf16": { diff --git a/vall_e/data.py b/vall_e/data.py index bb28d90..5257abf 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -91,14 +91,12 @@ def _load_paths_from_metadata(dataset_name, type="training", validate=False): _fn = _get_hdf5_paths if cfg.dataset.use_hdf5 else _get_paths_of_extensions - def _validate( entry ): - phones = entry['phones'] if "phones" in entry else 0 - duration = entry['duration'] if "duration" in entry else 0 - if type not in _total_durations: - _total_durations[type] = 0 - _total_durations[type] += duration + def key( id ): + if not cfg.dataset.use_hdf5: + return data_dir / id + + return f"/{type}/{_get_hdf5_path(data_dir)}/{id}" - return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones metadata_path = cfg.metadata_dir / f'{dataset_name}.json' metadata = {} @@ -110,13 +108,24 @@ def _load_paths_from_metadata(dataset_name, type="training", validate=False): return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_quant_extension(), validate ) - def key( dir, id ): - if not cfg.dataset.use_hdf5: - return data_dir / id + def _validate( id ): + entry = metadata[id] - return f"/{type}/{_get_hdf5_path(data_dir)}/{id}" + phones = entry['phones'] if "phones" in entry else 0 + duration = entry['duration'] if "duration" in entry else 0 + if type not in _total_durations: + _total_durations[type] = 0 + + _total_durations[type] += duration + + if cfg.dataset.use_hdf5: + k = key( id ) + if k not in cfg.hdf5 or "audio" not in cfg.hdf5[k] or "text" not in cfg.hdf5[k]: + return False - return [ key(dir, id) for id in metadata.keys() if not validate or _validate(metadata[id]) ] + return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones + + return [ key(id) for id in metadata.keys() if not validate or _validate(id) ] def _get_hdf5_path(path): @@ -377,7 +386,7 @@ class Dataset(_Dataset): key = _get_hdf5_path(path) if "audio" not in cfg.hdf5[key]: - _logger.warning("MISSING AUDIO:", key) + _logger.warning(f'MISSING AUDIO: {key}') continue qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)