dropped subtrain dataloader since its useless to duplicate

This commit is contained in:
mrq 2024-11-11 17:00:49 -06:00
parent cf9df71f2c
commit f7b8b1e825
5 changed files with 65 additions and 41 deletions

View File

@ -61,7 +61,12 @@ The dataloader handles some simple yet effective features, such as:
* picking an input prompt from the same speaker as the sample, if the above is not requested
* preparing the input sequence for the given task (such as non-TTS tasks)
The initial list of paths is cached through `diskcache`, if `cfg.dataset.cache == True`. Be sure to delete the resultant `.cache` folder, as well as the `sampler.*` state dicts alongside checkpoints, if you plan to modify the dataloader settings between training sessions.
If `cfg.dataset.cache == True`, the initial list of paths and duration metadata (used for sorting/bucketing) is cached ~~through `diskcache`~~ under `{YAML_PATH}/.cache/{DATASET_HASH}/`. To allow for seamless modifications to the loaded dataset, the `DATASET_HASH` relies on:
* duration range
* folders/groups in the dataset
* if using HDF5 (due to the key format differing)
Be sure to delete the resultant `.cache` folder, as well as the `sampler.*` state dicts alongside checkpoints, if you plan to modify the dataloader settings between training sessions.
## Tasks
@ -95,7 +100,7 @@ This section may be covered elsewhere in the documentation, but coverage here sh
* `nse`: noisy speech editing.
* the above, but injects some noise throughout the sampled utterances.
A mystical `rvc` for emulating RVC speech-to-speech synthesis is possible, but requires a dataset to do so.
A mystical `vc` for performing voice conversion is possible, but either requires a dataset to do so, or abusing an emergent property.
## `__main__`

View File

@ -60,7 +60,7 @@ Descript-Audio-Codec was thoroughly tested for promising much, much cleaner outp
However, due to the nature of the codec, simply throwing it at an attention-based transformer proves to be painful, as a unified AR+NAR model *heavily* suffers from noisy output in the NAR.
Ironically, testing through mal-encoded audio (feeding 24KHz audio without upsampling to 44.1KHz) proved to have "cleaner" but bad utterances.
Ironically, testing through erroneously encoded audio (feeding 24KHz audio without upsampling to 44.1KHz) proved to have "cleaner" but bad utterances.
I'm uncertain on how to remedy this, as my options are:
* train under a RetNet, if an attention-based transformer is simply the problem

View File

