1
0

test-grammar-llguidance.cpp 39 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201
  1. #ifdef NDEBUG
  2. # undef NDEBUG
  3. #endif
  4. #include "sampling.h"
  5. #include <cassert>
  6. #include <string>
  7. #include <vector>
  8. static const llama_vocab * vocab;
  9. static bool match_string(const std::string & input, llama_sampler * grammar) {
  10. llama_sampler_reset(grammar);
  11. auto tokens = common_tokenize(vocab, input, false, false);
  12. auto n_vocab = llama_vocab_n_tokens(vocab);
  13. std::vector<llama_token_data> cur;
  14. cur.reserve(n_vocab);
  15. for (llama_token token_id = 0; token_id < (llama_token) n_vocab; token_id++) {
  16. cur.emplace_back(llama_token_data{ token_id, 0.0f, 0.0f });
  17. }
  18. auto tok_arr = llama_token_data_array{ cur.data(), cur.size(), -1, false };
  19. for (const auto token : tokens) {
  20. for (llama_token token_id = 0; token_id < (llama_token) n_vocab; token_id++) {
  21. cur[token_id].logit = 0.0f;
  22. }
  23. llama_sampler_apply(grammar, &tok_arr);
  24. if (cur[token].logit < 0.0f) {
  25. return false;
  26. }
  27. llama_sampler_accept(grammar, token);
  28. }
  29. // do we allow EOS at the end? if so the grammar is accepting
  30. auto tok_eos = llama_vocab_eot(vocab);
  31. if (tok_eos == LLAMA_TOKEN_NULL) {
  32. tok_eos = llama_vocab_eos(vocab);
  33. }
  34. cur[tok_eos].logit = 0.0f;
  35. llama_sampler_apply(grammar, &tok_arr);
  36. return cur[tok_eos].logit >= 0.0f;
  37. }
  38. static void test(const std::string & test_desc, const std::string & grammar_str,
  39. const std::vector<std::string> & passing_strings, const std::vector<std::string> & failing_strings) {
  40. fprintf(stderr, "⚫ Testing %s\n%s\n", test_desc.c_str(), grammar_str.c_str());
  41. fflush(stderr);
  42. auto * grammar = llama_sampler_init_llg(vocab, "lark", grammar_str.c_str());
  43. fprintf(stderr, " 🔵 Valid strings:\n");
  44. // Passing strings
  45. for (const auto & test_string : passing_strings) {
  46. fprintf(stderr, " \"%s\" ", test_string.c_str());
  47. fflush(stderr);
  48. bool matched = match_string(test_string, grammar);
  49. if (!matched) {
  50. fprintf(stderr, "❌ (failed to match)\n");
  51. // DEBUG: Write strings to files so that we can analyze more easily with gbnf-validator program to see exactly where things failed.
  52. // DEBUG: Write the grammar_str to test-grammar-integration.grammar.gbnf
  53. FILE * grammar_file = fopen("test-grammar-integration.grammar.gbnf", "w");
  54. if (grammar_file) {
  55. fprintf(grammar_file, "%s", grammar_str.c_str());
  56. fclose(grammar_file);
  57. }
  58. // DEBUG: Write the test string to test-grammar-integration.string.txt
  59. FILE * string_file = fopen("test-grammar-integration.string.txt", "w");
  60. if (string_file) {
  61. fprintf(string_file, "%s", test_string.c_str());
  62. fclose(string_file);
  63. }
  64. fprintf(stderr,
  65. "\n NOTE: Debug grammar file generated. To analyze this failure in detail, run the following "
  66. "command: ./test-gbnf-validator test-grammar-integration.grammar.gbnf "
  67. "test-grammar-integration.string.txt\n\n");
  68. } else {
  69. fprintf(stdout, "✅︎\n");
  70. }
  71. assert(matched);
  72. }
  73. fprintf(stderr, " 🟠 Invalid strings:\n");
  74. // Failing strings
  75. for (const auto & test_string : failing_strings) {
  76. fprintf(stderr, " \"%s\" ", test_string.c_str());
  77. fflush(stderr);
  78. bool matched = match_string(test_string, grammar);
  79. if (matched) {
  80. fprintf(stderr, "❌ (incorrectly matched)\n");
  81. } else {
  82. fprintf(stdout, "✅︎\n");
  83. }
  84. assert(!matched);
  85. }
  86. llama_sampler_free(grammar);
  87. }
  88. static void test_grammar(const std::string & test_desc, const std::string & grammar_str,
  89. const std::vector<std::string> & passing_strings,
  90. const std::vector<std::string> & failing_strings) {
  91. test(test_desc + ". Grammar: " + grammar_str, grammar_str, passing_strings, failing_strings);
  92. }
  93. static void test_schema(const std::string & test_desc, const std::string & schema_str,
  94. const std::vector<std::string> & passing_strings,
  95. const std::vector<std::string> & failing_strings) {
  96. test(test_desc + ". Schema: " + schema_str, "%llguidance {}\nstart: %json " + schema_str, passing_strings,
  97. failing_strings);
  98. }
  99. static void test_simple_grammar() {
  100. test_schema("min 0",
  101. R"""({
  102. "type": "integer",
  103. "minimum": 0
  104. })""",
  105. // Passing strings
  106. {
  107. "0",
  108. "10",
  109. "12",
  110. "10000",
  111. },
  112. // Failing strings
  113. {
  114. "-1",
  115. "-10",
  116. "-10000",
  117. "-100000000000000000000000000000000",
  118. // "100000000000000000000000000000000",
  119. "00",
  120. "01",
  121. "-0",
  122. });
  123. test_schema("min 2",
  124. // Schema
  125. R"""({
  126. "type": "integer",
  127. "minimum": 2
  128. })""",
  129. // Passing strings
  130. {
  131. "2",
  132. "3",
  133. "4",
  134. "10",
  135. "20",
  136. "1234567890000000",
  137. },
  138. // Failing strings
  139. {
  140. "0", "1", "-1", "-100", "0", "1", "01", "02",
  141. // "12345678900000000",
  142. });
  143. test_schema("min 456",
  144. R"""({
  145. "type": "integer",
  146. "minimum": 456
  147. })""",
  148. // Passing strings
  149. {
  150. "456",
  151. "4560",
  152. "457",
  153. "460",
  154. "500",
  155. },
  156. // Failing strings
  157. {
  158. "455",
  159. "356",
  160. "50",
  161. "050",
  162. "-1",
  163. "-456",
  164. });
  165. test_schema("min -123",
  166. R"""({
  167. "type": "integer",
  168. "minimum": -123
  169. })""",
  170. // Passing strings
  171. {
  172. "-123",
  173. "-122",
  174. "-11",
  175. "-1",
  176. "0",
  177. "1",
  178. "123",
  179. "1234",
  180. "2345",
  181. },
  182. // Failing strings
  183. {
  184. "-1234",
  185. "-124",
  186. });
  187. test_schema("max 9999",
  188. // Schema
  189. R"""({
  190. "type": "integer",
  191. "maximum": 9999
  192. })""",
  193. // Passing strings
  194. {
  195. "-99999",
  196. "0",
  197. "9999",
  198. },
  199. // Failing strings
  200. {
  201. "10000",
  202. "99991",
  203. });
  204. test_schema("max -9999",
  205. // Schema
  206. R"""({
  207. "type": "integer",
  208. "maximum": -9999
  209. })""",
  210. // Passing strings
  211. {
  212. "-10000",
  213. "-9999",
  214. },
  215. // Failing strings
  216. {
  217. "-9998",
  218. "0",
  219. "9999",
  220. });
  221. test_schema("min 5 max 30",
  222. // Schema
  223. R"""({
  224. "type": "integer",
  225. "minimum": 5,
  226. "maximum": 30
  227. })""",
  228. // Passing strings
  229. {
  230. "5",
  231. "10",
  232. "30",
  233. },
  234. // Failing strings
  235. {
  236. "05",
  237. "4",
  238. "-1",
  239. "31",
  240. "123",
  241. "0123",
  242. });
  243. test_schema("min -1 max 1",
  244. R"""({
  245. "type": "integer",
  246. "minimum": -1,
  247. "maximum": 1
  248. })""",
  249. // Passing strings
  250. {
  251. "-1",
  252. "0",
  253. "1",
  254. },
  255. // Failing strings
  256. {
  257. "-11",
  258. "-10",
  259. "-2",
  260. "2",
  261. "10",
  262. "11",
  263. });
  264. test_schema("min -123 max 42",
  265. R"""({
  266. "type": "integer",
  267. "minimum": -123,
  268. "maximum": 42
  269. })""",
  270. // Passing strings
  271. {
  272. "-123",
  273. "-122",
  274. "-13",
  275. "-11",
  276. "-2",
  277. "-1",
  278. "0",
  279. "1",
  280. "5",
  281. "10",
  282. "39",
  283. "40",
  284. "42",
  285. },
  286. // Failing strings
  287. {
  288. "-0123",
  289. "-124",
  290. "-1123",
  291. "-200",
  292. "43",
  293. "123",
  294. "0123",
  295. });
  296. test_schema("exclusive min / max",
  297. // Schema
  298. R"""({
  299. "type": "integer",
  300. "exclusiveMinimum": 0,
  301. "exclusiveMaximum": 10000
  302. })""",
  303. // Passing strings
  304. {
  305. "1",
  306. "9999",
  307. },
  308. // Failing strings
  309. {
  310. "0",
  311. "01",
  312. "10000",
  313. "99999",
  314. });
  315. // Test case for a simple grammar
  316. test_grammar("simple grammar",
  317. R"""(
  318. start: expr
  319. expr: term ("+" term)*
  320. term: number
  321. number: /[0-9]+/ )""",
  322. // Passing strings
  323. {
  324. "42",
  325. "1+2+3+4+5",
  326. "123+456",
  327. },
  328. // Failing strings
  329. {
  330. "+",
  331. "/ 3",
  332. "1+2+3+4+5+",
  333. "12a45",
  334. });
  335. }
  336. static void test_complex_grammar() {
  337. // Test case for a more complex grammar, with both failure strings and success strings
  338. test_grammar("medium complexity grammar",
  339. // Grammar
  340. R"""(
  341. start: expression
  342. expression: term ws (("+"|"-") ws term)*
  343. term: factor ws (("*"|"/") ws factor)*
  344. factor: number | variable | "(" expression ")" | function-call
  345. number: /[0-9]+/
  346. variable: /[a-zA-Z_][a-zA-Z0-9_]*/
  347. function-call: variable ws "(" (expression ("," ws expression)*)? ")"
  348. ws: /[ \t\n\r]?/ )""",
  349. // Passing strings
  350. { "42",
  351. "1*2*3*4*5",
  352. "x",
  353. "x+10",
  354. "x1+y2",
  355. "(a+b)*(c-d)",
  356. "func()",
  357. "func(x,y+2)",
  358. "a*(b+c)-d/e",
  359. "f(g(x),h(y,z))",
  360. "x + 10",
  361. "x1 + y2",
  362. "(a + b) * (c - d)",
  363. "func()",
  364. "func(x, y + 2)",
  365. "a * (b + c) - d / e",
  366. "f(g(x), h(y, z))",
  367. "123+456",
  368. "123*456*789-123/456+789*123",
  369. "123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456" },
  370. // Failing strings
  371. {
  372. "+",
  373. "/ 3x",
  374. "x + + y",
  375. "a * / b",
  376. "func(,)",
  377. "func(x y)",
  378. "(a + b",
  379. "x + y)",
  380. "a + b * (c - d",
  381. "42 +",
  382. "x +",
  383. "x + 10 +",
  384. "(a + b) * (c - d",
  385. "func(",
  386. "func(x, y + 2",
  387. "a * (b + c) - d /",
  388. "f(g(x), h(y, z)",
  389. "123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456/",
  390. });
  391. }
  392. static void test_special_chars() {
  393. // A collection of tests to exercise special characters such as "."
  394. test_grammar("special characters",
  395. // Grammar
  396. R"""(
  397. start: /.../ "abc" /.../
  398. )""",
  399. // Passing strings
  400. { "abcabcabc", "aaaabcccc",
  401. // NOTE: Also ensures that multi-byte characters still count as a single character
  402. "🔵🟠✅abc❌🟠🔵" },
  403. // Failing strings
  404. { "aaabcccc", "aaaaabcccc", "aaaabccc", "aaaabccccc", "🔵🟠✅❌abc❌✅🟠🔵", "🔵🟠abc🟠🔵" });
  405. }
  406. static void test_quantifiers() {
  407. // A collection of tests to exercise * + and ? quantifiers
  408. test_grammar(
  409. "* quantifier",
  410. // Grammar
  411. R"""(start: "a"*)""",
  412. // Passing strings
  413. { "", "a", "aaaaa", "aaaaaaaaaaaaaaaaaa", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" },
  414. // Failing strings
  415. { "b", "ab", "aab", "ba", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab" });
  416. test_grammar(
  417. "+ quantifier",
  418. // Grammar
  419. R"""(start: "a"+)""",
  420. // Passing strings
  421. { "a", "aaaaa", "aaaaaaaaaaaaaaaaaa", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" },
  422. // Failing strings
  423. { "", "b", "ab", "aab", "ba", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab" });
  424. test_grammar("? quantifier",
  425. // Grammar
  426. R"""(start: "a"?)""",
  427. // Passing strings
  428. { "", "a" },
  429. // Failing strings
  430. {
  431. "b",
  432. "ab",
  433. "aa",
  434. "ba",
  435. });
  436. test_grammar("mixed quantifiers",
  437. // Grammar
  438. R"""(
  439. start: cons+ vowel* cons? (vowel cons)*
  440. vowel: /[aeiouy]/
  441. cons: /[bcdfghjklmnpqrstvwxyz]/
  442. )""",
  443. // Passing strings
  444. {
  445. "yes",
  446. "no",
  447. "noyes",
  448. "crwth",
  449. "four",
  450. "bryyyy",
  451. },
  452. // Failing strings
  453. {
  454. "yess",
  455. "yesno",
  456. "forty",
  457. "catyyy",
  458. });
  459. test_grammar("simple exact repetition",
  460. // Grammar
  461. R"""(
  462. start: /[ab]{4}/
  463. )""",
  464. // Passing strings
  465. {
  466. "aaaa",
  467. "bbbb",
  468. "abab",
  469. },
  470. // Failing strings
  471. {
  472. "a",
  473. "b",
  474. "aaaaa",
  475. });
  476. test_grammar("simple min repetition",
  477. // Grammar
  478. R"""(
  479. start: /[ab]{4,}/
  480. )""",
  481. // Passing strings
  482. {
  483. "aaaa",
  484. "aaaaab",
  485. "bbbb",
  486. "ababab",
  487. },
  488. // Failing strings
  489. {
  490. "",
  491. "aba",
  492. });
  493. test_grammar("simple max repetition",
  494. // Grammar
  495. R"""(
  496. start: /[ab]{0,4}/
  497. )""",
  498. // Passing strings
  499. {
  500. "",
  501. "a",
  502. "aa",
  503. "aaa",
  504. "aaab",
  505. },
  506. // Failing strings
  507. {
  508. "aaaaa",
  509. });
  510. // test_grammar("min / max repetition",
  511. // // Grammar
  512. // R"""(
  513. // start: ("0x" /[A-F0-9]{2}/ " "?){3,5}
  514. // )""",
  515. // // Passing strings
  516. // {
  517. // "0xFF 0x12 0xAB",
  518. // "0xFF 0x12 0xAB 0x00 0x00",
  519. // },
  520. // // Failing strings
  521. // {
  522. // "",
  523. // "0xFF",
  524. // "0xFF 0x12",
  525. // "0xFF 0x12 0xAB 0x00 0x00 0x00",
  526. // });
  527. }
  528. static void test_json_schema() {
  529. // Note that this is similar to the regular grammar tests,
  530. // but we convert each json schema to a grammar before parsing.
  531. // Otherwise, this test structure is the same.
  532. test_schema("empty schema (object)",
  533. // Schema
  534. R"""(
  535. {"type":"object"}
  536. )""",
  537. // Passing strings
  538. {
  539. R"""({})""",
  540. R"""({"foo": "bar"})""",
  541. },
  542. // Failing strings
  543. {
  544. "",
  545. "[]",
  546. "null",
  547. R"""("")""",
  548. "true",
  549. });
  550. test_schema(
  551. "exotic formats (list)",
  552. // Schema
  553. R"""({
  554. "items": [
  555. { "format": "date" },
  556. { "format": "uuid" },
  557. { "format": "time" },
  558. { "format": "date-time" }
  559. ]
  560. })""",
  561. // Passing strings
  562. {
  563. // "{}", // NOTE: This string passes for this schema on https://www.jsonschemavalidator.net/ -- should it?
  564. // "[]", // NOTE: This string passes for this schema on https://www.jsonschemavalidator.net/ -- should it?
  565. R"""(["2012-04-23", "12345678-1234-1234-1234-1234567890ab", "18:25:43.511Z", "2012-04-23T18:25:43.511Z"])""",
  566. //R"""(["2012-04-23","12345678-1234-1234-1234-1234567890ab"])""", // NOTE: This string passes for this schema on https://www.jsonschemavalidator.net/ -- should it?
  567. //R"""({"foo": "bar"})""", // NOTE: This string passes for this schema on https://www.jsonschemavalidator.net/ -- should it?
  568. },
  569. // Failing strings
  570. {
  571. R"""(["foo", "bar"])""",
  572. R"""(["12345678-1234-1234-1234-1234567890ab"])""",
  573. });
  574. test_schema("string",
  575. // Schema
  576. R"""({
  577. "type": "string"
  578. })""",
  579. // Passing strings
  580. {
  581. R"""("foo")""",
  582. R"""("bar")""",
  583. R"""("")""",
  584. },
  585. // Failing strings
  586. {
  587. R"""({})""",
  588. R"""("foo": "bar")""",
  589. });
  590. test_schema("string w/ min length 1",
  591. // Schema
  592. R"""({
  593. "type": "string",
  594. "minLength": 1
  595. })""",
  596. // Passing strings
  597. {
  598. R"""("foo")""",
  599. R"""("bar")""",
  600. },
  601. // Failing strings
  602. {
  603. R"""("")""",
  604. R"""({})""",
  605. R"""("foo": "bar")""",
  606. });
  607. test_schema("string w/ min length 3",
  608. // Schema
  609. R"""({
  610. "type": "string",
  611. "minLength": 3
  612. })""",
  613. // Passing strings
  614. {
  615. R"""("foo")""",
  616. R"""("bar")""",
  617. R"""("foobar")""",
  618. },
  619. // Failing strings
  620. {
  621. R"""("")""",
  622. R"""("f")""",
  623. R"""("fo")""",
  624. });
  625. test_schema("string w/ max length",
  626. // Schema
  627. R"""({
  628. "type": "string",
  629. "maxLength": 3
  630. })""",
  631. // Passing strings
  632. {
  633. R"""("foo")""",
  634. R"""("bar")""",
  635. R"""("")""",
  636. R"""("f")""",
  637. R"""("fo")""",
  638. },
  639. // Failing strings
  640. {
  641. R"""("foobar")""",
  642. });
  643. test_schema("string w/ min & max length",
  644. // Schema
  645. R"""({
  646. "type": "string",
  647. "minLength": 1,
  648. "maxLength": 4
  649. })""",
  650. // Passing strings
  651. {
  652. R"""("foo")""",
  653. R"""("bar")""",
  654. R"""("f")""",
  655. R"""("barf")""",
  656. },
  657. // Failing strings
  658. {
  659. R"""("")""",
  660. R"""("barfo")""",
  661. R"""("foobar")""",
  662. });
  663. test_schema("boolean",
  664. // Schema
  665. R"""({
  666. "type": "boolean"
  667. })""",
  668. // Passing strings
  669. {
  670. "true",
  671. "false",
  672. },
  673. // Failing strings
  674. {
  675. R"""("")""",
  676. R"""("true")""",
  677. R"""(True)""",
  678. R"""(FALSE)""",
  679. });
  680. test_schema("integer",
  681. // Schema
  682. R"""({
  683. "type": "integer"
  684. })""",
  685. // Passing strings
  686. {
  687. R"""(0)""",
  688. R"""(12345)""",
  689. R"""(1234567890123456)""",
  690. },
  691. // Failing strings
  692. {
  693. R"""()""",
  694. R"""(01)""",
  695. R"""(007)""",
  696. R"""(12345678901234567 )""",
  697. });
  698. test_schema("string const",
  699. // Schema
  700. R"""({
  701. "const": "foo"
  702. })""",
  703. // Passing strings
  704. {
  705. R"""("foo")""",
  706. },
  707. // Failing strings
  708. {
  709. R"""(foo)""",
  710. R"""("bar")""",
  711. });
  712. test_schema("non-string const",
  713. // Schema
  714. R"""({
  715. "const": true
  716. })""",
  717. // Passing strings
  718. {
  719. R"""(true)""",
  720. },
  721. // Failing strings
  722. {
  723. R"""()""",
  724. R"""(foo)""",
  725. R"""("true")""",
  726. });
  727. test_schema("non-string const",
  728. // Schema
  729. R"""({
  730. "enum": ["red", "amber", "green", null, 42, ["foo"]]
  731. })""",
  732. // Passing strings
  733. {
  734. R"""("red")""",
  735. R"""(null)""",
  736. R"""(42)""",
  737. R"""(["foo"])""",
  738. },
  739. // Failing strings
  740. {
  741. R"""()""",
  742. R"""(420)""",
  743. R"""(true)""",
  744. R"""(foo)""",
  745. });
  746. test_schema("simple pattern",
  747. // Schema
  748. R"""({
  749. "pattern": "^[a-zA-Z0-9_-]*$"
  750. })""",
  751. // Passing strings
  752. {
  753. R"""("")""",
  754. R"""("He_llo-12")""",
  755. },
  756. // Failing strings
  757. {
  758. R"""("!")""",
  759. R"""("Hello World")""",
  760. });
  761. test_schema("pattern with escapes",
  762. // Schema
  763. R"""({
  764. "pattern": "^a\\^\\$\\.\\[\\]\\(\\)\\|\\{\\}\\*\\+\\?b$"
  765. })""",
  766. // Passing strings
  767. {
  768. R"""("a^$.[]()|{}*+?b")""",
  769. },
  770. // Failing strings
  771. {
  772. R"""("ab")""",
  773. });
  774. test_schema("",
  775. // Schema
  776. R"""(
  777. {
  778. "type": ["array", "null"],
  779. "items": { "type": "string" }
  780. }
  781. )""",
  782. // Passing strings
  783. {
  784. "null",
  785. "[]",
  786. "[\"123\"]",
  787. "[\"foo\", \"bar\"]",
  788. },
  789. // Failing strings
  790. {
  791. "",
  792. "[123]",
  793. "\"foo\"",
  794. "[\"foo\", 42]",
  795. });
  796. test_schema("min+max items",
  797. // Schema
  798. R"""({
  799. "items": {
  800. "type": ["number", "integer"]
  801. },
  802. "minItems": 3,
  803. "maxItems": 5
  804. })""",
  805. // Passing strings
  806. {
  807. R"""([1, 2, 3])""",
  808. R"""([1, 2, 3, 4])""",
  809. R"""([1, 2, 3, 4, 5])""",
  810. // this is in fact correct; keyword do not apply if the type is wrong
  811. R"""(1)""",
  812. },
  813. // Failing strings
  814. {
  815. R"""([1, 2])""",
  816. R"""([1, 2, 3, 4, 5, 6])""",
  817. });
  818. // Properties (from: https://json-schema.org/understanding-json-schema/reference/object#properties)
  819. test_schema("object properties",
  820. // Schema
  821. R"""({
  822. "type": "object",
  823. "properties": {
  824. "number": { "type": "number" },
  825. "street_name": { "type": "string" },
  826. "street_type": { "enum": ["Street", "Avenue", "Boulevard"] }
  827. },
  828. "additionalProperties": false
  829. })""",
  830. // Passing strings
  831. {
  832. R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type":"Avenue"})""",
  833. // "By default, leaving out properties is valid"
  834. R"""({ "street_name": "Pennsylvania" })""",
  835. R"""({ "number": 1600, "street_name": "Pennsylvania" })""",
  836. // "By extension, even an empty object is valid"
  837. R"""({})""",
  838. R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type": "Avenue" })""",
  839. },
  840. // Failing strings
  841. {
  842. // Change datatype from number to string
  843. R"""({ "number": "1600", "street_name": "Pennsylvania", "street_type":"Avenue"})""",
  844. // Reorder properties
  845. R"""({ "street_name": "Pennsylvania", "number": 1600 })""",
  846. // Reorder properties
  847. R"""({ "number": "1600", "street_name": "Pennsylvania", "street_type":"Avenue"})""",
  848. // Additional properties set to false
  849. R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type":"Avenue", "direction":"NW"})""",
  850. });
  851. test_schema("additional properties can't override other properties",
  852. R"""({
  853. "properties": {
  854. "a": {"type": "integer"},
  855. "b": {"type": "integer"}
  856. },
  857. "additionalProperties": true
  858. })""",
  859. // Passing strings
  860. {
  861. R"""({"a": 42})""",
  862. R"""({"c": ""})""",
  863. R"""({"a": 42, "c": ""})""",
  864. R"""({"a_": ""})""",
  865. },
  866. // Failing strings
  867. {
  868. R"""()""",
  869. R"""({"a": ""})""",
  870. R"""({"a": "", "b": ""})""",
  871. });
  872. // Properties (from: https://json-schema.org/understanding-json-schema/reference/object#properties)
  873. test_schema("object properties, additionalProperties: true",
  874. // Schema
  875. R"""({
  876. "type": "object",
  877. "properties": {
  878. "number": { "type": "number" },
  879. "street_name": { "type": "string" },
  880. "street_type": { "enum": ["Street", "Avenue", "Boulevard"] }
  881. },
  882. "additionalProperties": true
  883. })""",
  884. // Passing strings
  885. {
  886. // "By extension, even an empty object is valid"
  887. R"""({})""",
  888. R"""({"number":1600,"street_name":"Pennsylvania","street_type":"Avenue"})""",
  889. // "By default, leaving out properties is valid"
  890. R"""({ "street_name": "Pennsylvania" })""",
  891. R"""({ "number": 1600, "street_name": "Pennsylvania" })""",
  892. // "By default, providing additional properties is valid"
  893. R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type":"Avenue", "direction":"NW"})""",
  894. R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type": "Avenue" })""",
  895. },
  896. // Failing strings
  897. {
  898. // Change datatype from number to string
  899. R"""({ "number": "1600", "street_name": "Pennsylvania", "street_type":"Avenue"})""",
  900. // Reorder properties
  901. R"""({ "street_name": "Pennsylvania", "number": 1600, "street_type":"Avenue"})""",
  902. });
  903. // Additional properties: false
  904. test_schema(
  905. "required + optional props each in original order",
  906. // Schema
  907. R"""({
  908. "type": "object",
  909. "properties": {
  910. "number": { "type": "number" },
  911. "street_name": { "type": "string" },
  912. "street_type": { "enum": ["Street", "Avenue", "Boulevard"] }
  913. },
  914. "additionalProperties": false
  915. })""",
  916. // Passing strings
  917. {
  918. R"""({ "street_name": "Pennsylvania" })""",
  919. R"""({ "number": 1600, "street_type":"Avenue"})""",
  920. R"""({ "number": 1600, "street_name": "Pennsylvania" })""",
  921. R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type":"Avenue"})""",
  922. // Spaces are permitted around enum values
  923. R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type": "Avenue" })""",
  924. },
  925. // Failing strings
  926. {
  927. // Reorder properties
  928. R"""({ "street_type": "Avenue", "number": 1600 })""",
  929. // Add "direction"
  930. R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type": "Avenue", "direction": "NW" })""",
  931. });
  932. test_schema("required + optional props each in original order",
  933. // Schema
  934. R"""({
  935. "properties": {
  936. "b": {"type": "string"},
  937. "a": {"type": "string"},
  938. "d": {"type": "string"},
  939. "c": {"type": "string"}
  940. },
  941. "required": ["a", "b"],
  942. "additionalProperties": false
  943. })""",
  944. // Passing strings
  945. {
  946. R"""({"b": "foo", "a": "bar"})""",
  947. R"""({"b":"foo","a":"bar","d":"qux"})""",
  948. R"""({"b":"foo", "a":"bar", "d":"qux", "c":"baz"})""",
  949. },
  950. // Failing strings
  951. {
  952. R"""({"a": "foo", "b": "bar"})""",
  953. R"""({"b": "bar"})""",
  954. R"""({"a": "foo", "c": "baz"})""",
  955. R"""({"a":"foo", "b":"bar", "c":"baz", "d":"qux"})""",
  956. });
  957. // NOTE: Example from https://json-schema.org/learn/getting-started-step-by-step#define-required-properties
  958. test_schema(
  959. "required props",
  960. // Schema
  961. R"""({
  962. "$schema": "https://json-schema.org/draft/2020-12/schema",
  963. "$id": "https://example.com/product.schema.json",
  964. "title": "Product",
  965. "description": "A product from Acme's catalog",
  966. "type": "object",
  967. "properties": {
  968. "productId": {
  969. "description": "The unique identifier for a product",
  970. "type": "integer"
  971. },
  972. "productName": {
  973. "description": "Name of the product",
  974. "type": "string"
  975. },
  976. "price": {
  977. "description": "The price of the product",
  978. "type": "number",
  979. "exclusiveMinimum": 0
  980. },
  981. "tags": {
  982. "description": "Tags for the product",
  983. "type": "array",
  984. "items": {
  985. "type": "string"
  986. },
  987. "minItems": 1,
  988. "DISABLED_uniqueItems": true
  989. },
  990. "dimensions": {
  991. "type": "object",
  992. "properties": {
  993. "length": {
  994. "type": "number"
  995. },
  996. "width": {
  997. "type": "number"
  998. },
  999. "height": {
  1000. "type": "number"
  1001. }
  1002. },
  1003. "required": [ "length", "width", "height" ]
  1004. }
  1005. },
  1006. "required": [ "productId", "productName", "price" ]
  1007. })""",
  1008. // Passing strings
  1009. {
  1010. R"""({"productId": 1, "productName": "A green door", "price": 12.50})""",
  1011. R"""({"productId": 1, "productName": "A green door", "price": 12.50, "tags": ["home", "green"]})""",
  1012. R"""({"productId": 1, "productName": "A green door", "price": 12.50, "tags": ["home", "green"], "dimensions": {"length": 785, "width": 250.5, "height": -0.359}})""",
  1013. },
  1014. // Failing strings
  1015. {
  1016. R"""({})""", // Missing all required properties
  1017. R"""({"productName": "A green door", "price": 12.50, "productId": 1})""", // Out of order properties
  1018. // `exclusiveMinimum` is OK for llg
  1019. R"""({"productId": 1, "productName": "A green door", "price": -12.50})""",
  1020. R"""({"productId": 1, "productName": "A green door"})""", // Missing required property (price)
  1021. R"""({"productName": "A green door", "price": 12.50})""", // Missing required property (productId)
  1022. R"""({"productId": 1, "productName": "A green door", "price": 12.50, "tags": []})""", // tags is empty, but minItems is 1
  1023. R"""({"productId": 1, "productName": "A green door", "price": 12.50, "dimensions": {"length": 785, "width": 250.5, "height": -0.359}, "tags": ["home", "green"]})""", // Tags and dimensions are out of order
  1024. // TODO: The following line should fail, but currently it passes. `uniqueItems` is not supported, as it would likely be too difficult to implement.
  1025. // R"""({"productId": 1, "productName": "A green door", "price": 12.50, "tags": ["home", "green", "home"]})""",
  1026. });
  1027. }
  1028. static void one_hot(llama_token_data_array & tok_arr, llama_token selected) {
  1029. auto n_vocab = tok_arr.size;
  1030. tok_arr.selected = -1;
  1031. tok_arr.sorted = false;
  1032. for (llama_token token_id = 0; token_id < (llama_token) n_vocab; token_id++) {
  1033. tok_arr.data[token_id].id = token_id;
  1034. tok_arr.data[token_id].logit = 0.0f;
  1035. }
  1036. tok_arr.data[selected].logit = 100.0f;
  1037. }
  1038. static void test_sampler_chain(void) {
  1039. auto sparams = llama_sampler_chain_default_params();
  1040. sparams.no_perf = false;
  1041. llama_sampler * sampler = llama_sampler_chain_init(sparams);
  1042. const auto grammar_data = R"(%llguidance {}
  1043. start: /[A-Z ]*/)";
  1044. llama_sampler_chain_add(sampler, llama_sampler_init_llg(vocab, "lark", grammar_data));
  1045. llama_sampler_chain_add(sampler, llama_sampler_init_dist(42));
  1046. auto input = "ALL YOUR BASE ARE BELONG TO US";
  1047. auto tokens = common_tokenize(vocab, input, false, false);
  1048. auto n_vocab = llama_vocab_n_tokens(vocab);
  1049. std::vector<llama_token_data> cur;
  1050. cur.reserve(n_vocab);
  1051. for (llama_token token_id = 0; token_id < (llama_token) n_vocab; token_id++) {
  1052. cur.emplace_back(llama_token_data{ token_id, 0.0f, 0.0f });
  1053. }
  1054. auto tok_arr = llama_token_data_array{ cur.data(), cur.size(), -1, false };
  1055. for (const auto token : tokens) {
  1056. one_hot(tok_arr, token);
  1057. fprintf(stderr, "applying token: %d\n", token);
  1058. llama_sampler_apply(sampler, &tok_arr);
  1059. auto idx = tok_arr.selected;
  1060. fprintf(stderr, " -> %d %f\n", cur[idx].id, cur[idx].logit);
  1061. assert(cur[tok_arr.selected].id == token);
  1062. llama_sampler_accept(sampler, token);
  1063. }
  1064. auto tok_eos = llama_vocab_eot(vocab);
  1065. if (tok_eos == LLAMA_TOKEN_NULL) {
  1066. tok_eos = llama_vocab_eos(vocab);
  1067. }
  1068. one_hot(tok_arr, tok_eos);
  1069. llama_sampler_apply(sampler, &tok_arr);
  1070. assert(cur[tok_arr.selected].id == tok_eos);
  1071. }
  1072. int main(int argc, const char ** argv) {
  1073. fprintf(stdout, "Running llguidance integration tests...\n");
  1074. if (argc != 2) {
  1075. fprintf(stderr, "Usage: %s <vocab-file>\n", argv[0]);
  1076. return 1;
  1077. }
  1078. const char * vocab_file = argv[1];
  1079. fprintf(stderr, "reading vocab from: '%s'\n", vocab_file);
  1080. llama_model * model;
  1081. llama_context * ctx;
  1082. llama_backend_init();
  1083. // load the vocab
  1084. {
  1085. auto mparams = llama_model_default_params();
  1086. mparams.vocab_only = true;
  1087. model = llama_model_load_from_file(vocab_file, mparams);
  1088. if (model == NULL) {
  1089. fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, vocab_file);
  1090. return 1;
  1091. }
  1092. // needed?
  1093. auto cparams = llama_context_default_params();
  1094. ctx = llama_init_from_model(model, cparams);
  1095. if (ctx == NULL) {
  1096. fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, vocab_file);
  1097. llama_model_free(model);
  1098. return 1;
  1099. }
  1100. }
  1101. vocab = llama_model_get_vocab(model);
  1102. test_simple_grammar();
  1103. test_complex_grammar();
  1104. test_special_chars();
  1105. test_quantifiers();
  1106. test_json_schema();
  1107. test_sampler_chain();
  1108. fprintf(stdout, "All tests passed.\n");
  1109. return 0;
  1110. }