dropped subtrain dataloader since its useless to duplicate
This commit is contained in:
parent
cf9df71f2c
commit
f7b8b1e825
|
@ -61,7 +61,12 @@ The dataloader handles some simple yet effective features, such as:
|
||||||
* picking an input prompt from the same speaker as the sample, if the above is not requested
|
* picking an input prompt from the same speaker as the sample, if the above is not requested
|
||||||
* preparing the input sequence for the given task (such as non-TTS tasks)
|
* preparing the input sequence for the given task (such as non-TTS tasks)
|
||||||
|
|
||||||
The initial list of paths is cached through `diskcache`, if `cfg.dataset.cache == True`. Be sure to delete the resultant `.cache` folder, as well as the `sampler.*` state dicts alongside checkpoints, if you plan to modify the dataloader settings between training sessions.
|
If `cfg.dataset.cache == True`, the initial list of paths and duration metadata (used for sorting/bucketing) is cached ~~through `diskcache`~~ under `{YAML_PATH}/.cache/{DATASET_HASH}/`. To allow for seamless modifications to the loaded dataset, the `DATASET_HASH` relies on:
|
||||||
|
* duration range
|
||||||
|
* folders/groups in the dataset
|
||||||
|
* if using HDF5 (due to the key format differing)
|
||||||
|
|
||||||
|
Be sure to delete the resultant `.cache` folder, as well as the `sampler.*` state dicts alongside checkpoints, if you plan to modify the dataloader settings between training sessions.
|
||||||
|
|
||||||
## Tasks
|
## Tasks
|
||||||
|
|
||||||
|
@ -95,7 +100,7 @@ This section may be covered elsewhere in the documentation, but coverage here sh
|
||||||
* `nse`: noisy speech editing.
|
* `nse`: noisy speech editing.
|
||||||
* the above, but injects some noise throughout the sampled utterances.
|
* the above, but injects some noise throughout the sampled utterances.
|
||||||
|
|
||||||
A mystical `rvc` for emulating RVC speech-to-speech synthesis is possible, but requires a dataset to do so.
|
A mystical `vc` for performing voice conversion is possible, but either requires a dataset to do so, or abusing an emergent property.
|
||||||
|
|
||||||
## `__main__`
|
## `__main__`
|
||||||
|
|
||||||
|
|
|
@ -60,7 +60,7 @@ Descript-Audio-Codec was thoroughly tested for promising much, much cleaner outp
|
||||||
|
|
||||||
However, due to the nature of the codec, simply throwing it at an attention-based transformer proves to be painful, as a unified AR+NAR model *heavily* suffers from noisy output in the NAR.
|
However, due to the nature of the codec, simply throwing it at an attention-based transformer proves to be painful, as a unified AR+NAR model *heavily* suffers from noisy output in the NAR.
|
||||||
|
|
||||||
Ironically, testing through mal-encoded audio (feeding 24KHz audio without upsampling to 44.1KHz) proved to have "cleaner" but bad utterances.
|
Ironically, testing through erroneously encoded audio (feeding 24KHz audio without upsampling to 44.1KHz) proved to have "cleaner" but bad utterances.
|
||||||
|
|
||||||
I'm uncertain on how to remedy this, as my options are:
|
I'm uncertain on how to remedy this, as my options are:
|
||||||
* train under a RetNet, if an attention-based transformer is simply the problem
|
* train under a RetNet, if an attention-based transformer is simply the problem
|
||||||
|
|
|
@ -534,8 +534,11 @@ _durations_map = {}
|
||||||
def _get_duration_map( type="training" ):
|
def _get_duration_map( type="training" ):
|
||||||
return _durations_map[type] if type in _durations_map else {}
|
return _durations_map[type] if type in _durations_map else {}
|
||||||
|
|
||||||
def _load_paths(dataset, type="training", silent=False):
|
def _load_paths(dataset, type="training", silent=False, dataset_hash_key=None):
|
||||||
cached_dir = cfg.cache_dir / cfg.dataset.hash_key(sorted(dataset))
|
if not dataset_hash_key:
|
||||||
|
dataset_hash_key = cfg.dataset.hash_key(sorted(dataset))
|
||||||
|
|
||||||
|
cached_dir = cfg.cache_dir / dataset_hash_key
|
||||||
|
|
||||||
cached_durations_path = cached_dir / f"durations[{type}].json"
|
cached_durations_path = cached_dir / f"durations[{type}].json"
|
||||||
cached_paths_path = cached_dir / f"dataloader[{type}].json"
|
cached_paths_path = cached_dir / f"dataloader[{type}].json"
|
||||||
|
@ -691,11 +694,13 @@ class Dataset(_Dataset):
|
||||||
|
|
||||||
self.training = training
|
self.training = training
|
||||||
self.dataset_type = "training" if self.training else "validation"
|
self.dataset_type = "training" if self.training else "validation"
|
||||||
self.dataset = cfg.dataset.training if self.training else cfg.dataset.validation
|
self.dataset = sorted(cfg.dataset.training if self.training else cfg.dataset.validation)
|
||||||
self.sampler_type = cfg.dataset.sample_type if self.dataset_type == "training" else "path"
|
self.sampler_type = cfg.dataset.sample_type if self.dataset_type == "training" else "path"
|
||||||
self.sampler_order = cfg.dataset.sample_order
|
self.sampler_order = cfg.dataset.sample_order
|
||||||
self.sampler_shuffle = cfg.dataset.sample_shuffle if self.dataset_type == "training" else True
|
self.sampler_shuffle = cfg.dataset.sample_shuffle if self.dataset_type == "training" else True
|
||||||
|
|
||||||
|
self.dataset_hash_key = cfg.dataset.hash_key(sorted(self.dataset))
|
||||||
|
|
||||||
# to-do: do not do validation if there's nothing in the validation
|
# to-do: do not do validation if there's nothing in the validation
|
||||||
# this just makes it be happy
|
# this just makes it be happy
|
||||||
if len(self.dataset) == 0:
|
if len(self.dataset) == 0:
|
||||||
|
@ -706,7 +711,7 @@ class Dataset(_Dataset):
|
||||||
raise Exception(f'Requesting sample_type={self.sampler_type} with sample_order={self.sampler_order}, yet combination will not give expected results.')
|
raise Exception(f'Requesting sample_type={self.sampler_type} with sample_order={self.sampler_order}, yet combination will not give expected results.')
|
||||||
|
|
||||||
# dict of paths keyed by speaker names
|
# dict of paths keyed by speaker names
|
||||||
self.paths_by_spkr_name = _load_paths(self.dataset, self.dataset_type)
|
self.paths_by_spkr_name = _load_paths(self.dataset, self.dataset_type, dataset_hash_key=self.dataset_hash_key)
|
||||||
# do it here due to the above
|
# do it here due to the above
|
||||||
self.duration = 0
|
self.duration = 0
|
||||||
self.duration_map = _get_duration_map( self.dataset_type )
|
self.duration_map = _get_duration_map( self.dataset_type )
|
||||||
|
@ -780,10 +785,6 @@ class Dataset(_Dataset):
|
||||||
# just interleave
|
# just interleave
|
||||||
self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)]
|
self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)]
|
||||||
|
|
||||||
# dereference buckets
|
|
||||||
self.duration_map = None
|
|
||||||
self.duration_buckets = None
|
|
||||||
|
|
||||||
# dict of speakers keyed by speaker group
|
# dict of speakers keyed by speaker group
|
||||||
self.spkrs_by_spkr_group = {}
|
self.spkrs_by_spkr_group = {}
|
||||||
for data_dir in self.dataset:
|
for data_dir in self.dataset:
|
||||||
|
@ -830,16 +831,21 @@ class Dataset(_Dataset):
|
||||||
self.duration_buckets if not self.sampler_state_dict_path.exists() else {}, # pass nothing if we're just going to load from a state anyways
|
self.duration_buckets if not self.sampler_state_dict_path.exists() else {}, # pass nothing if we're just going to load from a state anyways
|
||||||
max_duration=cfg.dataset.sample_max_duration_batch,
|
max_duration=cfg.dataset.sample_max_duration_batch,
|
||||||
max_batch_size=cfg.hyperparameters.batch_size if self.training else cfg.evaluation.batch_size,
|
max_batch_size=cfg.hyperparameters.batch_size if self.training else cfg.evaluation.batch_size,
|
||||||
shuffle=self.sampler_shuffle
|
shuffle=self.sampler_shuffle,
|
||||||
|
dataset_hash=self.dataset_hash_key,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.sampler = OrderedSampler( len(self) ) if not self.sampler_shuffle else RandomSampler( len(self) )
|
self.sampler = OrderedSampler( len(self), dataset_hash=self.dataset_hash_key ) if not self.sampler_shuffle else RandomSampler( len(self), dataset_hash=self.dataset_hash_key )
|
||||||
self.samplers = {}
|
self.samplers = {}
|
||||||
self.spkr_samplers = {}
|
self.spkr_samplers = {}
|
||||||
else:
|
else:
|
||||||
self.sampler = RandomSampler( len(self) )
|
self.sampler = RandomSampler( len(self), dataset_hash=self.dataset_hash_key )
|
||||||
self.samplers = { name: PoolSampler( paths, keep_all=True, shuffle=self.sampler_shuffle ) for name, paths in self.paths_by_spkr_name.items() }
|
self.samplers = { name: PoolSampler( paths, keep_all=True, shuffle=self.sampler_shuffle, dataset_hash=self.dataset_hash_key ) for name, paths in self.paths_by_spkr_name.items() }
|
||||||
self.spkr_samplers = { name: PoolSampler( [*set(speakers)], keep_all=True, shuffle=self.sampler_shuffle ) for name, speakers in self.spkrs_by_spkr_group.items() }
|
self.spkr_samplers = { name: PoolSampler( [*set(speakers)], keep_all=True, shuffle=self.sampler_shuffle, dataset_hash=self.dataset_hash_key ) for name, speakers in self.spkrs_by_spkr_group.items() }
|
||||||
|
|
||||||
|
# dereference buckets
|
||||||
|
self.duration_map = None
|
||||||
|
self.duration_buckets = None
|
||||||
|
|
||||||
self.load_state_dict()
|
self.load_state_dict()
|
||||||
|
|
||||||
|
@ -903,6 +909,7 @@ class Dataset(_Dataset):
|
||||||
return
|
return
|
||||||
|
|
||||||
state_dict = torch_load(path)
|
state_dict = torch_load(path)
|
||||||
|
|
||||||
if self.sampler_type == "path":
|
if self.sampler_type == "path":
|
||||||
state_dict = self.sampler.set_state(state_dict)
|
state_dict = self.sampler.set_state(state_dict)
|
||||||
else:
|
else:
|
||||||
|
@ -1461,16 +1468,8 @@ def create_val_dataloader():
|
||||||
# to-do, use the above two, then create the subtrain dataset
|
# to-do, use the above two, then create the subtrain dataset
|
||||||
def create_train_val_dataloader():
|
def create_train_val_dataloader():
|
||||||
train_dataset, val_dataset = create_datasets()
|
train_dataset, val_dataset = create_datasets()
|
||||||
|
|
||||||
# deepcopy is slow
|
|
||||||
subtrain_dataset = Dataset( training=True )
|
|
||||||
|
|
||||||
if subtrain_dataset.sampler_type == "path":
|
|
||||||
subtrain_dataset.head_(cfg.evaluation.size)
|
|
||||||
|
|
||||||
train_dl = _create_dataloader(train_dataset, training=True)
|
train_dl = _create_dataloader(train_dataset, training=True)
|
||||||
val_dl = _create_dataloader(val_dataset, training=False)
|
val_dl = _create_dataloader(val_dataset, training=False)
|
||||||
subtrain_dl = _create_dataloader(subtrain_dataset, training=False)
|
|
||||||
|
|
||||||
_logger.info(str(train_dataset.phone_symmap))
|
_logger.info(str(train_dataset.phone_symmap))
|
||||||
_logger.info(str(train_dataset.spkr_symmap))
|
_logger.info(str(train_dataset.spkr_symmap))
|
||||||
|
@ -1478,18 +1477,14 @@ def create_train_val_dataloader():
|
||||||
|
|
||||||
_logger.info(f"#samples (train): {len(train_dataset)}.")
|
_logger.info(f"#samples (train): {len(train_dataset)}.")
|
||||||
_logger.info(f"#samples (val): {len(val_dataset)}.")
|
_logger.info(f"#samples (val): {len(val_dataset)}.")
|
||||||
_logger.info(f"#samples (subtrain): {len(subtrain_dataset)}.")
|
|
||||||
|
|
||||||
_logger.info(f"#duration (train): {str(train_dataset.duration)}.")
|
_logger.info(f"#duration (train): {str(train_dataset.duration)}.")
|
||||||
_logger.info(f"#duration (val): {str(val_dataset.duration)}.")
|
_logger.info(f"#duration (val): {str(val_dataset.duration)}.")
|
||||||
_logger.info(f"#duration (subtrain): {str(subtrain_dataset.duration)}.")
|
|
||||||
|
|
||||||
assert isinstance(subtrain_dl.dataset, Dataset)
|
|
||||||
|
|
||||||
# remove duration map (it gets bloated)
|
# remove duration map (it gets bloated)
|
||||||
_durations_map = {}
|
_durations_map = {}
|
||||||
|
|
||||||
return train_dl, subtrain_dl, val_dl
|
return train_dl, val_dl
|
||||||
|
|
||||||
# parse metadata from an numpy file (.enc/.dac) and validate it
|
# parse metadata from an numpy file (.enc/.dac) and validate it
|
||||||
def process_artifact_metadata( artifact ):
|
def process_artifact_metadata( artifact ):
|
||||||
|
|
|
@ -108,6 +108,14 @@ def run_eval(engines, eval_name, dl, args=None):
|
||||||
|
|
||||||
processed = 0
|
processed = 0
|
||||||
while processed < cfg.evaluation.size:
|
while processed < cfg.evaluation.size:
|
||||||
|
# directly randomly sample
|
||||||
|
if eval_name == "subtrain":
|
||||||
|
# sample from dataset
|
||||||
|
# to-do: derive from current iteration
|
||||||
|
samples = [ to_device(dl.dataset[random.randint( 0, len( dl.dataset ) )], cfg.device) for sample in range( cfg.evaluation.batch_size ) ]
|
||||||
|
# collate manually
|
||||||
|
batch = {k: [s[k] for s in samples] for k in samples[0]}
|
||||||
|
else:
|
||||||
batch = to_device(next(iter(dl)), cfg.device)
|
batch = to_device(next(iter(dl)), cfg.device)
|
||||||
|
|
||||||
# limit to eval batch size in the event we somehow have a weird dataloader
|
# limit to eval batch size in the event we somehow have a weird dataloader
|
||||||
|
@ -209,14 +217,14 @@ def train():
|
||||||
if cfg.yaml_path is not None and is_global_leader():
|
if cfg.yaml_path is not None and is_global_leader():
|
||||||
shutil.copy( cfg.yaml_path, cfg.log_dir / "config.yaml" )
|
shutil.copy( cfg.yaml_path, cfg.log_dir / "config.yaml" )
|
||||||
# create dataloaders
|
# create dataloaders
|
||||||
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
train_dl, val_dl = create_train_val_dataloader()
|
||||||
# evaluation lambda
|
# evaluation lambda
|
||||||
def eval_fn(engines):
|
def eval_fn(engines):
|
||||||
do_gc()
|
do_gc()
|
||||||
engines.eval()
|
engines.eval()
|
||||||
# wrapped in a try block because it's sometimes prone to breaking
|
# wrapped in a try block because it's sometimes prone to breaking
|
||||||
try:
|
try:
|
||||||
run_eval(engines, "subtrain", subtrain_dl, args)
|
run_eval(engines, "subtrain", train_dl, args)
|
||||||
run_eval(engines, "val", val_dl, args)
|
run_eval(engines, "val", val_dl, args)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
_logger.warning(f"Error occurred while performing eval: {str(e)}")
|
_logger.warning(f"Error occurred while performing eval: {str(e)}")
|
||||||
|
|
|
@ -9,11 +9,12 @@ from .distributed import global_rank, local_rank, world_size
|
||||||
|
|
||||||
# Randomly picks an index from an array of indices
|
# Randomly picks an index from an array of indices
|
||||||
class PoolSampler():
|
class PoolSampler():
|
||||||
def __init__( self, pool = [], keep_all = False, shuffle = False ):
|
def __init__( self, pool = [], keep_all = False, shuffle = False, dataset_hash = None ):
|
||||||
self.length = len(pool)
|
self.length = len(pool)
|
||||||
self.shuffle = shuffle
|
self.shuffle = shuffle
|
||||||
self.global_pool = pool if keep_all else None
|
self.global_pool = pool if keep_all else None
|
||||||
self.global_indices = [ i for i in range(self.length) ]
|
self.global_indices = [ i for i in range(self.length) ]
|
||||||
|
self.dataset_hash = dataset_hash
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
|
@ -45,21 +46,25 @@ class PoolSampler():
|
||||||
return self.sample(*args, **kwargs)
|
return self.sample(*args, **kwargs)
|
||||||
|
|
||||||
def get_state(self):
|
def get_state(self):
|
||||||
return { "length": self.length, "global_pool": self.global_pool, "global_indices": self.global_indices, "current_pool": self.current_pool }
|
return { "length": self.length, "global_pool": self.global_pool, "global_indices": self.global_indices, "current_pool": self.current_pool, "dataset_hash": self.dataset_hash }
|
||||||
|
|
||||||
def set_state(self, state):
|
def set_state(self, state):
|
||||||
self.length = state["length"]
|
self.length = state["length"]
|
||||||
self.global_pool = state["global_pool"]
|
self.global_pool = state["global_pool"]
|
||||||
self.global_indices = state["global_indices"]
|
self.global_indices = state["global_indices"]
|
||||||
self.current_pool = state["current_pool"]
|
self.current_pool = state["current_pool"]
|
||||||
|
# could .pop()
|
||||||
|
if "dataset_hash" in state:
|
||||||
|
self.dataset_hash = state["dataset_hash"]
|
||||||
|
|
||||||
# "Samples" through a fixed sequence from 0 to length
|
# "Samples" through a fixed sequence from 0 to length
|
||||||
# Necessary for our "shuffle+sort by duration+interleave" sampling method
|
# Necessary for our "shuffle+sort by duration+interleave" sampling method
|
||||||
# Allows saving and loading state
|
# Allows saving and loading state
|
||||||
class OrderedSampler(Sampler):
|
class OrderedSampler(Sampler):
|
||||||
def __init__( self, length ):
|
def __init__( self, length, dataset_hash=None ):
|
||||||
self.position = 0
|
self.position = 0
|
||||||
self.length = length
|
self.length = length
|
||||||
|
self.dataset_hash = dataset_hash
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.length
|
return self.length
|
||||||
|
@ -73,18 +78,22 @@ class OrderedSampler(Sampler):
|
||||||
self.position += 1
|
self.position += 1
|
||||||
|
|
||||||
def get_state(self):
|
def get_state(self):
|
||||||
return { "position": self.position, "length": self.length }
|
return { "position": self.position, "length": self.length, "dataset_hash": self.dataset_hash }
|
||||||
|
|
||||||
def set_state(self, state):
|
def set_state(self, state):
|
||||||
self.position = state["position"]
|
self.position = state["position"]
|
||||||
self.length = state["length"]
|
self.length = state["length"]
|
||||||
|
# could .pop()
|
||||||
|
if "dataset_hash" in state:
|
||||||
|
self.dataset_hash = state["dataset_hash"]
|
||||||
|
|
||||||
# Like the above, but will batch based on token count
|
# Like the above, but will batch based on token count
|
||||||
class BatchedOrderedSampler(Sampler):
|
class BatchedOrderedSampler(Sampler):
|
||||||
def __init__( self, buckets, max_duration=0, max_batch_size=0, shuffle=False ):
|
def __init__( self, buckets, max_duration=0, max_batch_size=0, shuffle=False, dataset_hash=None ):
|
||||||
self.position = 0
|
self.position = 0
|
||||||
self.batches = []
|
self.batches = []
|
||||||
self.shuffle = shuffle
|
self.shuffle = shuffle
|
||||||
|
self.dataset_hash = dataset_hash
|
||||||
|
|
||||||
assert max_duration != 0 and max_batch_size != 0, "max_duration and max_batch_size cannot both be 0"
|
assert max_duration != 0 and max_batch_size != 0, "max_duration and max_batch_size cannot both be 0"
|
||||||
|
|
||||||
|
@ -126,18 +135,22 @@ class BatchedOrderedSampler(Sampler):
|
||||||
self.position += 1
|
self.position += 1
|
||||||
|
|
||||||
def get_state(self):
|
def get_state(self):
|
||||||
return { "position": self.position, "batches": self.batches }
|
return { "position": self.position, "batches": self.batches, "dataset_hash": self.dataset_hash }
|
||||||
|
|
||||||
def set_state(self, state):
|
def set_state(self, state):
|
||||||
self.position = state["position"]
|
self.position = state["position"]
|
||||||
self.batches = state["batches"]
|
self.batches = state["batches"]
|
||||||
|
# could .pop()
|
||||||
|
if "dataset_hash" in state:
|
||||||
|
self.dataset_hash = state["dataset_hash"]
|
||||||
|
|
||||||
# Randomly samples indices from a given sequence from 0 to length
|
# Randomly samples indices from a given sequence from 0 to length
|
||||||
# Allows saving and loading state
|
# Allows saving and loading state
|
||||||
class RandomSampler(Sampler):
|
class RandomSampler(Sampler):
|
||||||
def __init__( self, length ):
|
def __init__( self, length, dataset_hash=None ):
|
||||||
self.position = 0
|
self.position = 0
|
||||||
self.length = length
|
self.length = length
|
||||||
|
self.dataset_hash = dataset_hash
|
||||||
|
|
||||||
self.generator = torch.Generator()
|
self.generator = torch.Generator()
|
||||||
self.perm = torch.randperm(self.length, generator=self.generator)
|
self.perm = torch.randperm(self.length, generator=self.generator)
|
||||||
|
@ -155,10 +168,13 @@ class RandomSampler(Sampler):
|
||||||
self.position += 1
|
self.position += 1
|
||||||
|
|
||||||
def get_state(self):
|
def get_state(self):
|
||||||
return { "position": self.position, "length": self.length, "perm": self.perm, "generator": self.generator.get_state() }
|
return { "position": self.position, "length": self.length, "perm": self.perm, "generator": self.generator.get_state(), "dataset_hash": self.dataset_hash }
|
||||||
|
|
||||||
def set_state(self, state):
|
def set_state(self, state):
|
||||||
self.position = state["position"]
|
self.position = state["position"]
|
||||||
self.length = state["length"]
|
self.length = state["length"]
|
||||||
self.perm = state["perm"]
|
self.perm = state["perm"]
|
||||||
self.generator.set_state(state["generator"])
|
self.generator.set_state(state["generator"])
|
||||||
|
# could .pop()
|
||||||
|
if "dataset_hash" in state:
|
||||||
|
self.dataset_hash = state["dataset_hash"]
|
Loading…
Reference in New Issue
Block a user