This commit is contained in:
mrq 2024-11-07 19:58:47 -06:00
parent d17f0ebc7c
commit bfbd19c14f
2 changed files with 4 additions and 2 deletions

View File

@ -1004,7 +1004,9 @@ class Base(nn.Module):
# store dropout mask # store dropout mask
if "len" in self.capabilities and quant_level == 0: if "len" in self.capabilities and quant_level == 0:
dropout_mask = _dropout_mask( resps_list[i], p=0.8 ) t = random.random()
p = math.cos(t * math.pi * 0.5)
dropout_mask = _dropout_mask( resps_list[i], p=p )
inputs[i].append( ("dropout_mask", dropout_mask ) ) inputs[i].append( ("dropout_mask", dropout_mask ) )
# Audio length prediction task # Audio length prediction task

View File

@ -527,7 +527,7 @@ def add_gumbel_noise(t, temperature, device):
return (t + torch.Tensor(temperature * np.random.gumbel(size=t.shape)).to(device)) return (t + torch.Tensor(temperature * np.random.gumbel(size=t.shape)).to(device))
# derived from https://github.com/LeapLabTHU/ImprovedNAT/blob/main/libs/nat_misc.py#L39 # derived from https://github.com/LeapLabTHU/ImprovedNAT/blob/main/libs/nat_misc.py#L39
# this # this provides mostly poor output, but it might just be a matter of how I'm naively training the model for """diffusion"""
class SampleScheduler: class SampleScheduler:
def __init__( def __init__(
self, self,