Mods to tecogan to allow use of embeddings as input
This commit is contained in:
parent
b10bcf6436
commit
f6098155cd
|
@ -6,7 +6,7 @@ import torch.nn.functional as F
|
|||
import torchvision
|
||||
from torch.utils.checkpoint import checkpoint_sequential
|
||||
|
||||
from models.archs.arch_util import make_layer, default_init_weights, ConvGnSilu
|
||||
from models.archs.arch_util import make_layer, default_init_weights, ConvGnSilu, ConvGnLelu
|
||||
|
||||
|
||||
class ResidualDenseBlock(nn.Module):
|
||||
|
@ -60,11 +60,16 @@ class RRDB(nn.Module):
|
|||
growth_channels (int): Channels for each growth.
|
||||
"""
|
||||
|
||||
def __init__(self, mid_channels, growth_channels=32):
|
||||
def __init__(self, mid_channels, growth_channels=32, reduce_to=None):
|
||||
super(RRDB, self).__init__()
|
||||
self.rdb1 = ResidualDenseBlock(mid_channels, growth_channels)
|
||||
self.rdb2 = ResidualDenseBlock(mid_channels, growth_channels)
|
||||
self.rdb3 = ResidualDenseBlock(mid_channels, growth_channels)
|
||||
if reduce_to is not None:
|
||||
self.reducer = ConvGnLelu(mid_channels, reduce_to, kernel_size=3, activation=False, norm=False, bias=True)
|
||||
self.recover_ch = mid_channels - reduce_to
|
||||
else:
|
||||
self.reducer = None
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function.
|
||||
|
@ -78,6 +83,10 @@ class RRDB(nn.Module):
|
|||
out = self.rdb1(x)
|
||||
out = self.rdb2(out)
|
||||
out = self.rdb3(out)
|
||||
if self.reducer is not None:
|
||||
out = self.reducer(out)
|
||||
b, f, h, w = out.shape
|
||||
out = torch.cat([out, torch.zeros((b, self.recover_ch, h, w), device=out.device)], dim=1)
|
||||
# Emperically, we use 0.2 to scale the residual for better performance
|
||||
return out * 0.2 + x
|
||||
|
||||
|
@ -92,12 +101,19 @@ class RRDBWithBypass(nn.Module):
|
|||
growth_channels (int): Channels for each growth.
|
||||
"""
|
||||
|
||||
def __init__(self, mid_channels, growth_channels=32):
|
||||
def __init__(self, mid_channels, growth_channels=32, reduce_to=None):
|
||||
super(RRDBWithBypass, self).__init__()
|
||||
self.rdb1 = ResidualDenseBlock(mid_channels, growth_channels)
|
||||
self.rdb2 = ResidualDenseBlock(mid_channels, growth_channels)
|
||||
self.rdb3 = ResidualDenseBlock(mid_channels, growth_channels)
|
||||
self.bypass = nn.Sequential(ConvGnSilu(mid_channels*2, mid_channels, kernel_size=3, bias=True, activation=True, norm=True),
|
||||
if reduce_to is not None:
|
||||
self.reducer = ConvGnLelu(mid_channels, reduce_to, kernel_size=3, activation=False, norm=False, bias=True)
|
||||
self.recover_ch = mid_channels - reduce_to
|
||||
bypass_channels = mid_channels + reduce_to
|
||||
else:
|
||||
self.reducer = None
|
||||
bypass_channels = mid_channels * 2
|
||||
self.bypass = nn.Sequential(ConvGnSilu(bypass_channels, mid_channels, kernel_size=3, bias=True, activation=True, norm=True),
|
||||
ConvGnSilu(mid_channels, mid_channels//2, kernel_size=3, bias=False, activation=True, norm=False),
|
||||
ConvGnSilu(mid_channels//2, 1, kernel_size=3, bias=False, activation=False, norm=False),
|
||||
nn.Sigmoid())
|
||||
|
@ -114,8 +130,15 @@ class RRDBWithBypass(nn.Module):
|
|||
out = self.rdb1(x)
|
||||
out = self.rdb2(out)
|
||||
out = self.rdb3(out)
|
||||
|
||||
if self.reducer is not None:
|
||||
out = self.reducer(out)
|
||||
b, f, h, w = out.shape
|
||||
out = torch.cat([out, torch.zeros((b, self.recover_ch, h, w), device=out.device)], dim=1)
|
||||
|
||||
bypass = self.bypass(torch.cat([x, out], dim=1))
|
||||
self.bypass_map = bypass.detach().clone()
|
||||
|
||||
# Empirically, we use 0.2 to scale the residual for better performance
|
||||
return out * 0.2 * bypass + x
|
||||
|
||||
|
@ -143,30 +166,45 @@ class RRDBNet(nn.Module):
|
|||
num_blocks=23,
|
||||
growth_channels=32,
|
||||
body_block=RRDB,
|
||||
blocks_per_checkpoint=4,
|
||||
blocks_per_checkpoint=1,
|
||||
scale=4,
|
||||
additive_mode="not_additive" # Options: "not", "additive", "additive_enforced"
|
||||
additive_mode="not", # Options: "not", "additive", "additive_enforced"
|
||||
headless=False,
|
||||
feature_channels=64, # Only applicable when headless=True. How many channels are used at the trunk level.
|
||||
output_mode="hq_only", # Options: "hq_only", "hq+features", "features_only"
|
||||
):
|
||||
super(RRDBNet, self).__init__()
|
||||
assert output_mode in ['hq_only', 'hq+features', 'features_only']
|
||||
assert additive_mode in ['not', 'additive', 'additive_enforced']
|
||||
self.num_blocks = num_blocks
|
||||
self.blocks_per_checkpoint = blocks_per_checkpoint
|
||||
self.scale = scale
|
||||
self.in_channels = in_channels
|
||||
self.output_mode = output_mode
|
||||
first_conv_stride = 1 if in_channels <= 4 else scale
|
||||
first_conv_ksize = 3 if first_conv_stride == 1 else 7
|
||||
first_conv_padding = 1 if first_conv_stride == 1 else 3
|
||||
self.conv_first = nn.Conv2d(in_channels, mid_channels, first_conv_ksize, first_conv_stride, first_conv_padding)
|
||||
if headless:
|
||||
self.conv_first = None
|
||||
self.reduce_ch = feature_channels
|
||||
reduce_to = feature_channels
|
||||
self.conv_ref_first = ConvGnLelu(3, feature_channels, 7, stride=2, norm=False, activation=False, bias=True)
|
||||
else:
|
||||
self.conv_first = nn.Conv2d(in_channels, mid_channels, first_conv_ksize, first_conv_stride, first_conv_padding)
|
||||
self.reduce_ch = mid_channels
|
||||
reduce_to = None
|
||||
self.body = make_layer(
|
||||
body_block,
|
||||
num_blocks,
|
||||
mid_channels=mid_channels,
|
||||
growth_channels=growth_channels)
|
||||
self.conv_body = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1)
|
||||
growth_channels=growth_channels,
|
||||
reduce_to=reduce_to)
|
||||
self.conv_body = nn.Conv2d(self.reduce_ch, self.reduce_ch, 3, 1, 1)
|
||||
# upsample
|
||||
self.conv_up1 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1)
|
||||
self.conv_up2 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1)
|
||||
self.conv_hr = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1)
|
||||
self.conv_last = nn.Conv2d(mid_channels, out_channels, 3, 1, 1)
|
||||
self.conv_up1 = nn.Conv2d(self.reduce_ch, self.reduce_ch, 3, 1, 1)
|
||||
self.conv_up2 = nn.Conv2d(self.reduce_ch, self.reduce_ch, 3, 1, 1)
|
||||
self.conv_hr = nn.Conv2d(self.reduce_ch, self.reduce_ch, 3, 1, 1)
|
||||
self.conv_last = nn.Conv2d(self.reduce_ch, out_channels, 3, 1, 1)
|
||||
|
||||
self.additive_mode = additive_mode
|
||||
if additive_mode == "additive_enforced":
|
||||
|
@ -178,7 +216,8 @@ class RRDBNet(nn.Module):
|
|||
self.conv_first, self.conv_body, self.conv_up1,
|
||||
self.conv_up2, self.conv_hr, self.conv_last
|
||||
]:
|
||||
default_init_weights(m, 0.1)
|
||||
if m is not None:
|
||||
default_init_weights(m, 0.1)
|
||||
|
||||
def forward(self, x, ref=None):
|
||||
"""Forward function.
|
||||
|
@ -189,25 +228,39 @@ class RRDBNet(nn.Module):
|
|||
Returns:
|
||||
Tensor: Forward results.
|
||||
"""
|
||||
if self.in_channels > 4:
|
||||
x_lg = F.interpolate(x, scale_factor=self.scale, mode="bicubic")
|
||||
if ref is None:
|
||||
ref = torch.zeros_like(x_lg)
|
||||
x_lg = torch.cat([x_lg, ref], dim=1)
|
||||
if self.conv_first is None:
|
||||
# Headless mode -> embedding inputs.
|
||||
if ref is not None:
|
||||
ref = self.conv_ref_first(ref)
|
||||
feat = torch.cat([x, ref], dim=1)
|
||||
else:
|
||||
feat = x
|
||||
else:
|
||||
x_lg = x
|
||||
feat = self.conv_first(x_lg)
|
||||
body_feat = self.conv_body(checkpoint_sequential(self.body, self.num_blocks // self.blocks_per_checkpoint, feat))
|
||||
# "Normal" mode -> image input.
|
||||
if self.in_channels > 4:
|
||||
x_lg = F.interpolate(x, scale_factor=self.scale, mode="bicubic")
|
||||
if ref is None:
|
||||
ref = torch.zeros_like(x_lg)
|
||||
x_lg = torch.cat([x_lg, ref], dim=1)
|
||||
else:
|
||||
x_lg = x
|
||||
feat = self.conv_first(x_lg)
|
||||
feat = checkpoint_sequential(self.body, self.num_blocks // self.blocks_per_checkpoint, feat)
|
||||
feat = feat[:, :self.reduce_ch]
|
||||
body_feat = self.conv_body(feat)
|
||||
feat = feat + body_feat
|
||||
if self.output_mode == "features_only":
|
||||
return feat
|
||||
|
||||
# upsample
|
||||
feat = self.lrelu(
|
||||
out = self.lrelu(
|
||||
self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
|
||||
if self.scale == 4:
|
||||
feat = self.lrelu(
|
||||
self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
|
||||
out = self.lrelu(
|
||||
self.conv_up2(F.interpolate(out, scale_factor=2, mode='nearest')))
|
||||
else:
|
||||
feat = self.lrelu(self.conv_up2(feat))
|
||||
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
|
||||
out = self.lrelu(self.conv_up2(out))
|
||||
out = self.conv_last(self.lrelu(self.conv_hr(out)))
|
||||
if "additive" in self.additive_mode:
|
||||
x_interp = F.interpolate(x, scale_factor=self.scale, mode='bilinear')
|
||||
if self.additive_mode == 'additive':
|
||||
|
@ -216,9 +269,14 @@ class RRDBNet(nn.Module):
|
|||
out_pooled = self.add_enforced_pool(out)
|
||||
out = out - F.interpolate(out_pooled, scale_factor=self.scale, mode='nearest')
|
||||
out = out + x_interp
|
||||
|
||||
if self.output_mode == "hq+features":
|
||||
return out, feat
|
||||
return out
|
||||
|
||||
def visual_dbg(self, step, path):
|
||||
for i, bm in enumerate(self.body):
|
||||
if hasattr(bm, 'bypass_map'):
|
||||
torchvision.utils.save_image(bm.bypass_map.cpu().float(), os.path.join(path, "%i_bypass_%i.png" % (step, i+1)))
|
||||
|
||||
|
||||
|
|
|
@ -39,14 +39,17 @@ def define_G(opt, opt_net, scale=None):
|
|||
nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'])
|
||||
elif which_model == 'RRDBNet':
|
||||
additive_mode = opt_net['additive_mode'] if 'additive_mode' in opt_net.keys() else 'not_additive'
|
||||
output_mode = opt_net['output_mode'] if 'output_mode' in opt_net.keys() else 'hq_only'
|
||||
netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
|
||||
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], additive_mode=additive_mode)
|
||||
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], additive_mode=additive_mode,
|
||||
output_mode=output_mode)
|
||||
elif which_model == 'RRDBNetBypass':
|
||||
additive_mode = opt_net['additive_mode'] if 'additive_mode' in opt_net.keys() else 'not'
|
||||
output_mode = opt_net['output_mode'] if 'output_mode' in opt_net.keys() else 'hq_only'
|
||||
netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
|
||||
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], body_block=RRDBNet_arch.RRDBWithBypass,
|
||||
blocks_per_checkpoint=opt_net['blocks_per_checkpoint'], scale=opt_net['scale'],
|
||||
additive_mode=additive_mode)
|
||||
additive_mode=additive_mode, output_mode=output_mode)
|
||||
elif which_model == 'rcan':
|
||||
#args: n_resgroups, n_resblocks, res_scale, reduction, scale, n_feats
|
||||
opt_net['rgb_range'] = 255
|
||||
|
@ -110,8 +113,6 @@ def define_G(opt, opt_net, scale=None):
|
|||
netG = SwitchedGen_arch.BackboneResnet()
|
||||
elif which_model == "tecogen":
|
||||
netG = TecoGen(opt_net['nf'], opt_net['scale'])
|
||||
elif which_model == "basic_resampling_flow_predictor":
|
||||
netG = BasicResamplingFlowNet(opt_net['nf'], resample_scale=opt_net['resample_scale'])
|
||||
elif which_model == "rrdb_with_latent":
|
||||
netG = RRDBNetWithLatent(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
|
||||
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'],
|
||||
|
@ -153,6 +154,11 @@ def define_G(opt, opt_net, scale=None):
|
|||
netG = RRDBLatentWrapper(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
||||
nf=opt_net['nf'], nb=opt_net['nb'], with_bypass=opt_net['with_bypass'],
|
||||
blocks=opt_net['blocks_for_latent'], scale=opt_net['scale'], pretrain_rrdb_path=opt_net['pretrain_path'])
|
||||
elif which_model == 'rrdb_centipede':
|
||||
output_mode = opt_net['output_mode'] if 'output_mode' in opt_net.keys() else 'hq_only'
|
||||
netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
|
||||
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], scale=opt_net['scale'],
|
||||
headless=True, output_mode=output_mode)
|
||||
else:
|
||||
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
||||
return netG
|
||||
|
|
|
@ -285,6 +285,7 @@ class ForEachInjector(Injector):
|
|||
o['in'] = '_in'
|
||||
o['out'] = '_out'
|
||||
self.injector = create_injector(o, self.env)
|
||||
self.aslist = opt['aslist'] if 'aslist' in opt.keys() else False
|
||||
|
||||
def forward(self, state):
|
||||
injs = []
|
||||
|
@ -293,7 +294,10 @@ class ForEachInjector(Injector):
|
|||
for i in range(inputs.shape[1]):
|
||||
st['_in'] = inputs[:, i]
|
||||
injs.append(self.injector(st)['_out'])
|
||||
return {self.output: torch.stack(injs, dim=1)}
|
||||
if self.aslist:
|
||||
return {self.output: injs}
|
||||
else:
|
||||
return {self.output: torch.stack(injs, dim=1)}
|
||||
|
||||
|
||||
class ConstantInjector(Injector):
|
||||
|
|
|
@ -140,7 +140,7 @@ class ConfigurableStep(Module):
|
|||
# Don't do injections tagged with 'after' or 'before' when we are out of spec.
|
||||
if 'after' in inj.opt.keys() and self.env['step'] < inj.opt['after'] or \
|
||||
'before' in inj.opt.keys() and self.env['step'] > inj.opt['before'] or \
|
||||
'every' in inj.opt.keys() and self.env['step'] % inj.opt['every'] != 0:
|
||||
'every' in inj.opt.keys() and self.env['step'] % inj.opt['every'] != 0:
|
||||
continue
|
||||
injected = inj(local_state)
|
||||
local_state.update(injected)
|
||||
|
|
|
@ -44,10 +44,12 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
|||
super(RecurrentImageGeneratorSequenceInjector, self).__init__(opt, env)
|
||||
self.flow = opt['flow_network']
|
||||
self.input_lq_index = opt['input_lq_index'] if 'input_lq_index' in opt.keys() else 0
|
||||
self.output_hq_index = opt['output_hq_index'] if 'output_hq_index' in opt.keys() else 0
|
||||
self.recurrent_index = opt['recurrent_index']
|
||||
self.output_hq_index = opt['output_hq_index'] if 'output_hq_index' in opt.keys() else 0
|
||||
self.output_recurrent_index = opt['output_recurrent_index'] if 'output_recurrent_index' in opt.keys() else self.output_hq_index
|
||||
self.scale = opt['scale']
|
||||
self.resample = Resample2d()
|
||||
self.flow_key = opt['flow_input_key'] if 'flow_input_key' in opt.keys() else None
|
||||
self.first_inputs = opt['first_inputs'] if 'first_inputs' in opt.keys() else opt['in'] # Use this to specify inputs that will be used in the first teco iteration, the rest will use 'in'.
|
||||
self.do_backwards = opt['do_backwards'] if 'do_backwards' in opt.keys() else True
|
||||
self.hq_recurrent = opt['hq_recurrent'] if 'hq_recurrent' in opt.keys() else False # When True, recurrent_index is not touched for the first iteration, allowing you to specify what is fed in. When False, zeros are fed into the recurrent index.
|
||||
|
@ -82,20 +84,21 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
|||
else:
|
||||
input = extract_inputs_index(inputs, i)
|
||||
with torch.no_grad() and autocast(enabled=False):
|
||||
# This is a hack to workaround the fact that flownet2 cannot operate at resolutions < 64px. An assumption is
|
||||
# made here that if you are operating at 4x scale, your inputs are 32px x 32px
|
||||
if self.scale >= 4:
|
||||
flow_input = F.interpolate(input[self.input_lq_index], scale_factor=self.scale//2, mode='bicubic')
|
||||
if self.flow_key is not None:
|
||||
flow_input = state[self.flow_key][:, i]
|
||||
else:
|
||||
flow_input = input[self.input_lq_index]
|
||||
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=.5, mode='bicubic')
|
||||
reduced_recurrent = F.interpolate(hq_recurrent, scale_factor=1/self.scale, mode='bicubic')
|
||||
flow_input = torch.stack([flow_input, reduced_recurrent], dim=2).float()
|
||||
flowfield = F.interpolate(flow(flow_input), scale_factor=2, mode='bicubic')
|
||||
flowfield = flow(flow_input)
|
||||
if recurrent_input.shape[-1] != flow_input.shape[-1]:
|
||||
flowfield = F.interpolate(flowfield, scale_factor=self.scale, mode='bicubic')
|
||||
recurrent_input = self.resample(recurrent_input.float(), flowfield)
|
||||
input[self.recurrent_index] = recurrent_input
|
||||
if self.env['step'] % 50 == 0:
|
||||
self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.recurrent_index], debug_index)
|
||||
debug_index += 1
|
||||
if input[self.input_lq_index].shape[1] == 3: # Only debug this if we're dealing with images.
|
||||
self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.hq_recurrent], debug_index)
|
||||
debug_index += 1
|
||||
|
||||
with autocast(enabled=self.env['opt']['fp16']):
|
||||
gen_out = gen(*input)
|
||||
|
@ -104,7 +107,8 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
|||
gen_out = [gen_out]
|
||||
for i, out_key in enumerate(self.output):
|
||||
results[out_key].append(gen_out[i])
|
||||
recurrent_input = gen_out[self.output_hq_index]
|
||||
hq_recurrent = gen_out[self.output_hq_index]
|
||||
recurrent_input = gen_out[self.output_recurrent_index]
|
||||
|
||||
# Now go backwards, skipping the last element (it's already stored in recurrent_input)
|
||||
if self.do_backwards:
|
||||
|
@ -113,20 +117,21 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
|||
input = extract_inputs_index(inputs, i)
|
||||
with torch.no_grad():
|
||||
with autocast(enabled=False):
|
||||
# This is a hack to workaround the fact that flownet2 cannot operate at resolutions < 64px. An assumption is
|
||||
# made here that if you are operating at 4x scale, your inputs are 32px x 32px
|
||||
if self.scale >= 4:
|
||||
flow_input = F.interpolate(input[self.input_lq_index], scale_factor=self.scale//2, mode='bicubic')
|
||||
if self.flow_key is not None:
|
||||
flow_input = state[self.flow_key][:, i]
|
||||
else:
|
||||
flow_input = input[self.input_lq_index]
|
||||
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=.5, mode='bicubic')
|
||||
reduced_recurrent = F.interpolate(hq_recurrent, scale_factor=1/self.scale, mode='bicubic')
|
||||
flow_input = torch.stack([flow_input, reduced_recurrent], dim=2).float()
|
||||
flowfield = F.interpolate(flow(flow_input), scale_factor=2, mode='bicubic')
|
||||
flowfield = flow(flow_input)
|
||||
if recurrent_input.shape[-1] != flow_input.shape[-1]:
|
||||
flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic')
|
||||
recurrent_input = self.resample(recurrent_input.float(), flowfield)
|
||||
input[self.recurrent_index] = recurrent_input
|
||||
if self.env['step'] % 50 == 0:
|
||||
self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.recurrent_index], debug_index)
|
||||
debug_index += 1
|
||||
if input[self.input_lq_index].shape[1] == 3: # Only debug this if we're dealing with images.
|
||||
self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.recurrent_index], debug_index)
|
||||
debug_index += 1
|
||||
|
||||
with autocast(enabled=self.env['opt']['fp16']):
|
||||
gen_out = gen(*input)
|
||||
|
@ -135,7 +140,8 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
|||
gen_out = [gen_out]
|
||||
for i, out_key in enumerate(self.output):
|
||||
results[out_key].append(gen_out[i])
|
||||
recurrent_input = gen_out[self.output_hq_index]
|
||||
hq_recurrent = gen_out[self.output_hq_index]
|
||||
recurrent_input = gen_out[self.output_recurrent_index]
|
||||
|
||||
final_results = {}
|
||||
# Include 'hq_batched' here - because why not... Don't really need a separate injector for this.
|
||||
|
|
|
@ -291,7 +291,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_stylegan2_for_sr_v2.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_vix_using_rrdb_features.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
Loading…
Reference in New Issue
Block a user