ugh
This commit is contained in:
parent
856545f8bb
commit
3774fcbdee
|
@ -343,7 +343,6 @@ class DeepSpeed:
|
||||||
inferencing: bool = False
|
inferencing: bool = False
|
||||||
|
|
||||||
amp: bool = False
|
amp: bool = False
|
||||||
fp16: bool = False
|
|
||||||
|
|
||||||
config: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
|
config: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
|
||||||
|
|
||||||
|
@ -373,7 +372,7 @@ class DeepSpeed:
|
||||||
autotune_params['exps_dir'] = str( cfg.relpath / "autotune" / "exps_" )
|
autotune_params['exps_dir'] = str( cfg.relpath / "autotune" / "exps_" )
|
||||||
|
|
||||||
# DeepSpeed fp16 is incompatible with its AMP
|
# DeepSpeed fp16 is incompatible with its AMP
|
||||||
if cfg.trainer.weight_dtype.lower() == "float16" and self.fp16:
|
if cfg.trainer.weight_dtype.lower() == "float16":
|
||||||
self.amp = False
|
self.amp = False
|
||||||
|
|
||||||
# disable local AMP
|
# disable local AMP
|
||||||
|
@ -393,7 +392,7 @@ class DeepSpeed:
|
||||||
} if not cfg.hyperparameters.torch_scheduler else None,
|
} if not cfg.hyperparameters.torch_scheduler else None,
|
||||||
"gradient_clipping": cfg.hyperparameters.gradient_clipping,
|
"gradient_clipping": cfg.hyperparameters.gradient_clipping,
|
||||||
"fp16": {
|
"fp16": {
|
||||||
"enabled": cfg.trainer.weight_dtype.lower() == "float16" and self.fp16,
|
"enabled": cfg.trainer.weight_dtype.lower() == "float16",
|
||||||
"auto_cast": True, # ???
|
"auto_cast": True, # ???
|
||||||
},
|
},
|
||||||
"bf16": {
|
"bf16": {
|
||||||
|
|
|
@ -91,14 +91,12 @@ def _load_paths_from_metadata(dataset_name, type="training", validate=False):
|
||||||
|
|
||||||
_fn = _get_hdf5_paths if cfg.dataset.use_hdf5 else _get_paths_of_extensions
|
_fn = _get_hdf5_paths if cfg.dataset.use_hdf5 else _get_paths_of_extensions
|
||||||
|
|
||||||
def _validate( entry ):
|
def key( id ):
|
||||||
phones = entry['phones'] if "phones" in entry else 0
|
if not cfg.dataset.use_hdf5:
|
||||||
duration = entry['duration'] if "duration" in entry else 0
|
return data_dir / id
|
||||||
if type not in _total_durations:
|
|
||||||
_total_durations[type] = 0
|
return f"/{type}/{_get_hdf5_path(data_dir)}/{id}"
|
||||||
_total_durations[type] += duration
|
|
||||||
|
|
||||||
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
|
|
||||||
|
|
||||||
metadata_path = cfg.metadata_dir / f'{dataset_name}.json'
|
metadata_path = cfg.metadata_dir / f'{dataset_name}.json'
|
||||||
metadata = {}
|
metadata = {}
|
||||||
|
@ -110,13 +108,24 @@ def _load_paths_from_metadata(dataset_name, type="training", validate=False):
|
||||||
return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_quant_extension(), validate )
|
return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_quant_extension(), validate )
|
||||||
|
|
||||||
|
|
||||||
def key( dir, id ):
|
def _validate( id ):
|
||||||
if not cfg.dataset.use_hdf5:
|
entry = metadata[id]
|
||||||
return data_dir / id
|
|
||||||
|
|
||||||
return f"/{type}/{_get_hdf5_path(data_dir)}/{id}"
|
phones = entry['phones'] if "phones" in entry else 0
|
||||||
|
duration = entry['duration'] if "duration" in entry else 0
|
||||||
|
if type not in _total_durations:
|
||||||
|
_total_durations[type] = 0
|
||||||
|
|
||||||
return [ key(dir, id) for id in metadata.keys() if not validate or _validate(metadata[id]) ]
|
_total_durations[type] += duration
|
||||||
|
|
||||||
|
if cfg.dataset.use_hdf5:
|
||||||
|
k = key( id )
|
||||||
|
if k not in cfg.hdf5 or "audio" not in cfg.hdf5[k] or "text" not in cfg.hdf5[k]:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
|
||||||
|
|
||||||
|
return [ key(id) for id in metadata.keys() if not validate or _validate(id) ]
|
||||||
|
|
||||||
|
|
||||||
def _get_hdf5_path(path):
|
def _get_hdf5_path(path):
|
||||||
|
@ -377,7 +386,7 @@ class Dataset(_Dataset):
|
||||||
key = _get_hdf5_path(path)
|
key = _get_hdf5_path(path)
|
||||||
|
|
||||||
if "audio" not in cfg.hdf5[key]:
|
if "audio" not in cfg.hdf5[key]:
|
||||||
_logger.warning("MISSING AUDIO:", key)
|
_logger.warning(f'MISSING AUDIO: {key}')
|
||||||
continue
|
continue
|
||||||
|
|
||||||
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)
|
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user