diff --git a/codes/models/gpt_voice/unet_diffusion_tts_experimental.py b/codes/models/gpt_voice/unet_diffusion_tts_experimental.py index 5f229613..aba7d926 100644 --- a/codes/models/gpt_voice/unet_diffusion_tts_experimental.py +++ b/codes/models/gpt_voice/unet_diffusion_tts_experimental.py @@ -66,6 +66,11 @@ class ResBlock(TimestepBlock): :param emb: an [N x emb_channels] Tensor of timestep embeddings. :return: an [N x C x ...] Tensor of outputs. """ + return checkpoint( + self._forward, x, emb + ) + + def _forward(self, x, emb): h = self.in_layers(x) emb_out = self.emb_layers(emb).type(h.dtype) while len(emb_out.shape) < len(h.shape): @@ -313,12 +318,12 @@ class DiffusionTts(nn.Module): h_tok = F.interpolate(module(tokens).permute(0,2,1), size=(h.shape[-1]), mode='nearest') h = h + h_tok else: - h = checkpoint(module, h, emb) + h = module(h, emb) hs.append(h) - h = checkpoint(self.middle_block, h, emb) + h = self.middle_block(h, emb) for module in self.output_blocks: h = torch.cat([h, hs.pop()], dim=1) - h = checkpoint(module, h, emb) + h = module(h, emb) h = h.type(x.dtype) out = self.out(h) return out[:, :, :orig_x_shape] @@ -372,8 +377,6 @@ if __name__ == '__main__': model = DiffusionTts(64, channel_mult=[1,1.5,2, 3, 4, 6, 8, 8, 8, 8], num_res_blocks=[2, 2, 2, 2, 2, 2, 2, 4, 4, 4], token_conditioning_resolutions=[1,4,16,64], attention_resolutions=[256,512], num_heads=4, kernel_size=3, scale_factor=2, conditioning_inputs_provided=True, time_embed_dim_multiplier=4) - model(clip, ts, tok, cond) - """ p, r = model.benchmark(clip, ts, tok, cond) p = {k: v / 1000000000 for k, v in p.items()} p = sorted(p.items(), key=operator.itemgetter(1)) @@ -386,5 +389,4 @@ if __name__ == '__main__': r = sorted(r.items(), key=operator.itemgetter(1)) print(r) print(sum([j[1] for j in r])) - """