correctness
This commit is contained in:
parent
da473295b7
commit
e15c6c74c3
|
@ -146,7 +146,7 @@ class AR_NAR(Base):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
targ_list = [r[..., l] for r, l in zip(resps_list, quant_levels)] # ensures we only have 1 RVQ-bin (our target)
|
targ_list = [r[..., l] for r, l in zip(resps_list, quant_levels)] # ensures we only have 1 RVQ-bin (our target)
|
||||||
resps_list = [r if l == 0 else r[..., :l] for r, l in zip(resps_list, quant_levels)] # r[..., 0] is technically correct, but only r[:, 0] gets passed through the embedding
|
resps_list = [r[..., 0] if l == 0 else r[..., :l] for r, l in zip(resps_list, quant_levels)] # r if l == 0 is technically correct since only r[:, 0] is passed through the embedding, but this should save some VRAM
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if cfg.experimental:
|
if cfg.experimental:
|
||||||
|
@ -158,7 +158,7 @@ class AR_NAR(Base):
|
||||||
if quant_levels[i] > 0:
|
if quant_levels[i] > 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
resps_list[i] = torch.cat([resps_list[i], torch.Tensor([[self.stop_token] * n_levels]).to(device=device, dtype=torch.int16) ])
|
resps_list[i] = torch.cat([resps_list[i], torch.Tensor([self.stop_token]).to(device=device, dtype=torch.int16) ])
|
||||||
targ_list[i] = torch.cat([targ_list[i], torch.Tensor([self.stop_token]).to(device=device, dtype=torch.int16) ])
|
targ_list[i] = torch.cat([targ_list[i], torch.Tensor([self.stop_token]).to(device=device, dtype=torch.int16) ])
|
||||||
|
|
||||||
inputs = self.inputs(
|
inputs = self.inputs(
|
||||||
|
|
|
@ -352,7 +352,7 @@ class AudioEmbedding(nn.Module):
|
||||||
x = self.embeddings[k]( xi[:, k] ) * (self.weight[k] if self.weight is not None else 1)
|
x = self.embeddings[k]( xi[:, k] ) * (self.weight[k] if self.weight is not None else 1)
|
||||||
# AR resp
|
# AR resp
|
||||||
elif quant_levels is None or quant_levels == 0:
|
elif quant_levels is None or quant_levels == 0:
|
||||||
x = self.embeddings[0]( xi[:, 0] )
|
x = self.embeddings[0]( xi if len(xi.shape) == 1 else xi[:, 0] )
|
||||||
# NAR resp
|
# NAR resp
|
||||||
else:
|
else:
|
||||||
if self.sums:
|
if self.sums:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user