ugh
This commit is contained in:
parent
d17f0ebc7c
commit
bfbd19c14f
|
@ -1004,7 +1004,9 @@ class Base(nn.Module):
|
|||
|
||||
# store dropout mask
|
||||
if "len" in self.capabilities and quant_level == 0:
|
||||
dropout_mask = _dropout_mask( resps_list[i], p=0.8 )
|
||||
t = random.random()
|
||||
p = math.cos(t * math.pi * 0.5)
|
||||
dropout_mask = _dropout_mask( resps_list[i], p=p )
|
||||
inputs[i].append( ("dropout_mask", dropout_mask ) )
|
||||
|
||||
# Audio length prediction task
|
||||
|
|
|
@ -527,7 +527,7 @@ def add_gumbel_noise(t, temperature, device):
|
|||
return (t + torch.Tensor(temperature * np.random.gumbel(size=t.shape)).to(device))
|
||||
|
||||
# derived from https://github.com/LeapLabTHU/ImprovedNAT/blob/main/libs/nat_misc.py#L39
|
||||
# this
|
||||
# this provides mostly poor output, but it might just be a matter of how I'm naively training the model for """diffusion"""
|
||||
class SampleScheduler:
|
||||
def __init__(
|
||||
self,
|
||||
|
|
Loading…
Reference in New Issue
Block a user