|
|
@@ -35,7 +35,7 @@
|
|
|
|
|
|
// ggml-backend interface
|
|
|
|
|
|
-std::vector<ggml_backend_buffer_type_t>& ggml_backend_cpu_get_extra_buffers_type() {
|
|
|
+std::vector<ggml_backend_buffer_type_t> & ggml_backend_cpu_get_extra_buffer_types() {
|
|
|
static std::vector<ggml_backend_buffer_type_t> bufts = []() {
|
|
|
std::vector<ggml_backend_buffer_type_t> bufts;
|
|
|
|
|
|
@@ -57,8 +57,6 @@ std::vector<ggml_backend_buffer_type_t>& ggml_backend_cpu_get_extra_buffers_type
|
|
|
}
|
|
|
#endif
|
|
|
|
|
|
- bufts.push_back(NULL);
|
|
|
-
|
|
|
return bufts;
|
|
|
}();
|
|
|
|
|
|
@@ -66,14 +64,20 @@ std::vector<ggml_backend_buffer_type_t>& ggml_backend_cpu_get_extra_buffers_type
|
|
|
}
|
|
|
|
|
|
static ggml_backend_buffer_type_t * ggml_backend_cpu_device_get_extra_buffers_type(ggml_backend_dev_t device) {
|
|
|
- return ggml_backend_cpu_get_extra_buffers_type().data();
|
|
|
+ static std::vector<ggml_backend_buffer_type_t> extra_bufts = [] {
|
|
|
+ std::vector<ggml_backend_buffer_type_t> bufts = ggml_backend_cpu_get_extra_buffer_types();
|
|
|
+ bufts.push_back(nullptr);
|
|
|
+ return bufts;
|
|
|
+ }();
|
|
|
+
|
|
|
+ return extra_bufts.data();
|
|
|
|
|
|
GGML_UNUSED(device);
|
|
|
}
|
|
|
|
|
|
static bool ggml_backend_cpu_is_extra_buffer_type(ggml_backend_buffer_type_t buft) {
|
|
|
- for (auto * extra : ggml_backend_cpu_get_extra_buffers_type()) {
|
|
|
- if (extra && extra == buft) {
|
|
|
+ for (auto * extra : ggml_backend_cpu_get_extra_buffer_types()) {
|
|
|
+ if (extra == buft) {
|
|
|
return true;
|
|
|
}
|
|
|
}
|
|
|
@@ -397,20 +401,13 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
|
|
|
return true;
|
|
|
}
|
|
|
|
|
|
- // extra_buffer_op?
|
|
|
- for (auto extra : ggml_backend_cpu_get_extra_buffers_type()) {
|
|
|
- if (extra) {
|
|
|
- auto buf_extra = (ggml::cpu::extra_buffer_type*) extra->context;
|
|
|
- if (buf_extra && buf_extra->supports_op(dev, op)) {
|
|
|
- return true;
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- // the other case need host buffer.
|
|
|
- for (int i = 0; i < GGML_MAX_SRC; i++) {
|
|
|
- if (op->src[i] && op->src[i]->buffer && !ggml_backend_buft_is_host(op->src[i]->buffer->buft)) {
|
|
|
- return false;
|
|
|
+ // check extra buffer types
|
|
|
+ // note: only the first sources are checked for extra buffer types to reduce overhead, increase if necessary
|
|
|
+ for (int i = 0; i < 4; i++) {
|
|
|
+ if (op->src[i] && op->src[i]->buffer &&
|
|
|
+ ggml_backend_cpu_is_extra_buffer_type(op->src[i]->buffer->buft)) {
|
|
|
+ auto * buf_extra = (ggml::cpu::extra_buffer_type *) op->src[i]->buffer->buft->context;
|
|
|
+ return buf_extra->supports_op(dev, op);
|
|
|
}
|
|
|
}
|
|
|
|