forked from mrq/DL-Art-School
Revert "(re) attempt diffusion checkpointing logic"
This reverts commit b22eec8fe3
.
This commit is contained in:
parent
b22eec8fe3
commit
d18aec793a
|
@ -66,6 +66,11 @@ class ResBlock(TimestepBlock):
|
||||||
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
|
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
|
||||||
:return: an [N x C x ...] Tensor of outputs.
|
:return: an [N x C x ...] Tensor of outputs.
|
||||||
"""
|
"""
|
||||||
|
return checkpoint(
|
||||||
|
self._forward, x, emb
|
||||||
|
)
|
||||||
|
|
||||||
|
def _forward(self, x, emb):
|
||||||
h = self.in_layers(x)
|
h = self.in_layers(x)
|
||||||
emb_out = self.emb_layers(emb).type(h.dtype)
|
emb_out = self.emb_layers(emb).type(h.dtype)
|
||||||
while len(emb_out.shape) < len(h.shape):
|
while len(emb_out.shape) < len(h.shape):
|
||||||
|
@ -313,12 +318,12 @@ class DiffusionTts(nn.Module):
|
||||||
h_tok = F.interpolate(module(tokens).permute(0,2,1), size=(h.shape[-1]), mode='nearest')
|
h_tok = F.interpolate(module(tokens).permute(0,2,1), size=(h.shape[-1]), mode='nearest')
|
||||||
h = h + h_tok
|
h = h + h_tok
|
||||||
else:
|
else:
|
||||||
h = checkpoint(module, h, emb)
|
h = module(h, emb)
|
||||||
hs.append(h)
|
hs.append(h)
|
||||||
h = checkpoint(self.middle_block, h, emb)
|
h = self.middle_block(h, emb)
|
||||||
for module in self.output_blocks:
|
for module in self.output_blocks:
|
||||||
h = torch.cat([h, hs.pop()], dim=1)
|
h = torch.cat([h, hs.pop()], dim=1)
|
||||||
h = checkpoint(module, h, emb)
|
h = module(h, emb)
|
||||||
h = h.type(x.dtype)
|
h = h.type(x.dtype)
|
||||||
out = self.out(h)
|
out = self.out(h)
|
||||||
return out[:, :, :orig_x_shape]
|
return out[:, :, :orig_x_shape]
|
||||||
|
@ -372,8 +377,6 @@ if __name__ == '__main__':
|
||||||
model = DiffusionTts(64, channel_mult=[1,1.5,2, 3, 4, 6, 8, 8, 8, 8], num_res_blocks=[2, 2, 2, 2, 2, 2, 2, 4, 4, 4],
|
model = DiffusionTts(64, channel_mult=[1,1.5,2, 3, 4, 6, 8, 8, 8, 8], num_res_blocks=[2, 2, 2, 2, 2, 2, 2, 4, 4, 4],
|
||||||
token_conditioning_resolutions=[1,4,16,64], attention_resolutions=[256,512], num_heads=4, kernel_size=3,
|
token_conditioning_resolutions=[1,4,16,64], attention_resolutions=[256,512], num_heads=4, kernel_size=3,
|
||||||
scale_factor=2, conditioning_inputs_provided=True, time_embed_dim_multiplier=4)
|
scale_factor=2, conditioning_inputs_provided=True, time_embed_dim_multiplier=4)
|
||||||
model(clip, ts, tok, cond)
|
|
||||||
"""
|
|
||||||
p, r = model.benchmark(clip, ts, tok, cond)
|
p, r = model.benchmark(clip, ts, tok, cond)
|
||||||
p = {k: v / 1000000000 for k, v in p.items()}
|
p = {k: v / 1000000000 for k, v in p.items()}
|
||||||
p = sorted(p.items(), key=operator.itemgetter(1))
|
p = sorted(p.items(), key=operator.itemgetter(1))
|
||||||
|
@ -386,5 +389,4 @@ if __name__ == '__main__':
|
||||||
r = sorted(r.items(), key=operator.itemgetter(1))
|
r = sorted(r.items(), key=operator.itemgetter(1))
|
||||||
print(r)
|
print(r)
|
||||||
print(sum([j[1] for j in r]))
|
print(sum([j[1] for j in r]))
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user