Condition on full signal
This commit is contained in:
parent
e9dc37f19c
commit
83cccef9d8
|
@ -121,7 +121,8 @@ class DiffusionVocoderWithRef(nn.Module):
|
||||||
|
|
||||||
self.conditioning_enabled = conditioning_inputs_provided
|
self.conditioning_enabled = conditioning_inputs_provided
|
||||||
if conditioning_inputs_provided:
|
if conditioning_inputs_provided:
|
||||||
self.contextual_embedder = AudioMiniEncoder(conditioning_input_dim, time_embed_dim)
|
self.contextual_embedder = AudioMiniEncoder(in_channels, time_embed_dim, base_channels=32, depth=6, resnet_blocks=1,
|
||||||
|
attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5)
|
||||||
|
|
||||||
self.input_blocks = nn.ModuleList(
|
self.input_blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
|
@ -329,7 +330,7 @@ if __name__ == '__main__':
|
||||||
clip = torch.randn(2, 1, 40960)
|
clip = torch.randn(2, 1, 40960)
|
||||||
#spec = torch.randint(8192, (2, 40,))
|
#spec = torch.randint(8192, (2, 40,))
|
||||||
spec = torch.randn(2,512,160)
|
spec = torch.randn(2,512,160)
|
||||||
cond = torch.randn(2, 80, 173)
|
cond = torch.randn(2, 1, 40960)
|
||||||
ts = torch.LongTensor([555, 556])
|
ts = torch.LongTensor([555, 556])
|
||||||
model = DiffusionVocoderWithRef(32, conditioning_inputs_provided=True, time_embed_dim_multiplier=8)
|
model = DiffusionVocoderWithRef(32, conditioning_inputs_provided=True, time_embed_dim_multiplier=8)
|
||||||
print(model(clip, ts, spec, cond).shape)
|
print(model(clip, ts, spec, cond).shape)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user