if scaled:
if query_head_units is None:
query_shape = npx.shape_array(query)
scale = np.sqrt(query_shape[-1])
else:
scale = math.sqrt(query_head_units)
else:
scale = None
if layout == "NKT":
// 1. Expand the dimension of the mask:
// (B, L_query, L_mem) --> (B, 1, L_query, L_mem)
if mask is not None:
mask = np.expand_dims(mask, axis=1)
// 2. Calculate the attention weights
// Score: (B, N, L_query, C_Q) X (B, N, L_mem, C_Q) --> (B, N, L_query, L_mem)
scores = npx.batch_dot(query, key, transpose_b=True)
if edge_scores is not None:
scores = scores + edge_scores
if scaled:
scores = scores / scale
attn_weights = masked_softmax(scores, mask, dtype=dtype, axis=-1)
attn_weights = npx.dropout(attn_weights, p=dropout)
// 3. Calculate the context vector
// (B, N, L_query, L_mem) X (B, N, L_mem, C_V) --> (B, L_query, N * C_V)
if use_einsum:
context_vec = np.einsum("bnij,bnjc->binc", attn_weights, value)
else:
context_vec = npx.batch_dot(attn_weights, value).transpose((0, 2, 1, 3))
context_vec = npx.reshape(context_vec, (-2, -2, -1))
elif layout == "NTK":
// 1. Expand the dimension of the mask:
// (B, L_query, L_mem) --> (B, 1, L_query, L_mem)
if mask is not None:
mask = np.expand_dims(mask, axis=1)
// 2. Calculate the attention weights
// Score: (B, L_query, N, C_Q) X (B, L_mem, N, C_Q) --> (B, N, L_query, L_mem)
if use_einsum:
scores = np.einsum("binc,bjnc->bnij", query, key)
else:
scores = npx.batch_dot(np.swapaxes(query, 1, 2), np.swapaxes(key, 1, 2),
transpose_b=True)
if edge_scores is not None:
scores = scores + edge_scores
if scaled:
scores = scores / scale
attn_weights = masked_softmax(scores, mask, dtype=dtype)
attn_weights = npx.dropout(attn_weights, p=dropout)
// 3. Calculate the context vector
// (B, N, L_query, L_mem) X (B, L_mem, N, C_V) --> (B, L_query, N * C_V)
if use_einsum:
context_vec = np.einsum("bnij,bjnc->binc", attn_weights, value)
else:
context_vec = npx.batch_dot(attn_weights,
np.swapaxes(value, 1, 2)).transpose((0, 2, 1, 3))
context_vec = npx.reshape(context_vec, (-2, -2, -1))
elif layout == "TNK":
// 1. Expand the dimension of the mask:
// (B, L_query, L_mem) --> (B, 1, L_query, L_mem)
if mask is not None:
mask = np.expand_dims(mask, axis=1)
// 2. Calculate the attention weights
// Score: (L_query, B, N, C_Q) X (L_mem, B, N, C_Q) --> (B, N, L_query, L_mem)
// This layout structure can be implemented very efficiently because B, N are consecutive
// to each other. To have a clear picture of what"s happening, we may consider the
// (i, j)th element of the output
// out[i, j, :, :] = query[:, i, j, :] X key[:, i, j, :].T, which is just one GEMM call
// We can thus implement the whole kernel via a single call of batched GEMM with stride.
if use_einsum:
scores = np.einsum("ibnc,jbnc->bnij", query, key)
else:
scores = npx.batch_dot(query.transpose((1, 2, 0, 3)),
key.transpose((1, 2, 3, 0)))
if edge_scores is not None:
scores = scores + edge_scores
if scaled:
scores = scores / scale
attn_weights = masked_softmax(scores, mask, dtype=dtype)
attn_weights = npx.dropout(attn_weights, p=dropout)
// 3. Calculate the context vector
// (B, N, L_query, L_mem) X (L_mem, B, N, C_V) --> (L_query, B, N * C_V)
// Again, we can implement it via a single call to batched GEMM with stride.
// Shape (B, N, L_query, C_V)
if use_einsum:
context_vec = np.einsum("bnij,jbnc->ibnc", attn_weights, value)
else:
context_vec = npx.batch_dot(attn_weights,
value.transpose((1, 2, 0, 3))).transpose((2, 0, 1, 3))
context_vec = npx.reshape(context_vec, (-2, -2, -1))
else:
raise NotImplementedError("layout="{}" is not supported! "
"We only support layout = "NKT", "NTK", and "TNK"."
After Change
else:
scale = math.sqrt(query_head_units)
else:
scale = None
if layout == "NKT":
// 1. Expand the dimension of the mask:
// (B, L_query, L_mem) --> (B, 1, L_query, L_mem)
if mask is not None:
mask = np.expand_dims(mask, axis=1).astype(np.bool)
// 2. Calculate the attention weights
// Score: (B, N, L_query, C_Q) X (B, N, L_mem, C_Q) --> (B, N, L_query, L_mem)
scores = npx.batch_dot(query, key, transpose_b=True)
if edge_scores is not None:
scores = scores + edge_scores
attn_weights = masked_softmax(scores, mask, axis=-1, temperature=scale)
attn_weights = npx.dropout(attn_weights, p=dropout)
// 3. Calculate the context vector
// (B, N, L_query, L_mem) X (B, N, L_mem, C_V) --> (B, L_query, N * C_V)
if use_einsum:
context_vec = np.einsum("bnij,bnjc->binc", attn_weights, value)
else:
context_vec = npx.batch_dot(attn_weights, value).transpose((0, 2, 1, 3))
context_vec = npx.reshape(context_vec, (-2, -2, -1))
elif layout == "NTK":
// 1. Expand the dimension of the mask:
// (B, L_query, L_mem) --> (B, 1, L_query, L_mem)
if mask is not None:
mask = np.expand_dims(mask, axis=1).astype(np.bool)
// 2. Calculate the attention weights
// Score: (B, L_query, N, C_Q) X (B, L_mem, N, C_Q) --> (B, N, L_query, L_mem)
if use_einsum:
scores = np.einsum("binc,bjnc->bnij", query, key)
else:
scores = npx.batch_dot(np.swapaxes(query, 1, 2), np.swapaxes(key, 1, 2),
transpose_b=True)
if edge_scores is not None:
scores = scores + edge_scores
attn_weights = masked_softmax(scores, mask, axis=-1, temperature=scale)
attn_weights = npx.dropout(attn_weights, p=dropout)
// 3. Calculate the context vector
// (B, N, L_query, L_mem) X (B, L_mem, N, C_V) --> (B, L_query, N * C_V)
if use_einsum:
context_vec = np.einsum("bnij,bjnc->binc", attn_weights, value)
else:
context_vec = npx.batch_dot(attn_weights,
np.swapaxes(value, 1, 2)).transpose((0, 2, 1, 3))
context_vec = npx.reshape(context_vec, (-2, -2, -1))
elif layout == "TNK":
// 1. Expand the dimension of the mask:
// (B, L_query, L_mem) --> (B, 1, L_query, L_mem)
if mask is not None:
mask = np.expand_dims(mask, axis=1).astype(np.bool)
// 2. Calculate the attention weights
// Score: (L_query, B, N, C_Q) X (L_mem, B, N, C_Q) --> (B, N, L_query, L_mem)
// This layout structure can be implemented very efficiently because B, N are consecutive
// to each other. To have a clear picture of what"s happening, we may consider the
// (i, j)th element of the output
// out[i, j, :, :] = query[:, i, j, :] X key[:, i, j, :].T, which is just one GEMM call
// We can thus implement the whole kernel via a single call of batched GEMM with stride.
if use_einsum:
scores = np.einsum("ibnc,jbnc->bnij", query, key)
else:
scores = npx.batch_dot(query.transpose((1, 2, 0, 3)),
key.transpose((1, 2, 3, 0)))
if edge_scores is not None:
scores = scores + edge_scores
attn_weights = masked_softmax(scores, mask, axis=-1, temperature=scale)
attn_weights = npx.dropout(attn_weights, p=dropout)
// 3. Calculate the context vector
// (B, N, L_query, L_mem) X (L_mem, B, N, C_V) --> (L_query, B, N * C_V)
// Again, we can implement it via a single call to batched GEMM with stride.
// Shape (B, N, L_query, C_V)
if use_einsum:
context_vec = np.einsum("bnij,jbnc->ibnc", attn_weights, value)
else:
context_vec = npx.batch_dot(attn_weights,
value.transpose((1, 2, 0, 3))).transpose((2, 0, 1, 3))
context_vec = npx.reshape(context_vec, (-2, -2, -1))
else:
raise NotImplementedError("layout="{}" is not supported! "
"We only support layout = "NKT", "NTK", and "TNK"."