diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 3f183e9..e494e03 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -951,16 +951,16 @@ class Base(nn.Module): quant_levels: int | list[int] | Tensor | None = None ): - if text_list: + if text_list and text_list[0] is not None: device = text_list[0].device batch_size = len(text_list) - elif raw_text_list: + elif raw_text_list and raw_text_list[0] is not None: device = raw_text_list[0].device batch_size = len(raw_text_list) - elif proms_list: + elif proms_list and proms_list[0] is not None: device = proms_list[0].device batch_size = len(proms_list) - elif resps_list: + elif resps_list and resps_list[0] is not None: device = resps_list[0].device batch_size = len(resps_list)