8e84bd1c5a4cd0e61ee67abc532ab692e5335914,src/gluonnlp/attention_cell.py,,multi_head_dot_attn,#,393

Before Change


    if scaled:
        if query_head_units is None:
            query_shape = npx.shape_array(query)
            scale = np.sqrt(query_shape[-1])
        else:
            scale = math.sqrt(query_head_units)
    else:
        scale = None
    if layout == "NKT":
        // 1. Expand the dimension of the mask:
        //   (B, L_query, L_mem) --> (B, 1, L_query, L_mem)
        if mask is not None:
            mask = np.expand_dims(mask, axis=1)
        // 2. Calculate the attention weights
        //   Score: (B, N, L_query, C_Q) X (B, N, L_mem, C_Q) --> (B, N, L_query, L_mem)
        scores = npx.batch_dot(query, key, transpose_b=True)
        if edge_scores is not None:
            scores = scores + edge_scores
        if scaled:
            scores = scores / scale
        attn_weights = masked_softmax(scores, mask, dtype=dtype, axis=-1)
        attn_weights = npx.dropout(attn_weights, p=dropout)
        // 3. Calculate the context vector
        // (B, N, L_query, L_mem) X (B, N, L_mem, C_V) --> (B, L_query, N * C_V)
        if use_einsum:
            context_vec = np.einsum("bnij,bnjc->binc", attn_weights, value)
        else:
            context_vec = npx.batch_dot(attn_weights, value).transpose((0, 2, 1, 3))
        context_vec = npx.reshape(context_vec, (-2, -2, -1))
    elif layout == "NTK":
        // 1. Expand the dimension of the mask:
        //   (B, L_query, L_mem) --> (B, 1, L_query, L_mem)
        if mask is not None:
            mask = np.expand_dims(mask, axis=1)
        // 2. Calculate the attention weights
        //   Score: (B, L_query, N, C_Q) X (B, L_mem, N, C_Q) --> (B, N, L_query, L_mem)
        if use_einsum:
            scores = np.einsum("binc,bjnc->bnij", query, key)
        else:
            scores = npx.batch_dot(np.swapaxes(query, 1, 2), np.swapaxes(key, 1, 2),
                                     transpose_b=True)
        if edge_scores is not None:
            scores = scores + edge_scores
        if scaled:
            scores = scores / scale
        attn_weights = masked_softmax(scores, mask, dtype=dtype)
        attn_weights = npx.dropout(attn_weights, p=dropout)
        // 3. Calculate the context vector
        // (B, N, L_query, L_mem) X (B, L_mem, N, C_V) --> (B, L_query, N * C_V)
        if use_einsum:
            context_vec = np.einsum("bnij,bjnc->binc", attn_weights, value)
        else:
            context_vec = npx.batch_dot(attn_weights,
                                          np.swapaxes(value, 1, 2)).transpose((0, 2, 1, 3))
        context_vec = npx.reshape(context_vec, (-2, -2, -1))
    elif layout == "TNK":
        // 1. Expand the dimension of the mask:
        //   (B, L_query, L_mem) --> (B, 1, L_query, L_mem)
        if mask is not None:
            mask = np.expand_dims(mask, axis=1)
        // 2. Calculate the attention weights
        //   Score: (L_query, B, N, C_Q) X (L_mem, B, N, C_Q) --> (B, N, L_query, L_mem)
        //   This layout structure can be implemented very efficiently because B, N are consecutive
        //   to each other. To have a clear picture of what"s happening, we may consider the
        //   (i, j)th element of the output
        //       out[i, j, :, :] = query[:, i, j, :] X key[:, i, j, :].T, which is just one GEMM call
        //   We can thus implement the whole kernel via a single call of batched GEMM with stride.
        if use_einsum:
            scores = np.einsum("ibnc,jbnc->bnij", query, key)
        else:
            scores = npx.batch_dot(query.transpose((1, 2, 0, 3)),
                                     key.transpose((1, 2, 3, 0)))
        if edge_scores is not None:
            scores = scores + edge_scores
        if scaled:
            scores = scores / scale
        attn_weights = masked_softmax(scores, mask, dtype=dtype)
        attn_weights = npx.dropout(attn_weights, p=dropout)
        // 3. Calculate the context vector
        // (B, N, L_query, L_mem) X (L_mem, B, N, C_V) --> (L_query, B, N * C_V)
        // Again, we can implement it via a single call to batched GEMM with stride.

        // Shape (B, N, L_query, C_V)
        if use_einsum:
            context_vec = np.einsum("bnij,jbnc->ibnc", attn_weights, value)
        else:
            context_vec = npx.batch_dot(attn_weights,
                                          value.transpose((1, 2, 0, 3))).transpose((2, 0, 1, 3))
        context_vec = npx.reshape(context_vec, (-2, -2, -1))
    else:
        raise NotImplementedError("layout="{}" is not supported! "
                                  "We only support layout = "NKT", "NTK", and "TNK"."

After Change


        else:
            scale = math.sqrt(query_head_units)
    else:
        scale = None
    if layout == "NKT":
        // 1. Expand the dimension of the mask:
        //   (B, L_query, L_mem) --> (B, 1, L_query, L_mem)
        if mask is not None:
            mask = np.expand_dims(mask, axis=1).astype(np.bool)
        // 2. Calculate the attention weights
        //   Score: (B, N, L_query, C_Q) X (B, N, L_mem, C_Q) --> (B, N, L_query, L_mem)
        scores = npx.batch_dot(query, key, transpose_b=True)
        if edge_scores is not None:
            scores = scores + edge_scores
        attn_weights = masked_softmax(scores, mask, axis=-1, temperature=scale)
        attn_weights = npx.dropout(attn_weights, p=dropout)
        // 3. Calculate the context vector
        // (B, N, L_query, L_mem) X (B, N, L_mem, C_V) --> (B, L_query, N * C_V)
        if use_einsum:
            context_vec = np.einsum("bnij,bnjc->binc", attn_weights, value)
        else:
            context_vec = npx.batch_dot(attn_weights, value).transpose((0, 2, 1, 3))
        context_vec = npx.reshape(context_vec, (-2, -2, -1))
    elif layout == "NTK":
        // 1. Expand the dimension of the mask:
        //   (B, L_query, L_mem) --> (B, 1, L_query, L_mem)
        if mask is not None:
            mask = np.expand_dims(mask, axis=1).astype(np.bool)
        // 2. Calculate the attention weights
        //   Score: (B, L_query, N, C_Q) X (B, L_mem, N, C_Q) --> (B, N, L_query, L_mem)
        if use_einsum:
            scores = np.einsum("binc,bjnc->bnij", query, key)
        else:
            scores = npx.batch_dot(np.swapaxes(query, 1, 2), np.swapaxes(key, 1, 2),
                                   transpose_b=True)
        if edge_scores is not None:
            scores = scores + edge_scores
        attn_weights = masked_softmax(scores, mask, axis=-1, temperature=scale)
        attn_weights = npx.dropout(attn_weights, p=dropout)
        // 3. Calculate the context vector
        // (B, N, L_query, L_mem) X (B, L_mem, N, C_V) --> (B, L_query, N * C_V)
        if use_einsum:
            context_vec = np.einsum("bnij,bjnc->binc", attn_weights, value)
        else:
            context_vec = npx.batch_dot(attn_weights,
                                          np.swapaxes(value, 1, 2)).transpose((0, 2, 1, 3))
        context_vec = npx.reshape(context_vec, (-2, -2, -1))
    elif layout == "TNK":
        // 1. Expand the dimension of the mask:
        //   (B, L_query, L_mem) --> (B, 1, L_query, L_mem)
        if mask is not None:
            mask = np.expand_dims(mask, axis=1).astype(np.bool)
        // 2. Calculate the attention weights
        //   Score: (L_query, B, N, C_Q) X (L_mem, B, N, C_Q) --> (B, N, L_query, L_mem)
        //   This layout structure can be implemented very efficiently because B, N are consecutive
        //   to each other. To have a clear picture of what"s happening, we may consider the
        //   (i, j)th element of the output
        //       out[i, j, :, :] = query[:, i, j, :] X key[:, i, j, :].T, which is just one GEMM call
        //   We can thus implement the whole kernel via a single call of batched GEMM with stride.
        if use_einsum:
            scores = np.einsum("ibnc,jbnc->bnij", query, key)
        else:
            scores = npx.batch_dot(query.transpose((1, 2, 0, 3)),
                                     key.transpose((1, 2, 3, 0)))
        if edge_scores is not None:
            scores = scores + edge_scores
        attn_weights = masked_softmax(scores, mask, axis=-1, temperature=scale)
        attn_weights = npx.dropout(attn_weights, p=dropout)
        // 3. Calculate the context vector
        // (B, N, L_query, L_mem) X (L_mem, B, N, C_V) --> (L_query, B, N * C_V)
        // Again, we can implement it via a single call to batched GEMM with stride.

        // Shape (B, N, L_query, C_V)
        if use_einsum:
            context_vec = np.einsum("bnij,jbnc->ibnc", attn_weights, value)
        else:
            context_vec = npx.batch_dot(attn_weights,
                                          value.transpose((1, 2, 0, 3))).transpose((2, 0, 1, 3))
        context_vec = npx.reshape(context_vec, (-2, -2, -1))
    else:
        raise NotImplementedError("layout="{}" is not supported! "
                                  "We only support layout = "NKT", "NTK", and "TNK"."
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 10

Instances


Project Name: dmlc/gluon-nlp
Commit Name: 8e84bd1c5a4cd0e61ee67abc532ab692e5335914
Time: 2020-12-10
Author: xshiab@connect.ust.hk
File Name: src/gluonnlp/attention_cell.py
Class Name:
Method Name: multi_head_dot_attn