|
|
@@ -425,6 +425,33 @@ static void prompt_init(llama_tokens & prompt, const llama_vocab * vocab) {
|
|
|
prompt_add(prompt, vocab, "<|im_start|>\n", true, true);
|
|
|
}
|
|
|
|
|
|
+static std::vector<llama_token> prepare_guide_tokens(const llama_vocab * vocab, const std::string & str) {
|
|
|
+ const std::string& delimiter = "<|text_sep|>";
|
|
|
+
|
|
|
+ std::vector<llama_token> result;
|
|
|
+ size_t start = 0;
|
|
|
+ size_t end = str.find(delimiter);
|
|
|
+
|
|
|
+ //first token is always a newline, as it was not previously added
|
|
|
+ result.push_back(common_tokenize(vocab, "\n", false, true)[0]);
|
|
|
+
|
|
|
+ while (end != std::string::npos) {
|
|
|
+ std::string current_word = str.substr(start, end - start);
|
|
|
+ auto tmp = common_tokenize(vocab, current_word, false, true);
|
|
|
+ result.push_back(tmp[0]);
|
|
|
+ start = end + delimiter.length();
|
|
|
+ end = str.find(delimiter, start);
|
|
|
+ }
|
|
|
+
|
|
|
+ // Add the last part
|
|
|
+ std::string current_word = str.substr(start);
|
|
|
+ auto tmp = common_tokenize(vocab, current_word, false, true);
|
|
|
+ if (tmp.size() > 0) {
|
|
|
+ result.push_back(tmp[0]);
|
|
|
+ }
|
|
|
+ return result;
|
|
|
+}
|
|
|
+
|
|
|
int main(int argc, char ** argv) {
|
|
|
common_params params;
|
|
|
|
|
|
@@ -494,6 +521,7 @@ int main(int argc, char ** argv) {
|
|
|
const auto t_main_start = ggml_time_us();
|
|
|
|
|
|
std::vector<llama_token> codes;
|
|
|
+ std::vector<llama_token> guide_tokens;
|
|
|
|
|
|
// process prompt and generate voice codes
|
|
|
{
|
|
|
@@ -508,6 +536,9 @@ int main(int argc, char ** argv) {
|
|
|
// convert the input text into the necessary format expected by OuteTTS
|
|
|
{
|
|
|
std::string prompt_clean = process_text(params.prompt);
|
|
|
+ if (params.vocoder.use_guide_tokens) {
|
|
|
+ guide_tokens = prepare_guide_tokens(vocab, prompt_clean);
|
|
|
+ }
|
|
|
|
|
|
LOG_INF("%s: prompt: '%s'\n", __func__, prompt_clean.c_str());
|
|
|
|
|
|
@@ -717,6 +748,8 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
|
|
|
int n_past = batch.n_tokens;
|
|
|
int n_decode = 0;
|
|
|
|
|
|
+ bool next_token_uses_guide_token = true;
|
|
|
+
|
|
|
while (n_decode <= n_predict) {
|
|
|
// prepare the next batch
|
|
|
common_batch_clear(batch);
|
|
|
@@ -728,7 +761,17 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
|
|
|
continue;
|
|
|
}
|
|
|
|
|
|
- const llama_token new_token_id = common_sampler_sample(smpl[i], ctx_ttc, i_batch[i]);
|
|
|
+ llama_token new_token_id = common_sampler_sample(smpl[i], ctx_ttc, i_batch[i]);
|
|
|
+
|
|
|
+ //guide tokens help prevent hallucinations by forcing the TTS to use the correct word
|
|
|
+ if (!guide_tokens.empty() && next_token_uses_guide_token && !llama_vocab_is_control(vocab, new_token_id) && !llama_vocab_is_eog(vocab, new_token_id)) {
|
|
|
+ llama_token guide_token = guide_tokens[0];
|
|
|
+ guide_tokens.erase(guide_tokens.begin());
|
|
|
+ new_token_id = guide_token; //ensure correct word fragment is used
|
|
|
+ }
|
|
|
+
|
|
|
+ //this is the token id that always precedes a new word
|
|
|
+ next_token_uses_guide_token = (new_token_id == 198);
|
|
|
|
|
|
common_sampler_accept(smpl[i], new_token_id, true);
|
|
|
|