checkpointing bugs, smh
This commit is contained in:
parent
c61cd64bc9
commit
dee2b72786
|
@ -27,7 +27,7 @@ class ConditioningEncoder(nn.Module):
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
h = checkpoint(self.init, x)
|
h = checkpoint(self.init, x)
|
||||||
h = checkpoint(self.attn, h)
|
h = self.attn(h
|
||||||
return h.mean(dim=2)
|
return h.mean(dim=2)
|
||||||
|
|
||||||
|
|
||||||
|
@ -58,7 +58,7 @@ class UpperConditioningEncoder(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class GptMusicLower(nn.Module):
|
class GptMusicLower(nn.Module):
|
||||||
def __init__(self, dim, layers, dropout=0, num_target_vectors=512, num_target_groups=2, num_upper_vectors=64, num_upper_groups=4):
|
def __init__(self, dim, layers, dropout=0, num_target_vectors=512, num_target_groups=2, num_upper_vectors=64, num_upper_groups=4, fp16=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.internal_step = 0
|
self.internal_step = 0
|
||||||
self.num_groups = num_target_groups
|
self.num_groups = num_target_groups
|
||||||
|
@ -73,6 +73,7 @@ class GptMusicLower(nn.Module):
|
||||||
max(512,dim-512),
|
max(512,dim-512),
|
||||||
max(512,dim-512)], codevector_dim=dim,
|
max(512,dim-512)], codevector_dim=dim,
|
||||||
codebook_size=num_upper_vectors, codebook_groups=num_upper_groups, expressive_downsamples=True)
|
codebook_size=num_upper_vectors, codebook_groups=num_upper_groups, expressive_downsamples=True)
|
||||||
|
self.fp16 = fp16
|
||||||
# Following are unused quantizer constructs we delete to avoid DDP errors (and to be efficient.. of course..)
|
# Following are unused quantizer constructs we delete to avoid DDP errors (and to be efficient.. of course..)
|
||||||
del self.target_quantizer.decoder
|
del self.target_quantizer.decoder
|
||||||
del self.target_quantizer.up
|
del self.target_quantizer.up
|
||||||
|
@ -116,7 +117,7 @@ class GptMusicLower(nn.Module):
|
||||||
h = [embedding(inputs[:, :, i]) for i, embedding in enumerate(self.embeddings)]
|
h = [embedding(inputs[:, :, i]) for i, embedding in enumerate(self.embeddings)]
|
||||||
h = torch.cat(h, dim=-1) + upper_vector
|
h = torch.cat(h, dim=-1) + upper_vector
|
||||||
|
|
||||||
with torch.autocast(mel.device.type):
|
with torch.autocast(mel.device.type, enabled=self.fp16):
|
||||||
# Stick the conditioning embedding on the front of the input sequence.
|
# Stick the conditioning embedding on the front of the input sequence.
|
||||||
# The transformer will learn how to integrate it.
|
# The transformer will learn how to integrate it.
|
||||||
# This statement also serves to pre-pad the inputs by one token, which is the basis of the next-token-prediction task. IOW: this is the "START" token.
|
# This statement also serves to pre-pad the inputs by one token, which is the basis of the next-token-prediction task. IOW: this is the "START" token.
|
||||||
|
@ -162,10 +163,11 @@ class GptMusicLower(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class GptMusicUpper(nn.Module):
|
class GptMusicUpper(nn.Module):
|
||||||
def __init__(self, dim, layers, dropout=0, num_upper_vectors=64, num_upper_groups=4):
|
def __init__(self, dim, layers, dropout=0, num_upper_vectors=64, num_upper_groups=4, fp16=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.internal_step = 0
|
self.internal_step = 0
|
||||||
self.num_groups = num_upper_groups
|
self.num_groups = num_upper_groups
|
||||||
|
self.fp16 = fp16
|
||||||
self.config = GPT2Config(vocab_size=1, n_positions=8192, n_embd=dim, n_layer=layers, n_head=dim//64,
|
self.config = GPT2Config(vocab_size=1, n_positions=8192, n_embd=dim, n_layer=layers, n_head=dim//64,
|
||||||
n_inner=dim*2, attn_pdrop=dropout, resid_pdrop=dropout, gradient_checkpointing=True,
|
n_inner=dim*2, attn_pdrop=dropout, resid_pdrop=dropout, gradient_checkpointing=True,
|
||||||
use_cache=False)
|
use_cache=False)
|
||||||
|
@ -203,7 +205,7 @@ class GptMusicUpper(nn.Module):
|
||||||
h = [embedding(inputs[:, :, i]) for i, embedding in enumerate(self.embeddings)]
|
h = [embedding(inputs[:, :, i]) for i, embedding in enumerate(self.embeddings)]
|
||||||
h = torch.cat(h, dim=-1)
|
h = torch.cat(h, dim=-1)
|
||||||
|
|
||||||
with torch.autocast(mel.device.type):
|
with torch.autocast(mel.device.type, enabled=self.fp16):
|
||||||
# Stick the conditioning embedding on the front of the input sequence.
|
# Stick the conditioning embedding on the front of the input sequence.
|
||||||
# The transformer will learn how to integrate it.
|
# The transformer will learn how to integrate it.
|
||||||
# This statement also serves to pre-pad the inputs by one token, which is the basis of the next-token-prediction task. IOW: this is the "START" token.
|
# This statement also serves to pre-pad the inputs by one token, which is the basis of the next-token-prediction task. IOW: this is the "START" token.
|
||||||
|
@ -251,9 +253,9 @@ def test_lower():
|
||||||
base_diff = TransformerDiffusionWithQuantizer(in_channels=256, out_channels=512, model_channels=2048, block_channels=1024,
|
base_diff = TransformerDiffusionWithQuantizer(in_channels=256, out_channels=512, model_channels=2048, block_channels=1024,
|
||||||
prenet_channels=1024, prenet_layers=6, num_layers=16, input_vec_dim=1024,
|
prenet_channels=1024, prenet_layers=6, num_layers=16, input_vec_dim=1024,
|
||||||
dropout=.1, unconditioned_percentage=0, freeze_quantizer_until=6000)
|
dropout=.1, unconditioned_percentage=0, freeze_quantizer_until=6000)
|
||||||
base_diff.load_state_dict(torch.load('x:/dlas/experiments/train_music_diffusion_tfd8/models/28000_generator.pth', map_location=torch.device('cpu')))
|
base_diff.load_state_dict(torch.load('x:/dlas/experiments/train_music_diffusion_tfd8/models/47500_generator.pth', map_location=torch.device('cpu')))
|
||||||
|
|
||||||
model = GptMusicLower(512, 12)
|
model = GptMusicLower(512, 8, fp16=False)
|
||||||
model.target_quantizer.load_state_dict(base_diff.quantizer.state_dict(), strict=False)
|
model.target_quantizer.load_state_dict(base_diff.quantizer.state_dict(), strict=False)
|
||||||
torch.save(model.state_dict(), "sample.pth")
|
torch.save(model.state_dict(), "sample.pth")
|
||||||
mel = torch.randn(2,256,400)
|
mel = torch.randn(2,256,400)
|
||||||
|
@ -273,4 +275,4 @@ def test_upper():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_upper()
|
test_lower()
|
Loading…
Reference in New Issue
Block a user