Fix
This commit is contained in:
parent
6833048bf7
commit
e24c619387
|
@ -251,11 +251,7 @@ class DiffusionDVAE(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_debug_values(self, step, __):
|
def get_debug_values(self, step, __):
|
||||||
if self.record_codes:
|
|
||||||
# Report annealing schedule
|
|
||||||
return {'histogram_codes': self.codes}
|
return {'histogram_codes': self.codes}
|
||||||
else:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@eval_decorator
|
@eval_decorator
|
||||||
|
@ -379,10 +375,8 @@ class DiffusionDVAE(nn.Module):
|
||||||
|
|
||||||
# Test for ~4 second audio clip at 22050Hz
|
# Test for ~4 second audio clip at 22050Hz
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
spec = torch.randn(4, 80, 416)
|
spec = torch.randn(4, 80, 161)
|
||||||
cond = torch.randn(4, 5, 80, 200)
|
|
||||||
num_cond = torch.tensor([2,4,5,3], dtype=torch.long)
|
|
||||||
ts = torch.LongTensor([432, 234, 100, 555])
|
ts = torch.LongTensor([432, 234, 100, 555])
|
||||||
model = DiffusionDVAE(model_channels=128, num_res_blocks=1, in_channels=80, out_channels=160, spectrogram_conditioning_levels=[1,2],
|
model = DiffusionDVAE(model_channels=128, num_res_blocks=1, in_channels=80, out_channels=160, spectrogram_conditioning_levels=[1,2],
|
||||||
channel_mult=(1,2,4), attention_resolutions=[4], num_heads=4, kernel_size=3, scale_steps=2, conditioning_inputs_provided=False)
|
channel_mult=(1,2,4), attention_resolutions=[4], num_heads=4, kernel_size=3, scale_steps=2, conditioning_inputs_provided=False)
|
||||||
print(model(torch.randn_like(spec), ts, spec, cond, num_cond)[0].shape)
|
print(model(torch.randn_like(spec), ts, spec)[0].shape)
|
||||||
|
|
|
@ -565,5 +565,5 @@ class RandomAudioCropInjector(Injector):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
inj = DecomposeDimensionInjector({'dim':2, 'in': 'x', 'out': 'y'}, None)
|
inj = MelSpectrogramInjector({'in': 'x', 'out': 'y'}, None)
|
||||||
print(inj({'x':torch.randn(10,3,64,64)})['y'].shape)
|
print(inj({'x':torch.rand(10,1,40800)})['y'].shape)
|
Loading…
Reference in New Issue
Block a user