set timestep tensor to whatever the time embedding's dtype is because it'll gripe under amp

This commit is contained in:
mrq 2024-11-09 00:11:16 -06:00
parent 5a09a5f6e9
commit 69b0b3b854

View File

@ -1036,7 +1036,7 @@ class Base(nn.Module):
p = math.cos(t * math.pi * 0.5)
dropout_mask = _dropout_mask( resps_list[i], p=p )
inputs[i].append( ("timestep", torch.tensor(t, device=device) ) )
inputs[i].append( ("timestep", torch.tensor(t, device=device, dtype=self.time_emb.mlp[0].weight.dtype) ) )
inputs[i].append( ("dropout_mask", dropout_mask ) )
# Audio length prediction task