actually use the right embedding for nar-len
This commit is contained in:
parent
3ea8a610d6
commit
c00fc18b62
|
@ -534,7 +534,7 @@ _durations_map = {}
|
|||
def _get_duration_map( type="training" ):
|
||||
return _durations_map[type] if type in _durations_map else {}
|
||||
|
||||
def _load_paths(dataset, type="training", silent=False, dataset_hash_key=None):
|
||||
def _load_paths(dataset, type="training", silent=not is_global_leader(), dataset_hash_key=None):
|
||||
if not dataset_hash_key:
|
||||
dataset_hash_key = cfg.dataset.hash_key(sorted(dataset))
|
||||
|
||||
|
@ -750,11 +750,11 @@ class Dataset(_Dataset):
|
|||
self.duration_buckets[bucket] = []
|
||||
self.duration_buckets[bucket].append( ( Path(path), duration ) )
|
||||
|
||||
# ensure they're ordered
|
||||
self.duration_buckets = dict(sorted(self.duration_buckets.items()))
|
||||
|
||||
# sort by duration
|
||||
if self.sampler_order == "duration":
|
||||
# ensure they're ordered
|
||||
self.duration_buckets = dict(sorted(self.duration_buckets.items()))
|
||||
|
||||
flattened = {}
|
||||
# sort and interleave
|
||||
for bucket in self.duration_buckets:
|
||||
|
|
|
@ -1117,10 +1117,12 @@ class Base(nn.Module):
|
|||
|
||||
# pre-iterate
|
||||
for name, input in batch_input:
|
||||
if name == "dropout_mask":
|
||||
dropout_mask = input
|
||||
elif name == "classifier_level":
|
||||
if name == "classifier_level":
|
||||
classifier_level = input
|
||||
elif name == "dropout_mask":
|
||||
dropout_mask = input
|
||||
elif name == "timestep":
|
||||
timestep = input
|
||||
|
||||
for name, input in batch_input:
|
||||
# technically can provide a map for input_name => embedding, but some embedding requires additional processing
|
||||
|
@ -1169,6 +1171,15 @@ class Base(nn.Module):
|
|||
#quant_level = 0,
|
||||
name = classifier_level,
|
||||
)
|
||||
# NAR-len
|
||||
elif classifier_level == "NAR:0:0":
|
||||
embedding = self.resps_emb(
|
||||
# if masked use masked token, else original token
|
||||
input if input.dim() == 1 else input[:, 0],
|
||||
#offset = -1 if self.masking_separate_embeddings else 0, # pick last
|
||||
#quant_level = 0,
|
||||
name = classifier_level,
|
||||
)
|
||||
# cheat-y way to handle performing STT across all levels
|
||||
elif task_type in summed_embeddings_task:
|
||||
# we do a manual sum because I trained it to use the AR embeddings + NAR embeddings for STT......
|
||||
|
@ -1215,7 +1226,7 @@ class Base(nn.Module):
|
|||
continue
|
||||
|
||||
embedding[i] = self.dropout_token
|
||||
elif name == "timestep" and self.time_emb is not None:
|
||||
elif name == "timestep" and self.time_emb is not None and False:
|
||||
embedding = self.time_emb( input )
|
||||
elif name == "len" and self.len_emb is not None:
|
||||
embedding = self.len_emb( input )
|
||||
|
|
Loading…
Reference in New Issue
Block a user