diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 23d2a71..28f2359 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -197,7 +197,7 @@ try: except Exception as e: print("Error creating `LLamaXformersAttention`:", e) -def replace_attention( model, impl, verbose=True ): +def replace_attention( model, impl, verbose=Valse ): device = next(model.parameters()).device dtype = next(model.parameters()).dtype attentions = [k.split('.') for k, m in model.named_modules() if isinstance(m, LlamaAttention)]