(re) attempt diffusion checkpointing logic

This commit is contained in:
James Betker 2022-01-22 08:34:40 -07:00
parent 8f48848f91
commit b22eec8fe3
2 changed files with 7 additions and 9 deletions

View File

@ -66,11 +66,6 @@ class ResBlock(TimestepBlock):
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
return checkpoint(
self._forward, x, emb
)
def _forward(self, x, emb):
h = self.in_layers(x)
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
@ -318,12 +313,12 @@ class DiffusionTts(nn.Module):
h_tok = F.interpolate(module(tokens).permute(0,2,1), size=(h.shape[-1]), mode='nearest')
h = h + h_tok
else:
h = module(h, emb)
h = checkpoint(module, h, emb)
hs.append(h)
h = self.middle_block(h, emb)
h = checkpoint(self.middle_block, h, emb)
for module in self.output_blocks:
h = torch.cat([h, hs.pop()], dim=1)
h = module(h, emb)
h = checkpoint(module, h, emb)
h = h.type(x.dtype)
out = self.out(h)
return out[:, :, :orig_x_shape]
@ -377,6 +372,8 @@ if __name__ == '__main__':
model = DiffusionTts(64, channel_mult=[1,1.5,2, 3, 4, 6, 8, 8, 8, 8], num_res_blocks=[2, 2, 2, 2, 2, 2, 2, 4, 4, 4],
token_conditioning_resolutions=[1,4,16,64], attention_resolutions=[256,512], num_heads=4, kernel_size=3,
scale_factor=2, conditioning_inputs_provided=True, time_embed_dim_multiplier=4)
model(clip, ts, tok, cond)
"""
p, r = model.benchmark(clip, ts, tok, cond)
p = {k: v / 1000000000 for k, v in p.items()}
p = sorted(p.items(), key=operator.itemgetter(1))
@ -389,4 +386,5 @@ if __name__ == '__main__':
r = sorted(r.items(), key=operator.itemgetter(1))
print(r)
print(sum([j[1] for j in r]))
"""

View File

@ -300,7 +300,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_diffusion_tts.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_clip_cond_to_voice.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()