forked from mrq/DL-Art-School
Use corner alignment for linear interpolation in TFDPC and TFD12
I noticed from experimentation that when this is not enabled, the interpolation edges are "sticky", which is to say there is more variance in the center of the interpolation than at the edges.
This commit is contained in:
parent
5816a4595e
commit
48270272e7
codes/models/audio/music
|
@ -225,7 +225,7 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
|
||||||
cond_right_enc = self.conditioning_encoder(cond_right_full, time_emb)
|
cond_right_enc = self.conditioning_encoder(cond_right_full, time_emb)
|
||||||
ce = cond_right_enc[:,:,cond_right.shape[-1]-1]
|
ce = cond_right_enc[:,:,cond_right.shape[-1]-1]
|
||||||
cond_enc = torch.cat([cs.unsqueeze(-1), ce.unsqueeze(-1)], dim=-1)
|
cond_enc = torch.cat([cs.unsqueeze(-1), ce.unsqueeze(-1)], dim=-1)
|
||||||
cond = F.interpolate(cond_enc, size=(N,), mode='linear').permute(0,2,1)
|
cond = F.interpolate(cond_enc, size=(N,), mode='linear', align_corners=True).permute(0,2,1)
|
||||||
return cond
|
return cond
|
||||||
|
|
||||||
def forward(self, x, timesteps, conditioning_input=None, conditioning_free=False, cond_start=0, custom_conditioning_fetcher=None):
|
def forward(self, x, timesteps, conditioning_input=None, conditioning_free=False, cond_start=0, custom_conditioning_fetcher=None):
|
||||||
|
@ -304,6 +304,28 @@ def test_cheater_model():
|
||||||
print(f'{k}: {prmsz(v)/1000000}')
|
print(f'{k}: {prmsz(v)/1000000}')
|
||||||
|
|
||||||
|
|
||||||
|
def test_conditioning_splitting_logic():
|
||||||
|
ts = torch.LongTensor([600])
|
||||||
|
class fake_conditioner(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
def forward(self, t, _):
|
||||||
|
print(t[:,0])
|
||||||
|
return t
|
||||||
|
model = TransformerDiffusionWithPointConditioning(in_channels=256, out_channels=512, model_channels=1024,
|
||||||
|
contraction_dim=512, num_heads=8, num_layers=15, dropout=0,
|
||||||
|
unconditioned_percentage=.4)
|
||||||
|
model.conditioning_encoder = fake_conditioner()
|
||||||
|
BASEDIM=30
|
||||||
|
for x in range(BASEDIM+1, BASEDIM+20):
|
||||||
|
start = random.randint(0,x-BASEDIM)
|
||||||
|
cl = torch.arange(1, x+1, 1).view(1,1,-1).float().repeat(1,256,1)
|
||||||
|
print("Effective input: " + str(cl[0, 0, start:BASEDIM+start]))
|
||||||
|
res = model.process_conditioning(cl, ts, BASEDIM, start, None)
|
||||||
|
print("Result: " + str(res[0,:,0]))
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
def inference_tfdpc5_with_cheater():
|
def inference_tfdpc5_with_cheater():
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
os.makedirs('results/tfdpc_v3', exist_ok=True)
|
os.makedirs('results/tfdpc_v3', exist_ok=True)
|
||||||
|
@ -384,5 +406,6 @@ def inference_tfdpc5_with_cheater():
|
||||||
torchaudio.save(f'results/tfdpc_v3/{k}_ref.wav', sample.unsqueeze(0).cpu(), 22050)
|
torchaudio.save(f'results/tfdpc_v3/{k}_ref.wav', sample.unsqueeze(0).cpu(), 22050)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_cheater_model()
|
#test_cheater_model()
|
||||||
|
test_conditioning_splitting_logic()
|
||||||
#inference_tfdpc5_with_cheater()
|
#inference_tfdpc5_with_cheater()
|
||||||
|
|
|
@ -97,6 +97,7 @@ class TransformerDiffusion(nn.Module):
|
||||||
out_channels=512, # mean and variance
|
out_channels=512, # mean and variance
|
||||||
num_heads=4,
|
num_heads=4,
|
||||||
dropout=0,
|
dropout=0,
|
||||||
|
use_corner_alignment=False, # This is an interpolation parameter only provided for backwards compatibility. ALL NEW TRAINS SHOULD SET THIS TO TRUE.
|
||||||
use_fp16=False,
|
use_fp16=False,
|
||||||
new_code_expansion=False,
|
new_code_expansion=False,
|
||||||
permute_codes=False,
|
permute_codes=False,
|
||||||
|
@ -117,6 +118,7 @@ class TransformerDiffusion(nn.Module):
|
||||||
self.enable_fp16 = use_fp16
|
self.enable_fp16 = use_fp16
|
||||||
self.new_code_expansion = new_code_expansion
|
self.new_code_expansion = new_code_expansion
|
||||||
self.permute_codes = permute_codes
|
self.permute_codes = permute_codes
|
||||||
|
self.use_corner_alignment = use_corner_alignment
|
||||||
|
|
||||||
self.inp_block = conv_nd(1, in_channels, prenet_channels, 3, 1, 1)
|
self.inp_block = conv_nd(1, in_channels, prenet_channels, 3, 1, 1)
|
||||||
|
|
||||||
|
@ -189,7 +191,7 @@ class TransformerDiffusion(nn.Module):
|
||||||
|
|
||||||
def timestep_independent(self, prior, expected_seq_len):
|
def timestep_independent(self, prior, expected_seq_len):
|
||||||
if self.new_code_expansion:
|
if self.new_code_expansion:
|
||||||
prior = F.interpolate(prior.permute(0,2,1), size=expected_seq_len, mode='linear').permute(0,2,1)
|
prior = F.interpolate(prior.permute(0,2,1), size=expected_seq_len, mode='linear', align_corners=self.use_corner_alignment).permute(0,2,1)
|
||||||
code_emb = self.input_converter(prior)
|
code_emb = self.input_converter(prior)
|
||||||
code_emb = self.code_converter(code_emb)
|
code_emb = self.code_converter(code_emb)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user