support projecting to vectors

This commit is contained in:
James Betker 2022-05-28 22:27:45 -06:00
parent 86694aef4e
commit 6b43915eb8
3 changed files with 9 additions and 5 deletions

View File

@ -653,11 +653,14 @@ class ContrastiveTrainingWrapper(nn.Module):
}
return groups
def get_codes(self, mel):
def get_codes(self, mel, project=False):
proj = self.m2v.input_blocks(mel).permute(0,2,1)
_, proj = self.m2v.projector(proj)
codes = self.quantizer.get_codes(proj)
return codes
if project:
proj, _ = self.quantizer(proj)
return proj
else:
return self.quantizer.get_codes(proj)
def reconstruct(self, mel):
proj = self.m2v.input_blocks(mel).permute(0,2,1)

View File

@ -1,7 +1,7 @@
import os
if __name__ == '__main__':
basepath = 'Y:/clips/podcasts-0'
basepath = 'Y:\\bigasr_dataset\\hifi_tts'
english_file = os.path.join(basepath, 'transcribed-oco-realtext.tsv')
if not os.path.exists(english_file):

View File

@ -330,13 +330,14 @@ class Mel2vecCodesInjector(Injector):
self.m2v = get_music_codegen()
del self.m2v.m2v.encoder # This is a big memory sink which will not get used.
self.needs_move = True
self.inj_vector = opt_get(opt, ['vector'], False)
def forward(self, state):
mels = state[self.input]
with torch.no_grad():
if self.needs_move:
self.m2v = self.m2v.to(mels.device)
codes = self.m2v.get_codes(mels)
codes = self.m2v.get_codes(mels, project=self.inj_vector)
return {self.output: codes}