diff --git a/codes/models/diffusion/diffusion_dvae.py b/codes/models/diffusion/diffusion_dvae.py index 48cb6e43..3d11a873 100644 --- a/codes/models/diffusion/diffusion_dvae.py +++ b/codes/models/diffusion/diffusion_dvae.py @@ -9,7 +9,7 @@ from models.gpt_voice.mini_encoder import AudioMiniEncoder, EmbeddingCombiner from models.vqvae.vqvae import Quantize from trainer.networks import register_model import models.gpt_voice.my_dvae as mdvae -from utils.util import checkpoint, get_mask_from_lengths +from utils.util import get_mask_from_lengths class DiscreteEncoder(nn.Module): @@ -248,22 +248,6 @@ class DiffusionDVAE(nn.Module): zero_module(conv_nd(dims, model_channels, out_channels, kernel_size, padding=padding)), ) - def convert_to_fp16(self): - """ - Convert the torso of the model to float16. - """ - self.input_blocks.apply(convert_module_to_f16) - self.middle_block.apply(convert_module_to_f16) - self.output_blocks.apply(convert_module_to_f16) - - def convert_to_fp32(self): - """ - Convert the torso of the model to float32. - """ - self.input_blocks.apply(convert_module_to_f32) - self.middle_block.apply(convert_module_to_f32) - self.output_blocks.apply(convert_module_to_f32) - def _decode_continouous(self, x, timesteps, embeddings, conditioning_inputs, num_conditioning_signals): spec_hs = self.decoder(embeddings)[::-1] # Shape the spectrogram correctly. There is no guarantee it fits (though I probably should add an assertion here to make sure the resizing isn't too wacky.) diff --git a/codes/models/diffusion/unet_diffusion.py b/codes/models/diffusion/unet_diffusion.py index af5d17f4..86b3e56f 100644 --- a/codes/models/diffusion/unet_diffusion.py +++ b/codes/models/diffusion/unet_diffusion.py @@ -294,9 +294,11 @@ class AttentionBlock(nn.Module): num_heads=1, num_head_channels=-1, use_new_attention_order=False, + do_checkpoint=True, ): super().__init__() self.channels = channels + self.do_checkpoint = do_checkpoint if num_head_channels == -1: self.num_heads = num_heads else: @@ -316,7 +318,10 @@ class AttentionBlock(nn.Module): self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) def forward(self, x, mask=None): - return checkpoint(self._forward, x, mask) + if self.do_checkpoint: + return checkpoint(self._forward, x, mask) + else: + return self._forward(x, mask) def _forward(self, x, mask): b, c, *spatial = x.shape diff --git a/codes/models/gpt_voice/mini_encoder.py b/codes/models/gpt_voice/mini_encoder.py index 517d0e5c..22d6c3f4 100644 --- a/codes/models/gpt_voice/mini_encoder.py +++ b/codes/models/gpt_voice/mini_encoder.py @@ -21,7 +21,7 @@ class AudioMiniEncoder(nn.Module): res = [] for l in range(2): for r in range(resnet_blocks): - res.append(ResBlock(ch, dropout, dims=1)) + res.append(ResBlock(ch, dropout, dims=1, do_checkpoint=False)) res.append(Downsample(ch, use_conv=True, dims=1, out_channels=ch*2, factor=2)) ch *= 2 self.res = nn.Sequential(*res) @@ -32,7 +32,7 @@ class AudioMiniEncoder(nn.Module): ) attn = [] for a in range(attn_blocks): - attn.append(AttentionBlock(embedding_dim, num_attn_heads)) + attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=False)) self.attn = nn.Sequential(*attn) def forward(self, x): diff --git a/codes/models/gpt_voice/my_dvae.py b/codes/models/gpt_voice/my_dvae.py index 1cf78cc7..5443c609 100644 --- a/codes/models/gpt_voice/my_dvae.py +++ b/codes/models/gpt_voice/my_dvae.py @@ -41,6 +41,7 @@ class ResBlock(nn.Module): up=False, down=False, kernel_size=3, + do_checkpoint=True, ): super().__init__() self.channels = channels @@ -48,6 +49,7 @@ class ResBlock(nn.Module): self.out_channels = out_channels or channels self.use_conv = use_conv self.use_scale_shift_norm = use_scale_shift_norm + self.do_checkpoint = do_checkpoint padding = 1 if kernel_size == 3 else 2 self.in_layers = nn.Sequential( @@ -86,9 +88,12 @@ class ResBlock(nn.Module): self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) def forward(self, x): - return checkpoint( - self._forward, x - ) + if self.do_checkpoint: + return checkpoint( + self._forward, x + ) + else: + return self._forward(x) def _forward(self, x): if self.updown: