diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index a973fe2..ffe51ad 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -187,6 +187,7 @@ def load_engines(training=True, **model_kwargs): uses_stop_token = 1 if "len" not in model.capabilities and model.causal_size > 0 else 0 keys = [ ("text_emb.weight", model.config.text_tokens ), + ("tasks_emb.weight", model.config.tasks ), ("rvq_l_emb.weight", model.config.resp_levels + (1 if "len" in model.config.capabilities else 0) ), ("resps_emb.embeddings.0.weight", model.config.audio_tokens + uses_stop_token ), ("model.embed_tokens.weight", model.config.audio_tokens + uses_stop_token ), diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 09ff810..1a6323b 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -424,7 +424,7 @@ def example_usage(): """ bos_id, space_id, eos_id = cfg.tokenizer.encode( " " ) - available_tasks = ["tts", "stt"] # cfg.dataset.tasks_list + available_tasks = cfg.dataset.tasks_list model = AR_NAR(**kwargs).to(device) steps = 150 * len(available_tasks) # * cfg.model.experimental.causal_size diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 05625d6..1607af8 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -280,7 +280,7 @@ class Classifiers(nn.Module): xi = [ #x if l == 0 else x if x.shape[-1] == max_size else - torch.cat( [ x, torch.tensor( [[ -float("inf") ] for _ in range(x.shape[0])], device=device, dtype=dtype) ] * (max_size - x.shape[-1]), dim=-1 ) + torch.cat( [ x, torch.full( (x.shape[0], max_size - x.shape[-1]), -float("inf"), device=device, dtype=dtype) ], dim=-1 ) for x, l in zip(xi, levels) ] return torch.stack( xi )