From 23f3b56fda441f47b16d2b0b9f9dc28300729ff8 Mon Sep 17 00:00:00 2001 From: mrq Date: Sun, 4 Aug 2024 08:18:57 -0500 Subject: [PATCH] oops --- vall_e/models/base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 1771c32..895ab1a 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -257,7 +257,7 @@ class AudioClassifier(nn.Module): xi = [ #x if l == 0 else x if x.shape[-1] == max_size else - torch.cat( [ x, torch.tensor( [[ -float("inf") ] for _ in range(x.shape[0])] ).to(dtype=dtype, device=device) ] * (max_size - x.shape[-1]), dim=-1 ) + torch.cat( [ x, torch.tensor( [[ -float("inf") ] for _ in range(x.shape[0])], device=device, dtype=dtype) ] * (max_size - x.shape[-1]), dim=-1 ) for x, l in zip(xi, levels) ] return torch.stack( xi ) @@ -1098,7 +1098,7 @@ class Base(nn.Module): delta = ids[batch_index].shape[0] - batch.shape[0] if delta > 0: - batch = torch.cat( [ batch, torch.tensor([1] * delta) ] ) + batch = torch.cat( [ batch, torch.tensor([1] * delta, device=device, dtype=torch.int32) ] ) x_list.append( batch ) @@ -1119,7 +1119,7 @@ class Base(nn.Module): # handles tasks where the prompt has task tokens injected in the middle def prompt_input_to_token( input, quant_level ): if isinstance(input, str): - return torch.tensor( [ get_task_symmap()[f'<{input}>'] ] ).to(dtype=torch.int16, device=device) + return torch.tensor( [ get_task_symmap()[f'<{input}>'] ], device=device, dtype=torch.int16) # ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens if self.version < 4 or (self.version >= 5 and self.config and self.config.experimental.audio_embedding_sums):