This commit is contained in:
mrq 2024-05-11 22:58:38 -05:00
parent 856545f8bb
commit 3774fcbdee
2 changed files with 24 additions and 16 deletions

View File

@ -343,7 +343,6 @@ class DeepSpeed:
inferencing: bool = False
amp: bool = False
fp16: bool = False
config: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
@ -373,7 +372,7 @@ class DeepSpeed:
autotune_params['exps_dir'] = str( cfg.relpath / "autotune" / "exps_" )
# DeepSpeed fp16 is incompatible with its AMP
if cfg.trainer.weight_dtype.lower() == "float16" and self.fp16:
if cfg.trainer.weight_dtype.lower() == "float16":
self.amp = False
# disable local AMP
@ -393,7 +392,7 @@ class DeepSpeed:
} if not cfg.hyperparameters.torch_scheduler else None,
"gradient_clipping": cfg.hyperparameters.gradient_clipping,
"fp16": {
"enabled": cfg.trainer.weight_dtype.lower() == "float16" and self.fp16,
"enabled": cfg.trainer.weight_dtype.lower() == "float16",
"auto_cast": True, # ???
},
"bf16": {

View File

@ -91,14 +91,12 @@ def _load_paths_from_metadata(dataset_name, type="training", validate=False):
_fn = _get_hdf5_paths if cfg.dataset.use_hdf5 else _get_paths_of_extensions
def _validate( entry ):
phones = entry['phones'] if "phones" in entry else 0
duration = entry['duration'] if "duration" in entry else 0
if type not in _total_durations:
_total_durations[type] = 0
_total_durations[type] += duration
def key( id ):
if not cfg.dataset.use_hdf5:
return data_dir / id
return f"/{type}/{_get_hdf5_path(data_dir)}/{id}"
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
metadata_path = cfg.metadata_dir / f'{dataset_name}.json'
metadata = {}
@ -110,13 +108,24 @@ def _load_paths_from_metadata(dataset_name, type="training", validate=False):
return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_quant_extension(), validate )
def key( dir, id ):
if not cfg.dataset.use_hdf5:
return data_dir / id
def _validate( id ):
entry = metadata[id]
return f"/{type}/{_get_hdf5_path(data_dir)}/{id}"
phones = entry['phones'] if "phones" in entry else 0
duration = entry['duration'] if "duration" in entry else 0
if type not in _total_durations:
_total_durations[type] = 0
_total_durations[type] += duration
if cfg.dataset.use_hdf5:
k = key( id )
if k not in cfg.hdf5 or "audio" not in cfg.hdf5[k] or "text" not in cfg.hdf5[k]:
return False
return [ key(dir, id) for id in metadata.keys() if not validate or _validate(metadata[id]) ]
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
return [ key(id) for id in metadata.keys() if not validate or _validate(id) ]
def _get_hdf5_path(path):
@ -377,7 +386,7 @@ class Dataset(_Dataset):
key = _get_hdf5_path(path)
if "audio" not in cfg.hdf5[key]:
_logger.warning("MISSING AUDIO:", key)
_logger.warning(f'MISSING AUDIO: {key}')
continue
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)