ugh
This commit is contained in:
parent
da9e47ca0e
commit
6b1cfe8e66
|
@ -174,7 +174,7 @@ class TransformerDiffusion(nn.Module):
|
|||
assert x.shape == x_prior.shape, f'{x.shape} {x_prior.shape}'
|
||||
resolution = randrange(1, self.resolution_steps)
|
||||
resolution_scale = 2 ** resolution
|
||||
s = F.interpolate(x, scale_factor=1/resolution_scale, mode='nearest', align_corners=True)
|
||||
s = F.interpolate(x, scale_factor=1/resolution_scale, mode='linear', align_corners=True)
|
||||
s_prior = F.interpolate(x_prior, scale_factor=1/resolution_scale, mode='linear', align_corners=True)
|
||||
s_diff = s.shape[-1] - self.max_window
|
||||
if s_diff > 1:
|
||||
|
@ -265,5 +265,5 @@ def remove_conditioning(sd_path):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
remove_conditioning('X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr\\models\\12000_generator.pth')
|
||||
remove_conditioning('X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr_pre\\models\\12500_generator.pth')
|
||||
test_tfd()
|
||||
|
|
Loading…
Reference in New Issue
Block a user