|
@@ -231,39 +231,39 @@ std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::stri
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
std::vector<gpt_vocab::id> llama_tokenize(const gpt_vocab & vocab, const std::string & text, bool bos) {
|
|
std::vector<gpt_vocab::id> llama_tokenize(const gpt_vocab & vocab, const std::string & text, bool bos) {
|
|
|
- auto res = gpt_tokenize(vocab, text);
|
|
|
|
|
|
|
+ //auto res = gpt_tokenize(vocab, text);
|
|
|
|
|
+
|
|
|
|
|
+ //if (bos) {
|
|
|
|
|
+ // res.insert(res.begin(), 1); // TODO: replace with vocab.bos
|
|
|
|
|
+ //}
|
|
|
|
|
+
|
|
|
|
|
+ std::vector<gpt_vocab::id> res;
|
|
|
|
|
|
|
|
if (bos) {
|
|
if (bos) {
|
|
|
- res.insert(res.begin(), 1); // TODO: replace with vocab.bos
|
|
|
|
|
|
|
+ res.push_back(1); // TODO: replace with vocab.bos
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- //std::vector<gpt_vocab::id> res;
|
|
|
|
|
|
|
+ //find the longest token that matches the text
|
|
|
|
|
+ int pos = 0;
|
|
|
|
|
+ while (true) {
|
|
|
|
|
+ int l = 0;
|
|
|
|
|
+ int t = 0;
|
|
|
|
|
+ for (const auto & kv : vocab.id_to_token) {
|
|
|
|
|
+ if (kv.second.size() < l) continue;
|
|
|
|
|
+ if (kv.second.size() > text.size() - pos) continue;
|
|
|
|
|
+ if (text.substr(pos, kv.second.size()) == kv.second) {
|
|
|
|
|
+ l = kv.second.size();
|
|
|
|
|
+ t = kv.first;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- //if (bos) {
|
|
|
|
|
- // res.push_back(1); // TODO: replace with vocab.bos
|
|
|
|
|
- //}
|
|
|
|
|
|
|
+ if (l == 0 && t != 13) {
|
|
|
|
|
+ break;
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- // find the longest token that matches the text
|
|
|
|
|
- //int pos = 0;
|
|
|
|
|
- //while (true) {
|
|
|
|
|
- // int l = 0;
|
|
|
|
|
- // int t = 0;
|
|
|
|
|
- // for (const auto & kv : vocab.id_to_token) {
|
|
|
|
|
- // if (kv.second.size() < l) continue;
|
|
|
|
|
- // if (kv.second.size() > text.size() - pos) continue;
|
|
|
|
|
- // if (text.substr(pos, kv.second.size()) == kv.second) {
|
|
|
|
|
- // l = kv.second.size();
|
|
|
|
|
- // t = kv.first;
|
|
|
|
|
- // }
|
|
|
|
|
- // }
|
|
|
|
|
-
|
|
|
|
|
- // if (l == 0 && t != 13) {
|
|
|
|
|
- // break;
|
|
|
|
|
- // }
|
|
|
|
|
-
|
|
|
|
|
- // res.push_back(t);
|
|
|
|
|
- // pos += l;
|
|
|
|
|
- //}
|
|
|
|
|
|
|
+ res.push_back(t);
|
|
|
|
|
+ pos += l;
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
return res;
|
|
return res;
|
|
|
}
|
|
}
|