forked from mrq/DL-Art-School
Fix for tts6
This commit is contained in:
parent
5ae816bead
commit
65a546c4d7
|
@ -364,23 +364,6 @@ class DiffusionTts(nn.Module):
|
||||||
zero_module(conv_nd(dims, model_channels, out_channels, kernel_size, padding=padding)),
|
zero_module(conv_nd(dims, model_channels, out_channels, kernel_size, padding=padding)),
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]',
|
|
||||||
strict: bool = True):
|
|
||||||
# Temporary hack to allow the addition of nil-guidance token embeddings to the existing guidance embeddings.
|
|
||||||
lsd = self.state_dict()
|
|
||||||
revised = 0
|
|
||||||
for i, blk in enumerate(self.input_blocks):
|
|
||||||
if isinstance(blk, nn.Embedding):
|
|
||||||
key = f'input_blocks.{i}.weight'
|
|
||||||
if state_dict[key].shape[0] != lsd[key].shape[0]:
|
|
||||||
t = torch.randn_like(lsd[key]) * .02
|
|
||||||
t[:state_dict[key].shape[0]] = state_dict[key]
|
|
||||||
state_dict[key] = t
|
|
||||||
revised += 1
|
|
||||||
print(f"Loaded experimental unet_diffusion_net with {revised} modifications.")
|
|
||||||
return super().load_state_dict(state_dict, strict)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x, timesteps, tokens=None, conditioning_input=None, lr_input=None):
|
def forward(self, x, timesteps, tokens=None, conditioning_input=None, lr_input=None):
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue
Block a user