forked from mrq/DL-Art-School
Rescale causal scale along with timestep spacing
This commit is contained in:
parent
b432d7c7de
commit
3edca1a906
codes
|
@ -412,6 +412,9 @@ class GaussianDiffusion:
|
||||||
return t.float() * (1000.0 / self.num_timesteps)
|
return t.float() * (1000.0 / self.num_timesteps)
|
||||||
return t
|
return t
|
||||||
|
|
||||||
|
def _get_scale_ratio(self):
|
||||||
|
return 1
|
||||||
|
|
||||||
def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
|
def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
|
||||||
"""
|
"""
|
||||||
Compute the mean for the previous step, given a function cond_fn that
|
Compute the mean for the previous step, given a function cond_fn that
|
||||||
|
@ -586,7 +589,7 @@ class GaussianDiffusion:
|
||||||
t = th.tensor([i] * shape[0], device=device)
|
t = th.tensor([i] * shape[0], device=device)
|
||||||
mask = torch.zeros_like(img)
|
mask = torch.zeros_like(img)
|
||||||
if causal:
|
if causal:
|
||||||
t = causal_timestep_adjustment(t, shape[-1], self.num_timesteps, causal_slope, add_jitter=False).unsqueeze(1)
|
t = causal_timestep_adjustment(t, shape[-1], self.num_timesteps, causal_slope * self._get_scale_ratio(), add_jitter=False).unsqueeze(1)
|
||||||
t, mask = causal_mask_and_fix(t, self.num_timesteps)
|
t, mask = causal_mask_and_fix(t, self.num_timesteps)
|
||||||
mask = mask.repeat(img.shape[0], img.shape[1], 1)
|
mask = mask.repeat(img.shape[0], img.shape[1], 1)
|
||||||
with th.no_grad():
|
with th.no_grad():
|
||||||
|
@ -816,7 +819,7 @@ class GaussianDiffusion:
|
||||||
t = th.tensor([i] * shape[0], device=device)
|
t = th.tensor([i] * shape[0], device=device)
|
||||||
mask = torch.zeros_like(img)
|
mask = torch.zeros_like(img)
|
||||||
if causal:
|
if causal:
|
||||||
t = causal_timestep_adjustment(t, shape[-1], self.num_timesteps, causal_slope, add_jitter=False).unsqueeze(1)
|
t = causal_timestep_adjustment(t, shape[-1], self.num_timesteps, causal_slope * self._get_scale_ratio(), add_jitter=False).unsqueeze(1)
|
||||||
t, mask = causal_mask_and_fix(t, self.num_timesteps)
|
t, mask = causal_mask_and_fix(t, self.num_timesteps)
|
||||||
t[mask] = self.num_timesteps-1
|
t[mask] = self.num_timesteps-1
|
||||||
mask = mask.repeat(img.shape[0], img.shape[1], 1)
|
mask = mask.repeat(img.shape[0], img.shape[1], 1)
|
||||||
|
@ -880,7 +883,7 @@ class GaussianDiffusion:
|
||||||
Compute training losses for a causal diffusion process.
|
Compute training losses for a causal diffusion process.
|
||||||
"""
|
"""
|
||||||
assert len(x_start.shape) == 3, "causal_training_losses assumes a 1d sequence with the axis being the time axis."
|
assert len(x_start.shape) == 3, "causal_training_losses assumes a 1d sequence with the axis being the time axis."
|
||||||
ct = causal_timestep_adjustment(t, x_start.shape[-1], self.num_timesteps, causal_slope, add_jitter=True)
|
ct = causal_timestep_adjustment(t, x_start.shape[-1], self.num_timesteps, causal_slope * self._get_scale_ratio(), add_jitter=True)
|
||||||
ct = ct.unsqueeze(1) # Necessary to make the output shape compatible with x_start.
|
ct = ct.unsqueeze(1) # Necessary to make the output shape compatible with x_start.
|
||||||
return self.training_losses(model, x_start, ct, model_kwargs, noise, channel_balancing_fn)
|
return self.training_losses(model, x_start, ct, model_kwargs, noise, channel_balancing_fn)
|
||||||
|
|
||||||
|
@ -1095,19 +1098,18 @@ def test_causal_training_losses():
|
||||||
|
|
||||||
def graph_causal_timestep_adjustment():
|
def graph_causal_timestep_adjustment():
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
S = 2000
|
S = 400
|
||||||
#slope=4
|
#slope=4
|
||||||
num_timesteps=4000
|
num_timesteps=4000
|
||||||
for slpe in range(0, 200, 10):
|
for slpe in range(10, 400, 50):
|
||||||
slope = slpe / 10
|
slope = slpe / 10
|
||||||
t_res = []
|
t_res = []
|
||||||
for t in range(num_timesteps, -1, -num_timesteps//50):
|
for t in range(num_timesteps, -1, -num_timesteps//50):
|
||||||
T = causal_timestep_adjustment(torch.tensor([t]), S, num_timesteps, causal_slope=slope, add_jitter=False)[0]
|
T = causal_timestep_adjustment(torch.tensor([t]), S, num_timesteps, causal_slope=slope, add_jitter=False)[0]
|
||||||
|
|
||||||
# The following adjustment makes it easier to visualize the timestep regions where the model is actually working.
|
# The following adjustment makes it easier to visualize the timestep regions where the model is actually working.
|
||||||
T_adj = (T == num_timesteps).logical_or(T == -1)
|
#T_adj = (T == num_timesteps).logical_or(T == -1)
|
||||||
T[T_adj] = t
|
#T[T_adj] = t
|
||||||
print(t, T.float().mean())
|
|
||||||
|
|
||||||
t_res.append(T)
|
t_res.append(T)
|
||||||
plt.plot(T.numpy())
|
plt.plot(T.numpy())
|
||||||
|
@ -1116,13 +1118,32 @@ def graph_causal_timestep_adjustment():
|
||||||
for j in range(len(t_res)):
|
for j in range(len(t_res)):
|
||||||
if i == j:
|
if i == j:
|
||||||
continue
|
continue
|
||||||
#assert not torch.all(t_res[i] == t_res[j])
|
assert not torch.all(t_res[i] == t_res[j])
|
||||||
plt.ylim(0,num_timesteps)
|
plt.ylim(0,num_timesteps)
|
||||||
plt.xlim(0,4000)
|
plt.xlim(0,4000)
|
||||||
plt.ylabel('timestep')
|
plt.ylabel('timestep')
|
||||||
plt.savefig(f'{slpe}.png')
|
plt.savefig(f'{slpe}.png')
|
||||||
plt.clf()
|
plt.clf()
|
||||||
|
|
||||||
|
|
||||||
|
def graph_causal_timestep_adjustment_by_timestep():
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
S = 400
|
||||||
|
slope=10
|
||||||
|
num_timesteps=4000
|
||||||
|
t_res = []
|
||||||
|
for t in range(num_timesteps, -1, -num_timesteps//50):
|
||||||
|
T = causal_timestep_adjustment(torch.tensor([t]), S, num_timesteps, causal_slope=slope, add_jitter=False)[0]
|
||||||
|
|
||||||
|
t_res.append(T)
|
||||||
|
plt.plot(T.numpy())
|
||||||
|
plt.ylim(0,num_timesteps)
|
||||||
|
plt.xlim(0,4000)
|
||||||
|
plt.ylabel('timestep')
|
||||||
|
plt.savefig(f'{t}.png')
|
||||||
|
plt.clf()
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
#test_causal_training_losses()
|
#test_causal_training_losses()
|
||||||
graph_causal_timestep_adjustment()
|
#graph_causal_timestep_adjustment()
|
||||||
|
graph_causal_timestep_adjustment_by_timestep()
|
|
@ -113,6 +113,9 @@ class SpacedDiffusion(GaussianDiffusion):
|
||||||
# Scaling is done by the wrapped model.
|
# Scaling is done by the wrapped model.
|
||||||
return t
|
return t
|
||||||
|
|
||||||
|
def _get_scale_ratio(self):
|
||||||
|
return self.num_timesteps / self.original_num_steps
|
||||||
|
|
||||||
|
|
||||||
class _WrappedModel:
|
class _WrappedModel:
|
||||||
def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
|
def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
|
||||||
|
|
|
@ -424,11 +424,11 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_cheater_gen.yml', 'generator',
|
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_cheater_gen.yml', 'generator',
|
||||||
also_load_savepoint=False,
|
also_load_savepoint=False,
|
||||||
load_path='X:\\dlas\\experiments\\train_music_cheater_gen_v5_causal_retrain\\models\\18000_generator.pth'
|
load_path='X:\\dlas\\experiments\\train_music_cheater_gen_v5_causal_retrain\\models\\22000_generator_ema.pth'
|
||||||
).cuda()
|
).cuda()
|
||||||
opt_eval = {'path': 'Y:\\split\\yt-music-eval', # eval music, mostly electronica. :)
|
opt_eval = {'path': 'Y:\\split\\yt-music-eval', # eval music, mostly electronica. :)
|
||||||
#'path': 'E:\\music_eval', # this is music from the training dataset, including a lot more variety.
|
#'path': 'E:\\music_eval', # this is music from the training dataset, including a lot more variety.
|
||||||
'diffusion_steps': 256,
|
'diffusion_steps': 100,
|
||||||
'conditioning_free': True, 'conditioning_free_k': 1, 'use_ddim': False, 'clip_audio': False,
|
'conditioning_free': True, 'conditioning_free_k': 1, 'use_ddim': False, 'clip_audio': False,
|
||||||
'diffusion_schedule': 'linear', 'diffusion_type': 'cheater_gen',
|
'diffusion_schedule': 'linear', 'diffusion_type': 'cheater_gen',
|
||||||
'causal': True, 'causal_slope': 1,
|
'causal': True, 'causal_slope': 1,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user