diff --git a/codes/models/archs/StructuredSwitchedGenerator.py b/codes/models/archs/StructuredSwitchedGenerator.py index c68b8619..714df7a6 100644 --- a/codes/models/archs/StructuredSwitchedGenerator.py +++ b/codes/models/archs/StructuredSwitchedGenerator.py @@ -159,7 +159,7 @@ class SSGr1(nn.Module): self.init_temperature = init_temperature self.final_temperature_step = 10000 - def forward(self, x, *args): + def forward(self, x, embedding): noise_stds = [] # The attention_maps debugger outputs . Save that here. self.lr = x.detach().cpu()