LibLlama.swift 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. import Foundation
  2. // import llama
  3. enum LlamaError: Error {
  4. case couldNotInitializeContext
  5. }
  6. actor LlamaContext {
  7. private var model: OpaquePointer
  8. private var context: OpaquePointer
  9. private var batch: llama_batch
  10. private var tokens_list: [llama_token]
  11. var n_len: Int32 = 512
  12. var n_cur: Int32 = 0
  13. var n_decode: Int32 = 0
  14. init(model: OpaquePointer, context: OpaquePointer) {
  15. self.model = model
  16. self.context = context
  17. self.tokens_list = []
  18. self.batch = llama_batch_init(512, 0, 1)
  19. }
  20. deinit {
  21. llama_free(context)
  22. llama_free_model(model)
  23. llama_backend_free()
  24. }
  25. static func createContext(path: String) throws -> LlamaContext {
  26. llama_backend_init(false)
  27. let model_params = llama_model_default_params()
  28. let model = llama_load_model_from_file(path, model_params)
  29. guard let model else {
  30. print("Could not load model at \(path)")
  31. throw LlamaError.couldNotInitializeContext
  32. }
  33. var ctx_params = llama_context_default_params()
  34. ctx_params.seed = 1234
  35. ctx_params.n_ctx = 2048
  36. ctx_params.n_threads = 8
  37. ctx_params.n_threads_batch = 8
  38. let context = llama_new_context_with_model(model, ctx_params)
  39. guard let context else {
  40. print("Could not load context!")
  41. throw LlamaError.couldNotInitializeContext
  42. }
  43. return LlamaContext(model: model, context: context)
  44. }
  45. func get_n_tokens() -> Int32 {
  46. return batch.n_tokens;
  47. }
  48. func completion_init(text: String) {
  49. print("attempting to complete \"\(text)\"")
  50. tokens_list = tokenize(text: text, add_bos: true)
  51. let n_ctx = llama_n_ctx(context)
  52. let n_kv_req = tokens_list.count + (Int(n_len) - tokens_list.count)
  53. print("\n n_len = \(n_len), n_ctx = \(n_ctx), n_kv_req = \(n_kv_req)")
  54. if n_kv_req > n_ctx {
  55. print("error: n_kv_req > n_ctx, the required KV cache size is not big enough")
  56. }
  57. for id in tokens_list {
  58. print(token_to_piece(token: id))
  59. }
  60. // batch = llama_batch_init(512, 0) // done in init()
  61. batch.n_tokens = Int32(tokens_list.count)
  62. for i1 in 0..<batch.n_tokens {
  63. let i = Int(i1)
  64. batch.token[i] = tokens_list[i]
  65. batch.pos[i] = i1
  66. batch.n_seq_id[Int(i)] = 1
  67. batch.seq_id[Int(i)]![0] = 0
  68. batch.logits[i] = 0
  69. }
  70. batch.logits[Int(batch.n_tokens) - 1] = 1 // true
  71. if llama_decode(context, batch) != 0 {
  72. print("llama_decode() failed")
  73. }
  74. n_cur = batch.n_tokens
  75. }
  76. func completion_loop() -> String {
  77. var new_token_id: llama_token = 0
  78. let n_vocab = llama_n_vocab(model)
  79. let logits = llama_get_logits_ith(context, batch.n_tokens - 1)
  80. var candidates = Array<llama_token_data>()
  81. candidates.reserveCapacity(Int(n_vocab))
  82. for token_id in 0..<n_vocab {
  83. candidates.append(llama_token_data(id: token_id, logit: logits![Int(token_id)], p: 0.0))
  84. }
  85. candidates.withUnsafeMutableBufferPointer() { buffer in
  86. var candidates_p = llama_token_data_array(data: buffer.baseAddress, size: buffer.count, sorted: false)
  87. new_token_id = llama_sample_token_greedy(context, &candidates_p)
  88. }
  89. if new_token_id == llama_token_eos(context) || n_cur == n_len {
  90. print("\n")
  91. return ""
  92. }
  93. let new_token_str = token_to_piece(token: new_token_id)
  94. print(new_token_str)
  95. // tokens_list.append(new_token_id)
  96. batch.n_tokens = 0
  97. batch.token[Int(batch.n_tokens)] = new_token_id
  98. batch.pos[Int(batch.n_tokens)] = n_cur
  99. batch.n_seq_id[Int(batch.n_tokens)] = 1
  100. batch.seq_id[Int(batch.n_tokens)]![0] = 0
  101. batch.logits[Int(batch.n_tokens)] = 1 // true
  102. batch.n_tokens += 1
  103. n_decode += 1
  104. n_cur += 1
  105. if llama_decode(context, batch) != 0 {
  106. print("failed to evaluate llama!")
  107. }
  108. return new_token_str
  109. }
  110. func clear() {
  111. tokens_list.removeAll()
  112. }
  113. private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
  114. let n_tokens = text.count + (add_bos ? 1 : 0)
  115. let tokens = UnsafeMutablePointer<llama_token>.allocate(capacity: n_tokens)
  116. let tokenCount = llama_tokenize(model, text, Int32(text.count), tokens, Int32(n_tokens), add_bos, false)
  117. var swiftTokens: [llama_token] = []
  118. for i in 0..<tokenCount {
  119. swiftTokens.append(tokens[Int(i)])
  120. }
  121. tokens.deallocate()
  122. return swiftTokens
  123. }
  124. private func token_to_piece(token: llama_token) -> String {
  125. let result = UnsafeMutablePointer<Int8>.allocate(capacity: 8)
  126. result.initialize(repeating: Int8(0), count: 8)
  127. let _ = llama_token_to_piece(model, token, result, 8)
  128. let resultStr = String(cString: result)
  129. result.deallocate()
  130. return resultStr
  131. }
  132. }