|
|
@@ -576,6 +576,10 @@ void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+bool ggml_opt_static_graphs(ggml_opt_context_t opt_ctx) {
|
|
|
+ return opt_ctx->static_graphs;
|
|
|
+}
|
|
|
+
|
|
|
struct ggml_tensor * ggml_opt_inputs(ggml_opt_context_t opt_ctx) {
|
|
|
return opt_ctx->inputs;
|
|
|
}
|
|
|
@@ -842,6 +846,7 @@ void ggml_opt_epoch(
|
|
|
int64_t idata_split,
|
|
|
ggml_opt_epoch_callback callback_train,
|
|
|
ggml_opt_epoch_callback callback_eval) {
|
|
|
+ GGML_ASSERT(ggml_opt_static_graphs(opt_ctx) && "ggml_opt_epoch requires static graphs");
|
|
|
struct ggml_tensor * inputs = ggml_opt_inputs(opt_ctx);
|
|
|
struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
|
|
|
struct ggml_tensor * data = ggml_opt_dataset_data(dataset);
|