ugh
This commit is contained in:
parent
da9e47ca0e
commit
6b1cfe8e66
|
@ -174,7 +174,7 @@ class TransformerDiffusion(nn.Module):
|
||||||
assert x.shape == x_prior.shape, f'{x.shape} {x_prior.shape}'
|
assert x.shape == x_prior.shape, f'{x.shape} {x_prior.shape}'
|
||||||
resolution = randrange(1, self.resolution_steps)
|
resolution = randrange(1, self.resolution_steps)
|
||||||
resolution_scale = 2 ** resolution
|
resolution_scale = 2 ** resolution
|
||||||
s = F.interpolate(x, scale_factor=1/resolution_scale, mode='nearest', align_corners=True)
|
s = F.interpolate(x, scale_factor=1/resolution_scale, mode='linear', align_corners=True)
|
||||||
s_prior = F.interpolate(x_prior, scale_factor=1/resolution_scale, mode='linear', align_corners=True)
|
s_prior = F.interpolate(x_prior, scale_factor=1/resolution_scale, mode='linear', align_corners=True)
|
||||||
s_diff = s.shape[-1] - self.max_window
|
s_diff = s.shape[-1] - self.max_window
|
||||||
if s_diff > 1:
|
if s_diff > 1:
|
||||||
|
@ -265,5 +265,5 @@ def remove_conditioning(sd_path):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
remove_conditioning('X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr\\models\\12000_generator.pth')
|
remove_conditioning('X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr_pre\\models\\12500_generator.pth')
|
||||||
test_tfd()
|
test_tfd()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user