This commit is contained in:
James Betker 2022-07-05 11:14:09 -06:00
parent 7440e43531
commit 5816a4595e

View File

@ -197,7 +197,7 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
if custom_conditioning_fetcher is not None: if custom_conditioning_fetcher is not None:
cs, ce = custom_conditioning_fetcher(self.conditioning_encoder, time_emb) cs, ce = custom_conditioning_fetcher(self.conditioning_encoder, time_emb)
else: else:
assert conditioning_input.shape[-1] - cond_start - N > 0, f'Some sort of conditioning misalignment, {conditioning_input.shape[-1], cond_start, N}' assert conditioning_input.shape[-1] - cond_start - N >= 0, f'Some sort of conditioning misalignment, {conditioning_input.shape[-1], cond_start, N}'
cond_pre = conditioning_input[:,:,:cond_start] cond_pre = conditioning_input[:,:,:cond_start]
cond_aligned = conditioning_input[:,:,cond_start:N+cond_start] cond_aligned = conditioning_input[:,:,cond_start:N+cond_start]
cond_post = conditioning_input[:,:,N+cond_start:] cond_post = conditioning_input[:,:,N+cond_start:]
@ -308,7 +308,7 @@ def inference_tfdpc5_with_cheater():
with torch.no_grad(): with torch.no_grad():
os.makedirs('results/tfdpc_v3', exist_ok=True) os.makedirs('results/tfdpc_v3', exist_ok=True)
#length = 40 * 22050 // 256 // 16 # length = 40 * 22050 // 256 // 16
samples = {'electronica1': load_audio('Y:\\split\\yt-music-eval\\00001.wav', 22050), samples = {'electronica1': load_audio('Y:\\split\\yt-music-eval\\00001.wav', 22050),
'electronica2': load_audio('Y:\\split\\yt-music-eval\\00272.wav', 22050), 'electronica2': load_audio('Y:\\split\\yt-music-eval\\00272.wav', 22050),
'e_guitar': load_audio('Y:\\split\\yt-music-eval\\00227.wav', 22050), 'e_guitar': load_audio('Y:\\split\\yt-music-eval\\00227.wav', 22050),