|
|
@@ -554,14 +554,19 @@ struct decode_embd_batch {
|
|
|
llama_batch get_view(int offset, int n_tokens) {
|
|
|
llama_pos * pos_ptr;
|
|
|
pos_view.clear();
|
|
|
- pos_view.resize(n_tokens * n_pos_per_embd);
|
|
|
+ pos_view.reserve(n_tokens * n_pos_per_embd);
|
|
|
if (n_pos_per_embd > 1) {
|
|
|
// mrope
|
|
|
// for example, with layout of src: 1234...1234...1234...1234...
|
|
|
// offset 2 will give us dst: 34...34...34...34...
|
|
|
for (int i = 0; i < n_pos_per_embd; i++) {
|
|
|
- auto src = pos.begin() + i * batch.n_tokens + offset;
|
|
|
- pos_view.insert(pos_view.end(), src, src + n_tokens);
|
|
|
+ // assume n_tokens is less than or equal to batch.n_tokens
|
|
|
+ // batch.n_tokens is number of **total** tokens
|
|
|
+ // n_tokens is number of viewed token
|
|
|
+ size_t src_idx = i * batch.n_tokens + offset;
|
|
|
+ pos_view.insert(pos_view.end(),
|
|
|
+ pos.data() + src_idx,
|
|
|
+ pos.data() + src_idx + n_tokens);
|
|
|
}
|
|
|
pos_ptr = pos_view.data();
|
|
|
} else {
|