fixes
This commit is contained in:
parent
54547b74d8
commit
413097f5f7
|
@ -187,6 +187,7 @@ def load_engines(training=True, **model_kwargs):
|
|||
uses_stop_token = 1 if "len" not in model.capabilities and model.causal_size > 0 else 0
|
||||
keys = [
|
||||
("text_emb.weight", model.config.text_tokens ),
|
||||
("tasks_emb.weight", model.config.tasks ),
|
||||
("rvq_l_emb.weight", model.config.resp_levels + (1 if "len" in model.config.capabilities else 0) ),
|
||||
("resps_emb.embeddings.0.weight", model.config.audio_tokens + uses_stop_token ),
|
||||
("model.embed_tokens.weight", model.config.audio_tokens + uses_stop_token ),
|
||||
|
|
|
@ -424,7 +424,7 @@ def example_usage():
|
|||
"""
|
||||
|
||||
bos_id, space_id, eos_id = cfg.tokenizer.encode( " " )
|
||||
available_tasks = ["tts", "stt"] # cfg.dataset.tasks_list
|
||||
available_tasks = cfg.dataset.tasks_list
|
||||
|
||||
model = AR_NAR(**kwargs).to(device)
|
||||
steps = 150 * len(available_tasks) # * cfg.model.experimental.causal_size
|
||||
|
|
|
@ -280,7 +280,7 @@ class Classifiers(nn.Module):
|
|||
xi = [
|
||||
#x if l == 0 else
|
||||
x if x.shape[-1] == max_size else
|
||||
torch.cat( [ x, torch.tensor( [[ -float("inf") ] for _ in range(x.shape[0])], device=device, dtype=dtype) ] * (max_size - x.shape[-1]), dim=-1 )
|
||||
torch.cat( [ x, torch.full( (x.shape[0], max_size - x.shape[-1]), -float("inf"), device=device, dtype=dtype) ], dim=-1 )
|
||||
for x, l in zip(xi, levels)
|
||||
]
|
||||
return torch.stack( xi )
|
||||
|
|
Loading…
Reference in New Issue
Block a user