Просмотр исходного кода

llama.swiftui : fix infinite loop, ouput timings, buff UI (#4674)

* fix infinite loop

* slight UI simplification, clearer UX

* clearer UI text, add timings to completion log
Peter Sugihara 2 лет назад
Родитель
Сommit
afd997ab60

+ 2 - 0
examples/llama.swiftui/llama.cpp.swift/LibLlama.swift

@@ -1,5 +1,7 @@
 import Foundation
 import Foundation
 
 
+// To use this in your own project, add llama.cpp as a swift package dependency
+// and uncomment this import line.
 // import llama
 // import llama
 
 
 enum LlamaError: Error {
 enum LlamaError: Error {

+ 21 - 6
examples/llama.swiftui/llama.swiftui/Models/LlamaState.swift

@@ -4,6 +4,7 @@ import Foundation
 class LlamaState: ObservableObject {
 class LlamaState: ObservableObject {
     @Published var messageLog = ""
     @Published var messageLog = ""
     @Published var cacheCleared = false
     @Published var cacheCleared = false
+    let NS_PER_S = 1_000_000_000.0
 
 
     private var llamaContext: LlamaContext?
     private var llamaContext: LlamaContext?
     private var defaultModelUrl: URL? {
     private var defaultModelUrl: URL? {
@@ -20,12 +21,12 @@ class LlamaState: ObservableObject {
     }
     }
 
 
     func loadModel(modelUrl: URL?) throws {
     func loadModel(modelUrl: URL?) throws {
-        messageLog += "Loading model...\n"
         if let modelUrl {
         if let modelUrl {
+            messageLog += "Loading model...\n"
             llamaContext = try LlamaContext.create_context(path: modelUrl.path())
             llamaContext = try LlamaContext.create_context(path: modelUrl.path())
             messageLog += "Loaded model \(modelUrl.lastPathComponent)\n"
             messageLog += "Loaded model \(modelUrl.lastPathComponent)\n"
         } else {
         } else {
-            messageLog += "Could not locate model\n"
+            messageLog += "Load a model from the list below\n"
         }
         }
     }
     }
 
 
@@ -34,15 +35,29 @@ class LlamaState: ObservableObject {
             return
             return
         }
         }
 
 
+        let t_start = DispatchTime.now().uptimeNanoseconds
         await llamaContext.completion_init(text: text)
         await llamaContext.completion_init(text: text)
+        let t_heat_end = DispatchTime.now().uptimeNanoseconds
+        let t_heat = Double(t_heat_end - t_start) / NS_PER_S
+
         messageLog += "\(text)"
         messageLog += "\(text)"
 
 
-        while await llamaContext.n_cur <= llamaContext.n_len {
+        while await llamaContext.n_cur < llamaContext.n_len {
             let result = await llamaContext.completion_loop()
             let result = await llamaContext.completion_loop()
             messageLog += "\(result)"
             messageLog += "\(result)"
         }
         }
+
+        let t_end = DispatchTime.now().uptimeNanoseconds
+        let t_generation = Double(t_end - t_heat_end) / NS_PER_S
+        let tokens_per_second = Double(await llamaContext.n_len) / t_generation
+
         await llamaContext.clear()
         await llamaContext.clear()
-        messageLog += "\n\ndone\n"
+        messageLog += """
+            \n
+            Done
+            Heat up took \(t_heat)s
+            Generated \(tokens_per_second) t/s\n
+            """
     }
     }
 
 
     func bench() async {
     func bench() async {
@@ -56,10 +71,10 @@ class LlamaState: ObservableObject {
         messageLog += await llamaContext.model_info() + "\n"
         messageLog += await llamaContext.model_info() + "\n"
 
 
         let t_start = DispatchTime.now().uptimeNanoseconds
         let t_start = DispatchTime.now().uptimeNanoseconds
-        await llamaContext.bench(pp: 8, tg: 4, pl: 1) // heat up
+        let _ = await llamaContext.bench(pp: 8, tg: 4, pl: 1) // heat up
         let t_end = DispatchTime.now().uptimeNanoseconds
         let t_end = DispatchTime.now().uptimeNanoseconds
 
 
-        let t_heat = Double(t_end - t_start) / 1_000_000_000.0
+        let t_heat = Double(t_end - t_start) / NS_PER_S
         messageLog += "Heat up time: \(t_heat) seconds, please wait...\n"
         messageLog += "Heat up time: \(t_heat) seconds, please wait...\n"
 
 
         // if more than 5 seconds, then we're probably running on a slow device
         // if more than 5 seconds, then we're probably running on a slow device

+ 5 - 30
examples/llama.swiftui/llama.swiftui/UI/ContentView.swift

@@ -42,46 +42,27 @@ struct ContentView: View {
                 Button("Send") {
                 Button("Send") {
                     sendText()
                     sendText()
                 }
                 }
-                .padding(8)
-                .background(Color.blue)
-                .foregroundColor(.white)
-                .cornerRadius(8)
 
 
                 Button("Bench") {
                 Button("Bench") {
                     bench()
                     bench()
                 }
                 }
-                .padding(8)
-                .background(Color.blue)
-                .foregroundColor(.white)
-                .cornerRadius(8)
 
 
                 Button("Clear") {
                 Button("Clear") {
                     clear()
                     clear()
                 }
                 }
-                .padding(8)
-                .background(Color.blue)
-                .foregroundColor(.white)
-                .cornerRadius(8)
 
 
                 Button("Copy") {
                 Button("Copy") {
                     UIPasteboard.general.string = llamaState.messageLog
                     UIPasteboard.general.string = llamaState.messageLog
                 }
                 }
-                .padding(8)
-                .background(Color.blue)
-                .foregroundColor(.white)
-                .cornerRadius(8)
-            }
+            }.buttonStyle(.bordered)
 
 
-            VStack {
+            VStack(alignment: .leading) {
                 DownloadButton(
                 DownloadButton(
                     llamaState: llamaState,
                     llamaState: llamaState,
                     modelName: "TinyLlama-1.1B (Q4_0, 0.6 GiB)",
                     modelName: "TinyLlama-1.1B (Q4_0, 0.6 GiB)",
                     modelUrl: "https://huggingface.co/TheBloke/TinyLlama-1.1B-1T-OpenOrca-GGUF/resolve/main/tinyllama-1.1b-1t-openorca.Q4_0.gguf?download=true",
                     modelUrl: "https://huggingface.co/TheBloke/TinyLlama-1.1B-1T-OpenOrca-GGUF/resolve/main/tinyllama-1.1b-1t-openorca.Q4_0.gguf?download=true",
                     filename: "tinyllama-1.1b-1t-openorca.Q4_0.gguf"
                     filename: "tinyllama-1.1b-1t-openorca.Q4_0.gguf"
                 )
                 )
-                .font(.system(size: 12))
-                .padding(.top, 4)
-                .frame(maxWidth: .infinity, alignment: .leading)
 
 
                 DownloadButton(
                 DownloadButton(
                     llamaState: llamaState,
                     llamaState: llamaState,
@@ -89,7 +70,6 @@ struct ContentView: View {
                     modelUrl: "https://huggingface.co/TheBloke/TinyLlama-1.1B-1T-OpenOrca-GGUF/resolve/main/tinyllama-1.1b-1t-openorca.Q8_0.gguf?download=true",
                     modelUrl: "https://huggingface.co/TheBloke/TinyLlama-1.1B-1T-OpenOrca-GGUF/resolve/main/tinyllama-1.1b-1t-openorca.Q8_0.gguf?download=true",
                     filename: "tinyllama-1.1b-1t-openorca.Q8_0.gguf"
                     filename: "tinyllama-1.1b-1t-openorca.Q8_0.gguf"
                 )
                 )
-                .font(.system(size: 12))
 
 
                 DownloadButton(
                 DownloadButton(
                     llamaState: llamaState,
                     llamaState: llamaState,
@@ -97,8 +77,6 @@ struct ContentView: View {
                     modelUrl: "https://huggingface.co/ggml-org/models/resolve/main/tinyllama-1.1b/ggml-model-f16.gguf?download=true",
                     modelUrl: "https://huggingface.co/ggml-org/models/resolve/main/tinyllama-1.1b/ggml-model-f16.gguf?download=true",
                     filename: "tinyllama-1.1b-f16.gguf"
                     filename: "tinyllama-1.1b-f16.gguf"
                 )
                 )
-                .font(.system(size: 12))
-                .frame(maxWidth: .infinity, alignment: .leading)
 
 
                 DownloadButton(
                 DownloadButton(
                     llamaState: llamaState,
                     llamaState: llamaState,
@@ -106,7 +84,6 @@ struct ContentView: View {
                     modelUrl: "https://huggingface.co/ggml-org/models/resolve/main/phi-2/ggml-model-q4_0.gguf?download=true",
                     modelUrl: "https://huggingface.co/ggml-org/models/resolve/main/phi-2/ggml-model-q4_0.gguf?download=true",
                     filename: "phi-2-q4_0.gguf"
                     filename: "phi-2-q4_0.gguf"
                 )
                 )
-                .font(.system(size: 12))
 
 
                 DownloadButton(
                 DownloadButton(
                     llamaState: llamaState,
                     llamaState: llamaState,
@@ -114,8 +91,6 @@ struct ContentView: View {
                     modelUrl: "https://huggingface.co/ggml-org/models/resolve/main/phi-2/ggml-model-q8_0.gguf?download=true",
                     modelUrl: "https://huggingface.co/ggml-org/models/resolve/main/phi-2/ggml-model-q8_0.gguf?download=true",
                     filename: "phi-2-q8_0.gguf"
                     filename: "phi-2-q8_0.gguf"
                 )
                 )
-                .font(.system(size: 12))
-                .frame(maxWidth: .infinity, alignment: .leading)
 
 
                 DownloadButton(
                 DownloadButton(
                     llamaState: llamaState,
                     llamaState: llamaState,
@@ -123,15 +98,15 @@ struct ContentView: View {
                     modelUrl: "https://huggingface.co/TheBloke/Mistral-7B-v0.1-GGUF/resolve/main/mistral-7b-v0.1.Q4_0.gguf?download=true",
                     modelUrl: "https://huggingface.co/TheBloke/Mistral-7B-v0.1-GGUF/resolve/main/mistral-7b-v0.1.Q4_0.gguf?download=true",
                     filename: "mistral-7b-v0.1.Q4_0.gguf"
                     filename: "mistral-7b-v0.1.Q4_0.gguf"
                 )
                 )
-                .font(.system(size: 12))
 
 
                 Button("Clear downloaded models") {
                 Button("Clear downloaded models") {
                     ContentView.cleanupModelCaches()
                     ContentView.cleanupModelCaches()
                     llamaState.cacheCleared = true
                     llamaState.cacheCleared = true
                 }
                 }
-                .padding(8)
-                .font(.system(size: 12))
             }
             }
+            .padding(.top, 4)
+            .font(.system(size: 12))
+            .frame(maxWidth: .infinity, alignment: .leading)
         }
         }
         .padding()
         .padding()
     }
     }

+ 1 - 1
examples/llama.swiftui/llama.swiftui/UI/DownloadButton.swift

@@ -93,7 +93,7 @@ struct DownloadButton: View {
                         print("Error: \(err.localizedDescription)")
                         print("Error: \(err.localizedDescription)")
                     }
                     }
                 }) {
                 }) {
-                    Text("\(modelName) (Downloaded)")
+                    Text("Load \(modelName)")
                 }
                 }
             } else {
             } else {
                 Text("Unknown status")
                 Text("Unknown status")