tfd3 mods

This commit is contained in:
James Betker 2022-05-27 11:16:26 -06:00
parent bed3df4888
commit 0659fe3d1e

View File

@ -97,6 +97,8 @@ class TransformerDiffusion(nn.Module):
use_rmsnorm=True, use_rmsnorm=True,
ff_glu=True, ff_glu=True,
rotary_pos_emb=True, rotary_pos_emb=True,
zero_init_branch_output=True,
ff_mult=1,
) )
# Either code_converter or latent_converter is used, depending on what type of conditioning data is fed. # Either code_converter or latent_converter is used, depending on what type of conditioning data is fed.
@ -118,6 +120,8 @@ class TransformerDiffusion(nn.Module):
use_rmsnorm=True, use_rmsnorm=True,
ff_glu=True, ff_glu=True,
rotary_pos_emb=True, rotary_pos_emb=True,
zero_init_branch_output=True,
ff_mult=1,
) )
) )
self.latent_fade = nn.Parameter(torch.zeros(1,1,model_channels)) self.latent_fade = nn.Parameter(torch.zeros(1,1,model_channels))
@ -130,6 +134,8 @@ class TransformerDiffusion(nn.Module):
use_rmsnorm=True, use_rmsnorm=True,
ff_glu=True, ff_glu=True,
rotary_pos_emb=True, rotary_pos_emb=True,
zero_init_branch_output=True,
ff_mult=1,
) )
self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,model_channels)) self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,model_channels))