set timestep tensor to whatever the time embedding's dtype is because it'll gripe under amp
This commit is contained in:
parent
5a09a5f6e9
commit
69b0b3b854
|
@ -1036,7 +1036,7 @@ class Base(nn.Module):
|
|||
p = math.cos(t * math.pi * 0.5)
|
||||
dropout_mask = _dropout_mask( resps_list[i], p=p )
|
||||
|
||||
inputs[i].append( ("timestep", torch.tensor(t, device=device) ) )
|
||||
inputs[i].append( ("timestep", torch.tensor(t, device=device, dtype=self.time_emb.mlp[0].weight.dtype) ) )
|
||||
inputs[i].append( ("dropout_mask", dropout_mask ) )
|
||||
|
||||
# Audio length prediction task
|
||||
|
|
Loading…
Reference in New Issue
Block a user