@@ -1006,6 +1006,10 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
assert(nth > 0);
assert(ith < nth);
+ // only enable sgemm for prompt processing
+ if (n < 2)
+ return false;
+
if (Ctype != GGML_TYPE_F32)
return false;