diff --git a/codes/models/audio/music/transformer_diffusion5.py b/codes/models/audio/music/transformer_diffusion5.py index 77723c09..2ae84345 100644 --- a/codes/models/audio/music/transformer_diffusion5.py +++ b/codes/models/audio/music/transformer_diffusion5.py @@ -196,7 +196,7 @@ class TransformerDiffusion(nn.Module): @register_model -def register_transformer_diffusion4(opt_net, opt): +def register_transformer_diffusion5(opt_net, opt): return TransformerDiffusion(**opt_net['kwargs'])