checkpointing bugs, smh

This commit is contained in:
James Betker 2022-06-08 11:53:10 -06:00
parent c61cd64bc9
commit dee2b72786

View File

@ -27,7 +27,7 @@ class ConditioningEncoder(nn.Module):
def forward(self, x): def forward(self, x):
h = checkpoint(self.init, x) h = checkpoint(self.init, x)
h = checkpoint(self.attn, h) h = self.attn(h
return h.mean(dim=2) return h.mean(dim=2)
@ -58,7 +58,7 @@ class UpperConditioningEncoder(nn.Module):
class GptMusicLower(nn.Module): class GptMusicLower(nn.Module):
def __init__(self, dim, layers, dropout=0, num_target_vectors=512, num_target_groups=2, num_upper_vectors=64, num_upper_groups=4): def __init__(self, dim, layers, dropout=0, num_target_vectors=512, num_target_groups=2, num_upper_vectors=64, num_upper_groups=4, fp16=True):
super().__init__() super().__init__()
self.internal_step = 0 self.internal_step = 0
self.num_groups = num_target_groups self.num_groups = num_target_groups
@ -73,6 +73,7 @@ class GptMusicLower(nn.Module):
max(512,dim-512), max(512,dim-512),
max(512,dim-512)], codevector_dim=dim, max(512,dim-512)], codevector_dim=dim,
codebook_size=num_upper_vectors, codebook_groups=num_upper_groups, expressive_downsamples=True) codebook_size=num_upper_vectors, codebook_groups=num_upper_groups, expressive_downsamples=True)
self.fp16 = fp16
# Following are unused quantizer constructs we delete to avoid DDP errors (and to be efficient.. of course..) # Following are unused quantizer constructs we delete to avoid DDP errors (and to be efficient.. of course..)
del self.target_quantizer.decoder del self.target_quantizer.decoder
del self.target_quantizer.up del self.target_quantizer.up
@ -116,7 +117,7 @@ class GptMusicLower(nn.Module):
h = [embedding(inputs[:, :, i]) for i, embedding in enumerate(self.embeddings)] h = [embedding(inputs[:, :, i]) for i, embedding in enumerate(self.embeddings)]
h = torch.cat(h, dim=-1) + upper_vector h = torch.cat(h, dim=-1) + upper_vector
with torch.autocast(mel.device.type): with torch.autocast(mel.device.type, enabled=self.fp16):
# Stick the conditioning embedding on the front of the input sequence. # Stick the conditioning embedding on the front of the input sequence.
# The transformer will learn how to integrate it. # The transformer will learn how to integrate it.
# This statement also serves to pre-pad the inputs by one token, which is the basis of the next-token-prediction task. IOW: this is the "START" token. # This statement also serves to pre-pad the inputs by one token, which is the basis of the next-token-prediction task. IOW: this is the "START" token.
@ -162,10 +163,11 @@ class GptMusicLower(nn.Module):
class GptMusicUpper(nn.Module): class GptMusicUpper(nn.Module):
def __init__(self, dim, layers, dropout=0, num_upper_vectors=64, num_upper_groups=4): def __init__(self, dim, layers, dropout=0, num_upper_vectors=64, num_upper_groups=4, fp16=True):
super().__init__() super().__init__()
self.internal_step = 0 self.internal_step = 0
self.num_groups = num_upper_groups self.num_groups = num_upper_groups
self.fp16 = fp16
self.config = GPT2Config(vocab_size=1, n_positions=8192, n_embd=dim, n_layer=layers, n_head=dim//64, self.config = GPT2Config(vocab_size=1, n_positions=8192, n_embd=dim, n_layer=layers, n_head=dim//64,
n_inner=dim*2, attn_pdrop=dropout, resid_pdrop=dropout, gradient_checkpointing=True, n_inner=dim*2, attn_pdrop=dropout, resid_pdrop=dropout, gradient_checkpointing=True,
use_cache=False) use_cache=False)
@ -203,7 +205,7 @@ class GptMusicUpper(nn.Module):
h = [embedding(inputs[:, :, i]) for i, embedding in enumerate(self.embeddings)] h = [embedding(inputs[:, :, i]) for i, embedding in enumerate(self.embeddings)]
h = torch.cat(h, dim=-1) h = torch.cat(h, dim=-1)
with torch.autocast(mel.device.type): with torch.autocast(mel.device.type, enabled=self.fp16):
# Stick the conditioning embedding on the front of the input sequence. # Stick the conditioning embedding on the front of the input sequence.
# The transformer will learn how to integrate it. # The transformer will learn how to integrate it.
# This statement also serves to pre-pad the inputs by one token, which is the basis of the next-token-prediction task. IOW: this is the "START" token. # This statement also serves to pre-pad the inputs by one token, which is the basis of the next-token-prediction task. IOW: this is the "START" token.
@ -251,9 +253,9 @@ def test_lower():
base_diff = TransformerDiffusionWithQuantizer(in_channels=256, out_channels=512, model_channels=2048, block_channels=1024, base_diff = TransformerDiffusionWithQuantizer(in_channels=256, out_channels=512, model_channels=2048, block_channels=1024,
prenet_channels=1024, prenet_layers=6, num_layers=16, input_vec_dim=1024, prenet_channels=1024, prenet_layers=6, num_layers=16, input_vec_dim=1024,
dropout=.1, unconditioned_percentage=0, freeze_quantizer_until=6000) dropout=.1, unconditioned_percentage=0, freeze_quantizer_until=6000)
base_diff.load_state_dict(torch.load('x:/dlas/experiments/train_music_diffusion_tfd8/models/28000_generator.pth', map_location=torch.device('cpu'))) base_diff.load_state_dict(torch.load('x:/dlas/experiments/train_music_diffusion_tfd8/models/47500_generator.pth', map_location=torch.device('cpu')))
model = GptMusicLower(512, 12) model = GptMusicLower(512, 8, fp16=False)
model.target_quantizer.load_state_dict(base_diff.quantizer.state_dict(), strict=False) model.target_quantizer.load_state_dict(base_diff.quantizer.state_dict(), strict=False)
torch.save(model.state_dict(), "sample.pth") torch.save(model.state_dict(), "sample.pth")
mel = torch.randn(2,256,400) mel = torch.randn(2,256,400)
@ -273,4 +275,4 @@ def test_upper():
if __name__ == '__main__': if __name__ == '__main__':
test_upper() test_lower()