actually use the right embedding for nar-len

This commit is contained in:
mrq 2024-11-13 18:04:04 -06:00
parent 3ea8a610d6
commit c00fc18b62
2 changed files with 19 additions and 8 deletions

View File

@ -534,7 +534,7 @@ _durations_map = {}
def _get_duration_map( type="training" ):
return _durations_map[type] if type in _durations_map else {}
def _load_paths(dataset, type="training", silent=False, dataset_hash_key=None):
def _load_paths(dataset, type="training", silent=not is_global_leader(), dataset_hash_key=None):
if not dataset_hash_key:
dataset_hash_key = cfg.dataset.hash_key(sorted(dataset))
@ -750,11 +750,11 @@ class Dataset(_Dataset):
self.duration_buckets[bucket] = []
self.duration_buckets[bucket].append( ( Path(path), duration ) )
# ensure they're ordered
self.duration_buckets = dict(sorted(self.duration_buckets.items()))
# sort by duration
if self.sampler_order == "duration":
# ensure they're ordered
self.duration_buckets = dict(sorted(self.duration_buckets.items()))
flattened = {}
# sort and interleave
for bucket in self.duration_buckets:

View File

@ -1117,10 +1117,12 @@ class Base(nn.Module):
# pre-iterate
for name, input in batch_input:
if name == "dropout_mask":
dropout_mask = input
elif name == "classifier_level":
if name == "classifier_level":
classifier_level = input
elif name == "dropout_mask":
dropout_mask = input
elif name == "timestep":
timestep = input
for name, input in batch_input:
# technically can provide a map for input_name => embedding, but some embedding requires additional processing
@ -1169,6 +1171,15 @@ class Base(nn.Module):
#quant_level = 0,
name = classifier_level,
)
# NAR-len
elif classifier_level == "NAR:0:0":
embedding = self.resps_emb(
# if masked use masked token, else original token
input if input.dim() == 1 else input[:, 0],
#offset = -1 if self.masking_separate_embeddings else 0, # pick last
#quant_level = 0,
name = classifier_level,
)
# cheat-y way to handle performing STT across all levels
elif task_type in summed_embeddings_task:
# we do a manual sum because I trained it to use the AR embeddings + NAR embeddings for STT......
@ -1215,7 +1226,7 @@ class Base(nn.Module):
continue
embedding[i] = self.dropout_token
elif name == "timestep" and self.time_emb is not None:
elif name == "timestep" and self.time_emb is not None and False:
embedding = self.time_emb( input )
elif name == "len" and self.len_emb is not None:
embedding = self.len_emb( input )