forked from mrq/DL-Art-School
tfd3 mods
This commit is contained in:
parent
bed3df4888
commit
0659fe3d1e
|
@ -97,6 +97,8 @@ class TransformerDiffusion(nn.Module):
|
|||
use_rmsnorm=True,
|
||||
ff_glu=True,
|
||||
rotary_pos_emb=True,
|
||||
zero_init_branch_output=True,
|
||||
ff_mult=1,
|
||||
)
|
||||
|
||||
# Either code_converter or latent_converter is used, depending on what type of conditioning data is fed.
|
||||
|
@ -118,6 +120,8 @@ class TransformerDiffusion(nn.Module):
|
|||
use_rmsnorm=True,
|
||||
ff_glu=True,
|
||||
rotary_pos_emb=True,
|
||||
zero_init_branch_output=True,
|
||||
ff_mult=1,
|
||||
)
|
||||
)
|
||||
self.latent_fade = nn.Parameter(torch.zeros(1,1,model_channels))
|
||||
|
@ -130,6 +134,8 @@ class TransformerDiffusion(nn.Module):
|
|||
use_rmsnorm=True,
|
||||
ff_glu=True,
|
||||
rotary_pos_emb=True,
|
||||
zero_init_branch_output=True,
|
||||
ff_mult=1,
|
||||
)
|
||||
|
||||
self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,model_channels))
|
||||
|
|
Loading…
Reference in New Issue
Block a user