set timestep tensor to whatever the time embedding's dtype is because it'll gripe under amp
This commit is contained in:
parent
5a09a5f6e9
commit
69b0b3b854
|
@ -1036,7 +1036,7 @@ class Base(nn.Module):
|
||||||
p = math.cos(t * math.pi * 0.5)
|
p = math.cos(t * math.pi * 0.5)
|
||||||
dropout_mask = _dropout_mask( resps_list[i], p=p )
|
dropout_mask = _dropout_mask( resps_list[i], p=p )
|
||||||
|
|
||||||
inputs[i].append( ("timestep", torch.tensor(t, device=device) ) )
|
inputs[i].append( ("timestep", torch.tensor(t, device=device, dtype=self.time_emb.mlp[0].weight.dtype) ) )
|
||||||
inputs[i].append( ("dropout_mask", dropout_mask ) )
|
inputs[i].append( ("dropout_mask", dropout_mask ) )
|
||||||
|
|
||||||
# Audio length prediction task
|
# Audio length prediction task
|
||||||
|
|
Loading…
Reference in New Issue
Block a user