ugh
This commit is contained in:
parent
856545f8bb
commit
3774fcbdee
|
@ -343,7 +343,6 @@ class DeepSpeed:
|
|||
inferencing: bool = False
|
||||
|
||||
amp: bool = False
|
||||
fp16: bool = False
|
||||
|
||||
config: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
|
||||
|
||||
|
@ -373,7 +372,7 @@ class DeepSpeed:
|
|||
autotune_params['exps_dir'] = str( cfg.relpath / "autotune" / "exps_" )
|
||||
|
||||
# DeepSpeed fp16 is incompatible with its AMP
|
||||
if cfg.trainer.weight_dtype.lower() == "float16" and self.fp16:
|
||||
if cfg.trainer.weight_dtype.lower() == "float16":
|
||||
self.amp = False
|
||||
|
||||
# disable local AMP
|
||||
|
@ -393,7 +392,7 @@ class DeepSpeed:
|
|||
} if not cfg.hyperparameters.torch_scheduler else None,
|
||||
"gradient_clipping": cfg.hyperparameters.gradient_clipping,
|
||||
"fp16": {
|
||||
"enabled": cfg.trainer.weight_dtype.lower() == "float16" and self.fp16,
|
||||
"enabled": cfg.trainer.weight_dtype.lower() == "float16",
|
||||
"auto_cast": True, # ???
|
||||
},
|
||||
"bf16": {
|
||||
|
|
|
@ -91,14 +91,12 @@ def _load_paths_from_metadata(dataset_name, type="training", validate=False):
|
|||
|
||||
_fn = _get_hdf5_paths if cfg.dataset.use_hdf5 else _get_paths_of_extensions
|
||||
|
||||
def _validate( entry ):
|
||||
phones = entry['phones'] if "phones" in entry else 0
|
||||
duration = entry['duration'] if "duration" in entry else 0
|
||||
if type not in _total_durations:
|
||||
_total_durations[type] = 0
|
||||
_total_durations[type] += duration
|
||||
def key( id ):
|
||||
if not cfg.dataset.use_hdf5:
|
||||
return data_dir / id
|
||||
|
||||
return f"/{type}/{_get_hdf5_path(data_dir)}/{id}"
|
||||
|
||||
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
|
||||
|
||||
metadata_path = cfg.metadata_dir / f'{dataset_name}.json'
|
||||
metadata = {}
|
||||
|
@ -110,13 +108,24 @@ def _load_paths_from_metadata(dataset_name, type="training", validate=False):
|
|||
return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_quant_extension(), validate )
|
||||
|
||||
|
||||
def key( dir, id ):
|
||||
if not cfg.dataset.use_hdf5:
|
||||
return data_dir / id
|
||||
def _validate( id ):
|
||||
entry = metadata[id]
|
||||
|
||||
return f"/{type}/{_get_hdf5_path(data_dir)}/{id}"
|
||||
phones = entry['phones'] if "phones" in entry else 0
|
||||
duration = entry['duration'] if "duration" in entry else 0
|
||||
if type not in _total_durations:
|
||||
_total_durations[type] = 0
|
||||
|
||||
_total_durations[type] += duration
|
||||
|
||||
if cfg.dataset.use_hdf5:
|
||||
k = key( id )
|
||||
if k not in cfg.hdf5 or "audio" not in cfg.hdf5[k] or "text" not in cfg.hdf5[k]:
|
||||
return False
|
||||
|
||||
return [ key(dir, id) for id in metadata.keys() if not validate or _validate(metadata[id]) ]
|
||||
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
|
||||
|
||||
return [ key(id) for id in metadata.keys() if not validate or _validate(id) ]
|
||||
|
||||
|
||||
def _get_hdf5_path(path):
|
||||
|
@ -377,7 +386,7 @@ class Dataset(_Dataset):
|
|||
key = _get_hdf5_path(path)
|
||||
|
||||
if "audio" not in cfg.hdf5[key]:
|
||||
_logger.warning("MISSING AUDIO:", key)
|
||||
_logger.warning(f'MISSING AUDIO: {key}')
|
||||
continue
|
||||
|
||||
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)
|
||||
|
|
Loading…
Reference in New Issue
Block a user