@ -534,8 +534,11 @@ _durations_map = {}
def _get_duration_map( type="training" ):
return _durations_map[type] if type in _durations_map else {}
def _load_paths(dataset, type="training", silent=False):
cached_dir = cfg.cache_dir / cfg.dataset.hash_key(sorted(dataset))
def _load_paths(dataset, type="training", silent=False, dataset_hash_key=None):
if not dataset_hash_key:
dataset_hash_key = cfg.dataset.hash_key(sorted(dataset))
cached_dir = cfg.cache_dir / dataset_hash_key
cached_durations_path = cached_dir / f"durations[{type}].json"
cached_paths_path = cached_dir / f"dataloader[{type}].json"
@ -691,11 +694,13 @@ class Dataset(_Dataset):
self.training = training
self.dataset_type = "training" if self.training else "validation"
self.dataset = cfg.dataset.training if self.training else cfg.dataset.validation
self.dataset = sorted(cfg.dataset.training if self.training else cfg.dataset.validation)
self.sampler_type = cfg.dataset.sample_type if self.dataset_type == "training" else "path"
self.sampler_order = cfg.dataset.sample_order
self.sampler_shuffle = cfg.dataset.sample_shuffle if self.dataset_type == "training" else True
self.dataset_hash_key = cfg.dataset.hash_key(sorted(self.dataset))
# to-do: do not do validation if there's nothing in the validation
# this just makes it be happy
if len(self.dataset) == 0:
@ -706,7 +711,7 @@ class Dataset(_Dataset):
raise Exception(f'Requesting sample_type={self.sampler_type} with sample_order={self.sampler_order}, yet combination will not give expected results.')
# dict of paths keyed by speaker names
self.paths_by_spkr_name = _load_paths(self.dataset, self.dataset_type)
self.paths_by_spkr_name = _load_paths(self.dataset, self.dataset_type, dataset_hash_key=self.dataset_hash_key)
# do it here due to the above
self.duration = 0
self.duration_map = _get_duration_map( self.dataset_type )
@ -779,10 +784,6 @@ class Dataset(_Dataset):
else:
# just interleave
self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)]
# dereference buckets
self.duration_map = None
self.duration_buckets = None
# dict of speakers keyed by speaker group
self.spkrs_by_spkr_group = {}
@ -830,16 +831,21 @@ class Dataset(_Dataset):
self.duration_buckets if not self.sampler_state_dict_path.exists() else {}, # pass nothing if we're just going to load from a state anyways
max_duration=cfg.dataset.sample_max_duration_batch,
max_batch_size=cfg.hyperparameters.batch_size if self.training else cfg.evaluation.batch_size,
shuffle=self.sampler_shuffle
shuffle=self.sampler_shuffle,
dataset_hash=self.dataset_hash_key,
)
else:
self.sampler = OrderedSampler( len(self) ) if not self.sampler_shuffle else RandomSampler( len(self) )
self.sampler = OrderedSampler( len(self), dataset_hash=self.dataset_hash_key ) if not self.sampler_shuffle else RandomSampler( len(self), dataset_hash=self.dataset_hash_key )
self.samplers = {}
self.spkr_samplers = {}
else:
self.sampler = RandomSampler( len(self) )
self.samplers = { name: PoolSampler( paths, keep_all=True, shuffle=self.sampler_shuffle ) for name, paths in self.paths_by_spkr_name.items() }
self.spkr_samplers = { name: PoolSampler( [*set(speakers)], keep_all=True, shuffle=self.sampler_shuffle ) for name, speakers in self.spkrs_by_spkr_group.items() }
self.sampler = RandomSampler( len(self), dataset_hash=self.dataset_hash_key )
self.samplers = { name: PoolSampler( paths, keep_all=True, shuffle=self.sampler_shuffle, dataset_hash=self.dataset_hash_key ) for name, paths in self.paths_by_spkr_name.items() }
self.spkr_samplers = { name: PoolSampler( [*set(speakers)], keep_all=True, shuffle=self.sampler_shuffle, dataset_hash=self.dataset_hash_key ) for name, speakers in self.spkrs_by_spkr_group.items() }
# dereference buckets
self.duration_map = None
self.duration_buckets = None
self.load_state_dict()
@ -903,6 +909,7 @@ class Dataset(_Dataset):
return
state_dict = torch_load(path)
if self.sampler_type == "path":
state_dict = self.sampler.set_state(state_dict)
else:
@ -1461,16 +1468,8 @@ def create_val_dataloader():
# to-do, use the above two, then create the subtrain dataset
def create_train_val_dataloader():
train_dataset, val_dataset = create_datasets()
# deepcopy is slow
subtrain_dataset = Dataset( training=True )
if subtrain_dataset.sampler_type == "path":
subtrain_dataset.head_(cfg.evaluation.size)
train_dl = _create_dataloader(train_dataset, training=True)
val_dl = _create_dataloader(val_dataset, training=False)
subtrain_dl = _create_dataloader(subtrain_dataset, training=False)
_logger.info(str(train_dataset.phone_symmap))
_logger.info(str(train_dataset.spkr_symmap))
@ -1478,18 +1477,14 @@ def create_train_val_dataloader():
_logger.info(f"#samples (train): {len(train_dataset)}.")
_logger.info(f"#samples (val): {len(val_dataset)}.")
_logger.info(f"#samples (subtrain): {len(subtrain_dataset)}.")
_logger.info(f"#duration (train): {str(train_dataset.duration)}.")
_logger.info(f"#duration (val): {str(val_dataset.duration)}.")
_logger.info(f"#duration (subtrain): {str(subtrain_dataset.duration)}.")
assert isinstance(subtrain_dl.dataset, Dataset)
# remove duration map (it gets bloated)
_durations_map = {}
return train_dl, subtrain_dl, val_dl
return train_dl, val_dl
# parse metadata from an numpy file (.enc/.dac) and validate it
def process_artifact_metadata( artifact ):

View File

@ -108,7 +108,15 @@ def run_eval(engines, eval_name, dl, args=None):
processed = 0
while processed < cfg.evaluation.size:
batch = to_device(next(iter(dl)), cfg.device)
# directly randomly sample
if eval_name == "subtrain":
# sample from dataset
# to-do: derive from current iteration
samples = [ to_device(dl.dataset[random.randint( 0, len( dl.dataset ) )], cfg.device) for sample in range( cfg.evaluation.batch_size ) ]
# collate manually
batch = {k: [s[k] for s in samples] for k in samples[0]}
else:
batch = to_device(next(iter(dl)), cfg.device)
# limit to eval batch size in the event we somehow have a weird dataloader
for key in batch.keys():
@ -209,14 +217,14 @@ def train():
if cfg.yaml_path is not None and is_global_leader():
shutil.copy( cfg.yaml_path, cfg.log_dir / "config.yaml" )
# create dataloaders
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
train_dl, val_dl = create_train_val_dataloader()
# evaluation lambda
def eval_fn(engines):
do_gc()
engines.eval()
# wrapped in a try block because it's sometimes prone to breaking
try:
run_eval(engines, "subtrain", subtrain_dl, args)
run_eval(engines, "subtrain", train_dl, args)
run_eval(engines, "val", val_dl, args)
except Exception as e:
_logger.warning(f"Error occurred while performing eval: {str(e)}")

