forked from mrq/DL-Art-School
Another fix
This commit is contained in:
parent
83a4ef4149
commit
cf57c352c8
|
@ -190,7 +190,7 @@ class TransformerDiffusion(nn.Module):
|
||||||
s_prior = x_prior[:,:,start:start+self.max_window]
|
s_prior = x_prior[:,:,start:start+self.max_window]
|
||||||
s_prior = F.interpolate(s_prior, scale_factor=.25, mode='linear', align_corners=True)
|
s_prior = F.interpolate(s_prior, scale_factor=.25, mode='linear', align_corners=True)
|
||||||
s_prior = F.interpolate(s_prior, size=(s.shape[-1],), mode='linear', align_corners=True)
|
s_prior = F.interpolate(s_prior, size=(s.shape[-1],), mode='linear', align_corners=True)
|
||||||
self.preprocessed = (s_prior, resolution)
|
self.preprocessed = (s_prior, torch.tensor([resolution] * x.shape[0], dtype=torch.long, device=x.device))
|
||||||
return s
|
return s
|
||||||
|
|
||||||
def forward(self, x, timesteps, x_prior=None, resolution=None, conditioning_input=None, conditioning_free=False):
|
def forward(self, x, timesteps, x_prior=None, resolution=None, conditioning_input=None, conditioning_free=False):
|
||||||
|
@ -249,13 +249,14 @@ def register_transformer_diffusion13(opt_net, opt):
|
||||||
|
|
||||||
|
|
||||||
def test_tfd():
|
def test_tfd():
|
||||||
clip = torch.randn(2,256,2583)
|
clip = torch.randn(2,256,10336)
|
||||||
cond = torch.randn(2,256,2583)
|
cond = torch.randn(2,256,10336)
|
||||||
ts = torch.LongTensor([600, 600])
|
ts = torch.LongTensor([600, 600])
|
||||||
model = TransformerDiffusion(in_channels=256, model_channels=1024, contraction_dim=512,
|
model = TransformerDiffusion(in_channels=256, model_channels=1024, contraction_dim=512,
|
||||||
num_heads=512//64, input_vec_dim=256, num_layers=12, dropout=.1)
|
num_heads=512//64, input_vec_dim=256, num_layers=12, dropout=.1)
|
||||||
for k in range(100):
|
for k in range(100):
|
||||||
model(clip, ts, clip, conditioning_input=cond)
|
x = model.input_to_random_resolution_and_window(clip, x_prior=clip)
|
||||||
|
model(x, ts, clip)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
Loading…
Reference in New Issue
Block a user