|
|
@@ -7,7 +7,9 @@
|
|
|
|
|
|
#include "ggml-backend-impl.h"
|
|
|
#include "ggml-impl.h"
|
|
|
+#include "ggml-webgpu-shader-lib.hpp"
|
|
|
#include "ggml-wgsl-shaders.hpp"
|
|
|
+#include "pre_wgsl.hpp"
|
|
|
|
|
|
#ifdef __EMSCRIPTEN__
|
|
|
# include <emscripten/emscripten.h>
|
|
|
@@ -30,7 +32,7 @@
|
|
|
|
|
|
#ifdef GGML_WEBGPU_DEBUG
|
|
|
# define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl
|
|
|
-# define WEBGPU_DEBUG_BUF_ELEMS 32
|
|
|
+# define WEBGPU_DEBUG_BUF_ELEMS 512
|
|
|
#else
|
|
|
# define WEBGPU_LOG_DEBUG(msg) ((void) 0)
|
|
|
#endif // GGML_WEBGPU_DEBUG
|
|
|
@@ -251,6 +253,7 @@ struct webgpu_gpu_profile_buf_pool {
|
|
|
struct webgpu_pipeline {
|
|
|
wgpu::ComputePipeline pipeline;
|
|
|
std::string name;
|
|
|
+ void * context = nullptr;
|
|
|
};
|
|
|
|
|
|
struct webgpu_command {
|
|
|
@@ -263,6 +266,46 @@ struct webgpu_command {
|
|
|
#endif
|
|
|
};
|
|
|
|
|
|
+struct flash_attn_pipeline_key {
|
|
|
+ int q_type;
|
|
|
+ int kv_type;
|
|
|
+ int dst_type;
|
|
|
+ uint32_t head_dim_qk;
|
|
|
+ uint32_t head_dim_v;
|
|
|
+ bool kv_direct;
|
|
|
+ bool has_mask;
|
|
|
+ bool has_sinks;
|
|
|
+ bool uses_logit_softcap;
|
|
|
+
|
|
|
+ bool operator==(const flash_attn_pipeline_key & other) const {
|
|
|
+ return q_type == other.q_type && kv_type == other.kv_type && dst_type == other.dst_type &&
|
|
|
+ head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && kv_direct == other.kv_direct &&
|
|
|
+ has_mask == other.has_mask && has_sinks == other.has_sinks &&
|
|
|
+ uses_logit_softcap == other.uses_logit_softcap;
|
|
|
+ }
|
|
|
+};
|
|
|
+
|
|
|
+// Same hash combine function as in boost
|
|
|
+template <typename T> inline void ggml_webgpu_hash_combine(size_t & seed, const T & value) {
|
|
|
+ seed ^= std::hash<T>{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
|
|
|
+}
|
|
|
+
|
|
|
+struct flash_attn_pipeline_key_hash {
|
|
|
+ size_t operator()(const flash_attn_pipeline_key & key) const {
|
|
|
+ size_t seed = 0;
|
|
|
+ ggml_webgpu_hash_combine(seed, key.q_type);
|
|
|
+ ggml_webgpu_hash_combine(seed, key.kv_type);
|
|
|
+ ggml_webgpu_hash_combine(seed, key.dst_type);
|
|
|
+ ggml_webgpu_hash_combine(seed, key.head_dim_qk);
|
|
|
+ ggml_webgpu_hash_combine(seed, key.head_dim_v);
|
|
|
+ ggml_webgpu_hash_combine(seed, key.kv_direct);
|
|
|
+ ggml_webgpu_hash_combine(seed, key.has_mask);
|
|
|
+ ggml_webgpu_hash_combine(seed, key.has_sinks);
|
|
|
+ ggml_webgpu_hash_combine(seed, key.uses_logit_softcap);
|
|
|
+ return seed;
|
|
|
+ }
|
|
|
+};
|
|
|
+
|
|
|
// All the base objects needed to run operations on a WebGPU device
|
|
|
struct webgpu_context_struct {
|
|
|
wgpu::Instance instance;
|
|
|
@@ -271,12 +314,12 @@ struct webgpu_context_struct {
|
|
|
wgpu::Queue queue;
|
|
|
wgpu::Limits limits;
|
|
|
|
|
|
- uint32_t subgroup_size;
|
|
|
+ uint32_t max_subgroup_size;
|
|
|
|
|
|
-#ifndef __EMSCRIPTEN__
|
|
|
- bool supports_subgroup_matrix = false;
|
|
|
- wgpu::SubgroupMatrixConfig subgroup_matrix_config;
|
|
|
-#endif
|
|
|
+ bool supports_subgroup_matrix = false;
|
|
|
+ uint32_t sg_mat_m;
|
|
|
+ uint32_t sg_mat_n;
|
|
|
+ uint32_t sg_mat_k;
|
|
|
|
|
|
std::recursive_mutex mutex;
|
|
|
std::atomic_uint inflight_threads = 0;
|
|
|
@@ -284,20 +327,24 @@ struct webgpu_context_struct {
|
|
|
webgpu_buf_pool param_buf_pool;
|
|
|
webgpu_buf_pool set_rows_error_buf_pool;
|
|
|
|
|
|
+ pre_wgsl::Preprocessor p;
|
|
|
+
|
|
|
std::map<int, webgpu_pipeline> memset_pipelines; // variant or type index
|
|
|
|
|
|
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> mul_mat_pipelines; // src0_type, src1_type, vectorized
|
|
|
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>>
|
|
|
mul_mat_vec_pipelines; // src0_type, src1_type, vectorized
|
|
|
|
|
|
- std::map<int, std::map<int, webgpu_pipeline>> set_rows_pipelines; // dst_type, vectorized
|
|
|
- std::map<int, std::map<int, webgpu_pipeline>> get_rows_pipelines; // src_type, vectorized
|
|
|
+ std::unordered_map<flash_attn_pipeline_key, webgpu_pipeline, flash_attn_pipeline_key_hash> flash_attn_pipelines;
|
|
|
+
|
|
|
+ std::map<int, std::map<int, webgpu_pipeline>> set_rows_pipelines; // dst_type, vectorized
|
|
|
+ std::map<int, std::map<int, webgpu_pipeline>> get_rows_pipelines; // src_type, vectorized
|
|
|
|
|
|
- std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
|
|
|
- std::map<int, std::map<int, webgpu_pipeline>> add_pipelines; // type, inplace
|
|
|
- std::map<int, std::map<int, webgpu_pipeline>> sub_pipelines; // type, inplace
|
|
|
- std::map<int, std::map<int, webgpu_pipeline>> mul_pipelines; // type, inplace
|
|
|
- std::map<int, std::map<int, webgpu_pipeline>> div_pipelines; // type, inplace
|
|
|
+ std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
|
|
|
+ std::map<int, std::map<int, webgpu_pipeline>> add_pipelines; // type, inplace
|
|
|
+ std::map<int, std::map<int, webgpu_pipeline>> sub_pipelines; // type, inplace
|
|
|
+ std::map<int, std::map<int, webgpu_pipeline>> mul_pipelines; // type, inplace
|
|
|
+ std::map<int, std::map<int, webgpu_pipeline>> div_pipelines; // type, inplace
|
|
|
|
|
|
std::map<int, webgpu_pipeline> rms_norm_pipelines; // inplace
|
|
|
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> rope_pipelines; // type, ff, inplace
|
|
|
@@ -361,8 +408,6 @@ struct ggml_backend_webgpu_buffer_context {
|
|
|
label(std::move(lbl)) {}
|
|
|
};
|
|
|
|
|
|
-/* End struct definitions */
|
|
|
-
|
|
|
/* WebGPU object initializations */
|
|
|
|
|
|
// Process a WGSL shader string, replacing tokens of the form {{KEY}} with
|
|
|
@@ -484,14 +529,9 @@ static void ggml_backend_webgpu_debug(webgpu_context & ctx) {
|
|
|
encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize());
|
|
|
wgpu::CommandBuffer commands = encoder.Finish();
|
|
|
ctx->queue.Submit(1, &commands);
|
|
|
-
|
|
|
ggml_backend_webgpu_map_buffer(ctx, ctx->debug_host_buf, wgpu::MapMode::Read, 0, ctx->debug_host_buf.GetSize());
|
|
|
- const uint32_t * debug_data = (const uint32_t *) ctx->debug_host_buf.GetConstMappedRange();
|
|
|
- std::cout << "debug data:";
|
|
|
- for (size_t i = 0; i < WEBGPU_DEBUG_BUF_ELEMS; i++) {
|
|
|
- std::cout << " " << i << ": " << debug_data[i];
|
|
|
- }
|
|
|
- std::cout << "\n";
|
|
|
+ const float * debug_data = (const float *) ctx->debug_host_buf.GetConstMappedRange();
|
|
|
+ std::cout << "debug[0]: " << debug_data[0] << "\n";
|
|
|
ctx->debug_host_buf.Unmap();
|
|
|
}
|
|
|
#endif
|
|
|
@@ -673,6 +713,7 @@ static const char * ggml_backend_webgpu_name(ggml_backend_t backend) {
|
|
|
return ctx->name.c_str();
|
|
|
}
|
|
|
|
|
|
+// TODO: implement proper cleanup
|
|
|
static void ggml_backend_webgpu_free(ggml_backend_t backend) {
|
|
|
ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *) backend->context;
|
|
|
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_free(" << ctx->name << ")");
|
|
|
@@ -730,12 +771,12 @@ static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) {
|
|
|
return ctx->buffer;
|
|
|
}
|
|
|
|
|
|
-static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, ggml_tensor * t) {
|
|
|
+static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, const ggml_tensor * t) {
|
|
|
size_t offset = ggml_webgpu_tensor_offset(t);
|
|
|
return offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
|
|
|
}
|
|
|
|
|
|
-static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, ggml_tensor * t) {
|
|
|
+static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, const ggml_tensor * t) {
|
|
|
size_t offset = ggml_webgpu_tensor_offset(t);
|
|
|
return offset & ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
|
|
|
}
|
|
|
@@ -964,12 +1005,10 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
|
|
|
#ifndef __EMSCRIPTEN__
|
|
|
if (ctx->supports_subgroup_matrix) {
|
|
|
// The total number of subgroups/workgroups needed per matrix.
|
|
|
- uint32_t wg_m_sg_tile =
|
|
|
- WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M * ctx->subgroup_matrix_config.M;
|
|
|
- wg_m = CEIL_DIV(dst->ne[0], wg_m_sg_tile);
|
|
|
- uint32_t wg_n_sg_tile =
|
|
|
- WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * ctx->subgroup_matrix_config.N;
|
|
|
- wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile);
|
|
|
+ uint32_t wg_m_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M * ctx->sg_mat_m;
|
|
|
+ wg_m = CEIL_DIV(dst->ne[0], wg_m_sg_tile);
|
|
|
+ uint32_t wg_n_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * ctx->sg_mat_n;
|
|
|
+ wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile);
|
|
|
} else {
|
|
|
#endif
|
|
|
uint32_t tile_m_s = WEBGPU_MUL_MAT_TILE_M * WEBGPU_MUL_MAT_WG_SIZE_M;
|
|
|
@@ -986,6 +1025,146 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
|
|
|
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
|
|
|
}
|
|
|
|
|
|
+static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx,
|
|
|
+ ggml_tensor * Q,
|
|
|
+ ggml_tensor * K,
|
|
|
+ ggml_tensor * V,
|
|
|
+ ggml_tensor * mask,
|
|
|
+ ggml_tensor * sinks,
|
|
|
+ ggml_tensor * dst) {
|
|
|
+ float scale = *(float *) dst->op_params;
|
|
|
+ float max_bias;
|
|
|
+ memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
|
|
+ float logit_softcap;
|
|
|
+ memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
|
|
|
+ if (logit_softcap != 0.0f) {
|
|
|
+ scale /= logit_softcap;
|
|
|
+ }
|
|
|
+ float n_head_log2 = float(1u << (uint32_t) floor(log2(Q->ne[2])));
|
|
|
+ float m0 = powf(2.0f, -(max_bias) / n_head_log2);
|
|
|
+ float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
|
|
+
|
|
|
+ const int has_mask = (mask != nullptr);
|
|
|
+ const int has_sinks = (sinks != nullptr);
|
|
|
+
|
|
|
+ std::vector<uint32_t> params = {
|
|
|
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, Q) / ggml_type_size(Q->type)),
|
|
|
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)),
|
|
|
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)),
|
|
|
+ has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)) : 0,
|
|
|
+ has_sinks ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, sinks) / ggml_type_size(sinks->type)) : 0,
|
|
|
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
|
|
+ (uint32_t) Q->ne[2], // number of heads
|
|
|
+ (uint32_t) Q->ne[1], // sequence length (Q)
|
|
|
+ (uint32_t) K->ne[1], // sequence length (K/V)
|
|
|
+ (uint32_t) (Q->nb[1] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 1
|
|
|
+ (uint32_t) (Q->nb[2] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 2
|
|
|
+ (uint32_t) (Q->nb[3] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 3
|
|
|
+ (uint32_t) (K->nb[1] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 1
|
|
|
+ (uint32_t) (K->nb[2] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 2
|
|
|
+ (uint32_t) (K->nb[3] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 3
|
|
|
+ (uint32_t) (V->nb[1] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 1
|
|
|
+ (uint32_t) (V->nb[2] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 2
|
|
|
+ (uint32_t) (V->nb[3] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 3
|
|
|
+ has_mask ? (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)) : 0, // stride of mask dim 3
|
|
|
+ (uint32_t) (Q->ne[2] / K->ne[2]), // repeat factor for K/V in dim 2 (MHA/MQA/GQA)
|
|
|
+ *(uint32_t *) &scale, // scale (possibly adjusted for logit softcap)
|
|
|
+ *(uint32_t *) &max_bias,
|
|
|
+ *(uint32_t *) &logit_softcap,
|
|
|
+ *(uint32_t *) &n_head_log2,
|
|
|
+ *(uint32_t *) &m0,
|
|
|
+ *(uint32_t *) &m1
|
|
|
+
|
|
|
+ };
|
|
|
+ std::vector<wgpu::BindGroupEntry> entries = {
|
|
|
+ { .binding = 0,
|
|
|
+ .buffer = ggml_webgpu_tensor_buf(Q),
|
|
|
+ .offset = ggml_webgpu_tensor_align_offset(ctx, Q),
|
|
|
+ .size = ggml_webgpu_tensor_binding_size(ctx, Q) },
|
|
|
+ { .binding = 1,
|
|
|
+ .buffer = ggml_webgpu_tensor_buf(K),
|
|
|
+ .offset = ggml_webgpu_tensor_align_offset(ctx, K),
|
|
|
+ .size = ggml_webgpu_tensor_binding_size(ctx, K) },
|
|
|
+ { .binding = 2,
|
|
|
+ .buffer = ggml_webgpu_tensor_buf(V),
|
|
|
+ .offset = ggml_webgpu_tensor_align_offset(ctx, V),
|
|
|
+ .size = ggml_webgpu_tensor_binding_size(ctx, V) }
|
|
|
+ };
|
|
|
+ uint32_t binding_index = 3;
|
|
|
+ if (has_mask) {
|
|
|
+ entries.push_back({ .binding = binding_index++,
|
|
|
+ .buffer = ggml_webgpu_tensor_buf(mask),
|
|
|
+ .offset = ggml_webgpu_tensor_align_offset(ctx, mask),
|
|
|
+ .size = ggml_webgpu_tensor_binding_size(ctx, mask) });
|
|
|
+ }
|
|
|
+ if (has_sinks) {
|
|
|
+ entries.push_back({ .binding = binding_index++,
|
|
|
+ .buffer = ggml_webgpu_tensor_buf(sinks),
|
|
|
+ .offset = ggml_webgpu_tensor_align_offset(ctx, sinks),
|
|
|
+ .size = ggml_webgpu_tensor_binding_size(ctx, sinks) });
|
|
|
+ }
|
|
|
+ entries.push_back({ .binding = binding_index++,
|
|
|
+ .buffer = ggml_webgpu_tensor_buf(dst),
|
|
|
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
|
|
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
|
|
|
+
|
|
|
+ bool kv_direct =
|
|
|
+ (K->type == GGML_TYPE_F16) && (Q->ne[0] % ctx->sg_mat_k == 0) && (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
|
|
|
+
|
|
|
+ flash_attn_pipeline_key key = {
|
|
|
+ .q_type = Q->type,
|
|
|
+ .kv_type = K->type,
|
|
|
+ .dst_type = dst->type,
|
|
|
+ .head_dim_qk = (uint32_t) Q->ne[0],
|
|
|
+ .head_dim_v = (uint32_t) V->ne[0],
|
|
|
+ .kv_direct = kv_direct,
|
|
|
+ .has_mask = static_cast<bool>(has_mask),
|
|
|
+ .has_sinks = static_cast<bool>(has_sinks),
|
|
|
+ .uses_logit_softcap = logit_softcap != 0.0f,
|
|
|
+ };
|
|
|
+
|
|
|
+ webgpu_pipeline pipeline;
|
|
|
+ ggml_webgpu_flash_attn_shader_decisions decisions = {};
|
|
|
+
|
|
|
+ auto it = ctx->flash_attn_pipelines.find(key);
|
|
|
+ if (it != ctx->flash_attn_pipelines.end()) {
|
|
|
+ pipeline = it->second;
|
|
|
+ decisions = *static_cast<ggml_webgpu_flash_attn_shader_decisions *>(pipeline.context);
|
|
|
+ } else {
|
|
|
+ std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
|
|
|
+ it = ctx->flash_attn_pipelines.find(key);
|
|
|
+ if (it != ctx->flash_attn_pipelines.end()) {
|
|
|
+ pipeline = it->second;
|
|
|
+ decisions = *static_cast<ggml_webgpu_flash_attn_shader_decisions *>(pipeline.context);
|
|
|
+ } else {
|
|
|
+ ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = { .kv_type = K->type,
|
|
|
+ .head_dim_qk = (uint32_t) Q->ne[0],
|
|
|
+ .head_dim_v = (uint32_t) V->ne[0],
|
|
|
+ .kv_direct = kv_direct,
|
|
|
+ .has_mask = static_cast<bool>(has_mask),
|
|
|
+ .has_sinks = static_cast<bool>(has_sinks),
|
|
|
+ .uses_logit_softcap = logit_softcap != 0.0f,
|
|
|
+ .sg_mat_m = ctx->sg_mat_m,
|
|
|
+ .sg_mat_n = ctx->sg_mat_n,
|
|
|
+ .sg_mat_k = ctx->sg_mat_k,
|
|
|
+ .wg_mem_limit_bytes =
|
|
|
+ ctx->limits.maxComputeWorkgroupStorageSize,
|
|
|
+ .max_subgroup_size = ctx->max_subgroup_size };
|
|
|
+
|
|
|
+ ggml_webgpu_processed_shader processed =
|
|
|
+ ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn, shader_lib_ctx);
|
|
|
+ pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
|
|
|
+ pipeline.context = new ggml_webgpu_flash_attn_shader_decisions(processed.decisions);
|
|
|
+ ctx->flash_attn_pipelines.emplace(key, pipeline);
|
|
|
+ decisions = processed.decisions;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions.q_tile);
|
|
|
+ uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches
|
|
|
+ return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
|
|
|
+}
|
|
|
+
|
|
|
static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
|
|
uint32_t ne = (uint32_t) ggml_nelements(dst);
|
|
|
ggml_unary_op unary_op = ggml_get_unary_op(dst);
|
|
|
@@ -1397,6 +1576,8 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
|
|
|
return ggml_webgpu_get_rows(ctx, src0, src1, node);
|
|
|
case GGML_OP_MUL_MAT:
|
|
|
return ggml_webgpu_mul_mat(ctx, src0, src1, node);
|
|
|
+ case GGML_OP_FLASH_ATTN_EXT:
|
|
|
+ return ggml_webgpu_flash_attn(ctx, src0, src1, src2, node->src[3], node->src[4], node);
|
|
|
case GGML_OP_ADD:
|
|
|
{
|
|
|
int inplace = ggml_webgpu_tensor_equal(src0, node);
|
|
|
@@ -1466,6 +1647,7 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
|
|
|
webgpu_submission_futures new_futures = ggml_backend_webgpu_submit(ctx, commands);
|
|
|
futures.push_back(new_futures);
|
|
|
}
|
|
|
+
|
|
|
ggml_backend_webgpu_wait(ctx, futures);
|
|
|
ctx->inflight_threads--;
|
|
|
WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx);
|
|
|
@@ -1808,15 +1990,15 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
|
|
|
#ifndef __EMSCRIPTEN__
|
|
|
if (webgpu_ctx->supports_subgroup_matrix) {
|
|
|
std::map<std::string, std::string> sg_matrix_repls;
|
|
|
- sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] = std::to_string(webgpu_ctx->subgroup_size);
|
|
|
+ sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] = std::to_string(webgpu_ctx->max_subgroup_size);
|
|
|
sg_matrix_repls["WEBGPU_TILE_K"] = std::to_string(WEBGPU_MUL_MAT_TILE_K);
|
|
|
sg_matrix_repls["WEBGPU_SUBGROUP_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M);
|
|
|
sg_matrix_repls["WEBGPU_SUBGROUP_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N);
|
|
|
sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M);
|
|
|
sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N);
|
|
|
- sg_matrix_repls["WEBGPU_SG_MAT_M_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.M);
|
|
|
- sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.N);
|
|
|
- sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.K);
|
|
|
+ sg_matrix_repls["WEBGPU_SG_MAT_M_SIZE"] = std::to_string(webgpu_ctx->sg_mat_m);
|
|
|
+ sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"] = std::to_string(webgpu_ctx->sg_mat_n);
|
|
|
+ sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"] = std::to_string(webgpu_ctx->sg_mat_k);
|
|
|
|
|
|
proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls);
|
|
|
proc_mul_mat_f32_f32_vec =
|
|
|
@@ -2328,6 +2510,7 @@ static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) {
|
|
|
webgpu_ctx->device, wgsl_soft_max_f32_mask_f16_sink_inplace, "soft_max_f32_mask_f16_sink_inplace", constants);
|
|
|
}
|
|
|
|
|
|
+// TODO: move most initialization logic here
|
|
|
static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) {
|
|
|
GGML_UNUSED(params);
|
|
|
|
|
|
@@ -2489,6 +2672,29 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
|
|
}
|
|
|
break;
|
|
|
}
|
|
|
+ case GGML_OP_FLASH_ATTN_EXT:
|
|
|
+ {
|
|
|
+ if (!webgpu_ctx->supports_subgroup_matrix) {
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ // Head dimensions must fit in workgroup memory with minimum tile sizes
|
|
|
+ size_t limit_bytes = webgpu_ctx->limits.maxComputeWorkgroupStorageSize;
|
|
|
+ const bool has_mask = op->src[3] != nullptr;
|
|
|
+ const bool kv_direct = src1->type == GGML_TYPE_F16 && (src0->ne[0] % webgpu_ctx->sg_mat_k) == 0 &&
|
|
|
+ (src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0;
|
|
|
+ const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
|
|
|
+ webgpu_ctx->sg_mat_m, webgpu_ctx->sg_mat_n, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0],
|
|
|
+ has_mask, kv_direct);
|
|
|
+ if (min_bytes > limit_bytes) {
|
|
|
+ break;
|
|
|
+ }
|
|
|
+
|
|
|
+ supports_op = src0->type == GGML_TYPE_F32 &&
|
|
|
+ (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16 ||
|
|
|
+ src1->type == GGML_TYPE_Q4_0 || src1->type == GGML_TYPE_Q8_0) &&
|
|
|
+ src2->type == src1->type && op->type == GGML_TYPE_F32;
|
|
|
+ break;
|
|
|
+ }
|
|
|
case GGML_OP_RMS_NORM:
|
|
|
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
|
|
|
break;
|
|
|
@@ -2606,6 +2812,7 @@ static size_t ggml_backend_webgpu_reg_get_device_count(ggml_backend_reg_t reg) {
|
|
|
}
|
|
|
|
|
|
// TODO: Does this need to be thread safe? Is it only called once?
|
|
|
+// TODO: move most logic to device_init function so backend can be freed/initialized properly
|
|
|
// Only one device is supported for now
|
|
|
static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t reg, size_t index) {
|
|
|
GGML_ASSERT(index == 0);
|
|
|
@@ -2665,7 +2872,9 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
|
|
|
if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) &&
|
|
|
config.componentType == wgpu::SubgroupMatrixComponentType::F16 &&
|
|
|
config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) {
|
|
|
- ctx->subgroup_matrix_config = config;
|
|
|
+ ctx->sg_mat_m = config.M;
|
|
|
+ ctx->sg_mat_n = config.N;
|
|
|
+ ctx->sg_mat_k = config.K;
|
|
|
valid_subgroup_matrix_config = true;
|
|
|
break;
|
|
|
}
|
|
|
@@ -2676,7 +2885,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
|
|
|
#endif
|
|
|
// For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate.
|
|
|
// Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter.
|
|
|
- ctx->subgroup_size = info.subgroupMaxSize;
|
|
|
+ ctx->max_subgroup_size = info.subgroupMaxSize;
|
|
|
|
|
|
// Initialize device
|
|
|
std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16 };
|
|
|
@@ -2701,8 +2910,11 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
|
|
|
wgpu::CallbackMode::AllowSpontaneous,
|
|
|
[](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
|
|
|
GGML_UNUSED(device);
|
|
|
- GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason),
|
|
|
- std::string(message).c_str());
|
|
|
+ GGML_UNUSED(reason);
|
|
|
+ GGML_UNUSED(message);
|
|
|
+ //TODO: uncomment once proper free logic is in place
|
|
|
+ //GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason),
|
|
|
+ //std::string(message).c_str());
|
|
|
});
|
|
|
dev_desc.SetUncapturedErrorCallback(
|
|
|
[](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) {
|