|
|
@@ -6,22 +6,6 @@
|
|
|
|
|
|
// #define GRIT_DEBUG
|
|
|
|
|
|
-static float dot_product(const std::vector<float> & v1, const std::vector<float> & v2) {
|
|
|
- float dot = 0.0f;
|
|
|
- for (uint64_t i = 0; i < v1.size(); ++i) {
|
|
|
- dot += v1[i] * v2[i];
|
|
|
- }
|
|
|
- return dot;
|
|
|
-}
|
|
|
-
|
|
|
-static float norm(const std::vector<float> & v) {
|
|
|
- return std::sqrt(dot_product(v, v));
|
|
|
-}
|
|
|
-
|
|
|
-static float cosine_similarity(const std::vector<float> & v1, const std::vector<float> & v2) {
|
|
|
- return dot_product(v1, v2) / (norm(v1) * norm(v2));
|
|
|
-}
|
|
|
-
|
|
|
static std::vector<std::vector<float>> encode(llama_context * ctx, const std::vector<std::string> & sentences, const std::string & instruction) {
|
|
|
std::vector<std::vector<float>> result;
|
|
|
|
|
|
@@ -203,10 +187,12 @@ int main(int argc, char * argv[]) {
|
|
|
const std::vector<std::vector<float>> d_rep = encode(ctx, documents, gritlm_instruction(""));
|
|
|
const std::vector<std::vector<float>> q_rep = encode(ctx, queries, gritlm_instruction(instruction));
|
|
|
|
|
|
- const float cosine_sim_q0_d0 = cosine_similarity(q_rep[0], d_rep[0]);
|
|
|
- const float cosine_sim_q0_d1 = cosine_similarity(q_rep[0], d_rep[1]);
|
|
|
- const float cosine_sim_q1_d0 = cosine_similarity(q_rep[1], d_rep[0]);
|
|
|
- const float cosine_sim_q1_d1 = cosine_similarity(q_rep[1], d_rep[1]);
|
|
|
+ const int n_embd = llama_n_embd(mdl);
|
|
|
+
|
|
|
+ const float cosine_sim_q0_d0 = llama_embd_similarity_cos(q_rep[0].data(), d_rep[0].data(), n_embd);
|
|
|
+ const float cosine_sim_q0_d1 = llama_embd_similarity_cos(q_rep[0].data(), d_rep[1].data(), n_embd);
|
|
|
+ const float cosine_sim_q1_d0 = llama_embd_similarity_cos(q_rep[1].data(), d_rep[0].data(), n_embd);
|
|
|
+ const float cosine_sim_q1_d1 = llama_embd_similarity_cos(q_rep[1].data(), d_rep[1].data(), n_embd);
|
|
|
|
|
|
std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[0].c_str(), documents[0].c_str(), cosine_sim_q0_d0);
|
|
|
std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[0].c_str(), documents[1].c_str(), cosine_sim_q0_d1);
|