|
@@ -45,8 +45,10 @@ class LLamaAndroid {
|
|
|
private external fun free_context(context: Long)
|
|
private external fun free_context(context: Long)
|
|
|
private external fun backend_init(numa: Boolean)
|
|
private external fun backend_init(numa: Boolean)
|
|
|
private external fun backend_free()
|
|
private external fun backend_free()
|
|
|
- private external fun free_batch(batch: Long)
|
|
|
|
|
private external fun new_batch(nTokens: Int, embd: Int, nSeqMax: Int): Long
|
|
private external fun new_batch(nTokens: Int, embd: Int, nSeqMax: Int): Long
|
|
|
|
|
+ private external fun free_batch(batch: Long)
|
|
|
|
|
+ private external fun new_sampler(): Long
|
|
|
|
|
+ private external fun free_sampler(sampler: Long)
|
|
|
private external fun bench_model(
|
|
private external fun bench_model(
|
|
|
context: Long,
|
|
context: Long,
|
|
|
model: Long,
|
|
model: Long,
|
|
@@ -69,6 +71,7 @@ class LLamaAndroid {
|
|
|
private external fun completion_loop(
|
|
private external fun completion_loop(
|
|
|
context: Long,
|
|
context: Long,
|
|
|
batch: Long,
|
|
batch: Long,
|
|
|
|
|
+ sampler: Long,
|
|
|
nLen: Int,
|
|
nLen: Int,
|
|
|
ncur: IntVar
|
|
ncur: IntVar
|
|
|
): String?
|
|
): String?
|
|
@@ -101,8 +104,11 @@ class LLamaAndroid {
|
|
|
val batch = new_batch(512, 0, 1)
|
|
val batch = new_batch(512, 0, 1)
|
|
|
if (batch == 0L) throw IllegalStateException("new_batch() failed")
|
|
if (batch == 0L) throw IllegalStateException("new_batch() failed")
|
|
|
|
|
|
|
|
|
|
+ val sampler = new_sampler()
|
|
|
|
|
+ if (sampler == 0L) throw IllegalStateException("new_sampler() failed")
|
|
|
|
|
+
|
|
|
Log.i(tag, "Loaded model $pathToModel")
|
|
Log.i(tag, "Loaded model $pathToModel")
|
|
|
- threadLocalState.set(State.Loaded(model, context, batch))
|
|
|
|
|
|
|
+ threadLocalState.set(State.Loaded(model, context, batch, sampler))
|
|
|
}
|
|
}
|
|
|
else -> throw IllegalStateException("Model already loaded")
|
|
else -> throw IllegalStateException("Model already loaded")
|
|
|
}
|
|
}
|
|
@@ -114,7 +120,7 @@ class LLamaAndroid {
|
|
|
is State.Loaded -> {
|
|
is State.Loaded -> {
|
|
|
val ncur = IntVar(completion_init(state.context, state.batch, message, nlen))
|
|
val ncur = IntVar(completion_init(state.context, state.batch, message, nlen))
|
|
|
while (ncur.value <= nlen) {
|
|
while (ncur.value <= nlen) {
|
|
|
- val str = completion_loop(state.context, state.batch, nlen, ncur)
|
|
|
|
|
|
|
+ val str = completion_loop(state.context, state.batch, state.sampler, nlen, ncur)
|
|
|
if (str == null) {
|
|
if (str == null) {
|
|
|
break
|
|
break
|
|
|
}
|
|
}
|
|
@@ -138,6 +144,7 @@ class LLamaAndroid {
|
|
|
free_context(state.context)
|
|
free_context(state.context)
|
|
|
free_model(state.model)
|
|
free_model(state.model)
|
|
|
free_batch(state.batch)
|
|
free_batch(state.batch)
|
|
|
|
|
+ free_sampler(state.sampler);
|
|
|
|
|
|
|
|
threadLocalState.set(State.Idle)
|
|
threadLocalState.set(State.Idle)
|
|
|
}
|
|
}
|
|
@@ -161,7 +168,7 @@ class LLamaAndroid {
|
|
|
|
|
|
|
|
private sealed interface State {
|
|
private sealed interface State {
|
|
|
data object Idle: State
|
|
data object Idle: State
|
|
|
- data class Loaded(val model: Long, val context: Long, val batch: Long): State
|
|
|
|
|
|
|
+ data class Loaded(val model: Long, val context: Long, val batch: Long, val sampler: Long): State
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// Enforce only one instance of Llm.
|
|
// Enforce only one instance of Llm.
|