tweaks
This commit is contained in:
parent
6ae282e090
commit
478aea0e8c
|
@ -520,6 +520,9 @@ class TTS():
|
|||
else:
|
||||
len_list = model_len( **input_kwargs, task_list=["len"], **{"max_duration": 5} ) # "max_duration" is max tokens
|
||||
|
||||
# clamp
|
||||
len_list = [ max( l, 1 * cfg.dataset.frames_per_second ) for l in len_list ]
|
||||
|
||||
# add an additional X seconds
|
||||
len_list = [ int(l * duration_padding) for l in len_list ]
|
||||
|
||||
|
|
|
@ -239,7 +239,7 @@ class Attention(nn.Module):
|
|||
"default",
|
||||
torch.nn.attention.SDPBackend.MATH,
|
||||
torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION,
|
||||
torch.nn.attention.SDPBackend.FLASH_ATTENTION,
|
||||
#torch.nn.attention.SDPBackend.FLASH_ATTENTION,
|
||||
torch.nn.attention.SDPBackend.CUDNN_ATTENTION
|
||||
]
|
||||
|
||||
|
@ -365,6 +365,7 @@ class Attention(nn.Module):
|
|||
|
||||
# pain
|
||||
# SDPA backends only sometimes allowing/disallowing some arguments...
|
||||
"""
|
||||
if isinstance( is_causal, list ):
|
||||
count = sum( [ 1 if x else 0 for x in is_causal ] )
|
||||
if count == 0:
|
||||
|
@ -378,6 +379,12 @@ class Attention(nn.Module):
|
|||
x_mask = None
|
||||
elif is_causal == True:
|
||||
x_mask = None
|
||||
"""
|
||||
if self.attn_mode in [torch.nn.attention.SDPBackend.FLASH_ATTENTION]:
|
||||
x_mask = None
|
||||
|
||||
if x_mask is not None:
|
||||
is_causal = False
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||
|
@ -519,10 +526,8 @@ class DecoderLayer(nn.Module):
|
|||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# ugh
|
||||
"""
|
||||
if isinstance( is_causal, list ) and len(is_causal) == 1:
|
||||
is_causal = is_causal[0]
|
||||
"""
|
||||
|
||||
# Self Attention
|
||||
if self.config.attn_mode == "sparse":
|
||||
|
|
|
@ -765,7 +765,8 @@ class Base_V2(nn.Module):
|
|||
|
||||
# needed, cringe
|
||||
if task_type == "len":
|
||||
batch[-1] = torch.cat( [ batch[-1], self.sep[None], self.sep[None] ] )
|
||||
#batch[-1] = torch.cat( [ batch[-1], self.sep[None], self.sep[None] ] )
|
||||
batch[-1] = torch.cat( [ batch[-1], self.sep[None] ] )
|
||||
|
||||
x_list.append( _join( batch, self.sep ) )
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user