This commit is contained in:
mrq 2024-09-05 21:42:59 -05:00
parent 54547b74d8
commit 413097f5f7
3 changed files with 3 additions and 2 deletions

View File

@ -187,6 +187,7 @@ def load_engines(training=True, **model_kwargs):
uses_stop_token = 1 if "len" not in model.capabilities and model.causal_size > 0 else 0
keys = [
("text_emb.weight", model.config.text_tokens ),
("tasks_emb.weight", model.config.tasks ),
("rvq_l_emb.weight", model.config.resp_levels + (1 if "len" in model.config.capabilities else 0) ),
("resps_emb.embeddings.0.weight", model.config.audio_tokens + uses_stop_token ),
("model.embed_tokens.weight", model.config.audio_tokens + uses_stop_token ),

View File

@ -424,7 +424,7 @@ def example_usage():
"""
bos_id, space_id, eos_id = cfg.tokenizer.encode( " " )
available_tasks = ["tts", "stt"] # cfg.dataset.tasks_list
available_tasks = cfg.dataset.tasks_list
model = AR_NAR(**kwargs).to(device)
steps = 150 * len(available_tasks) # * cfg.model.experimental.causal_size

View File

@ -280,7 +280,7 @@ class Classifiers(nn.Module):
xi = [
#x if l == 0 else
x if x.shape[-1] == max_size else
torch.cat( [ x, torch.tensor( [[ -float("inf") ] for _ in range(x.shape[0])], device=device, dtype=dtype) ] * (max_size - x.shape[-1]), dim=-1 )
torch.cat( [ x, torch.full( (x.shape[0], max_size - x.shape[-1]), -float("inf"), device=device, dtype=dtype) ], dim=-1 )
for x, l in zip(xi, levels)
]
return torch.stack( xi )