View File

@ -9,11 +9,12 @@ from .distributed import global_rank, local_rank, world_size
# Randomly picks an index from an array of indices
class PoolSampler():
def __init__( self, pool = [], keep_all = False, shuffle = False ):
def __init__( self, pool = [], keep_all = False, shuffle = False, dataset_hash = None ):
self.length = len(pool)
self.shuffle = shuffle
self.global_pool = pool if keep_all else None
self.global_indices = [ i for i in range(self.length) ]
self.dataset_hash = dataset_hash
self.reset()
def reset(self):
@ -45,21 +46,25 @@ class PoolSampler():
return self.sample(*args, **kwargs)
def get_state(self):
return { "length": self.length, "global_pool": self.global_pool, "global_indices": self.global_indices, "current_pool": self.current_pool }
return { "length": self.length, "global_pool": self.global_pool, "global_indices": self.global_indices, "current_pool": self.current_pool, "dataset_hash": self.dataset_hash }
def set_state(self, state):
self.length = state["length"]
self.global_pool = state["global_pool"]
self.global_indices = state["global_indices"]
self.current_pool = state["current_pool"]
# could .pop()
if "dataset_hash" in state:
self.dataset_hash = state["dataset_hash"]
# "Samples" through a fixed sequence from 0 to length
# Necessary for our "shuffle+sort by duration+interleave" sampling method
# Allows saving and loading state
class OrderedSampler(Sampler):
def __init__( self, length ):
def __init__( self, length, dataset_hash=None ):
self.position = 0
self.length = length
self.dataset_hash = dataset_hash
def __len__(self):
return self.length
@ -73,18 +78,22 @@ class OrderedSampler(Sampler):
self.position += 1
def get_state(self):
return { "position": self.position, "length": self.length }
return { "position": self.position, "length": self.length, "dataset_hash": self.dataset_hash }
def set_state(self, state):
self.position = state["position"]
self.length = state["length"]
# could .pop()
if "dataset_hash" in state:
self.dataset_hash = state["dataset_hash"]
# Like the above, but will batch based on token count
class BatchedOrderedSampler(Sampler):
def __init__( self, buckets, max_duration=0, max_batch_size=0, shuffle=False ):
def __init__( self, buckets, max_duration=0, max_batch_size=0, shuffle=False, dataset_hash=None ):
self.position = 0
self.batches = []
self.shuffle = shuffle
self.dataset_hash = dataset_hash
assert max_duration != 0 and max_batch_size != 0, "max_duration and max_batch_size cannot both be 0"
@ -126,18 +135,22 @@ class BatchedOrderedSampler(Sampler):
self.position += 1
def get_state(self):
return { "position": self.position, "batches": self.batches }
return { "position": self.position, "batches": self.batches, "dataset_hash": self.dataset_hash }
def set_state(self, state):
self.position = state["position"]
self.batches = state["batches"]
# could .pop()
if "dataset_hash" in state:
self.dataset_hash = state["dataset_hash"]
# Randomly samples indices from a given sequence from 0 to length
# Allows saving and loading state
class RandomSampler(Sampler):
def __init__( self, length ):
def __init__( self, length, dataset_hash=None ):
self.position = 0
self.length = length
self.dataset_hash = dataset_hash
self.generator = torch.Generator()
self.perm = torch.randperm(self.length, generator=self.generator)
@ -155,10 +168,13 @@ class RandomSampler(Sampler):
self.position += 1
def get_state(self):
return { "position": self.position, "length": self.length, "perm": self.perm, "generator": self.generator.get_state() }
return { "position": self.position, "length": self.length, "perm": self.perm, "generator": self.generator.get_state(), "dataset_hash": self.dataset_hash }
def set_state(self, state):
self.position = state["position"]
self.length = state["length"]
self.perm = state["perm"]
self.generator.set_state(state["generator"])
self.generator.set_state(state["generator"])
# could .pop()
if "dataset_hash" in state:
self.dataset_hash = state["dataset_hash"]