tfd3 mods
This commit is contained in:
parent
bed3df4888
commit
0659fe3d1e
|
@ -97,6 +97,8 @@ class TransformerDiffusion(nn.Module):
|
||||||
use_rmsnorm=True,
|
use_rmsnorm=True,
|
||||||
ff_glu=True,
|
ff_glu=True,
|
||||||
rotary_pos_emb=True,
|
rotary_pos_emb=True,
|
||||||
|
zero_init_branch_output=True,
|
||||||
|
ff_mult=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Either code_converter or latent_converter is used, depending on what type of conditioning data is fed.
|
# Either code_converter or latent_converter is used, depending on what type of conditioning data is fed.
|
||||||
|
@ -118,6 +120,8 @@ class TransformerDiffusion(nn.Module):
|
||||||
use_rmsnorm=True,
|
use_rmsnorm=True,
|
||||||
ff_glu=True,
|
ff_glu=True,
|
||||||
rotary_pos_emb=True,
|
rotary_pos_emb=True,
|
||||||
|
zero_init_branch_output=True,
|
||||||
|
ff_mult=1,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.latent_fade = nn.Parameter(torch.zeros(1,1,model_channels))
|
self.latent_fade = nn.Parameter(torch.zeros(1,1,model_channels))
|
||||||
|
@ -130,6 +134,8 @@ class TransformerDiffusion(nn.Module):
|
||||||
use_rmsnorm=True,
|
use_rmsnorm=True,
|
||||||
ff_glu=True,
|
ff_glu=True,
|
||||||
rotary_pos_emb=True,
|
rotary_pos_emb=True,
|
||||||
|
zero_init_branch_output=True,
|
||||||
|
ff_mult=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,model_channels))
|
self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,model_channels))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user