LibLlama.swift 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. import Foundation
  2. // import llama
  3. enum LlamaError: Error {
  4. case couldNotInitializeContext
  5. }
  6. actor LlamaContext {
  7. private var model: OpaquePointer
  8. private var context: OpaquePointer
  9. private var batch: llama_batch
  10. private var tokens_list: [llama_token]
  11. /// This variable is used to store temporarily invalid cchars
  12. private var temporary_invalid_cchars: [CChar]
  13. var n_len: Int32 = 512
  14. var n_cur: Int32 = 0
  15. var n_decode: Int32 = 0
  16. init(model: OpaquePointer, context: OpaquePointer) {
  17. self.model = model
  18. self.context = context
  19. self.tokens_list = []
  20. self.batch = llama_batch_init(512, 0, 1)
  21. self.temporary_invalid_cchars = []
  22. }
  23. deinit {
  24. llama_free(context)
  25. llama_free_model(model)
  26. llama_backend_free()
  27. }
  28. static func createContext(path: String) throws -> LlamaContext {
  29. llama_backend_init(false)
  30. let model_params = llama_model_default_params()
  31. let model = llama_load_model_from_file(path, model_params)
  32. guard let model else {
  33. print("Could not load model at \(path)")
  34. throw LlamaError.couldNotInitializeContext
  35. }
  36. var ctx_params = llama_context_default_params()
  37. ctx_params.seed = 1234
  38. ctx_params.n_ctx = 2048
  39. ctx_params.n_threads = 8
  40. ctx_params.n_threads_batch = 8
  41. let context = llama_new_context_with_model(model, ctx_params)
  42. guard let context else {
  43. print("Could not load context!")
  44. throw LlamaError.couldNotInitializeContext
  45. }
  46. return LlamaContext(model: model, context: context)
  47. }
  48. func get_n_tokens() -> Int32 {
  49. return batch.n_tokens;
  50. }
  51. func completion_init(text: String) {
  52. print("attempting to complete \"\(text)\"")
  53. tokens_list = tokenize(text: text, add_bos: true)
  54. temporary_invalid_cchars = []
  55. let n_ctx = llama_n_ctx(context)
  56. let n_kv_req = tokens_list.count + (Int(n_len) - tokens_list.count)
  57. print("\n n_len = \(n_len), n_ctx = \(n_ctx), n_kv_req = \(n_kv_req)")
  58. if n_kv_req > n_ctx {
  59. print("error: n_kv_req > n_ctx, the required KV cache size is not big enough")
  60. }
  61. for id in tokens_list {
  62. print(String(cString: token_to_piece(token: id) + [0]))
  63. }
  64. // batch = llama_batch_init(512, 0) // done in init()
  65. batch.n_tokens = Int32(tokens_list.count)
  66. for i1 in 0..<batch.n_tokens {
  67. let i = Int(i1)
  68. batch.token[i] = tokens_list[i]
  69. batch.pos[i] = i1
  70. batch.n_seq_id[Int(i)] = 1
  71. batch.seq_id[Int(i)]![0] = 0
  72. batch.logits[i] = 0
  73. }
  74. batch.logits[Int(batch.n_tokens) - 1] = 1 // true
  75. if llama_decode(context, batch) != 0 {
  76. print("llama_decode() failed")
  77. }
  78. n_cur = batch.n_tokens
  79. }
  80. func completion_loop() -> String {
  81. var new_token_id: llama_token = 0
  82. let n_vocab = llama_n_vocab(model)
  83. let logits = llama_get_logits_ith(context, batch.n_tokens - 1)
  84. var candidates = Array<llama_token_data>()
  85. candidates.reserveCapacity(Int(n_vocab))
  86. for token_id in 0..<n_vocab {
  87. candidates.append(llama_token_data(id: token_id, logit: logits![Int(token_id)], p: 0.0))
  88. }
  89. candidates.withUnsafeMutableBufferPointer() { buffer in
  90. var candidates_p = llama_token_data_array(data: buffer.baseAddress, size: buffer.count, sorted: false)
  91. new_token_id = llama_sample_token_greedy(context, &candidates_p)
  92. }
  93. if new_token_id == llama_token_eos(context) || n_cur == n_len {
  94. print("\n")
  95. let new_token_str = String(cString: temporary_invalid_cchars + [0])
  96. temporary_invalid_cchars.removeAll()
  97. return new_token_str
  98. }
  99. let new_token_cchars = token_to_piece(token: new_token_id)
  100. temporary_invalid_cchars.append(contentsOf: new_token_cchars)
  101. let new_token_str: String
  102. if let string = String(validatingUTF8: temporary_invalid_cchars + [0]) {
  103. temporary_invalid_cchars.removeAll()
  104. new_token_str = string
  105. } else if (0 ..< temporary_invalid_cchars.count).contains(where: {$0 != 0 && String(validatingUTF8: Array(temporary_invalid_cchars.suffix($0)) + [0]) != nil}) {
  106. // in this case, at least the suffix of the temporary_invalid_cchars can be interpreted as UTF8 string
  107. let string = String(cString: temporary_invalid_cchars + [0])
  108. temporary_invalid_cchars.removeAll()
  109. new_token_str = string
  110. } else {
  111. new_token_str = ""
  112. }
  113. print(new_token_str)
  114. // tokens_list.append(new_token_id)
  115. batch.n_tokens = 0
  116. batch.token[Int(batch.n_tokens)] = new_token_id
  117. batch.pos[Int(batch.n_tokens)] = n_cur
  118. batch.n_seq_id[Int(batch.n_tokens)] = 1
  119. batch.seq_id[Int(batch.n_tokens)]![0] = 0
  120. batch.logits[Int(batch.n_tokens)] = 1 // true
  121. batch.n_tokens += 1
  122. n_decode += 1
  123. n_cur += 1
  124. if llama_decode(context, batch) != 0 {
  125. print("failed to evaluate llama!")
  126. }
  127. return new_token_str
  128. }
  129. func clear() {
  130. tokens_list.removeAll()
  131. temporary_invalid_cchars.removeAll()
  132. }
  133. private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
  134. let utf8Count = text.utf8.count
  135. let n_tokens = utf8Count + (add_bos ? 1 : 0)
  136. let tokens = UnsafeMutablePointer<llama_token>.allocate(capacity: n_tokens)
  137. let tokenCount = llama_tokenize(model, text, Int32(utf8Count), tokens, Int32(n_tokens), add_bos, false)
  138. var swiftTokens: [llama_token] = []
  139. for i in 0..<tokenCount {
  140. swiftTokens.append(tokens[Int(i)])
  141. }
  142. tokens.deallocate()
  143. return swiftTokens
  144. }
  145. /// - note: The result does not contain null-terminator
  146. private func token_to_piece(token: llama_token) -> [CChar] {
  147. let result = UnsafeMutablePointer<Int8>.allocate(capacity: 8)
  148. result.initialize(repeating: Int8(0), count: 8)
  149. defer {
  150. result.deallocate()
  151. }
  152. let nTokens = llama_token_to_piece(model, token, result, 8)
  153. if nTokens < 0 {
  154. let newResult = UnsafeMutablePointer<Int8>.allocate(capacity: Int(-nTokens))
  155. newResult.initialize(repeating: Int8(0), count: Int(-nTokens))
  156. defer {
  157. newResult.deallocate()
  158. }
  159. let nNewTokens = llama_token_to_piece(model, token, newResult, -nTokens)
  160. let bufferPointer = UnsafeBufferPointer(start: newResult, count: Int(nNewTokens))
  161. return Array(bufferPointer)
  162. } else {
  163. let bufferPointer = UnsafeBufferPointer(start: result, count: Int(nTokens))
  164. return Array(bufferPointer)
  165. }
  166. }
  167. }