From b84469e922b9d40302c52e3f86cc4e364c38a4eb Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 29 Dec 2020 16:08:31 -0700 Subject: [PATCH] SRFlow weight converter --- recipes/srflow/convert_official_weights.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 recipes/srflow/convert_official_weights.py diff --git a/recipes/srflow/convert_official_weights.py b/recipes/srflow/convert_official_weights.py new file mode 100644 index 00000000..2ce46bfe --- /dev/null +++ b/recipes/srflow/convert_official_weights.py @@ -0,0 +1,19 @@ +import torch + +# Quick script that can be used to convert from pretrained SRFlow weights to the variants used in this repo. The only +# differences between the two is the variable naming conventions used by the RRDBNet. (FWIW this repo is using the +# more up-to-date names that conform to Python standards). + +official_weight_file = 'SRFlow_CelebA_8X.pth' +output = 'CelebA_converted.pth' + +sd = torch.load(official_weight_file) +sdp = {} +for k,v in sd.items(): + k = k.replace('RRDB.RRDB_trunk', 'RRDB.body') + k = k.replace('.RDB', '.rdb') + k = k.replace('trunk_conv.', 'conv_body.') + k = k.replace('.upconv', '.conv_up') + k = k.replace('.HRconv', '.conv_hr') +sdp[k] = v +torch.save(sdp, output)