Codified fp16
This commit is contained in:
parent
45ab444c04
commit
d9f8f92840
|
@ -219,7 +219,6 @@ class DiffusionTts(nn.Module):
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
self.channel_mult = channel_mult
|
self.channel_mult = channel_mult
|
||||||
self.conv_resample = conv_resample
|
self.conv_resample = conv_resample
|
||||||
self.dtype = torch.float16 if use_fp16 else torch.float32
|
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.num_head_channels = num_head_channels
|
self.num_head_channels = num_head_channels
|
||||||
self.num_heads_upsample = num_heads_upsample
|
self.num_heads_upsample = num_heads_upsample
|
||||||
|
@ -229,6 +228,7 @@ class DiffusionTts(nn.Module):
|
||||||
self.super_sampling_enabled = super_sampling
|
self.super_sampling_enabled = super_sampling
|
||||||
self.super_sampling_max_noising_factor = super_sampling_max_noising_factor
|
self.super_sampling_max_noising_factor = super_sampling_max_noising_factor
|
||||||
self.unconditioned_percentage = unconditioned_percentage
|
self.unconditioned_percentage = unconditioned_percentage
|
||||||
|
self.enable_fp16 = use_fp16
|
||||||
padding = 1 if kernel_size == 3 else 2
|
padding = 1 if kernel_size == 3 else 2
|
||||||
|
|
||||||
time_embed_dim = model_channels * time_embed_dim_multiplier
|
time_embed_dim = model_channels * time_embed_dim_multiplier
|
||||||
|
@ -431,7 +431,7 @@ class DiffusionTts(nn.Module):
|
||||||
lr_input = F.interpolate(lr_input, size=(x.shape[-1],), mode='nearest')
|
lr_input = F.interpolate(lr_input, size=(x.shape[-1],), mode='nearest')
|
||||||
x = torch.cat([x, lr_input], dim=1)
|
x = torch.cat([x, lr_input], dim=1)
|
||||||
|
|
||||||
with autocast(x.device.type):
|
with autocast(x.device.type, enabled=self.enable_fp16):
|
||||||
orig_x_shape = x.shape[-1]
|
orig_x_shape = x.shape[-1]
|
||||||
cm = ceil_multiple(x.shape[-1], 2048)
|
cm = ceil_multiple(x.shape[-1], 2048)
|
||||||
if cm != 0:
|
if cm != 0:
|
||||||
|
@ -482,7 +482,7 @@ class DiffusionTts(nn.Module):
|
||||||
h_tok = F.interpolate(module(code_emb), size=(h.shape[-1]), mode='nearest')
|
h_tok = F.interpolate(module(code_emb), size=(h.shape[-1]), mode='nearest')
|
||||||
h = h + h_tok
|
h = h + h_tok
|
||||||
else:
|
else:
|
||||||
with autocast(x.device.type, enabled=not first):
|
with autocast(x.device.type, enabled=self.enable_fp16 and not first):
|
||||||
# First block has autocast disabled to allow a high precision signal to be properly vectorized.
|
# First block has autocast disabled to allow a high precision signal to be properly vectorized.
|
||||||
h = module(h, time_emb)
|
h = module(h, time_emb)
|
||||||
hs.append(h)
|
hs.append(h)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user