diff --git a/vall_e/data.py b/vall_e/data.py index ab1ffdb..1a643c7 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -534,7 +534,7 @@ _durations_map = {} def _get_duration_map( type="training" ): return _durations_map[type] if type in _durations_map else {} -def _load_paths(dataset, type="training", silent=False, dataset_hash_key=None): +def _load_paths(dataset, type="training", silent=not is_global_leader(), dataset_hash_key=None): if not dataset_hash_key: dataset_hash_key = cfg.dataset.hash_key(sorted(dataset)) @@ -750,11 +750,11 @@ class Dataset(_Dataset): self.duration_buckets[bucket] = [] self.duration_buckets[bucket].append( ( Path(path), duration ) ) - # ensure they're ordered - self.duration_buckets = dict(sorted(self.duration_buckets.items())) - # sort by duration if self.sampler_order == "duration": + # ensure they're ordered + self.duration_buckets = dict(sorted(self.duration_buckets.items())) + flattened = {} # sort and interleave for bucket in self.duration_buckets: diff --git a/vall_e/models/base.py b/vall_e/models/base.py index f9f495b..905e639 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -1117,10 +1117,12 @@ class Base(nn.Module): # pre-iterate for name, input in batch_input: - if name == "dropout_mask": - dropout_mask = input - elif name == "classifier_level": + if name == "classifier_level": classifier_level = input + elif name == "dropout_mask": + dropout_mask = input + elif name == "timestep": + timestep = input for name, input in batch_input: # technically can provide a map for input_name => embedding, but some embedding requires additional processing @@ -1169,6 +1171,15 @@ class Base(nn.Module): #quant_level = 0, name = classifier_level, ) + # NAR-len + elif classifier_level == "NAR:0:0": + embedding = self.resps_emb( + # if masked use masked token, else original token + input if input.dim() == 1 else input[:, 0], + #offset = -1 if self.masking_separate_embeddings else 0, # pick last + #quant_level = 0, + name = classifier_level, + ) # cheat-y way to handle performing STT across all levels elif task_type in summed_embeddings_task: # we do a manual sum because I trained it to use the AR embeddings + NAR embeddings for STT...... @@ -1215,7 +1226,7 @@ class Base(nn.Module): continue embedding[i] = self.dropout_token - elif name == "timestep" and self.time_emb is not None: + elif name == "timestep" and self.time_emb is not None and False: embedding = self.time_emb( input ) elif name == "len" and self.len_emb is not None: embedding = self.len_emb( input )