actually use the right embedding for nar-len
This commit is contained in:
parent
3ea8a610d6
commit
c00fc18b62
|
@ -534,7 +534,7 @@ _durations_map = {}
|
||||||
def _get_duration_map( type="training" ):
|
def _get_duration_map( type="training" ):
|
||||||
return _durations_map[type] if type in _durations_map else {}
|
return _durations_map[type] if type in _durations_map else {}
|
||||||
|
|
||||||
def _load_paths(dataset, type="training", silent=False, dataset_hash_key=None):
|
def _load_paths(dataset, type="training", silent=not is_global_leader(), dataset_hash_key=None):
|
||||||
if not dataset_hash_key:
|
if not dataset_hash_key:
|
||||||
dataset_hash_key = cfg.dataset.hash_key(sorted(dataset))
|
dataset_hash_key = cfg.dataset.hash_key(sorted(dataset))
|
||||||
|
|
||||||
|
@ -750,11 +750,11 @@ class Dataset(_Dataset):
|
||||||
self.duration_buckets[bucket] = []
|
self.duration_buckets[bucket] = []
|
||||||
self.duration_buckets[bucket].append( ( Path(path), duration ) )
|
self.duration_buckets[bucket].append( ( Path(path), duration ) )
|
||||||
|
|
||||||
|
# sort by duration
|
||||||
|
if self.sampler_order == "duration":
|
||||||
# ensure they're ordered
|
# ensure they're ordered
|
||||||
self.duration_buckets = dict(sorted(self.duration_buckets.items()))
|
self.duration_buckets = dict(sorted(self.duration_buckets.items()))
|
||||||
|
|
||||||
# sort by duration
|
|
||||||
if self.sampler_order == "duration":
|
|
||||||
flattened = {}
|
flattened = {}
|
||||||
# sort and interleave
|
# sort and interleave
|
||||||
for bucket in self.duration_buckets:
|
for bucket in self.duration_buckets:
|
||||||
|
|
|
@ -1117,10 +1117,12 @@ class Base(nn.Module):
|
||||||
|
|
||||||
# pre-iterate
|
# pre-iterate
|
||||||
for name, input in batch_input:
|
for name, input in batch_input:
|
||||||
if name == "dropout_mask":
|
if name == "classifier_level":
|
||||||
dropout_mask = input
|
|
||||||
elif name == "classifier_level":
|
|
||||||
classifier_level = input
|
classifier_level = input
|
||||||
|
elif name == "dropout_mask":
|
||||||
|
dropout_mask = input
|
||||||
|
elif name == "timestep":
|
||||||
|
timestep = input
|
||||||
|
|
||||||
for name, input in batch_input:
|
for name, input in batch_input:
|
||||||
# technically can provide a map for input_name => embedding, but some embedding requires additional processing
|
# technically can provide a map for input_name => embedding, but some embedding requires additional processing
|
||||||
|
@ -1169,6 +1171,15 @@ class Base(nn.Module):
|
||||||
#quant_level = 0,
|
#quant_level = 0,
|
||||||
name = classifier_level,
|
name = classifier_level,
|
||||||
)
|
)
|
||||||
|
# NAR-len
|
||||||
|
elif classifier_level == "NAR:0:0":
|
||||||
|
embedding = self.resps_emb(
|
||||||
|
# if masked use masked token, else original token
|
||||||
|
input if input.dim() == 1 else input[:, 0],
|
||||||
|
#offset = -1 if self.masking_separate_embeddings else 0, # pick last
|
||||||
|
#quant_level = 0,
|
||||||
|
name = classifier_level,
|
||||||
|
)
|
||||||
# cheat-y way to handle performing STT across all levels
|
# cheat-y way to handle performing STT across all levels
|
||||||
elif task_type in summed_embeddings_task:
|
elif task_type in summed_embeddings_task:
|
||||||
# we do a manual sum because I trained it to use the AR embeddings + NAR embeddings for STT......
|
# we do a manual sum because I trained it to use the AR embeddings + NAR embeddings for STT......
|
||||||
|
@ -1215,7 +1226,7 @@ class Base(nn.Module):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
embedding[i] = self.dropout_token
|
embedding[i] = self.dropout_token
|
||||||
elif name == "timestep" and self.time_emb is not None:
|
elif name == "timestep" and self.time_emb is not None and False:
|
||||||
embedding = self.time_emb( input )
|
embedding = self.time_emb( input )
|
||||||
elif name == "len" and self.len_emb is not None:
|
elif name == "len" and self.len_emb is not None:
|
||||||
embedding = self.len_emb( input )
|
embedding = self.len_emb( input )
|
||||||
|
|
Loading…
Reference in New Issue
Block a user