tts9 fix for alignment size
This commit is contained in:
parent
f563a8dd41
commit
0fc877cbc8
|
@ -187,6 +187,7 @@ class DiffusionTts(nn.Module):
|
||||||
self.super_sampling_max_noising_factor = super_sampling_max_noising_factor
|
self.super_sampling_max_noising_factor = super_sampling_max_noising_factor
|
||||||
self.unconditioned_percentage = unconditioned_percentage
|
self.unconditioned_percentage = unconditioned_percentage
|
||||||
self.enable_fp16 = use_fp16
|
self.enable_fp16 = use_fp16
|
||||||
|
self.alignment_size = 2 ** (len(channel_mult)+1)
|
||||||
padding = 1 if kernel_size == 3 else 2
|
padding = 1 if kernel_size == 3 else 2
|
||||||
down_kernel = 1 if efficient_convs else 3
|
down_kernel = 1 if efficient_convs else 3
|
||||||
|
|
||||||
|
@ -414,7 +415,7 @@ class DiffusionTts(nn.Module):
|
||||||
|
|
||||||
# Fix input size to the proper multiple of 2 so we don't get alignment errors going down and back up the U-net.
|
# Fix input size to the proper multiple of 2 so we don't get alignment errors going down and back up the U-net.
|
||||||
orig_x_shape = x.shape[-1]
|
orig_x_shape = x.shape[-1]
|
||||||
cm = ceil_multiple(x.shape[-1], 2048)
|
cm = ceil_multiple(x.shape[-1], self.alignment_size)
|
||||||
if cm != 0:
|
if cm != 0:
|
||||||
pc = (cm-x.shape[-1])/x.shape[-1]
|
pc = (cm-x.shape[-1])/x.shape[-1]
|
||||||
x = F.pad(x, (0,cm-x.shape[-1]))
|
x = F.pad(x, (0,cm-x.shape[-1]))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user