From 3a6bd50322db39abdbaee1b1e0d5a94f58d74ced Mon Sep 17 00:00:00 2001 From: mrq Date: Sun, 3 Sep 2023 21:36:58 -0500 Subject: [PATCH] haha --- vall_e/models/interleaved_ar.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/vall_e/models/interleaved_ar.py b/vall_e/models/interleaved_ar.py index 08b7921..3d78576 100644 --- a/vall_e/models/interleaved_ar.py +++ b/vall_e/models/interleaved_ar.py @@ -16,14 +16,17 @@ from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, Mult from .retnet import RetNetDecoder, RetNetConfig from .transformer import SinusoidalEmbedding, Block as TransformerBlock -from ..ext.interleaver import ( - CodebooksPatternProvider, - DelayedPatternProvider, - MusicLMPattern, - ParallelPatternProvider, - UnrolledPatternProvider, - VALLEPattern, -) +try: + from ..ext.interleaver import ( + CodebooksPatternProvider, + DelayedPatternProvider, + MusicLMPattern, + ParallelPatternProvider, + UnrolledPatternProvider, + VALLEPattern, + ) +except Exception as e: + pass from ..config import cfg @@ -422,9 +425,8 @@ class Base(nn.Module): if shift_targ_list: targ_list = [*targ_list] for i in range(len(targ_list)): - targ_list[i] = targ_list[i].roll(-self.n_resp_levels, dims=0) - for j in range(self.n_resp_levels): - targ_list[i][-j-1] = self.stop_token + targ_list[i] = targ_list[i].roll(-1, dims=0) + targ_list[i][-1] = self.stop_token # create the new target sequence to compute the loss against y_list = self._samplewise_merge_tensors( text_prom_list, targ_list, sep=ignore_sep ) @@ -539,7 +541,7 @@ def example_usage(): 'n_tokens': 1024, 'd_model': 1024, 'n_heads': 16, - 'n_layers': 12, + 'n_layers': 18, } models = { "ar": Base(**kwargs).to(device) } @@ -563,7 +565,7 @@ def example_usage(): qnt.to(device), ] - def sample( filename, steps=400 ): + def sample( filename, steps=450 * 4 ): AR = None engines.eval()