test-grammar-llguidance.cpp 39 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202
  1. #ifdef NDEBUG
  2. # undef NDEBUG
  3. #endif
  4. #include "unicode.h"
  5. #include "sampling.h"
  6. #include <cassert>
  7. #include <string>
  8. #include <vector>
  9. static const llama_vocab * vocab;
  10. static bool match_string(const std::string & input, llama_sampler * grammar) {
  11. llama_sampler_reset(grammar);
  12. auto tokens = common_tokenize(vocab, input, false, false);
  13. auto n_vocab = llama_vocab_n_tokens(vocab);
  14. std::vector<llama_token_data> cur;
  15. cur.reserve(n_vocab);
  16. for (llama_token token_id = 0; token_id < (llama_token) n_vocab; token_id++) {
  17. cur.emplace_back(llama_token_data{ token_id, 0.0f, 0.0f });
  18. }
  19. auto tok_arr = llama_token_data_array{ cur.data(), cur.size(), -1, false };
  20. for (const auto token : tokens) {
  21. for (llama_token token_id = 0; token_id < (llama_token) n_vocab; token_id++) {
  22. cur[token_id].logit = 0.0f;
  23. }
  24. llama_sampler_apply(grammar, &tok_arr);
  25. if (cur[token].logit < 0.0f) {
  26. return false;
  27. }
  28. llama_sampler_accept(grammar, token);
  29. }
  30. // do we allow EOS at the end? if so the grammar is accepting
  31. auto tok_eos = llama_vocab_eot(vocab);
  32. if (tok_eos == LLAMA_TOKEN_NULL) {
  33. tok_eos = llama_vocab_eos(vocab);
  34. }
  35. cur[tok_eos].logit = 0.0f;
  36. llama_sampler_apply(grammar, &tok_arr);
  37. return cur[tok_eos].logit >= 0.0f;
  38. }
  39. static void test(const std::string & test_desc, const std::string & grammar_str,
  40. const std::vector<std::string> & passing_strings, const std::vector<std::string> & failing_strings) {
  41. fprintf(stderr, "⚫ Testing %s\n%s\n", test_desc.c_str(), grammar_str.c_str());
  42. fflush(stderr);
  43. auto * grammar = llama_sampler_init_llg(vocab, "lark", grammar_str.c_str());
  44. fprintf(stderr, " 🔵 Valid strings:\n");
  45. // Passing strings
  46. for (const auto & test_string : passing_strings) {
  47. fprintf(stderr, " \"%s\" ", test_string.c_str());
  48. fflush(stderr);
  49. bool matched = match_string(test_string, grammar);
  50. if (!matched) {
  51. fprintf(stderr, "❌ (failed to match)\n");
  52. // DEBUG: Write strings to files so that we can analyze more easily with gbnf-validator program to see exactly where things failed.
  53. // DEBUG: Write the grammar_str to test-grammar-integration.grammar.gbnf
  54. FILE * grammar_file = fopen("test-grammar-integration.grammar.gbnf", "w");
  55. if (grammar_file) {
  56. fprintf(grammar_file, "%s", grammar_str.c_str());
  57. fclose(grammar_file);
  58. }
  59. // DEBUG: Write the test string to test-grammar-integration.string.txt
  60. FILE * string_file = fopen("test-grammar-integration.string.txt", "w");
  61. if (string_file) {
  62. fprintf(string_file, "%s", test_string.c_str());
  63. fclose(string_file);
  64. }
  65. fprintf(stderr,
  66. "\n NOTE: Debug grammar file generated. To analyze this failure in detail, run the following "
  67. "command: ./llama-gbnf-validator test-grammar-integration.grammar.gbnf "
  68. "test-grammar-integration.string.txt\n\n");
  69. } else {
  70. fprintf(stdout, "✅︎\n");
  71. }
  72. assert(matched);
  73. }
  74. fprintf(stderr, " 🟠 Invalid strings:\n");
  75. // Failing strings
  76. for (const auto & test_string : failing_strings) {
  77. fprintf(stderr, " \"%s\" ", test_string.c_str());
  78. fflush(stderr);
  79. bool matched = match_string(test_string, grammar);
  80. if (matched) {
  81. fprintf(stderr, "❌ (incorrectly matched)\n");
  82. } else {
  83. fprintf(stdout, "✅︎\n");
  84. }
  85. assert(!matched);
  86. }
  87. llama_sampler_free(grammar);
  88. }
  89. static void test_grammar(const std::string & test_desc, const std::string & grammar_str,
  90. const std::vector<std::string> & passing_strings,
  91. const std::vector<std::string> & failing_strings) {
  92. test(test_desc + ". Grammar: " + grammar_str, grammar_str, passing_strings, failing_strings);
  93. }
  94. static void test_schema(const std::string & test_desc, const std::string & schema_str,
  95. const std::vector<std::string> & passing_strings,
  96. const std::vector<std::string> & failing_strings) {
  97. test(test_desc + ". Schema: " + schema_str, "%llguidance {}\nstart: %json " + schema_str, passing_strings,
  98. failing_strings);
  99. }
  100. static void test_simple_grammar() {
  101. test_schema("min 0",
  102. R"""({
  103. "type": "integer",
  104. "minimum": 0
  105. })""",
  106. // Passing strings
  107. {
  108. "0",
  109. "10",
  110. "12",
  111. "10000",
  112. },
  113. // Failing strings
  114. {
  115. "-1",
  116. "-10",
  117. "-10000",
  118. "-100000000000000000000000000000000",
  119. // "100000000000000000000000000000000",
  120. "00",
  121. "01",
  122. "-0",
  123. });
  124. test_schema("min 2",
  125. // Schema
  126. R"""({
  127. "type": "integer",
  128. "minimum": 2
  129. })""",
  130. // Passing strings
  131. {
  132. "2",
  133. "3",
  134. "4",
  135. "10",
  136. "20",
  137. "1234567890000000",
  138. },
  139. // Failing strings
  140. {
  141. "0", "1", "-1", "-100", "0", "1", "01", "02",
  142. // "12345678900000000",
  143. });
  144. test_schema("min 456",
  145. R"""({
  146. "type": "integer",
  147. "minimum": 456
  148. })""",
  149. // Passing strings
  150. {
  151. "456",
  152. "4560",
  153. "457",
  154. "460",
  155. "500",
  156. },
  157. // Failing strings
  158. {
  159. "455",
  160. "356",
  161. "50",
  162. "050",
  163. "-1",
  164. "-456",
  165. });
  166. test_schema("min -123",
  167. R"""({
  168. "type": "integer",
  169. "minimum": -123
  170. })""",
  171. // Passing strings
  172. {
  173. "-123",
  174. "-122",
  175. "-11",
  176. "-1",
  177. "0",
  178. "1",
  179. "123",
  180. "1234",
  181. "2345",
  182. },
  183. // Failing strings
  184. {
  185. "-1234",
  186. "-124",
  187. });
  188. test_schema("max 9999",
  189. // Schema
  190. R"""({
  191. "type": "integer",
  192. "maximum": 9999
  193. })""",
  194. // Passing strings
  195. {
  196. "-99999",
  197. "0",
  198. "9999",
  199. },
  200. // Failing strings
  201. {
  202. "10000",
  203. "99991",
  204. });
  205. test_schema("max -9999",
  206. // Schema
  207. R"""({
  208. "type": "integer",
  209. "maximum": -9999
  210. })""",
  211. // Passing strings
  212. {
  213. "-10000",
  214. "-9999",
  215. },
  216. // Failing strings
  217. {
  218. "-9998",
  219. "0",
  220. "9999",
  221. });
  222. test_schema("min 5 max 30",
  223. // Schema
  224. R"""({
  225. "type": "integer",
  226. "minimum": 5,
  227. "maximum": 30
  228. })""",
  229. // Passing strings
  230. {
  231. "5",
  232. "10",
  233. "30",
  234. },
  235. // Failing strings
  236. {
  237. "05",
  238. "4",
  239. "-1",
  240. "31",
  241. "123",
  242. "0123",
  243. });
  244. test_schema("min -1 max 1",
  245. R"""({
  246. "type": "integer",
  247. "minimum": -1,
  248. "maximum": 1
  249. })""",
  250. // Passing strings
  251. {
  252. "-1",
  253. "0",
  254. "1",
  255. },
  256. // Failing strings
  257. {
  258. "-11",
  259. "-10",
  260. "-2",
  261. "2",
  262. "10",
  263. "11",
  264. });
  265. test_schema("min -123 max 42",
  266. R"""({
  267. "type": "integer",
  268. "minimum": -123,
  269. "maximum": 42
  270. })""",
  271. // Passing strings
  272. {
  273. "-123",
  274. "-122",
  275. "-13",
  276. "-11",
  277. "-2",
  278. "-1",
  279. "0",
  280. "1",
  281. "5",
  282. "10",
  283. "39",
  284. "40",
  285. "42",
  286. },
  287. // Failing strings
  288. {
  289. "-0123",
  290. "-124",
  291. "-1123",
  292. "-200",
  293. "43",
  294. "123",
  295. "0123",
  296. });
  297. test_schema("exclusive min / max",
  298. // Schema
  299. R"""({
  300. "type": "integer",
  301. "exclusiveMinimum": 0,
  302. "exclusiveMaximum": 10000
  303. })""",
  304. // Passing strings
  305. {
  306. "1",
  307. "9999",
  308. },
  309. // Failing strings
  310. {
  311. "0",
  312. "01",
  313. "10000",
  314. "99999",
  315. });
  316. // Test case for a simple grammar
  317. test_grammar("simple grammar",
  318. R"""(
  319. start: expr
  320. expr: term ("+" term)*
  321. term: number
  322. number: /[0-9]+/ )""",
  323. // Passing strings
  324. {
  325. "42",
  326. "1+2+3+4+5",
  327. "123+456",
  328. },
  329. // Failing strings
  330. {
  331. "+",
  332. "/ 3",
  333. "1+2+3+4+5+",
  334. "12a45",
  335. });
  336. }
  337. static void test_complex_grammar() {
  338. // Test case for a more complex grammar, with both failure strings and success strings
  339. test_grammar("medium complexity grammar",
  340. // Grammar
  341. R"""(
  342. start: expression
  343. expression: term ws (("+"|"-") ws term)*
  344. term: factor ws (("*"|"/") ws factor)*
  345. factor: number | variable | "(" expression ")" | function-call
  346. number: /[0-9]+/
  347. variable: /[a-zA-Z_][a-zA-Z0-9_]*/
  348. function-call: variable ws "(" (expression ("," ws expression)*)? ")"
  349. ws: /[ \t\n\r]?/ )""",
  350. // Passing strings
  351. { "42",
  352. "1*2*3*4*5",
  353. "x",
  354. "x+10",
  355. "x1+y2",
  356. "(a+b)*(c-d)",
  357. "func()",
  358. "func(x,y+2)",
  359. "a*(b+c)-d/e",
  360. "f(g(x),h(y,z))",
  361. "x + 10",
  362. "x1 + y2",
  363. "(a + b) * (c - d)",
  364. "func()",
  365. "func(x, y + 2)",
  366. "a * (b + c) - d / e",
  367. "f(g(x), h(y, z))",
  368. "123+456",
  369. "123*456*789-123/456+789*123",
  370. "123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456" },
  371. // Failing strings
  372. {
  373. "+",
  374. "/ 3x",
  375. "x + + y",
  376. "a * / b",
  377. "func(,)",
  378. "func(x y)",
  379. "(a + b",
  380. "x + y)",
  381. "a + b * (c - d",
  382. "42 +",
  383. "x +",
  384. "x + 10 +",
  385. "(a + b) * (c - d",
  386. "func(",
  387. "func(x, y + 2",
  388. "a * (b + c) - d /",
  389. "f(g(x), h(y, z)",
  390. "123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456/",
  391. });
  392. }
  393. static void test_special_chars() {
  394. // A collection of tests to exercise special characters such as "."
  395. test_grammar("special characters",
  396. // Grammar
  397. R"""(
  398. start: /.../ "abc" /.../
  399. )""",
  400. // Passing strings
  401. { "abcabcabc", "aaaabcccc",
  402. // NOTE: Also ensures that multi-byte characters still count as a single character
  403. "🔵🟠✅abc❌🟠🔵" },
  404. // Failing strings
  405. { "aaabcccc", "aaaaabcccc", "aaaabccc", "aaaabccccc", "🔵🟠✅❌abc❌✅🟠🔵", "🔵🟠abc🟠🔵" });
  406. }
  407. static void test_quantifiers() {
  408. // A collection of tests to exercise * + and ? quantifiers
  409. test_grammar(
  410. "* quantifier",
  411. // Grammar
  412. R"""(start: "a"*)""",
  413. // Passing strings
  414. { "", "a", "aaaaa", "aaaaaaaaaaaaaaaaaa", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" },
  415. // Failing strings
  416. { "b", "ab", "aab", "ba", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab" });
  417. test_grammar(
  418. "+ quantifier",
  419. // Grammar
  420. R"""(start: "a"+)""",
  421. // Passing strings
  422. { "a", "aaaaa", "aaaaaaaaaaaaaaaaaa", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" },
  423. // Failing strings
  424. { "", "b", "ab", "aab", "ba", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab" });
  425. test_grammar("? quantifier",
  426. // Grammar
  427. R"""(start: "a"?)""",
  428. // Passing strings
  429. { "", "a" },
  430. // Failing strings
  431. {
  432. "b",
  433. "ab",
  434. "aa",
  435. "ba",
  436. });
  437. test_grammar("mixed quantifiers",
  438. // Grammar
  439. R"""(
  440. start: cons+ vowel* cons? (vowel cons)*
  441. vowel: /[aeiouy]/
  442. cons: /[bcdfghjklmnpqrstvwxyz]/
  443. )""",
  444. // Passing strings
  445. {
  446. "yes",
  447. "no",
  448. "noyes",
  449. "crwth",
  450. "four",
  451. "bryyyy",
  452. },
  453. // Failing strings
  454. {
  455. "yess",
  456. "yesno",
  457. "forty",
  458. "catyyy",
  459. });
  460. test_grammar("simple exact repetition",
  461. // Grammar
  462. R"""(
  463. start: /[ab]{4}/
  464. )""",
  465. // Passing strings
  466. {
  467. "aaaa",
  468. "bbbb",
  469. "abab",
  470. },
  471. // Failing strings
  472. {
  473. "a",
  474. "b",
  475. "aaaaa",
  476. });
  477. test_grammar("simple min repetition",
  478. // Grammar
  479. R"""(
  480. start: /[ab]{4,}/
  481. )""",
  482. // Passing strings
  483. {
  484. "aaaa",
  485. "aaaaab",
  486. "bbbb",
  487. "ababab",
  488. },
  489. // Failing strings
  490. {
  491. "",
  492. "aba",
  493. });
  494. test_grammar("simple max repetition",
  495. // Grammar
  496. R"""(
  497. start: /[ab]{0,4}/
  498. )""",
  499. // Passing strings
  500. {
  501. "",
  502. "a",
  503. "aa",
  504. "aaa",
  505. "aaab",
  506. },
  507. // Failing strings
  508. {
  509. "aaaaa",
  510. });
  511. // test_grammar("min / max repetition",
  512. // // Grammar
  513. // R"""(
  514. // start: ("0x" /[A-F0-9]{2}/ " "?){3,5}
  515. // )""",
  516. // // Passing strings
  517. // {
  518. // "0xFF 0x12 0xAB",
  519. // "0xFF 0x12 0xAB 0x00 0x00",
  520. // },
  521. // // Failing strings
  522. // {
  523. // "",
  524. // "0xFF",
  525. // "0xFF 0x12",
  526. // "0xFF 0x12 0xAB 0x00 0x00 0x00",
  527. // });
  528. }
  529. static void test_json_schema() {
  530. // Note that this is similar to the regular grammar tests,
  531. // but we convert each json schema to a grammar before parsing.
  532. // Otherwise, this test structure is the same.
  533. test_schema("empty schema (object)",
  534. // Schema
  535. R"""(
  536. {"type":"object"}
  537. )""",
  538. // Passing strings
  539. {
  540. R"""({})""",
  541. R"""({"foo": "bar"})""",
  542. },
  543. // Failing strings
  544. {
  545. "",
  546. "[]",
  547. "null",
  548. R"""("")""",
  549. "true",
  550. });
  551. test_schema(
  552. "exotic formats (list)",
  553. // Schema
  554. R"""({
  555. "items": [
  556. { "format": "date" },
  557. { "format": "uuid" },
  558. { "format": "time" },
  559. { "format": "date-time" }
  560. ]
  561. })""",
  562. // Passing strings
  563. {
  564. // "{}", // NOTE: This string passes for this schema on https://www.jsonschemavalidator.net/ -- should it?
  565. // "[]", // NOTE: This string passes for this schema on https://www.jsonschemavalidator.net/ -- should it?
  566. R"""(["2012-04-23", "12345678-1234-1234-1234-1234567890ab", "18:25:43.511Z", "2012-04-23T18:25:43.511Z"])""",
  567. //R"""(["2012-04-23","12345678-1234-1234-1234-1234567890ab"])""", // NOTE: This string passes for this schema on https://www.jsonschemavalidator.net/ -- should it?
  568. //R"""({"foo": "bar"})""", // NOTE: This string passes for this schema on https://www.jsonschemavalidator.net/ -- should it?
  569. },
  570. // Failing strings
  571. {
  572. R"""(["foo", "bar"])""",
  573. R"""(["12345678-1234-1234-1234-1234567890ab"])""",
  574. });
  575. test_schema("string",
  576. // Schema
  577. R"""({
  578. "type": "string"
  579. })""",
  580. // Passing strings
  581. {
  582. R"""("foo")""",
  583. R"""("bar")""",
  584. R"""("")""",
  585. },
  586. // Failing strings
  587. {
  588. R"""({})""",
  589. R"""("foo": "bar")""",
  590. });
  591. test_schema("string w/ min length 1",
  592. // Schema
  593. R"""({
  594. "type": "string",
  595. "minLength": 1
  596. })""",
  597. // Passing strings
  598. {
  599. R"""("foo")""",
  600. R"""("bar")""",
  601. },
  602. // Failing strings
  603. {
  604. R"""("")""",
  605. R"""({})""",
  606. R"""("foo": "bar")""",
  607. });
  608. test_schema("string w/ min length 3",
  609. // Schema
  610. R"""({
  611. "type": "string",
  612. "minLength": 3
  613. })""",
  614. // Passing strings
  615. {
  616. R"""("foo")""",
  617. R"""("bar")""",
  618. R"""("foobar")""",
  619. },
  620. // Failing strings
  621. {
  622. R"""("")""",
  623. R"""("f")""",
  624. R"""("fo")""",
  625. });
  626. test_schema("string w/ max length",
  627. // Schema
  628. R"""({
  629. "type": "string",
  630. "maxLength": 3
  631. })""",
  632. // Passing strings
  633. {
  634. R"""("foo")""",
  635. R"""("bar")""",
  636. R"""("")""",
  637. R"""("f")""",
  638. R"""("fo")""",
  639. },
  640. // Failing strings
  641. {
  642. R"""("foobar")""",
  643. });
  644. test_schema("string w/ min & max length",
  645. // Schema
  646. R"""({
  647. "type": "string",
  648. "minLength": 1,
  649. "maxLength": 4
  650. })""",
  651. // Passing strings
  652. {
  653. R"""("foo")""",
  654. R"""("bar")""",
  655. R"""("f")""",
  656. R"""("barf")""",
  657. },
  658. // Failing strings
  659. {
  660. R"""("")""",
  661. R"""("barfo")""",
  662. R"""("foobar")""",
  663. });
  664. test_schema("boolean",
  665. // Schema
  666. R"""({
  667. "type": "boolean"
  668. })""",
  669. // Passing strings
  670. {
  671. "true",
  672. "false",
  673. },
  674. // Failing strings
  675. {
  676. R"""("")""",
  677. R"""("true")""",
  678. R"""(True)""",
  679. R"""(FALSE)""",
  680. });
  681. test_schema("integer",
  682. // Schema
  683. R"""({
  684. "type": "integer"
  685. })""",
  686. // Passing strings
  687. {
  688. R"""(0)""",
  689. R"""(12345)""",
  690. R"""(1234567890123456)""",
  691. },
  692. // Failing strings
  693. {
  694. R"""()""",
  695. R"""(01)""",
  696. R"""(007)""",
  697. R"""(12345678901234567 )""",
  698. });
  699. test_schema("string const",
  700. // Schema
  701. R"""({
  702. "const": "foo"
  703. })""",
  704. // Passing strings
  705. {
  706. R"""("foo")""",
  707. },
  708. // Failing strings
  709. {
  710. R"""(foo)""",
  711. R"""("bar")""",
  712. });
  713. test_schema("non-string const",
  714. // Schema
  715. R"""({
  716. "const": true
  717. })""",
  718. // Passing strings
  719. {
  720. R"""(true)""",
  721. },
  722. // Failing strings
  723. {
  724. R"""()""",
  725. R"""(foo)""",
  726. R"""("true")""",
  727. });
  728. test_schema("non-string const",
  729. // Schema
  730. R"""({
  731. "enum": ["red", "amber", "green", null, 42, ["foo"]]
  732. })""",
  733. // Passing strings
  734. {
  735. R"""("red")""",
  736. R"""(null)""",
  737. R"""(42)""",
  738. R"""(["foo"])""",
  739. },
  740. // Failing strings
  741. {
  742. R"""()""",
  743. R"""(420)""",
  744. R"""(true)""",
  745. R"""(foo)""",
  746. });
  747. test_schema("simple pattern",
  748. // Schema
  749. R"""({
  750. "pattern": "^[a-zA-Z0-9_-]*$"
  751. })""",
  752. // Passing strings
  753. {
  754. R"""("")""",
  755. R"""("He_llo-12")""",
  756. },
  757. // Failing strings
  758. {
  759. R"""("!")""",
  760. R"""("Hello World")""",
  761. });
  762. test_schema("pattern with escapes",
  763. // Schema
  764. R"""({
  765. "pattern": "^a\\^\\$\\.\\[\\]\\(\\)\\|\\{\\}\\*\\+\\?b$"
  766. })""",
  767. // Passing strings
  768. {
  769. R"""("a^$.[]()|{}*+?b")""",
  770. },
  771. // Failing strings
  772. {
  773. R"""("ab")""",
  774. });
  775. test_schema("",
  776. // Schema
  777. R"""(
  778. {
  779. "type": ["array", "null"],
  780. "items": { "type": "string" }
  781. }
  782. )""",
  783. // Passing strings
  784. {
  785. "null",
  786. "[]",
  787. "[\"123\"]",
  788. "[\"foo\", \"bar\"]",
  789. },
  790. // Failing strings
  791. {
  792. "",
  793. "[123]",
  794. "\"foo\"",
  795. "[\"foo\", 42]",
  796. });
  797. test_schema("min+max items",
  798. // Schema
  799. R"""({
  800. "items": {
  801. "type": ["number", "integer"]
  802. },
  803. "minItems": 3,
  804. "maxItems": 5
  805. })""",
  806. // Passing strings
  807. {
  808. R"""([1, 2, 3])""",
  809. R"""([1, 2, 3, 4])""",
  810. R"""([1, 2, 3, 4, 5])""",
  811. // this is in fact correct; keyword do not apply if the type is wrong
  812. R"""(1)""",
  813. },
  814. // Failing strings
  815. {
  816. R"""([1, 2])""",
  817. R"""([1, 2, 3, 4, 5, 6])""",
  818. });
  819. // Properties (from: https://json-schema.org/understanding-json-schema/reference/object#properties)
  820. test_schema("object properties",
  821. // Schema
  822. R"""({
  823. "type": "object",
  824. "properties": {
  825. "number": { "type": "number" },
  826. "street_name": { "type": "string" },
  827. "street_type": { "enum": ["Street", "Avenue", "Boulevard"] }
  828. },
  829. "additionalProperties": false
  830. })""",
  831. // Passing strings
  832. {
  833. R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type":"Avenue"})""",
  834. // "By default, leaving out properties is valid"
  835. R"""({ "street_name": "Pennsylvania" })""",
  836. R"""({ "number": 1600, "street_name": "Pennsylvania" })""",
  837. // "By extension, even an empty object is valid"
  838. R"""({})""",
  839. R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type": "Avenue" })""",
  840. },
  841. // Failing strings
  842. {
  843. // Change datatype from number to string
  844. R"""({ "number": "1600", "street_name": "Pennsylvania", "street_type":"Avenue"})""",
  845. // Reorder properties
  846. R"""({ "street_name": "Pennsylvania", "number": 1600 })""",
  847. // Reorder properties
  848. R"""({ "number": "1600", "street_name": "Pennsylvania", "street_type":"Avenue"})""",
  849. // Additional properties set to false
  850. R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type":"Avenue", "direction":"NW"})""",
  851. });
  852. test_schema("additional properties can't override other properties",
  853. R"""({
  854. "properties": {
  855. "a": {"type": "integer"},
  856. "b": {"type": "integer"}
  857. },
  858. "additionalProperties": true
  859. })""",
  860. // Passing strings
  861. {
  862. R"""({"a": 42})""",
  863. R"""({"c": ""})""",
  864. R"""({"a": 42, "c": ""})""",
  865. R"""({"a_": ""})""",
  866. },
  867. // Failing strings
  868. {
  869. R"""()""",
  870. R"""({"a": ""})""",
  871. R"""({"a": "", "b": ""})""",
  872. });
  873. // Properties (from: https://json-schema.org/understanding-json-schema/reference/object#properties)
  874. test_schema("object properties, additionalProperties: true",
  875. // Schema
  876. R"""({
  877. "type": "object",
  878. "properties": {
  879. "number": { "type": "number" },
  880. "street_name": { "type": "string" },
  881. "street_type": { "enum": ["Street", "Avenue", "Boulevard"] }
  882. },
  883. "additionalProperties": true
  884. })""",
  885. // Passing strings
  886. {
  887. // "By extension, even an empty object is valid"
  888. R"""({})""",
  889. R"""({"number":1600,"street_name":"Pennsylvania","street_type":"Avenue"})""",
  890. // "By default, leaving out properties is valid"
  891. R"""({ "street_name": "Pennsylvania" })""",
  892. R"""({ "number": 1600, "street_name": "Pennsylvania" })""",
  893. // "By default, providing additional properties is valid"
  894. R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type":"Avenue", "direction":"NW"})""",
  895. R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type": "Avenue" })""",
  896. },
  897. // Failing strings
  898. {
  899. // Change datatype from number to string
  900. R"""({ "number": "1600", "street_name": "Pennsylvania", "street_type":"Avenue"})""",
  901. // Reorder properties
  902. R"""({ "street_name": "Pennsylvania", "number": 1600, "street_type":"Avenue"})""",
  903. });
  904. // Additional properties: false
  905. test_schema(
  906. "required + optional props each in original order",
  907. // Schema
  908. R"""({
  909. "type": "object",
  910. "properties": {
  911. "number": { "type": "number" },
  912. "street_name": { "type": "string" },
  913. "street_type": { "enum": ["Street", "Avenue", "Boulevard"] }
  914. },
  915. "additionalProperties": false
  916. })""",
  917. // Passing strings
  918. {
  919. R"""({ "street_name": "Pennsylvania" })""",
  920. R"""({ "number": 1600, "street_type":"Avenue"})""",
  921. R"""({ "number": 1600, "street_name": "Pennsylvania" })""",
  922. R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type":"Avenue"})""",
  923. // Spaces are permitted around enum values
  924. R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type": "Avenue" })""",
  925. },
  926. // Failing strings
  927. {
  928. // Reorder properties
  929. R"""({ "street_type": "Avenue", "number": 1600 })""",
  930. // Add "direction"
  931. R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type": "Avenue", "direction": "NW" })""",
  932. });
  933. test_schema("required + optional props each in original order",
  934. // Schema
  935. R"""({
  936. "properties": {
  937. "b": {"type": "string"},
  938. "a": {"type": "string"},
  939. "d": {"type": "string"},
  940. "c": {"type": "string"}
  941. },
  942. "required": ["a", "b"],
  943. "additionalProperties": false
  944. })""",
  945. // Passing strings
  946. {
  947. R"""({"b": "foo", "a": "bar"})""",
  948. R"""({"b":"foo","a":"bar","d":"qux"})""",
  949. R"""({"b":"foo", "a":"bar", "d":"qux", "c":"baz"})""",
  950. },
  951. // Failing strings
  952. {
  953. R"""({"a": "foo", "b": "bar"})""",
  954. R"""({"b": "bar"})""",
  955. R"""({"a": "foo", "c": "baz"})""",
  956. R"""({"a":"foo", "b":"bar", "c":"baz", "d":"qux"})""",
  957. });
  958. // NOTE: Example from https://json-schema.org/learn/getting-started-step-by-step#define-required-properties
  959. test_schema(
  960. "required props",
  961. // Schema
  962. R"""({
  963. "$schema": "https://json-schema.org/draft/2020-12/schema",
  964. "$id": "https://example.com/product.schema.json",
  965. "title": "Product",
  966. "description": "A product from Acme's catalog",
  967. "type": "object",
  968. "properties": {
  969. "productId": {
  970. "description": "The unique identifier for a product",
  971. "type": "integer"
  972. },
  973. "productName": {
  974. "description": "Name of the product",
  975. "type": "string"
  976. },
  977. "price": {
  978. "description": "The price of the product",
  979. "type": "number",
  980. "exclusiveMinimum": 0
  981. },
  982. "tags": {
  983. "description": "Tags for the product",
  984. "type": "array",
  985. "items": {
  986. "type": "string"
  987. },
  988. "minItems": 1,
  989. "DISABLED_uniqueItems": true
  990. },
  991. "dimensions": {
  992. "type": "object",
  993. "properties": {
  994. "length": {
  995. "type": "number"
  996. },
  997. "width": {
  998. "type": "number"
  999. },
  1000. "height": {
  1001. "type": "number"
  1002. }
  1003. },
  1004. "required": [ "length", "width", "height" ]
  1005. }
  1006. },
  1007. "required": [ "productId", "productName", "price" ]
  1008. })""",
  1009. // Passing strings
  1010. {
  1011. R"""({"productId": 1, "productName": "A green door", "price": 12.50})""",
  1012. R"""({"productId": 1, "productName": "A green door", "price": 12.50, "tags": ["home", "green"]})""",
  1013. R"""({"productId": 1, "productName": "A green door", "price": 12.50, "tags": ["home", "green"], "dimensions": {"length": 785, "width": 250.5, "height": -0.359}})""",
  1014. },
  1015. // Failing strings
  1016. {
  1017. R"""({})""", // Missing all required properties
  1018. R"""({"productName": "A green door", "price": 12.50, "productId": 1})""", // Out of order properties
  1019. // `exclusiveMinimum` is OK for llg
  1020. R"""({"productId": 1, "productName": "A green door", "price": -12.50})""",
  1021. R"""({"productId": 1, "productName": "A green door"})""", // Missing required property (price)
  1022. R"""({"productName": "A green door", "price": 12.50})""", // Missing required property (productId)
  1023. R"""({"productId": 1, "productName": "A green door", "price": 12.50, "tags": []})""", // tags is empty, but minItems is 1
  1024. R"""({"productId": 1, "productName": "A green door", "price": 12.50, "dimensions": {"length": 785, "width": 250.5, "height": -0.359}, "tags": ["home", "green"]})""", // Tags and dimensions are out of order
  1025. // TODO: The following line should fail, but currently it passes. `uniqueItems` is not supported, as it would likely be too difficult to implement.
  1026. // R"""({"productId": 1, "productName": "A green door", "price": 12.50, "tags": ["home", "green", "home"]})""",
  1027. });
  1028. }
  1029. static void one_hot(llama_token_data_array & tok_arr, llama_token selected) {
  1030. auto n_vocab = tok_arr.size;
  1031. tok_arr.selected = -1;
  1032. tok_arr.sorted = false;
  1033. for (llama_token token_id = 0; token_id < (llama_token) n_vocab; token_id++) {
  1034. tok_arr.data[token_id].id = token_id;
  1035. tok_arr.data[token_id].logit = 0.0f;
  1036. }
  1037. tok_arr.data[selected].logit = 100.0f;
  1038. }
  1039. static void test_sampler_chain(void) {
  1040. auto sparams = llama_sampler_chain_default_params();
  1041. sparams.no_perf = false;
  1042. llama_sampler * sampler = llama_sampler_chain_init(sparams);
  1043. const auto grammar_data = R"(%llguidance {}
  1044. start: /[A-Z ]*/)";
  1045. llama_sampler_chain_add(sampler, llama_sampler_init_llg(vocab, "lark", grammar_data));
  1046. llama_sampler_chain_add(sampler, llama_sampler_init_dist(42));
  1047. auto input = "ALL YOUR BASE ARE BELONG TO US";
  1048. auto tokens = common_tokenize(vocab, input, false, false);
  1049. auto n_vocab = llama_vocab_n_tokens(vocab);
  1050. std::vector<llama_token_data> cur;
  1051. cur.reserve(n_vocab);
  1052. for (llama_token token_id = 0; token_id < (llama_token) n_vocab; token_id++) {
  1053. cur.emplace_back(llama_token_data{ token_id, 0.0f, 0.0f });
  1054. }
  1055. auto tok_arr = llama_token_data_array{ cur.data(), cur.size(), -1, false };
  1056. for (const auto token : tokens) {
  1057. one_hot(tok_arr, token);
  1058. fprintf(stderr, "applying token: %d\n", token);
  1059. llama_sampler_apply(sampler, &tok_arr);
  1060. auto idx = tok_arr.selected;
  1061. fprintf(stderr, " -> %d %f\n", cur[idx].id, cur[idx].logit);
  1062. assert(cur[tok_arr.selected].id == token);
  1063. llama_sampler_accept(sampler, token);
  1064. }
  1065. auto tok_eos = llama_vocab_eot(vocab);
  1066. if (tok_eos == LLAMA_TOKEN_NULL) {
  1067. tok_eos = llama_vocab_eos(vocab);
  1068. }
  1069. one_hot(tok_arr, tok_eos);
  1070. llama_sampler_apply(sampler, &tok_arr);
  1071. assert(cur[tok_arr.selected].id == tok_eos);
  1072. }
  1073. int main(int argc, const char ** argv) {
  1074. fprintf(stdout, "Running llguidance integration tests...\n");
  1075. if (argc != 2) {
  1076. fprintf(stderr, "Usage: %s <vocab-file>\n", argv[0]);
  1077. return 1;
  1078. }
  1079. const char * vocab_file = argv[1];
  1080. fprintf(stderr, "reading vocab from: '%s'\n", vocab_file);
  1081. llama_model * model;
  1082. llama_context * ctx;
  1083. llama_backend_init();
  1084. // load the vocab
  1085. {
  1086. auto mparams = llama_model_default_params();
  1087. mparams.vocab_only = true;
  1088. model = llama_model_load_from_file(vocab_file, mparams);
  1089. if (model == NULL) {
  1090. fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, vocab_file);
  1091. return 1;
  1092. }
  1093. // needed?
  1094. auto cparams = llama_context_default_params();
  1095. ctx = llama_init_from_model(model, cparams);
  1096. if (ctx == NULL) {
  1097. fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, vocab_file);
  1098. llama_model_free(model);
  1099. return 1;
  1100. }
  1101. }
  1102. vocab = llama_model_get_vocab(model);
  1103. test_simple_grammar();
  1104. test_complex_grammar();
  1105. test_special_chars();
  1106. test_quantifiers();
  1107. test_json_schema();
  1108. test_sampler_chain();
  1109. fprintf(stdout, "All tests passed.\n");
  1110. return 0;
  1111. }