support projecting to vectors

This commit is contained in:
James Betker 2022-05-28 22:27:45 -06:00
parent 86694aef4e
commit 6b43915eb8
3 changed files with 9 additions and 5 deletions

View File

@ -653,11 +653,14 @@ class ContrastiveTrainingWrapper(nn.Module):
} }
return groups return groups
def get_codes(self, mel): def get_codes(self, mel, project=False):
proj = self.m2v.input_blocks(mel).permute(0,2,1) proj = self.m2v.input_blocks(mel).permute(0,2,1)
_, proj = self.m2v.projector(proj) _, proj = self.m2v.projector(proj)
codes = self.quantizer.get_codes(proj) if project:
return codes proj, _ = self.quantizer(proj)
return proj
else:
return self.quantizer.get_codes(proj)
def reconstruct(self, mel): def reconstruct(self, mel):
proj = self.m2v.input_blocks(mel).permute(0,2,1) proj = self.m2v.input_blocks(mel).permute(0,2,1)

View File

@ -1,7 +1,7 @@
import os import os
if __name__ == '__main__': if __name__ == '__main__':
basepath = 'Y:/clips/podcasts-0' basepath = 'Y:\\bigasr_dataset\\hifi_tts'
english_file = os.path.join(basepath, 'transcribed-oco-realtext.tsv') english_file = os.path.join(basepath, 'transcribed-oco-realtext.tsv')
if not os.path.exists(english_file): if not os.path.exists(english_file):

View File

@ -330,13 +330,14 @@ class Mel2vecCodesInjector(Injector):
self.m2v = get_music_codegen() self.m2v = get_music_codegen()
del self.m2v.m2v.encoder # This is a big memory sink which will not get used. del self.m2v.m2v.encoder # This is a big memory sink which will not get used.
self.needs_move = True self.needs_move = True
self.inj_vector = opt_get(opt, ['vector'], False)
def forward(self, state): def forward(self, state):
mels = state[self.input] mels = state[self.input]
with torch.no_grad(): with torch.no_grad():
if self.needs_move: if self.needs_move:
self.m2v = self.m2v.to(mels.device) self.m2v = self.m2v.to(mels.device)
codes = self.m2v.get_codes(mels) codes = self.m2v.get_codes(mels, project=self.inj_vector)
return {self.output: codes} return {self.output: codes}