SRFlow weight converter
This commit is contained in:
parent
ba543d1152
commit
b84469e922
19
recipes/srflow/convert_official_weights.py
Normal file
19
recipes/srflow/convert_official_weights.py
Normal file
|
@ -0,0 +1,19 @@
|
|||
import torch
|
||||
|
||||
# Quick script that can be used to convert from pretrained SRFlow weights to the variants used in this repo. The only
|
||||
# differences between the two is the variable naming conventions used by the RRDBNet. (FWIW this repo is using the
|
||||
# more up-to-date names that conform to Python standards).
|
||||
|
||||
official_weight_file = 'SRFlow_CelebA_8X.pth'
|
||||
output = 'CelebA_converted.pth'
|
||||
|
||||
sd = torch.load(official_weight_file)
|
||||
sdp = {}
|
||||
for k,v in sd.items():
|
||||
k = k.replace('RRDB.RRDB_trunk', 'RRDB.body')
|
||||
k = k.replace('.RDB', '.rdb')
|
||||
k = k.replace('trunk_conv.', 'conv_body.')
|
||||
k = k.replace('.upconv', '.conv_up')
|
||||
k = k.replace('.HRconv', '.conv_hr')
|
||||
sdp[k] = v
|
||||
torch.save(sdp, output)
|
Loading…
Reference in New Issue
Block a user