diff --git a/codes/models/gpt_voice/mini_encoder.py b/codes/models/gpt_voice/mini_encoder.py index bc04b1a5..fd289252 100644 --- a/codes/models/gpt_voice/mini_encoder.py +++ b/codes/models/gpt_voice/mini_encoder.py @@ -7,7 +7,7 @@ from models.diffusion.unet_diffusion import Downsample, AttentionBlock, QKVAtten # Combined resnet & full-attention encoder for converting an audio clip into an embedding. from trainer.networks import register_model -from utils.util import checkpoint, opt_get +from utils.util import checkpoint, opt_get, sequential_checkpoint class ResBlock(nn.Module): @@ -100,14 +100,14 @@ class AudioMiniEncoder(nn.Module): num_attn_heads=4, dropout=0, downsample_factor=2, - kernel_size=3, - do_checkpointing=False): + kernel_size=3): super().__init__() self.init = nn.Sequential( conv_nd(1, spec_dim, base_channels, 3, padding=1) ) ch = base_channels res = [] + self.layers = depth for l in range(depth): for r in range(resnet_blocks): res.append(ResBlock(ch, dropout, dims=1, do_checkpoint=False, kernel_size=kernel_size)) @@ -124,16 +124,13 @@ class AudioMiniEncoder(nn.Module): attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=False)) self.attn = nn.Sequential(*attn) self.dim = embedding_dim - self.do_checkpointing = do_checkpointing def forward(self, x): h = self.init(x) - h = self.res(h) + h = sequential_checkpoint(self.res, self.layers, h) h = self.final(h) - if self.do_checkpointing: - h = checkpoint(self.attn, h) - else: - h = self.attn(h) + for blk in self.attn: + h = checkpoint(blk, h) return h[:, :, 0]