From 478aea0e8cf6ac62936bd4bb761f7271252e0be3 Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 28 Mar 2025 19:49:54 -0500 Subject: [PATCH] tweaks --- vall_e/inference.py | 3 +++ vall_e/models/arch/llama.py | 11 ++++++++--- vall_e/models/base_v2.py | 3 ++- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/vall_e/inference.py b/vall_e/inference.py index 7f7edb5..74fa64b 100644 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -520,6 +520,9 @@ class TTS(): else: len_list = model_len( **input_kwargs, task_list=["len"], **{"max_duration": 5} ) # "max_duration" is max tokens + # clamp + len_list = [ max( l, 1 * cfg.dataset.frames_per_second ) for l in len_list ] + # add an additional X seconds len_list = [ int(l * duration_padding) for l in len_list ] diff --git a/vall_e/models/arch/llama.py b/vall_e/models/arch/llama.py index 2700fce..98b021f 100644 --- a/vall_e/models/arch/llama.py +++ b/vall_e/models/arch/llama.py @@ -239,7 +239,7 @@ class Attention(nn.Module): "default", torch.nn.attention.SDPBackend.MATH, torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION, - torch.nn.attention.SDPBackend.FLASH_ATTENTION, + #torch.nn.attention.SDPBackend.FLASH_ATTENTION, torch.nn.attention.SDPBackend.CUDNN_ATTENTION ] @@ -365,6 +365,7 @@ class Attention(nn.Module): # pain # SDPA backends only sometimes allowing/disallowing some arguments... + """ if isinstance( is_causal, list ): count = sum( [ 1 if x else 0 for x in is_causal ] ) if count == 0: @@ -378,6 +379,12 @@ class Attention(nn.Module): x_mask = None elif is_causal == True: x_mask = None + """ + if self.attn_mode in [torch.nn.attention.SDPBackend.FLASH_ATTENTION]: + x_mask = None + + if x_mask is not None: + is_causal = False # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. @@ -519,10 +526,8 @@ class DecoderLayer(nn.Module): hidden_states = self.input_layernorm(hidden_states) # ugh - """ if isinstance( is_causal, list ) and len(is_causal) == 1: is_causal = is_causal[0] - """ # Self Attention if self.config.attn_mode == "sparse": diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py index f726cbb..3c1dc52 100644 --- a/vall_e/models/base_v2.py +++ b/vall_e/models/base_v2.py @@ -765,7 +765,8 @@ class Base_V2(nn.Module): # needed, cringe if task_type == "len": - batch[-1] = torch.cat( [ batch[-1], self.sep[None], self.sep[None] ] ) + #batch[-1] = torch.cat( [ batch[-1], self.sep[None], self.sep[None] ] ) + batch[-1] = torch.cat( [ batch[-1], self.sep[None] ] ) x_list.append( _join( batch, self.sep ) )