DL0042 Attention Computation

Please break down the computational cost of attention.

Answer

Here is the breakdown of the computational cost of attention:
\mathcal{O}(n^2 \cdot d + n \cdot d^2)
Input dimensions: Sequence length =  n , hidden dimension =  d , number of heads =  h .
(1) Linear projections (Q, K, V):
Each input  X \in \mathbb{R}^{n \times d} is projected into queries, keys, and values.
Cost: \mathcal{O}(n \cdot d^2) (for all 3 matrices).

(2) Attention score computation (QKᵀ):
Queries:  Q \in \mathbb{R}^{n \times d_k}
Keys:  K \in \mathbb{R}^{n \times d_k}
Score matrix:
S = QK^\top \in \mathbb{R}^{n \times n}
Cost: \mathcal{O}(n^2 \cdot d_k)

(3) Softmax normalization:
For each row of the score matrix:
\mbox{Softmax}(s_i) = \frac{e^{s_i}}{\sum_{j=1}^n e^{s_j}}
Where:
 s_i = raw score for position  i
 n = total sequence length
Cost: \mathcal{O}(n^2)

(4) Weighted sum with values (AV):
Attention weights  A \in \mathbb{R}^{n \times n} applied to values  V \in \mathbb{R}^{n \times d_v} :
O = AV
Cost: \mathcal{O}(n^2 \cdot d_v)

(5) Output projection:
Final linear layer to mix heads back to  d .
Cost: \mathcal{O}(n \cdot d^2)

Total Complexity:
Putting it all together:
\mathcal{O}(n \cdot d^2) + \mathcal{O}(n^2 \cdot d_k) + \mathcal{O}(n^2 \cdot d_v)
Since  d_k, d_v \approx d/h , the dominant term is:
\mathcal{O}(n^2 \cdot d + n \cdot d^2)
Overall, the cost of Multi-Head Attention (MHA)is in the same order as single-head, because per-head dims scale as  d/h .


Login to view more content

Did you solve the problem?

Comments

Leave a Reply

Your email address will not be published. Required fields are marked *