forked from mrq/DL-Art-School
Support causal diffusion!
This commit is contained in:
parent
78bba690de
commit
7b4dcbf136
|
@ -81,11 +81,14 @@ class ConditioningEncoder(nn.Module):
|
||||||
attn_blocks=6,
|
attn_blocks=6,
|
||||||
num_attn_heads=8,
|
num_attn_heads=8,
|
||||||
dropout=.1,
|
dropout=.1,
|
||||||
do_checkpointing=False):
|
do_checkpointing=False,
|
||||||
|
time_proj=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
attn = []
|
attn = []
|
||||||
self.init = nn.Conv1d(cond_dim, embedding_dim, kernel_size=1)
|
self.init = nn.Conv1d(cond_dim, embedding_dim, kernel_size=1)
|
||||||
self.time_proj = nn.Linear(time_embed_dim, embedding_dim)
|
self.time_proj = time_proj
|
||||||
|
if time_proj:
|
||||||
|
self.time_proj = nn.Linear(time_embed_dim, embedding_dim)
|
||||||
self.attn = Encoder(
|
self.attn = Encoder(
|
||||||
dim=embedding_dim,
|
dim=embedding_dim,
|
||||||
depth=attn_blocks,
|
depth=attn_blocks,
|
||||||
|
@ -103,8 +106,9 @@ class ConditioningEncoder(nn.Module):
|
||||||
|
|
||||||
def forward(self, x, time_emb):
|
def forward(self, x, time_emb):
|
||||||
h = self.init(x).permute(0,2,1)
|
h = self.init(x).permute(0,2,1)
|
||||||
time_enc = self.time_proj(time_emb)
|
if self.time_proj:
|
||||||
h = torch.cat([time_enc.unsqueeze(1), h], dim=1)
|
time_enc = self.time_proj(time_emb)
|
||||||
|
h = torch.cat([time_enc.unsqueeze(1), h], dim=1)
|
||||||
h = self.attn(h).permute(0,2,1)
|
h = self.attn(h).permute(0,2,1)
|
||||||
return h
|
return h
|
||||||
|
|
||||||
|
@ -125,6 +129,7 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
|
||||||
input_cond_dim=1024,
|
input_cond_dim=1024,
|
||||||
num_heads=8,
|
num_heads=8,
|
||||||
dropout=0,
|
dropout=0,
|
||||||
|
time_proj=False,
|
||||||
use_fp16=False,
|
use_fp16=False,
|
||||||
checkpoint_conditioning=True, # This will need to be false for DDP training. :(
|
checkpoint_conditioning=True, # This will need to be false for DDP training. :(
|
||||||
# Parameters for regularization.
|
# Parameters for regularization.
|
||||||
|
@ -141,7 +146,7 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
|
||||||
self.enable_fp16 = use_fp16
|
self.enable_fp16 = use_fp16
|
||||||
|
|
||||||
self.inp_block = conv_nd(1, in_channels, model_channels, 3, 1, 1)
|
self.inp_block = conv_nd(1, in_channels, model_channels, 3, 1, 1)
|
||||||
self.conditioning_encoder = ConditioningEncoder(256, model_channels, time_embed_dim, do_checkpointing=checkpoint_conditioning)
|
self.conditioning_encoder = ConditioningEncoder(256, model_channels, time_embed_dim, do_checkpointing=checkpoint_conditioning, time_proj=time_proj)
|
||||||
|
|
||||||
self.time_embed = nn.Sequential(
|
self.time_embed = nn.Sequential(
|
||||||
linear(time_embed_dim, time_embed_dim),
|
linear(time_embed_dim, time_embed_dim),
|
||||||
|
@ -210,11 +215,12 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
|
||||||
cond_left = cond_aligned[:,:,:break_pt]
|
cond_left = cond_aligned[:,:,:break_pt]
|
||||||
cond_right = cond_aligned[:,:,break_pt:]
|
cond_right = cond_aligned[:,:,break_pt:]
|
||||||
|
|
||||||
# Drop out a random amount of the aligned data. The network will need to figure out how to reconstruct this.
|
if self.training:
|
||||||
to_remove_left = random.randint(1, cond_left.shape[-1]-MIN_MARGIN)
|
# Drop out a random amount of the aligned data. The network will need to figure out how to reconstruct this.
|
||||||
cond_left = cond_left[:,:,:-to_remove_left]
|
to_remove_left = random.randint(1, cond_left.shape[-1]-MIN_MARGIN)
|
||||||
to_remove_right = random.randint(1, cond_right.shape[-1]-MIN_MARGIN)
|
cond_left = cond_left[:,:,:-to_remove_left]
|
||||||
cond_right = cond_right[:,:,to_remove_right:]
|
to_remove_right = random.randint(1, cond_right.shape[-1]-MIN_MARGIN)
|
||||||
|
cond_right = cond_right[:,:,to_remove_right:]
|
||||||
|
|
||||||
# Concatenate the _pre and _post back on.
|
# Concatenate the _pre and _post back on.
|
||||||
cond_left_full = torch.cat([cond_pre, cond_left], dim=-1)
|
cond_left_full = torch.cat([cond_pre, cond_left], dim=-1)
|
||||||
|
|
|
@ -7,14 +7,15 @@ Docstrings have been added, as well as DDIM sampling and a new collection of bet
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
import math
|
import math
|
||||||
|
import random
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch as th
|
import torch as th
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from .nn import mean_flat
|
from models.diffusion.nn import mean_flat
|
||||||
from .losses import normal_kl, discretized_gaussian_log_likelihood
|
from models.diffusion.losses import normal_kl, discretized_gaussian_log_likelihood
|
||||||
|
|
||||||
|
|
||||||
def causal_timestep_adjustment(t, S, num_timesteps, causal_slope=1, add_jitter=True):
|
def causal_timestep_adjustment(t, S, num_timesteps, causal_slope=1, add_jitter=True):
|
||||||
|
@ -37,43 +38,18 @@ def causal_timestep_adjustment(t, S, num_timesteps, causal_slope=1, add_jitter=T
|
||||||
# This algorithm for adding causality does so by simply adding S_sloped additional timesteps. To make this
|
# This algorithm for adding causality does so by simply adding S_sloped additional timesteps. To make this
|
||||||
# actually work, we map the existing t from the timescale specified to the model to the causal timescale:
|
# actually work, we map the existing t from the timescale specified to the model to the causal timescale:
|
||||||
adj_t = t * (num_timesteps + S_sloped) // num_timesteps
|
adj_t = t * (num_timesteps + S_sloped) // num_timesteps
|
||||||
|
adj_t = adj_t - S_sloped
|
||||||
if add_jitter:
|
if add_jitter:
|
||||||
jitter = (random.random() - .5) * S_sloped
|
t_gap = (num_timesteps + S_sloped) / num_timesteps
|
||||||
adj_t = (adj_t+jitter).clamp(0, num_timesteps+S_sloped)
|
jitter = (2*random.random()-1) * t_gap
|
||||||
|
adj_t = (adj_t+jitter).clamp(-S_sloped, num_timesteps)
|
||||||
|
|
||||||
# Now use the re-mapped adj_t to create a timestep vector that propagates across the sequence with the specified slope.
|
# Now use the re-mapped adj_t to create a timestep vector that propagates across the sequence with the specified slope.
|
||||||
t = adj_t.unsqueeze(1).repeat(1, S)
|
t = adj_t.unsqueeze(1).repeat(1, S)
|
||||||
t = (t - torch.arange(0, S) * causal_slope).clamp(0, num_timesteps).long()
|
t = (t + torch.arange(0, S, device=t.device) * causal_slope).clamp(0, num_timesteps).long()
|
||||||
return t
|
return t
|
||||||
|
|
||||||
|
|
||||||
def graph_causal_timestep_adjustment():
|
|
||||||
S = 400
|
|
||||||
slope=4
|
|
||||||
num_timesteps=4000
|
|
||||||
#for num_timesteps in range(100, 4000, 200):
|
|
||||||
t_res = []
|
|
||||||
for t in range(num_timesteps, -1, -num_timesteps//50):
|
|
||||||
T = causal_timestep_adjustment(torch.tensor([t]), S, num_timesteps, causal_slope=slope, add_jitter=False)[0]
|
|
||||||
t_res.append(T)
|
|
||||||
plt.plot(T.numpy())
|
|
||||||
plt.ylim(0,4000)
|
|
||||||
plt.xlim(0,500)
|
|
||||||
plt.savefig(f'{t}.png')
|
|
||||||
plt.clf()
|
|
||||||
|
|
||||||
for i in range(len(t_res)):
|
|
||||||
for j in range(len(t_res)):
|
|
||||||
if i == j:
|
|
||||||
continue
|
|
||||||
#assert not torch.all(t_res[i] == t_res[j])
|
|
||||||
plt.ylim(0,4000)
|
|
||||||
plt.xlim(0,500)
|
|
||||||
plt.ylabel('timestep')
|
|
||||||
plt.savefig(f'{num_timesteps}.png')
|
|
||||||
plt.clf()
|
|
||||||
|
|
||||||
|
|
||||||
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
|
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
|
||||||
"""
|
"""
|
||||||
Get a pre-defined beta schedule for the given name.
|
Get a pre-defined beta schedule for the given name.
|
||||||
|
@ -319,7 +295,7 @@ class GaussianDiffusion:
|
||||||
model_kwargs = {}
|
model_kwargs = {}
|
||||||
|
|
||||||
B, C = x.shape[:2]
|
B, C = x.shape[:2]
|
||||||
assert t.shape == (B,)
|
assert t.shape == (B,) or t.shape == (B,1,x.shape[-1])
|
||||||
model_output = model(x, self._scale_timesteps(t), **model_kwargs)
|
model_output = model(x, self._scale_timesteps(t), **model_kwargs)
|
||||||
if self.conditioning_free:
|
if self.conditioning_free:
|
||||||
model_output_no_conditioning = model(x, self._scale_timesteps(t), conditioning_free=True, **model_kwargs)
|
model_output_no_conditioning = model(x, self._scale_timesteps(t), conditioning_free=True, **model_kwargs)
|
||||||
|
@ -844,7 +820,10 @@ class GaussianDiffusion:
|
||||||
|
|
||||||
# At the first timestep return the decoder NLL,
|
# At the first timestep return the decoder NLL,
|
||||||
# otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
|
# otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
|
||||||
output = th.where((t == 0).view(-1, 1, 1), decoder_nll, kl)
|
if len(t.shape) == 1:
|
||||||
|
output = th.where((t == 0).view(-1, 1, 1), decoder_nll, kl)
|
||||||
|
else:
|
||||||
|
output = th.where((t == 0), decoder_nll, kl)
|
||||||
return {"output": output, "pred_xstart": out["pred_xstart"]}
|
return {"output": output, "pred_xstart": out["pred_xstart"]}
|
||||||
|
|
||||||
def causal_training_losses(self, model, x_start, t, causal_slope=1, model_kwargs=None, noise=None, channel_balancing_fn=None):
|
def causal_training_losses(self, model, x_start, t, causal_slope=1, model_kwargs=None, noise=None, channel_balancing_fn=None):
|
||||||
|
@ -852,8 +831,9 @@ class GaussianDiffusion:
|
||||||
Compute training losses for a causal diffusion process.
|
Compute training losses for a causal diffusion process.
|
||||||
"""
|
"""
|
||||||
assert len(x_start.shape) == 3, "causal_training_losses assumes a 1d sequence with the axis being the time axis."
|
assert len(x_start.shape) == 3, "causal_training_losses assumes a 1d sequence with the axis being the time axis."
|
||||||
t = causal_timestep_adjustment(t, x_start.shape[-1], self.num_timesteps, causal_slope, add_jitter=True)
|
ct = causal_timestep_adjustment(t, x_start.shape[-1], self.num_timesteps, causal_slope, add_jitter=True)
|
||||||
return self.training_losses(model, x_start, t, model_kwargs, noise, channel_balancing_fn)
|
ct = ct.unsqueeze(1) # Necessary to make the output shape compatible with x_start.
|
||||||
|
return self.training_losses(model, x_start, ct, model_kwargs, noise, channel_balancing_fn)
|
||||||
|
|
||||||
def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, channel_balancing_fn=None):
|
def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, channel_balancing_fn=None):
|
||||||
"""
|
"""
|
||||||
|
@ -872,6 +852,13 @@ class GaussianDiffusion:
|
||||||
model_kwargs = {}
|
model_kwargs = {}
|
||||||
if noise is None:
|
if noise is None:
|
||||||
noise = th.randn_like(x_start)
|
noise = th.randn_like(x_start)
|
||||||
|
|
||||||
|
if len(t.shape) == 3:
|
||||||
|
t_mask = t != self.num_timesteps
|
||||||
|
t[t_mask.logical_not()] = self.num_timesteps-1
|
||||||
|
else:
|
||||||
|
t_mask = torch.ones_like(x_start)
|
||||||
|
|
||||||
x_t = self.q_sample(x_start, t, noise=noise)
|
x_t = self.q_sample(x_start, t, noise=noise)
|
||||||
|
|
||||||
terms = {}
|
terms = {}
|
||||||
|
@ -912,7 +899,7 @@ class GaussianDiffusion:
|
||||||
x_t=x_t,
|
x_t=x_t,
|
||||||
t=t,
|
t=t,
|
||||||
clip_denoised=False,
|
clip_denoised=False,
|
||||||
)["output"]
|
)["output"] * t_mask
|
||||||
if self.loss_type == LossType.RESCALED_MSE:
|
if self.loss_type == LossType.RESCALED_MSE:
|
||||||
# Divide by 1000 for equivalence with initial implementation.
|
# Divide by 1000 for equivalence with initial implementation.
|
||||||
# Without a factor of 1/1000, the VB term hurts the MSE term.
|
# Without a factor of 1/1000, the VB term hurts the MSE term.
|
||||||
|
@ -932,7 +919,7 @@ class GaussianDiffusion:
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(self.model_mean_type)
|
raise NotImplementedError(self.model_mean_type)
|
||||||
assert model_output.shape == target.shape == x_start.shape
|
assert model_output.shape == target.shape == x_start.shape
|
||||||
s_err = (target - model_output) ** 2
|
s_err = t_mask * (target - model_output) ** 2
|
||||||
if channel_balancing_fn is not None:
|
if channel_balancing_fn is not None:
|
||||||
s_err = channel_balancing_fn(s_err)
|
s_err = channel_balancing_fn(s_err)
|
||||||
terms["mse_by_batch"] = s_err.reshape(s_err.shape[0], -1).mean(dim=1)
|
terms["mse_by_batch"] = s_err.reshape(s_err.shape[0], -1).mean(dim=1)
|
||||||
|
@ -1039,3 +1026,47 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
|
||||||
while len(res.shape) < len(broadcast_shape):
|
while len(res.shape) < len(broadcast_shape):
|
||||||
res = res[..., None]
|
res = res[..., None]
|
||||||
return res.expand(broadcast_shape)
|
return res.expand(broadcast_shape)
|
||||||
|
|
||||||
|
|
||||||
|
def test_causal_training_losses():
|
||||||
|
from models.diffusion.respace import SpacedDiffusion
|
||||||
|
from models.diffusion.respace import space_timesteps
|
||||||
|
diff = SpacedDiffusion(use_timesteps=space_timesteps(4000, [4000]), model_mean_type='epsilon',
|
||||||
|
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', 4000),
|
||||||
|
conditioning_free=False, conditioning_free_k=1)
|
||||||
|
class IdentityTwoArg(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
def forward(self, x, *args, **kwargs):
|
||||||
|
return x.repeat(1,2,1)
|
||||||
|
|
||||||
|
model = IdentityTwoArg()
|
||||||
|
diff.causal_training_losses(model, torch.randn(4,256,400), torch.tensor([500,1000,3000,3500]), causal_slope=4)
|
||||||
|
|
||||||
|
def graph_causal_timestep_adjustment():
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
S = 400
|
||||||
|
#slope=4
|
||||||
|
num_timesteps=4000
|
||||||
|
for slpe in range(10, 400, 10):
|
||||||
|
slope = slpe / 10
|
||||||
|
t_res = []
|
||||||
|
for t in range(num_timesteps, -1, -num_timesteps//50):
|
||||||
|
T = causal_timestep_adjustment(torch.tensor([t]), S, num_timesteps, causal_slope=slope, add_jitter=False)[0]
|
||||||
|
t_res.append(T)
|
||||||
|
plt.plot(T.numpy())
|
||||||
|
|
||||||
|
for i in range(len(t_res)):
|
||||||
|
for j in range(len(t_res)):
|
||||||
|
if i == j:
|
||||||
|
continue
|
||||||
|
#assert not torch.all(t_res[i] == t_res[j])
|
||||||
|
plt.ylim(0,num_timesteps)
|
||||||
|
plt.xlim(0,4000)
|
||||||
|
plt.ylabel('timestep')
|
||||||
|
plt.savefig(f'{slpe}.png')
|
||||||
|
plt.clf()
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
#test_causal_training_losses()
|
||||||
|
graph_causal_timestep_adjustment()
|
|
@ -122,7 +122,10 @@ def timestep_embedding(timesteps, dim, max_period=10000):
|
||||||
freqs = th.exp(
|
freqs = th.exp(
|
||||||
-math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
|
-math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
|
||||||
).to(device=timesteps.device)
|
).to(device=timesteps.device)
|
||||||
args = timesteps[:, None].float() * freqs[None]
|
if len(timesteps.shape) == 1:
|
||||||
|
args = timesteps[:, None].float() * freqs[None]
|
||||||
|
else:
|
||||||
|
args = (timesteps.float() * freqs.view(1,half,1)).permute(0,2,1)
|
||||||
embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
|
embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
|
||||||
if dim % 2:
|
if dim % 2:
|
||||||
embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
|
embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
|
||||||
|
|
|
@ -365,8 +365,11 @@ class RMSScaleShiftNorm(nn.Module):
|
||||||
norm = x / norm.clamp(min=self.eps) * self.g
|
norm = x / norm.clamp(min=self.eps) * self.g
|
||||||
|
|
||||||
ss_emb = self.scale_shift_process(norm_scale_shift_inp)
|
ss_emb = self.scale_shift_process(norm_scale_shift_inp)
|
||||||
scale, shift = torch.chunk(ss_emb, 2, dim=1)
|
scale, shift = torch.chunk(ss_emb, 2, dim=-1)
|
||||||
h = norm * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
if len(scale.shape) == 2:
|
||||||
|
scale = scale.unsqueeze(1)
|
||||||
|
shift = shift.unsqueeze(1)
|
||||||
|
h = norm * (1 + scale) + shift
|
||||||
return h
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -339,7 +339,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_tfd12_finetune_ar_outputs.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_cheater_gen.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
opt = option.parse(args.opt, is_train=True)
|
opt = option.parse(args.opt, is_train=True)
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import functools
|
import functools
|
||||||
|
import random
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.cuda.amp import autocast
|
from torch.cuda.amp import autocast
|
||||||
|
@ -44,6 +45,8 @@ class GaussianDiffusionInjector(Injector):
|
||||||
self.extra_model_output_keys = opt_get(opt, ['extra_model_output_keys'], [])
|
self.extra_model_output_keys = opt_get(opt, ['extra_model_output_keys'], [])
|
||||||
self.deterministic_timesteps_every = opt_get(opt, ['deterministic_timesteps_every'], 0)
|
self.deterministic_timesteps_every = opt_get(opt, ['deterministic_timesteps_every'], 0)
|
||||||
self.deterministic_sampler = DeterministicSampler(self.diffusion, opt_get(opt, ['deterministic_sampler_expected_batch_size'], 2048), env)
|
self.deterministic_sampler = DeterministicSampler(self.diffusion, opt_get(opt, ['deterministic_sampler_expected_batch_size'], 2048), env)
|
||||||
|
self.causal_mode = opt_get(opt, ['causal_mode'], False)
|
||||||
|
self.causal_slope_range = opt_get(opt, ['causal_slope_range'], [1,8])
|
||||||
|
|
||||||
k = 0
|
k = 0
|
||||||
if 'channel_balancer_proportion' in opt.keys():
|
if 'channel_balancer_proportion' in opt.keys():
|
||||||
|
@ -86,7 +89,16 @@ class GaussianDiffusionInjector(Injector):
|
||||||
self.deterministic_sampler.reset() # Keep this reset whenever it is not being used, so it is ready to use automatically.
|
self.deterministic_sampler.reset() # Keep this reset whenever it is not being used, so it is ready to use automatically.
|
||||||
model_inputs = {k: state[v] if isinstance(v, str) else v for k, v in self.model_input_keys.items()}
|
model_inputs = {k: state[v] if isinstance(v, str) else v for k, v in self.model_input_keys.items()}
|
||||||
t, weights = sampler.sample(hq.shape[0], hq.device)
|
t, weights = sampler.sample(hq.shape[0], hq.device)
|
||||||
diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs, channel_balancing_fn=self.channel_balancing_fn)
|
if self.causal_mode:
|
||||||
|
cs, ce = self.causal_slope_range
|
||||||
|
slope = random.random() * (ce-cs) + cs
|
||||||
|
diffusion_outputs = self.diffusion.causal_training_losses(gen, hq, t, model_kwargs=model_inputs,
|
||||||
|
channel_balancing_fn=self.channel_balancing_fn,
|
||||||
|
causal_slope=slope)
|
||||||
|
else:
|
||||||
|
diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs,
|
||||||
|
channel_balancing_fn=self.channel_balancing_fn)
|
||||||
|
|
||||||
if isinstance(sampler, LossAwareSampler):
|
if isinstance(sampler, LossAwareSampler):
|
||||||
sampler.update_with_local_losses(t, diffusion_outputs['loss'])
|
sampler.update_with_local_losses(t, diffusion_outputs['loss'])
|
||||||
if len(self.extra_model_output_keys) > 0:
|
if len(self.extra_model_output_keys) > 0:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user