vall-e/vall_e.cpp/include/quantizer.h

112 lines
3.5 KiB
C
Raw Normal View History

2024-12-25 04:39:32 +00:00
#pragma once
#include <cassert>
#include <vector>
#include "ggml.h"
#include "ggml-alloc.h"
#include "ggml-backend.h"
#include "utils.h"
struct encodec_quant_block {
struct ggml_tensor *embed;
};
struct encodec_quantizer {
std::vector<encodec_quant_block> blocks;
};
struct ggml_tensor *encodec_forward_quantizer_encode(
const struct encodec_quantizer *quantizer, struct ggml_context *ctx0,
struct ggml_tensor *encoded_inp, const int n_bins, const int sr, const int bandwidth,
const int hop_length) {
if (!encoded_inp) {
fprintf(stderr, "%s: null input tensor\n", __func__);
return NULL;
}
const int frame_rate = (int)ceilf(sr / hop_length);
const int n_q = get_num_quantizers_for_bandwidth(n_bins, frame_rate, bandwidth);
const int seq_length = encoded_inp->ne[0];
struct ggml_tensor *codes = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, seq_length, n_q);
ggml_set_input(codes);
struct ggml_tensor *inpL = ggml_cont(ctx0, ggml_transpose(ctx0, encoded_inp));
struct ggml_tensor *residual = inpL;
struct ggml_tensor *indices;
for (int i = 0; i < n_q; i++) {
encodec_quant_block block = quantizer->blocks[i];
// compute distance
// [seq_length, n_bins]
struct ggml_tensor *dp = ggml_scale(
ctx0, ggml_mul_mat(ctx0, block.embed, residual), -2.0f);
// [n_bins]
struct ggml_tensor *sqr_embed = ggml_sqr(ctx0, block.embed);
struct ggml_tensor *sqr_embed_nrm = ggml_sum_rows(ctx0, sqr_embed);
// [seq_length]
struct ggml_tensor *sqr_inp = ggml_sqr(ctx0, residual);
struct ggml_tensor *sqr_inp_nrm = ggml_sum_rows(ctx0, sqr_inp);
// [seq_length, n_bins]
struct ggml_tensor *dist = ggml_add(ctx0, ggml_repeat(ctx0, sqr_inp_nrm, dp), dp);
dist = ggml_add(ctx0, ggml_repeat(ctx0, ggml_transpose(ctx0, sqr_embed_nrm), dist), dist);
dist = ggml_neg(ctx0, dist);
// take the argmax over the column dimension
// [seq_length]
indices = ggml_argmax(ctx0, dist);
// look up in embedding table
struct ggml_tensor *quantized = ggml_get_rows(ctx0, block.embed, indices);
residual = ggml_sub(ctx0, residual, quantized);
codes = ggml_set_1d(ctx0, codes, indices, i * codes->nb[1]);
}
return codes;
}
struct ggml_tensor *encodec_forward_quantizer_decode(
const struct encodec_quantizer *quantizer, struct ggml_context *ctx0,
struct ggml_tensor *codes, const int hidden_dim, const int n_bins, const int sr, const int bandwidth,
const int hop_length) {
if (!codes) {
fprintf(stderr, "%s: null input tensor\n", __func__);
return NULL;
}
const int seq_length = codes->ne[0];
const int frame_rate = (int)ceilf(sr / hop_length);
const int n_q = get_num_quantizers_for_bandwidth(n_bins, frame_rate, bandwidth);
assert(n_q == codes->ne[1]);
struct ggml_tensor *quantized_out = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hidden_dim, seq_length);
ggml_set_input(quantized_out);
ggml_set_name(quantized_out, "quantized_out");
for (int i = 0; i < n_q; i++) {
encodec_quant_block block = quantizer->blocks[i];
struct ggml_tensor *indices = ggml_view_1d(ctx0, codes, seq_length, i * codes->nb[1]);
struct ggml_tensor *quantized = ggml_get_rows(ctx0, block.embed, indices);
quantized_out = ggml_add(ctx0, quantized_out, quantized);
}
quantized_out = ggml_cont(ctx0, ggml_transpose(ctx0, quantized_out));
return quantized_out;
}