This commit is contained in:
mrq 2025-03-28 19:49:54 -05:00
parent 6ae282e090
commit 478aea0e8c
3 changed files with 13 additions and 4 deletions

View File

@ -520,6 +520,9 @@ class TTS():
else:
len_list = model_len( **input_kwargs, task_list=["len"], **{"max_duration": 5} ) # "max_duration" is max tokens
# clamp
len_list = [ max( l, 1 * cfg.dataset.frames_per_second ) for l in len_list ]
# add an additional X seconds
len_list = [ int(l * duration_padding) for l in len_list ]

View File

@ -239,7 +239,7 @@ class Attention(nn.Module):
"default",
torch.nn.attention.SDPBackend.MATH,
torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION,
torch.nn.attention.SDPBackend.FLASH_ATTENTION,
#torch.nn.attention.SDPBackend.FLASH_ATTENTION,
torch.nn.attention.SDPBackend.CUDNN_ATTENTION
]
@ -365,6 +365,7 @@ class Attention(nn.Module):
# pain
# SDPA backends only sometimes allowing/disallowing some arguments...
"""
if isinstance( is_causal, list ):
count = sum( [ 1 if x else 0 for x in is_causal ] )
if count == 0:
@ -378,6 +379,12 @@ class Attention(nn.Module):
x_mask = None
elif is_causal == True:
x_mask = None
"""
if self.attn_mode in [torch.nn.attention.SDPBackend.FLASH_ATTENTION]:
x_mask = None
if x_mask is not None:
is_causal = False
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
@ -519,10 +526,8 @@ class DecoderLayer(nn.Module):
hidden_states = self.input_layernorm(hidden_states)
# ugh
"""
if isinstance( is_causal, list ) and len(is_causal) == 1:
is_causal = is_causal[0]
"""
# Self Attention
if self.config.attn_mode == "sparse":

View File

@ -765,7 +765,8 @@ class Base_V2(nn.Module):
# needed, cringe
if task_type == "len":
batch[-1] = torch.cat( [ batch[-1], self.sep[None], self.sep[None] ] )
#batch[-1] = torch.cat( [ batch[-1], self.sep[None], self.sep[None] ] )
batch[-1] = torch.cat( [ batch[-1], self.sep[None] ] )
x_list.append( _join( batch, self.sep ) )