ugh
This commit is contained in:
parent
d17f0ebc7c
commit
bfbd19c14f
|
@ -1004,7 +1004,9 @@ class Base(nn.Module):
|
||||||
|
|
||||||
# store dropout mask
|
# store dropout mask
|
||||||
if "len" in self.capabilities and quant_level == 0:
|
if "len" in self.capabilities and quant_level == 0:
|
||||||
dropout_mask = _dropout_mask( resps_list[i], p=0.8 )
|
t = random.random()
|
||||||
|
p = math.cos(t * math.pi * 0.5)
|
||||||
|
dropout_mask = _dropout_mask( resps_list[i], p=p )
|
||||||
inputs[i].append( ("dropout_mask", dropout_mask ) )
|
inputs[i].append( ("dropout_mask", dropout_mask ) )
|
||||||
|
|
||||||
# Audio length prediction task
|
# Audio length prediction task
|
||||||
|
|
|
@ -527,7 +527,7 @@ def add_gumbel_noise(t, temperature, device):
|
||||||
return (t + torch.Tensor(temperature * np.random.gumbel(size=t.shape)).to(device))
|
return (t + torch.Tensor(temperature * np.random.gumbel(size=t.shape)).to(device))
|
||||||
|
|
||||||
# derived from https://github.com/LeapLabTHU/ImprovedNAT/blob/main/libs/nat_misc.py#L39
|
# derived from https://github.com/LeapLabTHU/ImprovedNAT/blob/main/libs/nat_misc.py#L39
|
||||||
# this
|
# this provides mostly poor output, but it might just be a matter of how I'm naively training the model for """diffusion"""
|
||||||
class SampleScheduler:
|
class SampleScheduler:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user