network updates
This commit is contained in:
parent
5a54d7db11
commit
c61cd64bc9
|
@ -369,7 +369,7 @@ class ResBlock(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
channels,
|
channels,
|
||||||
dropout,
|
dropout=0,
|
||||||
out_channels=None,
|
out_channels=None,
|
||||||
use_conv=False,
|
use_conv=False,
|
||||||
dims=2,
|
dims=2,
|
||||||
|
|
|
@ -3,12 +3,12 @@ from torch import nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from transformers import GPT2Config, GPT2Model
|
from transformers import GPT2Config, GPT2Model
|
||||||
|
|
||||||
from models.arch_util import AttentionBlock
|
from models.arch_util import AttentionBlock, ResBlock
|
||||||
from models.audio.music.music_quantizer import MusicQuantizer
|
from models.audio.music.music_quantizer import MusicQuantizer
|
||||||
from models.audio.music.music_quantizer2 import MusicQuantizer2
|
from models.audio.music.music_quantizer2 import MusicQuantizer2
|
||||||
from models.lucidrains.x_transformers import Encoder
|
from models.lucidrains.x_transformers import Encoder
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get, checkpoint
|
||||||
|
|
||||||
|
|
||||||
class ConditioningEncoder(nn.Module):
|
class ConditioningEncoder(nn.Module):
|
||||||
|
@ -25,6 +25,32 @@ class ConditioningEncoder(nn.Module):
|
||||||
self.attn = nn.Sequential(*attn)
|
self.attn = nn.Sequential(*attn)
|
||||||
self.dim = embedding_dim
|
self.dim = embedding_dim
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
h = checkpoint(self.init, x)
|
||||||
|
h = checkpoint(self.attn, h)
|
||||||
|
return h.mean(dim=2)
|
||||||
|
|
||||||
|
|
||||||
|
class UpperConditioningEncoder(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
spec_dim,
|
||||||
|
embedding_dim,
|
||||||
|
attn_blocks=6,
|
||||||
|
num_attn_heads=4):
|
||||||
|
super().__init__()
|
||||||
|
attn = []
|
||||||
|
self.init = nn.Sequential(nn.Conv1d(spec_dim, min(spec_dim+128, embedding_dim), kernel_size=3, stride=2, padding=1),
|
||||||
|
nn.Conv1d(min(spec_dim+128, embedding_dim), min(spec_dim+256, embedding_dim), kernel_size=3, stride=2, padding=1),
|
||||||
|
nn.Conv1d(min(spec_dim+256, embedding_dim), min(spec_dim+384, embedding_dim), kernel_size=3, stride=2, padding=1),
|
||||||
|
nn.Conv1d(min(spec_dim+384, embedding_dim), min(spec_dim+512, embedding_dim), kernel_size=3, stride=2, padding=1),
|
||||||
|
ResBlock(min(spec_dim+512, embedding_dim), dims=1),
|
||||||
|
nn.Conv1d(min(spec_dim+512, embedding_dim), min(spec_dim+512, embedding_dim), kernel_size=3, stride=2, padding=1),
|
||||||
|
ResBlock(min(spec_dim+512, embedding_dim), dims=1))
|
||||||
|
for a in range(attn_blocks):
|
||||||
|
attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_activation=True))
|
||||||
|
self.attn = nn.Sequential(*attn)
|
||||||
|
self.dim = embedding_dim
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
h = self.init(x)
|
h = self.init(x)
|
||||||
h = self.attn(h)
|
h = self.attn(h)
|
||||||
|
@ -135,12 +161,92 @@ class GptMusicLower(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GptMusicUpper(nn.Module):
|
||||||
|
def __init__(self, dim, layers, dropout=0, num_upper_vectors=64, num_upper_groups=4):
|
||||||
|
super().__init__()
|
||||||
|
self.internal_step = 0
|
||||||
|
self.num_groups = num_upper_groups
|
||||||
|
self.config = GPT2Config(vocab_size=1, n_positions=8192, n_embd=dim, n_layer=layers, n_head=dim//64,
|
||||||
|
n_inner=dim*2, attn_pdrop=dropout, resid_pdrop=dropout, gradient_checkpointing=True,
|
||||||
|
use_cache=False)
|
||||||
|
self.upper_quantizer = MusicQuantizer2(inp_channels=256, inner_dim=[dim,
|
||||||
|
max(512,dim-128),
|
||||||
|
max(512,dim-256),
|
||||||
|
max(512,dim-384),
|
||||||
|
max(512,dim-512),
|
||||||
|
max(512,dim-512)], codevector_dim=dim,
|
||||||
|
codebook_size=num_upper_vectors, codebook_groups=num_upper_groups,
|
||||||
|
expressive_downsamples=True)
|
||||||
|
# Following are unused quantizer constructs we delete to avoid DDP errors (and to be efficient.. of course..)
|
||||||
|
del self.upper_quantizer.up
|
||||||
|
# Freeze the quantizer.
|
||||||
|
for p in self.upper_quantizer.parameters():
|
||||||
|
p.DO_NOT_TRAIN = True
|
||||||
|
p.requires_grad = False
|
||||||
|
|
||||||
|
self.conditioning_encoder = UpperConditioningEncoder(256, dim, attn_blocks=4, num_attn_heads=dim//64)
|
||||||
|
|
||||||
|
self.gpt = GPT2Model(self.config)
|
||||||
|
del self.gpt.wte # Unused, we'll do our own embeddings.
|
||||||
|
|
||||||
|
self.embeddings = nn.ModuleList([nn.Embedding(num_upper_vectors, dim // num_upper_groups) for _ in range(num_upper_groups)])
|
||||||
|
self.heads = nn.ModuleList([nn.Linear(dim, num_upper_vectors) for _ in range(num_upper_groups)])
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, mel, conditioning, return_latent=False):
|
||||||
|
with torch.no_grad():
|
||||||
|
self.upper_quantizer.eval()
|
||||||
|
codes = self.upper_quantizer.get_codes(mel)
|
||||||
|
|
||||||
|
inputs = codes[:, :-1]
|
||||||
|
targets = codes
|
||||||
|
h = [embedding(inputs[:, :, i]) for i, embedding in enumerate(self.embeddings)]
|
||||||
|
h = torch.cat(h, dim=-1)
|
||||||
|
|
||||||
|
with torch.autocast(mel.device.type):
|
||||||
|
# Stick the conditioning embedding on the front of the input sequence.
|
||||||
|
# The transformer will learn how to integrate it.
|
||||||
|
# This statement also serves to pre-pad the inputs by one token, which is the basis of the next-token-prediction task. IOW: this is the "START" token.
|
||||||
|
cond_emb = self.conditioning_encoder(conditioning).unsqueeze(1)
|
||||||
|
h = torch.cat([cond_emb, h], dim=1)
|
||||||
|
|
||||||
|
h = self.gpt(inputs_embeds=h, return_dict=True).last_hidden_state
|
||||||
|
|
||||||
|
if return_latent:
|
||||||
|
return h.float()
|
||||||
|
|
||||||
|
losses = 0
|
||||||
|
for i, head in enumerate(self.heads):
|
||||||
|
logits = head(h).permute(0,2,1)
|
||||||
|
loss = F.cross_entropy(logits, targets[:,:,i])
|
||||||
|
losses = losses + loss
|
||||||
|
|
||||||
|
return losses / self.num_groups
|
||||||
|
|
||||||
|
def get_grad_norm_parameter_groups(self):
|
||||||
|
groups = {
|
||||||
|
'gpt': list(self.gpt.parameters()),
|
||||||
|
'conditioning': list(self.conditioning_encoder.parameters()),
|
||||||
|
}
|
||||||
|
return groups
|
||||||
|
|
||||||
|
def get_debug_values(self, step, __):
|
||||||
|
if self.upper_quantizer.total_codes > 0:
|
||||||
|
return {'histogram_upper_codes': self.upper_quantizer.codes[:self.upper_quantizer.total_codes]}
|
||||||
|
else:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def register_music_gpt_lower(opt_net, opt):
|
def register_music_gpt_lower(opt_net, opt):
|
||||||
return GptMusicLower(**opt_get(opt_net, ['kwargs'], {}))
|
return GptMusicLower(**opt_get(opt_net, ['kwargs'], {}))
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def register_music_gpt_upper(opt_net, opt):
|
||||||
|
return GptMusicUpper(**opt_get(opt_net, ['kwargs'], {}))
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
|
def test_lower():
|
||||||
from models.audio.music.transformer_diffusion8 import TransformerDiffusionWithQuantizer
|
from models.audio.music.transformer_diffusion8 import TransformerDiffusionWithQuantizer
|
||||||
base_diff = TransformerDiffusionWithQuantizer(in_channels=256, out_channels=512, model_channels=2048, block_channels=1024,
|
base_diff = TransformerDiffusionWithQuantizer(in_channels=256, out_channels=512, model_channels=2048, block_channels=1024,
|
||||||
prenet_channels=1024, prenet_layers=6, num_layers=16, input_vec_dim=1024,
|
prenet_channels=1024, prenet_layers=6, num_layers=16, input_vec_dim=1024,
|
||||||
|
@ -153,3 +259,18 @@ if __name__ == '__main__':
|
||||||
mel = torch.randn(2,256,400)
|
mel = torch.randn(2,256,400)
|
||||||
model(mel, mel)
|
model(mel, mel)
|
||||||
model.get_grad_norm_parameter_groups()
|
model.get_grad_norm_parameter_groups()
|
||||||
|
|
||||||
|
|
||||||
|
def test_upper():
|
||||||
|
lower = GptMusicLower(512, 12)
|
||||||
|
lower.load_state_dict(torch.load('D:\\dlas\\experiments\\train_music_gpt\\models\\44500_generator_ema.pth'))
|
||||||
|
model = GptMusicUpper(512, 12)
|
||||||
|
model.upper_quantizer.load_state_dict(lower.upper_quantizer.state_dict())
|
||||||
|
torch.save(model.state_dict(), 'sample.pth')
|
||||||
|
mel = torch.randn(2,256,2500)
|
||||||
|
model(mel, mel)
|
||||||
|
model.get_grad_norm_parameter_groups()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_upper()
|
|
@ -73,6 +73,7 @@ class TransformerDiffusion(nn.Module):
|
||||||
out_channels=512, # mean and variance
|
out_channels=512, # mean and variance
|
||||||
dropout=0,
|
dropout=0,
|
||||||
use_fp16=False,
|
use_fp16=False,
|
||||||
|
ar_prior=False,
|
||||||
# Parameters for regularization.
|
# Parameters for regularization.
|
||||||
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
||||||
):
|
):
|
||||||
|
@ -95,6 +96,22 @@ class TransformerDiffusion(nn.Module):
|
||||||
)
|
)
|
||||||
prenet_heads = prenet_channels//64
|
prenet_heads = prenet_channels//64
|
||||||
|
|
||||||
|
self.ar_prior = ar_prior
|
||||||
|
if ar_prior:
|
||||||
|
self.ar_input = nn.Linear(input_vec_dim, prenet_channels)
|
||||||
|
self.ar_prior_intg = Encoder(
|
||||||
|
dim=prenet_channels,
|
||||||
|
depth=prenet_layers,
|
||||||
|
heads=prenet_heads,
|
||||||
|
ff_dropout=dropout,
|
||||||
|
attn_dropout=dropout,
|
||||||
|
use_rmsnorm=True,
|
||||||
|
ff_glu=True,
|
||||||
|
rotary_pos_emb=True,
|
||||||
|
zero_init_branch_output=True,
|
||||||
|
ff_mult=1,
|
||||||
|
)
|
||||||
|
else:
|
||||||
self.input_converter = nn.Linear(input_vec_dim, prenet_channels)
|
self.input_converter = nn.Linear(input_vec_dim, prenet_channels)
|
||||||
self.code_converter = Encoder(
|
self.code_converter = Encoder(
|
||||||
dim=prenet_channels,
|
dim=prenet_channels,
|
||||||
|
@ -130,16 +147,16 @@ class TransformerDiffusion(nn.Module):
|
||||||
}
|
}
|
||||||
return groups
|
return groups
|
||||||
|
|
||||||
def timestep_independent(self, codes, expected_seq_len):
|
def timestep_independent(self, prior, expected_seq_len):
|
||||||
code_emb = self.input_converter(codes)
|
code_emb = self.ar_input(prior) if self.ar_prior else self.input_converter(prior)
|
||||||
|
|
||||||
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
|
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
|
||||||
if self.training and self.unconditioned_percentage > 0:
|
if self.training and self.unconditioned_percentage > 0:
|
||||||
unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1),
|
unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1),
|
||||||
device=code_emb.device) < self.unconditioned_percentage
|
device=code_emb.device) < self.unconditioned_percentage
|
||||||
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(codes.shape[0], 1, 1),
|
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(prior.shape[0], 1, 1),
|
||||||
code_emb)
|
code_emb)
|
||||||
code_emb = self.code_converter(code_emb)
|
code_emb = self.ar_prior_intg(code_emb) if self.ar_prior else self.code_converter(code_emb)
|
||||||
|
|
||||||
expanded_code_emb = F.interpolate(code_emb.permute(0,2,1), size=expected_seq_len, mode='nearest').permute(0,2,1)
|
expanded_code_emb = F.interpolate(code_emb.permute(0,2,1), size=expected_seq_len, mode='nearest').permute(0,2,1)
|
||||||
return expanded_code_emb
|
return expanded_code_emb
|
||||||
|
@ -151,7 +168,6 @@ class TransformerDiffusion(nn.Module):
|
||||||
unused_params = []
|
unused_params = []
|
||||||
if conditioning_free:
|
if conditioning_free:
|
||||||
code_emb = self.unconditioned_embedding.repeat(x.shape[0], x.shape[-1], 1)
|
code_emb = self.unconditioned_embedding.repeat(x.shape[0], x.shape[-1], 1)
|
||||||
unused_params.extend(list(self.code_converter.parameters()))
|
|
||||||
else:
|
else:
|
||||||
if precomputed_code_embeddings is not None:
|
if precomputed_code_embeddings is not None:
|
||||||
code_emb = precomputed_code_embeddings
|
code_emb = precomputed_code_embeddings
|
||||||
|
@ -240,6 +256,47 @@ class TransformerDiffusionWithQuantizer(nn.Module):
|
||||||
return groups
|
return groups
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerDiffusionWithARPrior(nn.Module):
|
||||||
|
def __init__(self, freeze_diff=False, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.internal_step = 0
|
||||||
|
from models.audio.music.gpt_music import GptMusicLower
|
||||||
|
self.ar = GptMusicLower(dim=512, layers=12)
|
||||||
|
for p in self.ar.parameters():
|
||||||
|
p.DO_NOT_TRAIN = True
|
||||||
|
p.requires_grad = False
|
||||||
|
|
||||||
|
self.diff = TransformerDiffusion(ar_prior=True, **kwargs)
|
||||||
|
if freeze_diff:
|
||||||
|
for p in self.diff.parameters():
|
||||||
|
p.DO_NOT_TRAIN = True
|
||||||
|
p.requires_grad = False
|
||||||
|
for p in list(self.diff.ar_prior_intg.parameters()) + list(self.diff.ar_input.parameters()):
|
||||||
|
del p.DO_NOT_TRAIN
|
||||||
|
p.requires_grad = True
|
||||||
|
|
||||||
|
def get_grad_norm_parameter_groups(self):
|
||||||
|
groups = {
|
||||||
|
'attention_layers': list(itertools.chain.from_iterable([lyr.attn.parameters() for lyr in self.diff.layers])),
|
||||||
|
'ff_layers': list(itertools.chain.from_iterable([lyr.ff.parameters() for lyr in self.diff.layers])),
|
||||||
|
'rotary_embeddings': list(self.diff.rotary_embeddings.parameters()),
|
||||||
|
'out': list(self.diff.out.parameters()),
|
||||||
|
'x_proj': list(self.diff.inp_block.parameters()),
|
||||||
|
'layers': list(self.diff.layers.parameters()),
|
||||||
|
'ar_prior_intg': list(self.diff.ar_prior_intg.parameters()),
|
||||||
|
'time_embed': list(self.diff.time_embed.parameters()),
|
||||||
|
}
|
||||||
|
return groups
|
||||||
|
|
||||||
|
def forward(self, x, timesteps, truth_mel, disable_diversity=False, conditioning_input=None, conditioning_free=False):
|
||||||
|
with torch.no_grad():
|
||||||
|
prior = self.ar(truth_mel, conditioning_input, return_latent=True)
|
||||||
|
|
||||||
|
diff = self.diff(x, timesteps, prior, conditioning_free=conditioning_free)
|
||||||
|
return diff
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def register_transformer_diffusion8(opt_net, opt):
|
def register_transformer_diffusion8(opt_net, opt):
|
||||||
return TransformerDiffusion(**opt_net['kwargs'])
|
return TransformerDiffusion(**opt_net['kwargs'])
|
||||||
|
@ -250,24 +307,17 @@ def register_transformer_diffusion8_with_quantizer(opt_net, opt):
|
||||||
return TransformerDiffusionWithQuantizer(**opt_net['kwargs'])
|
return TransformerDiffusionWithQuantizer(**opt_net['kwargs'])
|
||||||
|
|
||||||
|
|
||||||
"""
|
@register_model
|
||||||
# For TFD5
|
def register_transformer_diffusion8_with_ar_prior(opt_net, opt):
|
||||||
if __name__ == '__main__':
|
return TransformerDiffusionWithARPrior(**opt_net['kwargs'])
|
||||||
clip = torch.randn(2, 256, 400)
|
|
||||||
aligned_sequence = torch.randn(2,100,512)
|
|
||||||
cond = torch.randn(2, 256, 400)
|
|
||||||
ts = torch.LongTensor([600, 600])
|
|
||||||
model = TransformerDiffusion(model_channels=3072, block_channels=1536, prenet_channels=1536)
|
|
||||||
torch.save(model, 'sample.pth')
|
|
||||||
print_network(model)
|
|
||||||
o = model(clip, ts, aligned_sequence, cond)
|
|
||||||
"""
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
|
def test_quant_model():
|
||||||
clip = torch.randn(2, 256, 400)
|
clip = torch.randn(2, 256, 400)
|
||||||
cond = torch.randn(2, 256, 400)
|
cond = torch.randn(2, 256, 400)
|
||||||
ts = torch.LongTensor([600, 600])
|
ts = torch.LongTensor([600, 600])
|
||||||
model = TransformerDiffusionWithQuantizer(model_channels=2048, block_channels=1024, prenet_channels=1024, input_vec_dim=1024, num_layers=16, prenet_layers=6)
|
model = TransformerDiffusionWithQuantizer(model_channels=2048, block_channels=1024, prenet_channels=1024,
|
||||||
|
input_vec_dim=1024, num_layers=16, prenet_layers=6)
|
||||||
model.get_grad_norm_parameter_groups()
|
model.get_grad_norm_parameter_groups()
|
||||||
|
|
||||||
quant_weights = torch.load('D:\\dlas\\experiments\\train_music_quant_r4\\models\\5000_generator.pth')
|
quant_weights = torch.load('D:\\dlas\\experiments\\train_music_quant_r4\\models\\5000_generator.pth')
|
||||||
|
@ -279,3 +329,28 @@ if __name__ == '__main__':
|
||||||
print_network(model)
|
print_network(model)
|
||||||
o = model(clip, ts, clip, cond)
|
o = model(clip, ts, clip, cond)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ar_model():
|
||||||
|
clip = torch.randn(2, 256, 400)
|
||||||
|
cond = torch.randn(2, 256, 400)
|
||||||
|
ts = torch.LongTensor([600, 600])
|
||||||
|
model = TransformerDiffusionWithARPrior(model_channels=2048, block_channels=1024, prenet_channels=1024,
|
||||||
|
input_vec_dim=512, num_layers=16, prenet_layers=6, freeze_diff=True)
|
||||||
|
model.get_grad_norm_parameter_groups()
|
||||||
|
|
||||||
|
ar_weights = torch.load('D:\\dlas\\experiments\\train_music_gpt\\models\\44500_generator_ema.pth')
|
||||||
|
model.ar.load_state_dict(ar_weights, strict=True)
|
||||||
|
diff_weights = torch.load('X:\\dlas\\experiments\\train_music_diffusion_tfd8\\models\\47500_generator_ema.pth')
|
||||||
|
pruned_diff_weights = {}
|
||||||
|
for k,v in diff_weights.items():
|
||||||
|
if k.startswith('diff.'):
|
||||||
|
pruned_diff_weights[k.replace('diff.', '')] = v
|
||||||
|
model.diff.load_state_dict(pruned_diff_weights, strict=False)
|
||||||
|
torch.save(model.state_dict(), 'sample.pth')
|
||||||
|
|
||||||
|
model(clip, ts, cond, conditioning_input=cond)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_ar_model()
|
||||||
|
|
|
@ -780,6 +780,15 @@ class UNetMusicModelARPrior(nn.Module):
|
||||||
del p.DO_NOT_TRAIN
|
del p.DO_NOT_TRAIN
|
||||||
p.requires_grad = True
|
p.requires_grad = True
|
||||||
|
|
||||||
|
def get_grad_norm_parameter_groups(self):
|
||||||
|
groups = {
|
||||||
|
'input_blocks': list(self.diff.input_blocks.parameters()),
|
||||||
|
'output_blocks': list(self.diff.output_blocks.parameters()),
|
||||||
|
'ar_prior_intg': list(self.diff.ar_prior_intg.parameters()),
|
||||||
|
'time_embed': list(self.diff.time_embed.parameters()),
|
||||||
|
}
|
||||||
|
return groups
|
||||||
|
|
||||||
def forward(self, x, timesteps, truth_mel, disable_diversity=False, conditioning_input=None, conditioning_free=False):
|
def forward(self, x, timesteps, truth_mel, disable_diversity=False, conditioning_input=None, conditioning_free=False):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
prior = self.ar(truth_mel, conditioning_input, return_latent=True)
|
prior = self.ar(truth_mel, conditioning_input, return_latent=True)
|
||||||
|
@ -805,6 +814,7 @@ if __name__ == '__main__':
|
||||||
attention_resolutions=(2,4), channel_mult=(1,2,3), dims=1,
|
attention_resolutions=(2,4), channel_mult=(1,2,3), dims=1,
|
||||||
use_scale_shift_norm=True, dropout=.1, num_heads=8, unconditioned_percentage=.4, freeze_unet=True)
|
use_scale_shift_norm=True, dropout=.1, num_heads=8, unconditioned_percentage=.4, freeze_unet=True)
|
||||||
print_network(model)
|
print_network(model)
|
||||||
|
model.get_grad_norm_parameter_groups()
|
||||||
|
|
||||||
ar_weights = torch.load('D:\\dlas\\experiments\\train_music_gpt\\models\\44500_generator_ema.pth')
|
ar_weights = torch.load('D:\\dlas\\experiments\\train_music_gpt\\models\\44500_generator_ema.pth')
|
||||||
model.ar.load_state_dict(ar_weights, strict=True)
|
model.ar.load_state_dict(ar_weights, strict=True)
|
||||||
|
|
|
@ -339,7 +339,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_gpt.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_gpt_upper.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
opt = option.parse(args.opt, is_train=True)
|
opt = option.parse(args.opt, is_train=True)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user