checkpointing bugs, smh

This commit is contained in:
James Betker 2022-06-08 11:53:10 -06:00
parent c61cd64bc9
commit dee2b72786

View File

@ -27,7 +27,7 @@ class ConditioningEncoder(nn.Module):
def forward(self, x):
h = checkpoint(self.init, x)
h = checkpoint(self.attn, h)
h = self.attn(h
return h.mean(dim=2)
@ -58,7 +58,7 @@ class UpperConditioningEncoder(nn.Module):
class GptMusicLower(nn.Module):
def __init__(self, dim, layers, dropout=0, num_target_vectors=512, num_target_groups=2, num_upper_vectors=64, num_upper_groups=4):
def __init__(self, dim, layers, dropout=0, num_target_vectors=512, num_target_groups=2, num_upper_vectors=64, num_upper_groups=4, fp16=True):
super().__init__()
self.internal_step = 0
self.num_groups = num_target_groups
@ -73,6 +73,7 @@ class GptMusicLower(nn.Module):
max(512,dim-512),
max(512,dim-512)], codevector_dim=dim,
codebook_size=num_upper_vectors, codebook_groups=num_upper_groups, expressive_downsamples=True)
self.fp16 = fp16
# Following are unused quantizer constructs we delete to avoid DDP errors (and to be efficient.. of course..)
del self.target_quantizer.decoder
del self.target_quantizer.up
@ -116,7 +117,7 @@ class GptMusicLower(nn.Module):
h = [embedding(inputs[:, :, i]) for i, embedding in enumerate(self.embeddings)]
h = torch.cat(h, dim=-1) + upper_vector
with torch.autocast(mel.device.type):
with torch.autocast(mel.device.type, enabled=self.fp16):
# Stick the conditioning embedding on the front of the input sequence.
# The transformer will learn how to integrate it.
# This statement also serves to pre-pad the inputs by one token, which is the basis of the next-token-prediction task. IOW: this is the "START" token.
@ -162,10 +163,11 @@ class GptMusicLower(nn.Module):
class GptMusicUpper(nn.Module):
def __init__(self, dim, layers, dropout=0, num_upper_vectors=64, num_upper_groups=4):
def __init__(self, dim, layers, dropout=0, num_upper_vectors=64, num_upper_groups=4, fp16=True):
super().__init__()
self.internal_step = 0
self.num_groups = num_upper_groups
self.fp16 = fp16
self.config = GPT2Config(vocab_size=1, n_positions=8192, n_embd=dim, n_layer=layers, n_head=dim//64,
n_inner=dim*2, attn_pdrop=dropout, resid_pdrop=dropout, gradient_checkpointing=True,
use_cache=False)
@ -203,7 +205,7 @@ class GptMusicUpper(nn.Module):
h = [embedding(inputs[:, :, i]) for i, embedding in enumerate(self.embeddings)]
h = torch.cat(h, dim=-1)
with torch.autocast(mel.device.type):
with torch.autocast(mel.device.type, enabled=self.fp16):
# Stick the conditioning embedding on the front of the input sequence.
# The transformer will learn how to integrate it.
# This statement also serves to pre-pad the inputs by one token, which is the basis of the next-token-prediction task. IOW: this is the "START" token.
@ -251,9 +253,9 @@ def test_lower():
base_diff = TransformerDiffusionWithQuantizer(in_channels=256, out_channels=512, model_channels=2048, block_channels=1024,
prenet_channels=1024, prenet_layers=6, num_layers=16, input_vec_dim=1024,
dropout=.1, unconditioned_percentage=0, freeze_quantizer_until=6000)
base_diff.load_state_dict(torch.load('x:/dlas/experiments/train_music_diffusion_tfd8/models/28000_generator.pth', map_location=torch.device('cpu')))
base_diff.load_state_dict(torch.load('x:/dlas/experiments/train_music_diffusion_tfd8/models/47500_generator.pth', map_location=torch.device('cpu')))
model = GptMusicLower(512, 12)
model = GptMusicLower(512, 8, fp16=False)
model.target_quantizer.load_state_dict(base_diff.quantizer.state_dict(), strict=False)
torch.save(model.state_dict(), "sample.pth")
mel = torch.randn(2,256,400)
@ -273,4 +275,4 @@ def test_upper():
if __name__ == '__main__':
test_upper()
test_lower()