diff --git a/codes/models/ExtensibleTrainer.py b/codes/models/ExtensibleTrainer.py
index d7f1bd49..7d61a19d 100644
--- a/codes/models/ExtensibleTrainer.py
+++ b/codes/models/ExtensibleTrainer.py
@@ -58,7 +58,7 @@ class ExtensibleTrainer(BaseModel):
                 new_net = None
             if net['type'] == 'generator':
                 if new_net is None:
-                    new_net = networks.define_G(net, None, opt['scale']).to(self.device)
+                    new_net = networks.define_G(opt, net, opt['scale']).to(self.device)
                 self.netsG[name] = new_net
             elif net['type'] == 'discriminator':
                 if new_net is None:
diff --git a/codes/models/archs/srflow_orig/FlowActNorms.py b/codes/models/archs/srflow_orig/FlowActNorms.py
index 3292aafa..e92dc642 100644
--- a/codes/models/archs/srflow_orig/FlowActNorms.py
+++ b/codes/models/archs/srflow_orig/FlowActNorms.py
@@ -1,7 +1,7 @@
 import torch
 from torch import nn as nn
 
-from models.modules import thops
+from models.archs.srflow_orig import thops
 
 
 class _ActNorm(nn.Module):
diff --git a/codes/models/archs/srflow_orig/FlowAffineCouplingsAblation.py b/codes/models/archs/srflow_orig/FlowAffineCouplingsAblation.py
index 5a94abe3..a3c9fb0d 100644
--- a/codes/models/archs/srflow_orig/FlowAffineCouplingsAblation.py
+++ b/codes/models/archs/srflow_orig/FlowAffineCouplingsAblation.py
@@ -1,8 +1,8 @@
 import torch
 from torch import nn as nn
 
-from models.modules import thops
-from models.modules.flow import Conv2d, Conv2dZeros
+from models.archs.srflow_orig import thops
+from models.archs.srflow_orig.flow import Conv2d, Conv2dZeros
 from utils.util import opt_get
 
 
@@ -15,10 +15,10 @@ class CondAffineSeparatedAndCond(nn.Module):
         self.kernel_hidden = 1
         self.affine_eps = 0.0001
         self.n_hidden_layers = 1
-        hidden_channels = opt_get(opt, ['network_G', 'flow', 'CondAffineSeparatedAndCond', 'hidden_channels'])
+        hidden_channels = opt_get(opt, ['networks', 'generator','flow', 'CondAffineSeparatedAndCond', 'hidden_channels'])
         self.hidden_channels = 64 if hidden_channels is None else hidden_channels
 
-        self.affine_eps = opt_get(opt, ['network_G', 'flow', 'CondAffineSeparatedAndCond', 'eps'],  0.0001)
+        self.affine_eps = opt_get(opt, ['networks', 'generator','flow', 'CondAffineSeparatedAndCond', 'eps'],  0.0001)
 
         self.channels_for_nn = self.in_channels // 2
         self.channels_for_co = self.in_channels - self.channels_for_nn
diff --git a/codes/models/archs/srflow_orig/FlowStep.py b/codes/models/archs/srflow_orig/FlowStep.py
index 41a867be..1c128c92 100644
--- a/codes/models/archs/srflow_orig/FlowStep.py
+++ b/codes/models/archs/srflow_orig/FlowStep.py
@@ -1,9 +1,8 @@
 import torch
 from torch import nn as nn
 
-import models.modules
-import models.modules.Permutations
-from models.modules import flow, thops, FlowAffineCouplingsAblation
+import models.archs.srflow_orig.Permutations
+from models.archs.srflow_orig import flow, thops, FlowAffineCouplingsAblation
 from utils.util import opt_get
 
 
@@ -47,16 +46,16 @@ class FlowStep(nn.Module):
         self.acOpt = acOpt
 
         # 1. actnorm
-        self.actnorm = models.modules.FlowActNorms.ActNorm2d(in_channels, actnorm_scale)
+        self.actnorm = models.archs.srflow_orig.FlowActNorms.ActNorm2d(in_channels, actnorm_scale)
 
         # 2. permute
         if flow_permutation == "invconv":
