off by one bateman

This commit is contained in:
mrq 2025-03-18 08:40:43 -05:00
parent 0280e72257
commit 9a8a8e3195

View File

@ -167,7 +167,7 @@ class AudioEmbedding(nn.Module):
# sum all prior codebook levels if requested (as quant_level = 0 does not have any other codebooks to sum through)
if sums and quant_level > 0:
x = sum( [ self.embeddings[input_quant_level + offset]( xi[:, input_quant_level] ) for input_quant_level in range( quant_level ) ] )
x = sum( [ self.embeddings[input_quant_level + offset]( xi[:, input_quant_level] ) for input_quant_level in range( quant_level + 1 ) ] )
else:
input_quant_level = quant_level
x = self.embeddings[input_quant_level + offset]( xi if xi.dim() == 1 else xi[:, input_quant_level] )
@ -1719,4 +1719,4 @@ if __name__ == "__main__":
resp = generate( phn, prom, sequence=resp, mode=f"resp|NAR:{i-1}:{i}" )
print( f"NAR:{i-1}:{i}: ", resp[-1] )
decode_to_file( torch.tensor(resp, dtype=torch.int16, device=device).t(), "./data/test.wav" )
decode_to_file( torch.tensor(resp, dtype=torch.int16, device=device).t(), "./data/test.wav" )