Fix loading new state dicts for RRDB

This commit is contained in:
James Betker 2020-06-11 08:25:57 -06:00
parent 5ca53e7786
commit d3b2cbfe7c

View File

@ -209,7 +209,8 @@ class RRDBNet(RRDBBase):
# The parameters in self.trunk used to be in this class. To support loading legacy saves, restore them.
t_state = self.trunk.state_dict()
for k in t_state.keys():
state_dict["trunk.%s" % (k,)] = state_dict.pop(k)
if k in state_dict.keys():
state_dict["trunk.%s" % (k,)] = state_dict.pop(k)
super(RRDBNet, self).load_state_dict(state_dict, strict)
# Variant of RRDBNet that is "assisted" by an external pretrained image classifier whose