dropped subtrain dataloader since its useless to duplicate
This commit is contained in:
parent
cf9df71f2c
commit
f7b8b1e825
|
@ -61,7 +61,12 @@ The dataloader handles some simple yet effective features, such as:
|
|||
* picking an input prompt from the same speaker as the sample, if the above is not requested
|
||||
* preparing the input sequence for the given task (such as non-TTS tasks)
|
||||
|
||||
The initial list of paths is cached through `diskcache`, if `cfg.dataset.cache == True`. Be sure to delete the resultant `.cache` folder, as well as the `sampler.*` state dicts alongside checkpoints, if you plan to modify the dataloader settings between training sessions.
|
||||
If `cfg.dataset.cache == True`, the initial list of paths and duration metadata (used for sorting/bucketing) is cached ~~through `diskcache`~~ under `{YAML_PATH}/.cache/{DATASET_HASH}/`. To allow for seamless modifications to the loaded dataset, the `DATASET_HASH` relies on:
|
||||
* duration range
|
||||
* folders/groups in the dataset
|
||||
* if using HDF5 (due to the key format differing)
|
||||
|
||||
Be sure to delete the resultant `.cache` folder, as well as the `sampler.*` state dicts alongside checkpoints, if you plan to modify the dataloader settings between training sessions.
|
||||
|
||||
## Tasks
|
||||
|
||||
|
@ -95,7 +100,7 @@ This section may be covered elsewhere in the documentation, but coverage here sh
|
|||
* `nse`: noisy speech editing.
|
||||
* the above, but injects some noise throughout the sampled utterances.
|
||||
|
||||
A mystical `rvc` for emulating RVC speech-to-speech synthesis is possible, but requires a dataset to do so.
|
||||
A mystical `vc` for performing voice conversion is possible, but either requires a dataset to do so, or abusing an emergent property.
|
||||
|
||||
## `__main__`
|
||||
|
||||
|
|
|
@ -60,7 +60,7 @@ Descript-Audio-Codec was thoroughly tested for promising much, much cleaner outp
|
|||
|
||||
However, due to the nature of the codec, simply throwing it at an attention-based transformer proves to be painful, as a unified AR+NAR model *heavily* suffers from noisy output in the NAR.
|
||||
|
||||
Ironically, testing through mal-encoded audio (feeding 24KHz audio without upsampling to 44.1KHz) proved to have "cleaner" but bad utterances.
|
||||
Ironically, testing through erroneously encoded audio (feeding 24KHz audio without upsampling to 44.1KHz) proved to have "cleaner" but bad utterances.
|
||||
|
||||
I'm uncertain on how to remedy this, as my options are:
|
||||
* train under a RetNet, if an attention-based transformer is simply the problem
|
||||
|
|
|
@ -534,8 +534,11 @@ _durations_map = {}
|
|||
def _get_duration_map( type="training" ):
|
||||
return _durations_map[type] if type in _durations_map else {}
|
||||
|
||||
def _load_paths(dataset, type="training", silent=False):
|
||||
cached_dir = cfg.cache_dir / cfg.dataset.hash_key(sorted(dataset))
|
||||
def _load_paths(dataset, type="training", silent=False, dataset_hash_key=None):
|
||||
if not dataset_hash_key:
|
||||
dataset_hash_key = cfg.dataset.hash_key(sorted(dataset))
|
||||
|
||||
cached_dir = cfg.cache_dir / dataset_hash_key
|
||||
|
||||
cached_durations_path = cached_dir / f"durations[{type}].json"
|
||||
cached_paths_path = cached_dir / f"dataloader[{type}].json"
|
||||
|
@ -691,11 +694,13 @@ class Dataset(_Dataset):
|
|||
|
||||
self.training = training
|
||||
self.dataset_type = "training" if self.training else "validation"
|
||||
self.dataset = cfg.dataset.training if self.training else cfg.dataset.validation
|
||||
self.dataset = sorted(cfg.dataset.training if self.training else cfg.dataset.validation)
|
||||
self.sampler_type = cfg.dataset.sample_type if self.dataset_type == "training" else "path"
|
||||
self.sampler_order = cfg.dataset.sample_order
|
||||
self.sampler_shuffle = cfg.dataset.sample_shuffle if self.dataset_type == "training" else True
|
||||
|
||||
self.dataset_hash_key = cfg.dataset.hash_key(sorted(self.dataset))
|
||||
|
||||
# to-do: do not do validation if there's nothing in the validation
|
||||
# this just makes it be happy
|
||||
if len(self.dataset) == 0:
|
||||
|
@ -706,7 +711,7 @@ class Dataset(_Dataset):
|
|||
raise Exception(f'Requesting sample_type={self.sampler_type} with sample_order={self.sampler_order}, yet combination will not give expected results.')
|
||||
|
||||
# dict of paths keyed by speaker names
|
||||
self.paths_by_spkr_name = _load_paths(self.dataset, self.dataset_type)
|
||||
self.paths_by_spkr_name = _load_paths(self.dataset, self.dataset_type, dataset_hash_key=self.dataset_hash_key)
|
||||
# do it here due to the above
|
||||
self.duration = 0
|
||||
self.duration_map = _get_duration_map( self.dataset_type )
|
||||
|
@ -779,10 +784,6 @@ class Dataset(_Dataset):
|
|||
else:
|
||||
# just interleave
|
||||
self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)]
|
||||
|
||||
# dereference buckets
|
||||
self.duration_map = None
|
||||
self.duration_buckets = None
|
||||
|
||||
# dict of speakers keyed by speaker group
|
||||
self.spkrs_by_spkr_group = {}
|
||||
|
@ -830,16 +831,21 @@ class Dataset(_Dataset):
|
|||
self.duration_buckets if not self.sampler_state_dict_path.exists() else {}, # pass nothing if we're just going to load from a state anyways
|
||||
max_duration=cfg.dataset.sample_max_duration_batch,
|
||||
max_batch_size=cfg.hyperparameters.batch_size if self.training else cfg.evaluation.batch_size,
|
||||
shuffle=self.sampler_shuffle
|
||||
shuffle=self.sampler_shuffle,
|
||||
dataset_hash=self.dataset_hash_key,
|
||||
)
|
||||
else:
|
||||
self.sampler = OrderedSampler( len(self) ) if not self.sampler_shuffle else RandomSampler( len(self) )
|
||||
self.sampler = OrderedSampler( len(self), dataset_hash=self.dataset_hash_key ) if not self.sampler_shuffle else RandomSampler( len(self), dataset_hash=self.dataset_hash_key )
|
||||
self.samplers = {}
|
||||
self.spkr_samplers = {}
|
||||
else:
|
||||
self.sampler = RandomSampler( len(self) )
|
||||
self.samplers = { name: PoolSampler( paths, keep_all=True, shuffle=self.sampler_shuffle ) for name, paths in self.paths_by_spkr_name.items() }
|
||||
self.spkr_samplers = { name: PoolSampler( [*set(speakers)], keep_all=True, shuffle=self.sampler_shuffle ) for name, speakers in self.spkrs_by_spkr_group.items() }
|
||||
self.sampler = RandomSampler( len(self), dataset_hash=self.dataset_hash_key )
|
||||
self.samplers = { name: PoolSampler( paths, keep_all=True, shuffle=self.sampler_shuffle, dataset_hash=self.dataset_hash_key ) for name, paths in self.paths_by_spkr_name.items() }
|
||||
self.spkr_samplers = { name: PoolSampler( [*set(speakers)], keep_all=True, shuffle=self.sampler_shuffle, dataset_hash=self.dataset_hash_key ) for name, speakers in self.spkrs_by_spkr_group.items() }
|
||||
|
||||
# dereference buckets
|
||||
self.duration_map = None
|
||||
self.duration_buckets = None
|
||||
|
||||
self.load_state_dict()
|
||||
|
||||
|
@ -903,6 +909,7 @@ class Dataset(_Dataset):
|
|||
return
|
||||
|
||||
state_dict = torch_load(path)
|
||||
|
||||
if self.sampler_type == "path":
|
||||
state_dict = self.sampler.set_state(state_dict)
|
||||
else:
|
||||
|
@ -1461,16 +1468,8 @@ def create_val_dataloader():
|
|||
# to-do, use the above two, then create the subtrain dataset
|
||||
def create_train_val_dataloader():
|
||||
train_dataset, val_dataset = create_datasets()
|
||||
|
||||
# deepcopy is slow
|
||||
subtrain_dataset = Dataset( training=True )
|
||||
|
||||
if subtrain_dataset.sampler_type == "path":
|
||||
subtrain_dataset.head_(cfg.evaluation.size)
|
||||
|
||||
train_dl = _create_dataloader(train_dataset, training=True)
|
||||
val_dl = _create_dataloader(val_dataset, training=False)
|
||||
subtrain_dl = _create_dataloader(subtrain_dataset, training=False)
|
||||
|
||||
_logger.info(str(train_dataset.phone_symmap))
|
||||
_logger.info(str(train_dataset.spkr_symmap))
|
||||
|
@ -1478,18 +1477,14 @@ def create_train_val_dataloader():
|
|||
|
||||
_logger.info(f"#samples (train): {len(train_dataset)}.")
|
||||
_logger.info(f"#samples (val): {len(val_dataset)}.")
|
||||
_logger.info(f"#samples (subtrain): {len(subtrain_dataset)}.")
|
||||
|
||||
_logger.info(f"#duration (train): {str(train_dataset.duration)}.")
|
||||
_logger.info(f"#duration (val): {str(val_dataset.duration)}.")
|
||||
_logger.info(f"#duration (subtrain): {str(subtrain_dataset.duration)}.")
|
||||
|
||||
assert isinstance(subtrain_dl.dataset, Dataset)
|
||||
|
||||
# remove duration map (it gets bloated)
|
||||
_durations_map = {}
|
||||
|
||||
return train_dl, subtrain_dl, val_dl
|
||||
return train_dl, val_dl
|
||||
|
||||
# parse metadata from an numpy file (.enc/.dac) and validate it
|
||||
def process_artifact_metadata( artifact ):
|
||||
|
|
|
@ -108,7 +108,15 @@ def run_eval(engines, eval_name, dl, args=None):
|
|||
|
||||
processed = 0
|
||||
while processed < cfg.evaluation.size:
|
||||
batch = to_device(next(iter(dl)), cfg.device)
|
||||
# directly randomly sample
|
||||
if eval_name == "subtrain":
|
||||
# sample from dataset
|
||||
# to-do: derive from current iteration
|
||||
samples = [ to_device(dl.dataset[random.randint( 0, len( dl.dataset ) )], cfg.device) for sample in range( cfg.evaluation.batch_size ) ]
|
||||
# collate manually
|
||||
batch = {k: [s[k] for s in samples] for k in samples[0]}
|
||||
else:
|
||||
batch = to_device(next(iter(dl)), cfg.device)
|
||||
|
||||
# limit to eval batch size in the event we somehow have a weird dataloader
|
||||
for key in batch.keys():
|
||||
|
@ -209,14 +217,14 @@ def train():
|
|||
if cfg.yaml_path is not None and is_global_leader():
|
||||
shutil.copy( cfg.yaml_path, cfg.log_dir / "config.yaml" )
|
||||
# create dataloaders
|
||||
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
||||
train_dl, val_dl = create_train_val_dataloader()
|
||||
# evaluation lambda
|
||||
def eval_fn(engines):
|
||||
do_gc()
|
||||
engines.eval()
|
||||
# wrapped in a try block because it's sometimes prone to breaking
|
||||
try:
|
||||
run_eval(engines, "subtrain", subtrain_dl, args)
|
||||
run_eval(engines, "subtrain", train_dl, args)
|
||||
run_eval(engines, "val", val_dl, args)
|
||||
except Exception as e:
|
||||
_logger.warning(f"Error occurred while performing eval: {str(e)}")
|
||||
|
|
|
@ -9,11 +9,12 @@ from .distributed import global_rank, local_rank, world_size
|
|||
|
||||
# Randomly picks an index from an array of indices
|
||||
class PoolSampler():
|
||||
def __init__( self, pool = [], keep_all = False, shuffle = False ):
|
||||
def __init__( self, pool = [], keep_all = False, shuffle = False, dataset_hash = None ):
|
||||
self.length = len(pool)
|
||||
self.shuffle = shuffle
|
||||
self.global_pool = pool if keep_all else None
|
||||
self.global_indices = [ i for i in range(self.length) ]
|
||||
self.dataset_hash = dataset_hash
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
|
@ -45,21 +46,25 @@ class PoolSampler():
|
|||
return self.sample(*args, **kwargs)
|
||||
|
||||
def get_state(self):
|
||||
return { "length": self.length, "global_pool": self.global_pool, "global_indices": self.global_indices, "current_pool": self.current_pool }
|
||||
return { "length": self.length, "global_pool": self.global_pool, "global_indices": self.global_indices, "current_pool": self.current_pool, "dataset_hash": self.dataset_hash }
|
||||
|
||||
def set_state(self, state):
|
||||
self.length = state["length"]
|
||||
self.global_pool = state["global_pool"]
|
||||
self.global_indices = state["global_indices"]
|
||||
self.current_pool = state["current_pool"]
|
||||
# could .pop()
|
||||
if "dataset_hash" in state:
|
||||
self.dataset_hash = state["dataset_hash"]
|
||||
|
||||
# "Samples" through a fixed sequence from 0 to length
|
||||
# Necessary for our "shuffle+sort by duration+interleave" sampling method
|
||||
# Allows saving and loading state
|
||||
class OrderedSampler(Sampler):
|
||||
def __init__( self, length ):
|
||||
def __init__( self, length, dataset_hash=None ):
|
||||
self.position = 0
|
||||
self.length = length
|
||||
self.dataset_hash = dataset_hash
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
@ -73,18 +78,22 @@ class OrderedSampler(Sampler):
|
|||
self.position += 1
|
||||
|
||||
def get_state(self):
|
||||
return { "position": self.position, "length": self.length }
|
||||
return { "position": self.position, "length": self.length, "dataset_hash": self.dataset_hash }
|
||||
|
||||
def set_state(self, state):
|
||||
self.position = state["position"]
|
||||
self.length = state["length"]
|
||||
# could .pop()
|
||||
if "dataset_hash" in state:
|
||||
self.dataset_hash = state["dataset_hash"]
|
||||
|
||||
# Like the above, but will batch based on token count
|
||||
class BatchedOrderedSampler(Sampler):
|
||||
def __init__( self, buckets, max_duration=0, max_batch_size=0, shuffle=False ):
|
||||
def __init__( self, buckets, max_duration=0, max_batch_size=0, shuffle=False, dataset_hash=None ):
|
||||
self.position = 0
|
||||
self.batches = []
|
||||
self.shuffle = shuffle
|
||||
self.dataset_hash = dataset_hash
|
||||
|
||||
assert max_duration != 0 and max_batch_size != 0, "max_duration and max_batch_size cannot both be 0"
|
||||
|
||||
|
@ -126,18 +135,22 @@ class BatchedOrderedSampler(Sampler):
|
|||
self.position += 1
|
||||
|
||||
def get_state(self):
|
||||
return { "position": self.position, "batches": self.batches }
|
||||
return { "position": self.position, "batches": self.batches, "dataset_hash": self.dataset_hash }
|
||||
|
||||
def set_state(self, state):
|
||||
self.position = state["position"]
|
||||
self.batches = state["batches"]
|
||||
# could .pop()
|
||||
if "dataset_hash" in state:
|
||||
self.dataset_hash = state["dataset_hash"]
|
||||
|
||||
# Randomly samples indices from a given sequence from 0 to length
|
||||
# Allows saving and loading state
|
||||
class RandomSampler(Sampler):
|
||||
def __init__( self, length ):
|
||||
def __init__( self, length, dataset_hash=None ):
|
||||
self.position = 0
|
||||
self.length = length
|
||||
self.dataset_hash = dataset_hash
|
||||
|
||||
self.generator = torch.Generator()
|
||||
self.perm = torch.randperm(self.length, generator=self.generator)
|
||||
|
@ -155,10 +168,13 @@ class RandomSampler(Sampler):
|
|||
self.position += 1
|
||||
|
||||
def get_state(self):
|
||||
return { "position": self.position, "length": self.length, "perm": self.perm, "generator": self.generator.get_state() }
|
||||
return { "position": self.position, "length": self.length, "perm": self.perm, "generator": self.generator.get_state(), "dataset_hash": self.dataset_hash }
|
||||
|
||||
def set_state(self, state):
|
||||
self.position = state["position"]
|
||||
self.length = state["length"]
|
||||
self.perm = state["perm"]
|
||||
self.generator.set_state(state["generator"])
|
||||
self.generator.set_state(state["generator"])
|
||||
# could .pop()
|
||||
if "dataset_hash" in state:
|
||||
self.dataset_hash = state["dataset_hash"]
|
Loading…
Reference in New Issue
Block a user