-            self.invconv = models.modules.Permutations.InvertibleConv1x1(
+            self.invconv = models.archs.srflow_orig.Permutations.InvertibleConv1x1(
                 in_channels, LU_decomposed=LU_decomposed)
 
         # 3. coupling
         if flow_coupling == "CondAffineSeparatedAndCond":
-            self.affine = models.modules.FlowAffineCouplingsAblation.CondAffineSeparatedAndCond(in_channels=in_channels,
+            self.affine = models.archs.srflow_orig.FlowAffineCouplingsAblation.CondAffineSeparatedAndCond(in_channels=in_channels,
                                                                                                 opt=opt)
         elif flow_coupling == "noCoupling":
             pass
diff --git a/codes/models/archs/srflow_orig/FlowUpsamplerNet.py b/codes/models/archs/srflow_orig/FlowUpsamplerNet.py
index 6bc4f5d8..0b0d4e23 100644
--- a/codes/models/archs/srflow_orig/FlowUpsamplerNet.py
+++ b/codes/models/archs/srflow_orig/FlowUpsamplerNet.py
@@ -2,11 +2,11 @@ import numpy as np
 import torch
 from torch import nn as nn
 
-import models.modules.Split
-from models.modules import flow, thops
-from models.modules.Split import Split2d
-from models.modules.glow_arch import f_conv2d_bias
-from models.modules.FlowStep import FlowStep
+import models.archs.srflow_orig.Split
+from models.archs.srflow_orig import flow, thops
+from models.archs.srflow_orig.Split import Split2d
+from models.archs.srflow_orig.glow_arch import f_conv2d_bias
+from models.archs.srflow_orig.FlowStep import FlowStep
 from utils.util import opt_get
 
 
@@ -21,8 +21,8 @@ class FlowUpsamplerNet(nn.Module):
 
         self.layers = nn.ModuleList()
         self.output_shapes = []
-        self.L = opt_get(opt, ['network_G', 'flow', 'L'])
-        self.K = opt_get(opt, ['network_G', 'flow', 'K'])
+        self.L = opt_get(opt, ['networks', 'generator','flow', 'L'])
+        self.K = opt_get(opt, ['networks', 'generator','flow', 'K'])
         if isinstance(self.K, int):
             self.K = [K for K in [K, ] * (self.L + 1)]
 
@@ -60,11 +60,11 @@ class FlowUpsamplerNet(nn.Module):
         affineInCh = self.get_affineInCh(opt_get)
         flow_permutation = self.get_flow_permutation(flow_permutation, opt)
 
-        normOpt = opt_get(opt, ['network_G', 'flow', 'norm'])
+        normOpt = opt_get(opt, ['networks', 'generator','flow', 'norm'])
 
         conditional_channels = {}
         n_rrdb = self.get_n_rrdb_channels(opt, opt_get)
-        n_bypass_channels = opt_get(opt, ['network_G', 'flow', 'levelConditional', 'n_channels'])
+        n_bypass_channels = opt_get(opt, ['networks', 'generator','flow', 'levelConditional', 'n_channels'])
         conditional_channels[0] = n_rrdb
         for level in range(1, self.L + 1):
             # Level 1 gets conditionals from 2, 3, 4 => L - level
@@ -88,7 +88,7 @@ class FlowUpsamplerNet(nn.Module):
             # Split
             self.arch_split(H, W, level, self.L, opt, opt_get)
 
-        if opt_get(opt, ['network_G', 'flow', 'split', 'enable']):
+        if opt_get(opt, ['networks', 'generator','flow', 'split', 'enable']):
             self.f = f_conv2d_bias(affineInCh, 2 * 3 * 64 // 2 // 2)
         else:
             self.f = f_conv2d_bias(affineInCh, 2 * 3 * 64)
@@ -99,7 +99,7 @@ class FlowUpsamplerNet(nn.Module):
         self.scaleW = 160 / W
 
     def get_n_rrdb_channels(self, opt, opt_get):
-        blocks = opt_get(opt, ['network_G', 'flow', 'stackRRDB', 'blocks'])
+        blocks = opt_get(opt, ['networks', 'generator','flow', 'stackRRDB', 'blocks'])
         n_rrdb = 64 if blocks is None else (len(blocks) + 1) * 64
         return n_rrdb
 
@@ -126,33 +126,33 @@ class FlowUpsamplerNet(nn.Module):
                 [-1, self.C, H, W])
 
     def get_condAffSetting(self, opt, opt_get):
-        condAff = opt_get(opt, ['network_G', 'flow', 'condAff']) or None
-        condAff = opt_get(opt, ['network_G', 'flow', 'condFtAffine']) or condAff
+        condAff = opt_get(opt, ['networks', 'generator','flow', 'condAff']) or None
+        condAff = opt_get(opt, ['networks', 'generator','flow', 'condFtAffine']) or condAff
         return condAff
 
     def arch_split(self, H, W, L, levels, opt, opt_get):
-        correct_splits = opt_get(opt, ['network_G', 'flow', 'split', 'correct_splits'], False)
+        correct_splits = opt_get(opt, ['networks', 'generator','flow', 'split', 'correct_splits'], False)
         correction = 0 if correct_splits else 1
-        if opt_get(opt, ['network_G', 'flow', 'split', 'enable']) and L < levels - correction:
-            logs_eps = opt_get(opt, ['network_G', 'flow', 'split', 'logs_eps']) or 0
-            consume_ratio = opt_get(opt, ['network_G', 'flow', 'split', 'consume_ratio']) or 0.5
+        if opt_get(opt, ['networks', 'generator','flow', 'split', 'enable']) and L < levels - correction:
+            logs_eps = opt_get(opt, ['networks', 'generator','flow', 'split', 'logs_eps']) or 0
+            consume_ratio = opt_get(opt, ['networks', 'generator','flow', 'split', 'consume_ratio']) or 0.5
             position_name = get_position_name(H, self.opt['scale'])
-            position = position_name if opt_get(opt, ['network_G', 'flow', 'split', 'conditional']) else None
-            cond_channels = opt_get(opt, ['network_G', 'flow', 'split', 'cond_channels'])
+            position = position_name if opt_get(opt, ['networks', 'generator','flow', 'split', 'conditional']) else None
+            cond_channels = opt_get(opt, ['networks', 'generator','flow', 'split', 'cond_channels'])
             cond_channels = 0 if cond_channels is None else cond_channels
 
-            t = opt_get(opt, ['network_G', 'flow', 'split', 'type'], 'Split2d')
+            t = opt_get(opt, ['networks', 'generator','flow', 'split', 'type'], 'Split2d')
 
             if t == 'Split2d':
-                split = models.modules.Split.Split2d(num_channels=self.C, logs_eps=logs_eps, position=position,
+                split = models.archs.srflow_orig.Split.Split2d(num_channels=self.C, logs_eps=logs_eps, position=position,
                                                      cond_channels=cond_channels, consume_ratio=consume_ratio, opt=opt)
             self.layers.append(split)
             self.output_shapes.append([-1, split.num_channels_pass, H, W])
             self.C = split.num_channels_pass
 
     def arch_additionalFlowAffine(self, H, LU_decomposed, W, actnorm_scale, hidden_channels, opt):
-        if 'additionalFlowNoAffine' in opt['network_G']['flow']:
-            n_additionalFlowNoAffine = int(opt['network_G']['flow']['additionalFlowNoAffine'])
+        if 'additionalFlowNoAffine' in opt['networks']['generator']['flow']:
+            n_additionalFlowNoAffine = int(opt['networks']['generator']['flow']['additionalFlowNoAffine'])
             for _ in range(n_additionalFlowNoAffine):
                 self.layers.append(
                     FlowStep(in_channels=self.C,
@@ -171,11 +171,11 @@ class FlowUpsamplerNet(nn.Module):
         return H, W
 
     def get_flow_permutation(self, flow_permutation, opt):
-        flow_permutation = opt['network_G']['flow'].get('flow_permutation', 'invconv')
+        flow_permutation = opt['networks']['generator']['flow'].get('flow_permutation', 'invconv')
         return flow_permutation
 
     def get_affineInCh(self, opt_get):
-        affineInCh = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or []
+        affineInCh = opt_get(self.opt, ['networks', 'generator','flow', 'stackRRDB', 'blocks']) or []
         affineInCh = (len(affineInCh) + 1) * 64
         return affineInCh
 
@@ -204,7 +204,7 @@ class FlowUpsamplerNet(nn.Module):
         level_conditionals = {}
         bypasses = {}
 
-        L = opt_get(self.opt, ['network_G', 'flow', 'L'])
+        L = opt_get(self.opt, ['networks', 'generator','flow', 'L'])
 
         for level in range(1, L + 1):
             bypasses[level] = torch.nn.functional.interpolate(gt, scale_factor=2 ** -level, mode='bilinear', align_corners=False)
@@ -255,7 +255,7 @@ class FlowUpsamplerNet(nn.Module):
         # debug.imwrite("fl_fea", fl_fea)
         bypasses = {}
         level_conditionals = {}
-        if not opt_get(self.opt, ['network_G', 'flow', 'levelConditional', 'conditional']) == True:
+        if not opt_get(self.opt, ['networks', 'generator','flow', 'levelConditional', 'conditional']) == True:
             for level in range(self.L + 1):
                 level_conditionals[level] = rrdbResults[self.levelToName[level]]
 
diff --git a/codes/models/archs/srflow_orig/Permutations.py b/codes/models/archs/srflow_orig/Permutations.py
index 86584e58..9115dcc7 100644
--- a/codes/models/archs/srflow_orig/Permutations.py
+++ b/codes/models/archs/srflow_orig/Permutations.py
@@ -3,7 +3,7 @@ import torch
 from torch import nn as nn
 from torch.nn import functional as F
 
-from models.modules import thops
+from models.archs.srflow_orig import thops
 
 
 class InvertibleConv1x1(nn.Module):
diff --git a/codes/models/archs/srflow_orig/RRDBNet_arch.py b/codes/models/archs/srflow_orig/RRDBNet_arch.py
index f5cdb4d5..a6d75117 100644
--- a/codes/models/archs/srflow_orig/RRDBNet_arch.py
+++ b/codes/models/archs/srflow_orig/RRDBNet_arch.py
@@ -2,70 +2,148 @@ import functools
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
-import models.modules.module_util as mutil
+import models.archs.srflow_orig.module_util as mutil
+from models.archs.arch_util import default_init_weights, ConvGnSilu
 from utils.util import opt_get
 
 
-class ResidualDenseBlock_5C(nn.Module):
-    def __init__(self, nf=64, gc=32, bias=True):
-        super(ResidualDenseBlock_5C, self).__init__()
-        # gc: growth channel, i.e. intermediate channels
-        self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
-        self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
-        self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
-        self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
-        self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
-        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+class ResidualDenseBlock(nn.Module):
+    """Residual Dense Block.
+
+    Used in RRDB block in ESRGAN.
+
+    Args:
+        mid_channels (int): Channel number of intermediate features.
+        growth_channels (int): Channels for each growth.
+    """
+
+    def __init__(self, mid_channels=64, growth_channels=32):
+        super(ResidualDenseBlock, self).__init__()
+        for i in range(5):
+            out_channels = mid_channels if i == 4 else growth_channels
+            self.add_module(
+                f'conv{i+1}',
+                nn.Conv2d(mid_channels + i * growth_channels, out_channels, 3,
+                          1, 1))
+        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+        for i in range(5):
+            default_init_weights(getattr(self, f'conv{i+1}'), 0.1)
 
-        # initialization
-        mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
 
     def forward(self, x):
+        """Forward function.
+
+        Args:
+            x (Tensor): Input tensor with shape (n, c, h, w).
+
+        Returns:
+            Tensor: Forward results.
+        """
         x1 = self.lrelu(self.conv1(x))
         x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
         x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
         x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
         x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
+        # Emperically, we use 0.2 to scale the residual for better performance
         return x5 * 0.2 + x
 
 
 class RRDB(nn.Module):
-    '''Residual in Residual Dense Block'''
+    """Residual in Residual Dense Block.
 
-    def __init__(self, nf, gc=32):
+    Used in RRDB-Net in ESRGAN.
+
+    Args:
+        mid_channels (int): Channel number of intermediate features.
+        growth_channels (int): Channels for each growth.
+    """
+
+    def __init__(self, mid_channels, growth_channels=32):
         super(RRDB, self).__init__()
-        self.RDB1 = ResidualDenseBlock_5C(nf, gc)
-        self.RDB2 = ResidualDenseBlock_5C(nf, gc)
-        self.RDB3 = ResidualDenseBlock_5C(nf, gc)
+        self.rdb1 = ResidualDenseBlock(mid_channels, growth_channels)
+        self.rdb2 = ResidualDenseBlock(mid_channels, growth_channels)
+        self.rdb3 = ResidualDenseBlock(mid_channels, growth_channels)
 
     def forward(self, x):
-        out = self.RDB1(x)
-        out = self.RDB2(out)
-        out = self.RDB3(out)
+        """Forward function.
+
+        Args:
+            x (Tensor): Input tensor with shape (n, c, h, w).
+
+        Returns:
+            Tensor: Forward results.
+        """
+        out = self.rdb1(x)
+        out = self.rdb2(out)
+        out = self.rdb3(out)
+        # Emperically, we use 0.2 to scale the residual for better performance
         return out * 0.2 + x
 
 
+class RRDBWithBypass(nn.Module):
+    """Residual in Residual Dense Block.
+
+    Used in RRDB-Net in ESRGAN.
+
+    Args:
+        mid_channels (int): Channel number of intermediate features.
+        growth_channels (int): Channels for each growth.
+    """
+
+    def __init__(self, mid_channels, growth_channels=32):
+        super(RRDBWithBypass, self).__init__()
+        self.rdb1 = ResidualDenseBlock(mid_channels, growth_channels)
+        self.rdb2 = ResidualDenseBlock(mid_channels, growth_channels)
+        self.rdb3 = ResidualDenseBlock(mid_channels, growth_channels)
+        self.bypass = nn.Sequential(ConvGnSilu(mid_channels*2, mid_channels, kernel_size=3, bias=True, activation=True, norm=True),
+                                    ConvGnSilu(mid_channels, mid_channels//2, kernel_size=3, bias=False, activation=True, norm=False),
+                                    ConvGnSilu(mid_channels//2, 1, kernel_size=3, bias=False, activation=False, norm=False),
+                                    nn.Sigmoid())
+
+    def forward(self, x):
+        """Forward function.
+
+        Args:
+            x (Tensor): Input tensor with shape (n, c, h, w).
+
+        Returns:
+            Tensor: Forward results.
+        """
+        out = self.rdb1(x)
+        out = self.rdb2(out)
+        out = self.rdb3(out)
+        bypass = self.bypass(torch.cat([x, out], dim=1))
+        self.bypass_map = bypass.detach().clone()
+        # Empirically, we use 0.2 to scale the residual for better performance
+        return out * 0.2 * bypass + x
+
+
 class RRDBNet(nn.Module):
     def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4, opt=None):
         self.opt = opt
         super(RRDBNet, self).__init__()
-        RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
+
+        bypass = opt_get(self.opt, ['networks', 'generator', 'rrdb_bypass'])
+        if bypass:
+            RRDB_block_f = functools.partial(RRDBWithBypass, mid_channels=nf, growth_channels=gc)
+        else:
+            RRDB_block_f = functools.partial(RRDB, mid_channels=nf, growth_channels=gc)
         self.scale = scale
 
         self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
-        self.RRDB_trunk = mutil.make_layer(RRDB_block_f, nb)
-        self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
+        self.body = mutil.make_layer(RRDB_block_f, nb)
+        self.conv_body = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
         #### upsampling
-        self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
-        self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
+        self.conv_up1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
+        self.conv_up2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
         if self.scale >= 8:
-            self.upconv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
+            self.conv_up3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
         if self.scale >= 16:
-            self.upconv4 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
+            self.conv_up4 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
         if self.scale >= 32:
-            self.upconv5 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
+            self.conv_up5 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
 
-        self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
+        self.conv_hr = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
         self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
 
         self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
@@ -73,23 +151,23 @@ class RRDBNet(nn.Module):
     def forward(self, x, get_steps=False):
         fea = self.conv_first(x)
 
-        block_idxs = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or []
+        block_idxs = opt_get(self.opt, ['networks', 'generator','flow', 'stackRRDB', 'blocks']) or []
         block_results = {}
 
-        for idx, m in enumerate(self.RRDB_trunk.children()):
+        for idx, m in enumerate(self.body.children()):
             fea = m(fea)
             for b in block_idxs:
                 if b == idx:
                     block_results["block_{}".format(idx)] = fea
 
-        trunk = self.trunk_conv(fea)
+        trunk = self.conv_body(fea)
 
         last_lr_fea = fea + trunk
 
-        fea_up2 = self.upconv1(F.interpolate(last_lr_fea, scale_factor=2, mode='nearest'))
+        fea_up2 = self.conv_up1(F.interpolate(last_lr_fea, scale_factor=2, mode='nearest'))
         fea = self.lrelu(fea_up2)
 
-        fea_up4 = self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))
+        fea_up4 = self.conv_up2(F.interpolate(fea, scale_factor=2, mode='nearest'))
         fea = self.lrelu(fea_up4)
 
         fea_up8 = None
@@ -97,16 +175,16 @@ class RRDBNet(nn.Module):
         fea_up32 = None
 
         if self.scale >= 8:
-            fea_up8 = self.upconv3(F.interpolate(fea, scale_factor=2, mode='nearest'))
+            fea_up8 = self.conv_up3(F.interpolate(fea, scale_factor=2, mode='nearest'))
             fea = self.lrelu(fea_up8)
         if self.scale >= 16:
-            fea_up16 = self.upconv4(F.interpolate(fea, scale_factor=2, mode='nearest'))
+            fea_up16 = self.conv_up4(F.interpolate(fea, scale_factor=2, mode='nearest'))
             fea = self.lrelu(fea_up16)
         if self.scale >= 32:
-            fea_up32 = self.upconv5(F.interpolate(fea, scale_factor=2, mode='nearest'))
+            fea_up32 = self.conv_up5(F.interpolate(fea, scale_factor=2, mode='nearest'))
             fea = self.lrelu(fea_up32)
 
-        out = self.conv_last(self.lrelu(self.HRconv(fea)))
+        out = self.conv_last(self.lrelu(self.conv_hr(fea)))
 
         results = {'last_lr_fea': last_lr_fea,
                    'fea_up1': last_lr_fea,
@@ -117,10 +195,10 @@ class RRDBNet(nn.Module):
                    'fea_up32': fea_up32,
                    'out': out}
 
-        fea_up0_en = opt_get(self.opt, ['network_G', 'flow', 'fea_up0']) or False
+        fea_up0_en = opt_get(self.opt, ['networks', 'generator','flow', 'fea_up0']) or False
         if fea_up0_en:
             results['fea_up0'] = F.interpolate(last_lr_fea, scale_factor=1/2, mode='bilinear', align_corners=False, recompute_scale_factor=True)
-        fea_upn1_en = opt_get(self.opt, ['network_G', 'flow', 'fea_up-1']) or False
+        fea_upn1_en = opt_get(self.opt, ['networks', 'generator','flow', 'fea_up-1']) or False
         if fea_upn1_en:
             results['fea_up-1'] = F.interpolate(last_lr_fea, scale_factor=1/4, mode='bilinear', align_corners=False, recompute_scale_factor=True)
 
diff --git a/codes/models/archs/srflow_orig/SRFlowNet_arch.py b/codes/models/archs/srflow_orig/SRFlowNet_arch.py
index b95374ef..2bdab09d 100644
--- a/codes/models/archs/srflow_orig/SRFlowNet_arch.py
+++ b/codes/models/archs/srflow_orig/SRFlowNet_arch.py
@@ -4,10 +4,10 @@ import torch
 import torch.nn as nn
 import torch.nn.functional as F
 import numpy as np
-from models.modules.RRDBNet_arch import RRDBNet
-from models.modules.FlowUpsamplerNet import FlowUpsamplerNet
-import models.modules.thops as thops
-import models.modules.flow as flow
+from models.archs.srflow_orig.RRDBNet_arch import RRDBNet
+from models.archs.srflow_orig.FlowUpsamplerNet import FlowUpsamplerNet
+import models.archs.srflow_orig.thops as thops
+import models.archs.srflow_orig.flow as flow
 from utils.util import opt_get
 
 
@@ -19,27 +19,40 @@ class SRFlowNet(nn.Module):
         self.quant = 255 if opt_get(opt, ['datasets', 'train', 'quant']) is \
                             None else opt_get(opt, ['datasets', 'train', 'quant'])
         self.RRDB = RRDBNet(in_nc, out_nc, nf, nb, gc, scale, opt)
-        hidden_channels = opt_get(opt, ['network_G', 'flow', 'hidden_channels'])
+        if 'pretrain_rrdb' in opt['networks']['generator'].keys():
+            rrdb_state_dict = torch.load(opt['networks']['generator']['pretrain_rrdb'])
+            self.RRDB.load_state_dict(rrdb_state_dict, strict=True)
+
+        hidden_channels = opt_get(opt, ['networks', 'generator','flow', 'hidden_channels'])
         hidden_channels = hidden_channels or 64
         self.RRDB_training = True  # Default is true
 
-        train_RRDB_delay = opt_get(self.opt, ['network_G', 'train_RRDB_delay'])
-        set_RRDB_to_train = False
-        if set_RRDB_to_train:
-            self.set_rrdb_training(True)
+        train_RRDB_delay = opt_get(self.opt, ['networks', 'generator','train_RRDB_delay'])
+        self.RRDB_training = False
 
         self.flowUpsamplerNet = \
             FlowUpsamplerNet((160, 160, 3), hidden_channels, K,
-                             flow_coupling=opt['network_G']['flow']['coupling'], opt=opt)
+                             flow_coupling=opt['networks']['generator']['flow']['coupling'], opt=opt)
         self.i = 0
 
-    def set_rrdb_training(self, trainable):
-        if self.RRDB_training != trainable:
-            for p in self.RRDB.parameters():
-                p.requires_grad = trainable
-            self.RRDB_training = trainable
-            return True
-        return False
+    def get_random_z(self, heat, seed=None, batch_size=1, lr_shape=None, device='cuda'):
+        if seed: torch.manual_seed(seed)
+        if opt_get(self.opt, ['networks', 'generator', 'flow', 'split', 'enable']):
+            C = self.flowUpsamplerNet.C
+            H = int(self.opt['scale'] * lr_shape[2] // self.flowUpsamplerNet.scaleH)
+            W = int(self.opt['scale'] * lr_shape[3] // self.flowUpsamplerNet.scaleW)
+
+            size = (batch_size, C, H, W)
+            if heat == 0:
+                z = torch.zeros(size)
+            else:
+                z = torch.normal(mean=0, std=heat, size=size)
+        else:
+            L = opt_get(self.opt, ['networks', 'generator', 'flow', 'L']) or 3
+            fac = 2 ** (L - 3)
+            z_size = int(self.lr_size // (2 ** (L - 3)))
+            z = torch.normal(mean=0, std=heat, size=(batch_size, 3 * 8 * 8 * fac * fac, z_size, z_size))
+        return z.to(device)
 
     def forward(self, gt=None, lr=None, z=None, eps_std=None, reverse=False, epses=None, reverse_with_grad=False,
                 lr_enc=None,
@@ -48,14 +61,10 @@ class SRFlowNet(nn.Module):
             return self.normal_flow(gt, lr, epses=epses, lr_enc=lr_enc, add_gt_noise=add_gt_noise, step=step,
                                     y_onehot=y_label)
         else:
-            # assert lr.shape[0] == 1
             assert lr.shape[1] == 3
-            # assert lr.shape[2] == 20
-            # assert lr.shape[3] == 20
-            # assert z.shape[0] == 1
-            # assert z.shape[1] == 3 * 8 * 8
-            # assert z.shape[2] == 20
-            # assert z.shape[3] == 20
+            if z is None:
+                # Synthesize it.
+                z = self.get_random_z(eps_std, batch_size=lr.shape[0], lr_shape=lr.shape, device=lr.device)
             if reverse_with_grad:
                 return self.reverse_flow(lr, z, y_onehot=y_label, eps_std=eps_std, epses=epses, lr_enc=lr_enc,
                                          add_gt_noise=add_gt_noise)
@@ -66,7 +75,11 @@ class SRFlowNet(nn.Module):
 
     def normal_flow(self, gt, lr, y_onehot=None, epses=None, lr_enc=None, add_gt_noise=True, step=None):
         if lr_enc is None:
-            lr_enc = self.rrdbPreprocessing(lr)
+            if self.RRDB_training:
+                lr_enc = self.rrdbPreprocessing(lr)
+            else:
+                with torch.no_grad():
+                    lr_enc = self.rrdbPreprocessing(lr)
 
         logdet = torch.zeros_like(gt[:, 0, 0, 0])
         pixels = thops.pixels(gt)
@@ -75,7 +88,7 @@ class SRFlowNet(nn.Module):
 
         if add_gt_noise:
             # Setup
-            noiseQuant = opt_get(self.opt, ['network_G', 'flow', 'augmentation', 'noiseQuant'], True)
+            noiseQuant = opt_get(self.opt, ['networks', 'generator','flow', 'augmentation', 'noiseQuant'], True)
             if noiseQuant:
                 z = z + ((torch.rand(z.shape, device=z.device) - 0.5) / self.quant)
             logdet = logdet + float(-np.log(self.quant) * pixels)
@@ -101,11 +114,11 @@ class SRFlowNet(nn.Module):
 
     def rrdbPreprocessing(self, lr):
         rrdbResults = self.RRDB(lr, get_steps=True)
-        block_idxs = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or []
+        block_idxs = opt_get(self.opt, ['networks', 'generator','flow', 'stackRRDB', 'blocks']) or []
         if len(block_idxs) > 0:
             concat = torch.cat([rrdbResults["block_{}".format(idx)] for idx in block_idxs], dim=1)
 
-            if opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'concat']) or False:
+            if opt_get(self.opt, ['networks', 'generator','flow', 'stackRRDB', 'concat']) or False:
                 keys = ['last_lr_fea', 'fea_up1', 'fea_up2', 'fea_up4']
                 if 'fea_up0' in rrdbResults.keys():
                     keys.append('fea_up0')
@@ -134,7 +147,11 @@ class SRFlowNet(nn.Module):
             logdet = logdet - float(-np.log(self.quant) * pixels)
 
         if lr_enc is None:
-            lr_enc = self.rrdbPreprocessing(lr)
+            if self.RRDB_training:
+                lr_enc = self.rrdbPreprocessing(lr)
+            else:
+                with torch.no_grad():
+                    lr_enc = self.rrdbPreprocessing(lr)
 
         x, logdet = self.flowUpsamplerNet(rrdbResults=lr_enc, z=z, eps_std=eps_std, reverse=True, epses=epses,
                                           logdet=logdet)
diff --git a/codes/models/archs/srflow_orig/Split.py b/codes/models/archs/srflow_orig/Split.py
index 60897eb0..b7b1df98 100644
--- a/codes/models/archs/srflow_orig/Split.py
+++ b/codes/models/archs/srflow_orig/Split.py
@@ -1,9 +1,9 @@
 import torch
 from torch import nn as nn
 
-from models.modules import thops
-from models.modules.FlowStep import FlowStep
-from models.modules.flow import Conv2dZeros, GaussianDiag
+from models.archs.srflow_orig import thops
+from models.archs.srflow_orig.FlowStep import FlowStep
+from models.archs.srflow_orig.flow import Conv2dZeros, GaussianDiag
 from utils.util import opt_get
 
 
diff --git a/codes/models/archs/srflow_orig/flow.py b/codes/models/archs/srflow_orig/flow.py
index 5c0ae968..71e03646 100644
--- a/codes/models/archs/srflow_orig/flow.py
+++ b/codes/models/archs/srflow_orig/flow.py
@@ -3,7 +3,7 @@ import torch.nn as nn
 import torch.nn.functional as F
 import numpy as np
 
-from models.modules.FlowActNorms import ActNorm2d
+from models.archs.srflow_orig.FlowActNorms import ActNorm2d
 from . import thops
 
 
diff --git a/codes/models/base_model.py b/codes/models/base_model.py
index c60afdc1..55371407 100644
--- a/codes/models/base_model.py
+++ b/codes/models/base_model.py
@@ -105,12 +105,25 @@ class BaseModel():
         if 'state_dict' in load_net:
             load_net = load_net['state_dict']
 
+        is_srflow = False
         load_net_clean = OrderedDict()  # remove unnecessary 'module.'
         for k, v in load_net.items():
             if k.startswith('module.'):
                 load_net_clean[k[7:]] = v
             if k.startswith('generator'):   # Hack to fix ESRGAN pretrained model.
                 load_net_clean[k[10:]] = v
+            if 'RRDB_trunk' in k or is_srflow:   # Hacks to fix SRFlow imports, which uses some strange RDB names.
+                is_srflow = True
+                fixed_key = k.replace('RRDB_trunk', 'body')
+                if '.RDB' in fixed_key:
+                    fixed_key = fixed_key.replace('.RDB', '.rdb')
+                elif '.upconv' in fixed_key:
+                    fixed_key = fixed_key.replace('.upconv', '.conv_up')
+                elif '.trunk_conv' in fixed_key:
+                    fixed_key = fixed_key.replace('.trunk_conv', '.conv_body')
+                elif '.HRconv' in fixed_key:
+                    fixed_key = fixed_key.replace('.HRconv', '.conv_hr')
+                load_net_clean[fixed_key] = v
             else:
                 load_net_clean[k] = v
         network.load_state_dict(load_net_clean, strict=strict)
diff --git a/codes/models/networks.py b/codes/models/networks.py
index fb1a0149..75130989 100644
--- a/codes/models/networks.py
+++ b/codes/models/networks.py
@@ -28,11 +28,7 @@ from models.archs.teco_resgen import TecoGen
 logger = logging.getLogger('base')
 
 # Generator
-def define_G(opt, net_key='network_G', scale=None):
-    if net_key is not None:
-        opt_net = opt[net_key]
-    else:
-        opt_net = opt
+def define_G(opt, opt_net, scale=None):
     if scale is None:
         scale = opt['scale']
     which_model = opt_net['which_model_G']
@@ -141,6 +137,17 @@ def define_G(opt, net_key='network_G', scale=None):
         netG = stylegan2.StyleGan2GeneratorWithLatent(image_size=opt_net['image_size'], latent_dim=opt_net['latent_dim'],
                                             style_depth=opt_net['style_depth'], structure_input=is_structured,
                                             attn_layers=attn)
+    elif which_model == 'srflow':
+        from models.archs.srflow import SRFlow_arch
+        netG = SRFlow_arch.SRFlowNet(in_nc=3, out_nc=3, nf=opt_net['nf'], nb=opt_net['nb'],
+                                     quant=opt_net['quant'], flow_block_maps=opt_net['rrdb_block_maps'],
+                                     noise_quant=opt_net['noise_quant'], hidden_channels=opt_net['nf'],
+                                     K=opt_net['K'], L=opt_net['L'], train_rrdb_at_step=opt_net['rrdb_train_step'],
+                                     hr_img_shape=opt_net['hr_shape'], scale=opt_net['scale'])
+    elif which_model == 'srflow_orig':
+        from models.archs.srflow_orig import SRFlowNet_arch
+        netG = SRFlowNet_arch.SRFlowNet(in_nc=3, out_nc=3, nf=opt_net['nf'], nb=opt_net['nb'], scale=opt['scale'],
+                                     K=opt_net['K'], opt=opt)
     else:
         raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
     return netG
diff --git a/codes/models/steps/injectors.py b/codes/models/steps/injectors.py
index 333601c5..3cb9366a 100644
--- a/codes/models/steps/injectors.py
+++ b/codes/models/steps/injectors.py
@@ -157,6 +157,8 @@ class AddNoiseInjector(Injector):
             scale = state[self.opt['scale']]
         else:
             scale = self.opt['scale']
+            if scale is None:
+                scale = 1
 
         ref = state[self.opt['in']]
         if self.mode == 'normal':
diff --git a/codes/models/steps/steps.py b/codes/models/steps/steps.py
index 483d9592..448d5c09 100644
--- a/codes/models/steps/steps.py
+++ b/codes/models/steps/steps.py
@@ -139,7 +139,8 @@ class ConfigurableStep(Module):
                 continue
             # Don't do injections tagged with 'after' or 'before' when we are out of spec.
             if 'after' in inj.opt.keys() and self.env['step'] < inj.opt['after'] or \
-               'before' in inj.opt.keys() and self.env['step'] > inj.opt['before']:
+               'before' in inj.opt.keys() and self.env['step'] > inj.opt['before'] or \
+               'every' in inj.opt.keys() and  self.env['step'] % inj.opt['every'] != 0:
                 continue
             injected = inj(local_state)
             local_state.update(injected)
diff --git a/codes/train.py b/codes/train.py
index ecef64c3..bb0be9cb 100644
--- a/codes/train.py
+++ b/codes/train.py
@@ -291,7 +291,7 @@ class Trainer:
 
 if __name__ == '__main__':
     parser = argparse.ArgumentParser()
-    parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_vix_srg2_classic_proper_disc.yml')
+    parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_srflow.yml')
     parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
     parser.add_argument('--local_rank', type=int, default=0)
     args = parser.parse_args()
diff --git a/codes/train2.py b/codes/train2.py
index 4eff3b79..ecef64c3 100644
--- a/codes/train2.py
+++ b/codes/train2.py
@@ -291,7 +291,7 @@ class Trainer:
 
 if __name__ == '__main__':
     parser = argparse.ArgumentParser()
-    parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_srg2_grad_penalty.yml')
+    parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_vix_srg2_classic_proper_disc.yml')
     parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
     parser.add_argument('--local_rank', type=int, default=0)
     args = parser.parse_args()