Add time_embed_dim_multiplier
This commit is contained in:
parent
ba6e46c02a
commit
43e389aac6
|
@ -90,6 +90,7 @@ class DiffusionVocoderWithRef(nn.Module):
|
||||||
scale_factor=2,
|
scale_factor=2,
|
||||||
conditioning_inputs_provided=True,
|
conditioning_inputs_provided=True,
|
||||||
conditioning_input_dim=80,
|
conditioning_input_dim=80,
|
||||||
|
time_embed_dim_multiplier=4,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -111,7 +112,7 @@ class DiffusionVocoderWithRef(nn.Module):
|
||||||
|
|
||||||
padding = 1 if kernel_size == 3 else 2
|
padding = 1 if kernel_size == 3 else 2
|
||||||
|
|
||||||
time_embed_dim = model_channels * 4
|
time_embed_dim = model_channels * time_embed_dim_multiplier
|
||||||
self.time_embed = nn.Sequential(
|
self.time_embed = nn.Sequential(
|
||||||
linear(model_channels, time_embed_dim),
|
linear(model_channels, time_embed_dim),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
|
@ -332,5 +333,5 @@ if __name__ == '__main__':
|
||||||
spec = torch.randn(2,512,160)
|
spec = torch.randn(2,512,160)
|
||||||
cond = torch.randn(2, 3, 80, 173)
|
cond = torch.randn(2, 3, 80, 173)
|
||||||
ts = torch.LongTensor([555, 556])
|
ts = torch.LongTensor([555, 556])
|
||||||
model = DiffusionVocoderWithRef(32, conditioning_inputs_provided=True)
|
model = DiffusionVocoderWithRef(32, conditioning_inputs_provided=True, time_embed_dim_multiplier=8)
|
||||||
print(model(clip, ts, spec, cond, 3).shape)
|
print(model(clip, ts, spec, cond, 3).shape)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user