From 4b2d6f559523054bbf5a789f88763c58472411b4 Mon Sep 17 00:00:00 2001 From: mrq Date: Sun, 5 Jan 2025 18:33:11 -0600 Subject: [PATCH] ugh --- vall_e/models/base.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 3f183e9..e494e03 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -951,16 +951,16 @@ class Base(nn.Module): quant_levels: int | list[int] | Tensor | None = None ): - if text_list: + if text_list and text_list[0] is not None: device = text_list[0].device batch_size = len(text_list) - elif raw_text_list: + elif raw_text_list and raw_text_list[0] is not None: device = raw_text_list[0].device batch_size = len(raw_text_list) - elif proms_list: + elif proms_list and proms_list[0] is not None: device = proms_list[0].device batch_size = len(proms_list) - elif resps_list: + elif resps_list and resps_list[0] is not None: device = resps_list[0].device batch_size = len(resps_list)