This commit is contained in:
mrq 2024-05-11 22:58:38 -05:00
parent 856545f8bb
commit 3774fcbdee
2 changed files with 24 additions and 16 deletions

View File

@ -343,7 +343,6 @@ class DeepSpeed:
inferencing: bool = False inferencing: bool = False
amp: bool = False amp: bool = False
fp16: bool = False
config: dict = field(default_factory=lambda: {}) # to pass through deepspeed config config: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
@ -373,7 +372,7 @@ class DeepSpeed:
autotune_params['exps_dir'] = str( cfg.relpath / "autotune" / "exps_" ) autotune_params['exps_dir'] = str( cfg.relpath / "autotune" / "exps_" )
# DeepSpeed fp16 is incompatible with its AMP # DeepSpeed fp16 is incompatible with its AMP
if cfg.trainer.weight_dtype.lower() == "float16" and self.fp16: if cfg.trainer.weight_dtype.lower() == "float16":
self.amp = False self.amp = False
# disable local AMP # disable local AMP
@ -393,7 +392,7 @@ class DeepSpeed:
} if not cfg.hyperparameters.torch_scheduler else None, } if not cfg.hyperparameters.torch_scheduler else None,
"gradient_clipping": cfg.hyperparameters.gradient_clipping, "gradient_clipping": cfg.hyperparameters.gradient_clipping,
"fp16": { "fp16": {
"enabled": cfg.trainer.weight_dtype.lower() == "float16" and self.fp16, "enabled": cfg.trainer.weight_dtype.lower() == "float16",
"auto_cast": True, # ??? "auto_cast": True, # ???
}, },
"bf16": { "bf16": {

View File

@ -91,14 +91,12 @@ def _load_paths_from_metadata(dataset_name, type="training", validate=False):
_fn = _get_hdf5_paths if cfg.dataset.use_hdf5 else _get_paths_of_extensions _fn = _get_hdf5_paths if cfg.dataset.use_hdf5 else _get_paths_of_extensions
def _validate( entry ): def key( id ):
phones = entry['phones'] if "phones" in entry else 0 if not cfg.dataset.use_hdf5:
duration = entry['duration'] if "duration" in entry else 0 return data_dir / id
if type not in _total_durations:
_total_durations[type] = 0 return f"/{type}/{_get_hdf5_path(data_dir)}/{id}"
_total_durations[type] += duration
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
metadata_path = cfg.metadata_dir / f'{dataset_name}.json' metadata_path = cfg.metadata_dir / f'{dataset_name}.json'
metadata = {} metadata = {}
@ -110,13 +108,24 @@ def _load_paths_from_metadata(dataset_name, type="training", validate=False):
return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_quant_extension(), validate ) return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_quant_extension(), validate )
def key( dir, id ): def _validate( id ):
if not cfg.dataset.use_hdf5: entry = metadata[id]
return data_dir / id
return f"/{type}/{_get_hdf5_path(data_dir)}/{id}" phones = entry['phones'] if "phones" in entry else 0
duration = entry['duration'] if "duration" in entry else 0
if type not in _total_durations:
_total_durations[type] = 0
return [ key(dir, id) for id in metadata.keys() if not validate or _validate(metadata[id]) ] _total_durations[type] += duration
if cfg.dataset.use_hdf5:
k = key( id )
if k not in cfg.hdf5 or "audio" not in cfg.hdf5[k] or "text" not in cfg.hdf5[k]:
return False
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
return [ key(id) for id in metadata.keys() if not validate or _validate(id) ]
def _get_hdf5_path(path): def _get_hdf5_path(path):
@ -377,7 +386,7 @@ class Dataset(_Dataset):
key = _get_hdf5_path(path) key = _get_hdf5_path(path)
if "audio" not in cfg.hdf5[key]: if "audio" not in cfg.hdf5[key]:
_logger.warning("MISSING AUDIO:", key) _logger.warning(f'MISSING AUDIO: {key}')
continue continue
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16) qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)