|
|
@@ -1,6 +1,7 @@
|
|
|
#version 450
|
|
|
|
|
|
#extension GL_EXT_control_flow_attributes : require
|
|
|
+#extension GL_KHR_shader_subgroup_basic : enable
|
|
|
#if USE_SUBGROUP_ADD
|
|
|
#extension GL_KHR_shader_subgroup_arithmetic : enable
|
|
|
#endif
|
|
|
@@ -9,7 +10,8 @@
|
|
|
|
|
|
layout(constant_id = 0) const uint D_STATE = 128;
|
|
|
layout(constant_id = 1) const uint SUBGROUP_SIZE = 32;
|
|
|
-layout(constant_id = 2) const uint SPLIT_H = 16;
|
|
|
+
|
|
|
+const uint32_t c_factor = D_STATE / SUBGROUP_SIZE;
|
|
|
|
|
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
|
|
|
|
|
@@ -41,22 +43,28 @@ float softplus(float x) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-shared float stateC[SPLIT_H * D_STATE];
|
|
|
+#if !USE_SUBGROUP_ADD
|
|
|
+shared float temp[D_STATE];
|
|
|
+#endif
|
|
|
|
|
|
void main() {
|
|
|
- const uint tid = gl_LocalInvocationID.x;
|
|
|
- const uint head_idx = (gl_WorkGroupID.x * SPLIT_H) / d_head;
|
|
|
- const uint head_off = ((gl_WorkGroupID.x * SPLIT_H) % d_head) * 4;
|
|
|
- const uint seq_idx = gl_WorkGroupID.y;
|
|
|
+ const uint subgroup = gl_SubgroupID;
|
|
|
+ const uint lane = gl_SubgroupInvocationID;
|
|
|
+ const uint tid = gl_SubgroupID * SUBGROUP_SIZE + lane;
|
|
|
+ const uint subgroup_idx = gl_WorkGroupID.x * c_factor + subgroup;
|
|
|
+
|
|
|
+ const uint head_idx = subgroup_idx / d_head;
|
|
|
+ const uint head_off = (subgroup_idx % d_head) * 4;
|
|
|
+ const uint seq_idx = gl_WorkGroupID.y;
|
|
|
|
|
|
const uint group_off = (head_idx / (n_head / n_group)) * D_STATE * 4;
|
|
|
const uint s0_base_idx = (uint(ids[seq_idx]) * nb03 + head_idx * nb02 + head_off * D_STATE) / 4;
|
|
|
- const uint x_base_idx = (seq_idx * nb13 + gl_WorkGroupID.x * SPLIT_H * 4) / 4;
|
|
|
+ const uint x_base_idx = (seq_idx * nb13 + subgroup_idx * 4) / 4;
|
|
|
const uint dt_base_idx = (seq_idx * nb22 + head_idx * 4) / 4;
|
|
|
const uint A_base_idx = (head_idx * nb31) / 4;
|
|
|
const uint B_base_idx = (seq_idx * nb43 + group_off) / 4;
|
|
|
const uint C_base_idx = (seq_idx * nb53 + group_off) / 4;
|
|
|
- const uint y_base_idx = seq_idx * n_tok * n_head * d_head + gl_WorkGroupID.x * SPLIT_H;
|
|
|
+ const uint y_base_idx = seq_idx * n_tok * n_head * d_head + subgroup_idx;
|
|
|
const uint s_base_idx = (s_off + seq_idx * nb03 + head_idx * nb02 + head_off * D_STATE) / 4;
|
|
|
|
|
|
const uint stride_x = nb12 / 4;
|
|
|
@@ -65,76 +73,52 @@ void main() {
|
|
|
const uint stride_C = nb52 / 4;
|
|
|
const uint stride_y = n_head * d_head;
|
|
|
|
|
|
- float state[SPLIT_H];
|
|
|
- [[unroll]] for (uint j = 0; j < SPLIT_H; j++) {
|
|
|
- state[j] = s0[s0_base_idx + j * D_STATE + tid];
|
|
|
- }
|
|
|
+ float state[c_factor];
|
|
|
|
|
|
- for (uint i = 0; i < n_tok; i++) {
|
|
|
- const float dt_soft_plus = softplus(dt[dt_base_idx + i * stride_dt]);
|
|
|
+ [[unroll]] for (uint j = 0; j < c_factor; j++) {
|
|
|
+ state[j] = s0[s0_base_idx + SUBGROUP_SIZE * j + lane];
|
|
|
+ }
|
|
|
|
|
|
- const float dA = exp(dt_soft_plus * A[A_base_idx]);
|
|
|
+ float a = A[A_base_idx];
|
|
|
|
|
|
- const float B_val = B[B_base_idx + i * stride_B + tid];
|
|
|
- const float C_val = C[C_base_idx + i * stride_C + tid];
|
|
|
+ for (uint i = 0; i < n_tok; i++) {
|
|
|
+ float dt_soft_plus = softplus(dt[dt_base_idx + i * stride_dt]);
|
|
|
|
|
|
- [[unroll]] for (uint j = 0; j < SPLIT_H; j++) {
|
|
|
- const float x_dt = x[x_base_idx + i * stride_x + j] * dt_soft_plus;
|
|
|
+ float state_sum = 0.0f;
|
|
|
|
|
|
+ const float dA = exp(dt_soft_plus * a);
|
|
|
+ const float x_dt = x[x_base_idx + i * stride_x] * dt_soft_plus;
|
|
|
+ [[unroll]] for (uint j = 0; j < c_factor; j++) {
|
|
|
+ float B_val = B[B_base_idx + i * stride_B + SUBGROUP_SIZE * j + lane];
|
|
|
+ float C_val = C[C_base_idx + i * stride_C + SUBGROUP_SIZE * j + lane];
|
|
|
state[j] = (state[j] * dA) + (B_val * x_dt);
|
|
|
-
|
|
|
- stateC[j * D_STATE + tid] = state[j] * C_val;
|
|
|
+ state_sum += state[j] * C_val;
|
|
|
}
|
|
|
|
|
|
+#if USE_SUBGROUP_ADD
|
|
|
+ state_sum = subgroupAdd(state_sum);
|
|
|
+#else
|
|
|
+ temp[tid] = state_sum;
|
|
|
barrier();
|
|
|
- [[unroll]]
|
|
|
- for (uint w = D_STATE / 2; w >= SUBGROUP_SIZE; w >>= 1) {
|
|
|
- [[unroll]] for (uint j = 0; j < (w * SPLIT_H + D_STATE - 1) / D_STATE; j++) {
|
|
|
- const uint k = (tid % w) + (D_STATE * (tid / w)) + j * D_STATE * (D_STATE / w);
|
|
|
- if (k < SPLIT_H * D_STATE && (k + w) < SPLIT_H * D_STATE) {
|
|
|
- stateC[k] += stateC[k + w];
|
|
|
- }
|
|
|
+ [[unroll]] for (uint s = SUBGROUP_SIZE / 2; s > 0; s >>= 1) {
|
|
|
+ if (lane < s) {
|
|
|
+ temp[tid] += temp[tid + s];
|
|
|
}
|
|
|
barrier();
|
|
|
}
|
|
|
-
|
|
|
- [[unroll]] for (uint j = 0; j < max(1, SPLIT_H / (D_STATE / SUBGROUP_SIZE)); j++) {
|
|
|
- const uint idx = (tid % SUBGROUP_SIZE) +
|
|
|
- D_STATE * (tid / SUBGROUP_SIZE) +
|
|
|
- j * D_STATE * (D_STATE / SUBGROUP_SIZE);
|
|
|
- const uint max_idx = SUBGROUP_SIZE - 1 +
|
|
|
- D_STATE * ((D_STATE - 1) / SUBGROUP_SIZE) +
|
|
|
- j * D_STATE * (D_STATE / SUBGROUP_SIZE);
|
|
|
-
|
|
|
- if (idx < SPLIT_H * D_STATE ||
|
|
|
- max_idx < SPLIT_H * D_STATE) {
|
|
|
- float sc;
|
|
|
-#if USE_SUBGROUP_ADD
|
|
|
- sc = stateC[idx];
|
|
|
- sc = subgroupAdd(sc);
|
|
|
-#else
|
|
|
- [[unroll]] for (uint offset = SUBGROUP_SIZE / 2; offset > 0; offset >>= 1) {
|
|
|
- if (idx + offset < SPLIT_H * D_STATE) {
|
|
|
- stateC[idx] += stateC[idx + offset];
|
|
|
- }
|
|
|
- barrier();
|
|
|
- }
|
|
|
- if (tid % SUBGROUP_SIZE == 0) {
|
|
|
- sc = stateC[idx];
|
|
|
- }
|
|
|
+ // get the value from lane 0
|
|
|
+ state_sum = temp[subgroup * SUBGROUP_SIZE];
|
|
|
+ barrier();
|
|
|
#endif
|
|
|
|
|
|
- if (tid % SUBGROUP_SIZE == 0) {
|
|
|
- const uint k = tid / SUBGROUP_SIZE + j * (D_STATE / SUBGROUP_SIZE);
|
|
|
- d[y_base_idx + i * stride_y + k] = sc;
|
|
|
- }
|
|
|
- }
|
|
|
+ if (lane == 0) {
|
|
|
+ d[y_base_idx + i * stride_y] = state_sum;
|
|
|
}
|
|
|
-
|
|
|
- barrier();
|
|
|
}
|
|
|
|
|
|
- [[unroll]] for (uint j = 0; j < SPLIT_H; j++) {
|
|
|
- d[s_base_idx + j * D_STATE + tid] = state[j];
|
|
|
+ // write back the state
|
|
|
+ [[unroll]]
|
|
|
+ for (int j = 0; j < c_factor; j++) {
|
|
|
+ d[s_base_idx + SUBGROUP_SIZE * j + lane] = state[j];
|
|
|
}
|
|
|
}
|