Please break down the computational cost of attention.
Answer
Here is the breakdown of the computational cost of attention:
Input dimensions: Sequence length = , hidden dimension =
, number of heads =
.
(1) Linear projections (Q, K, V):
Each input is projected into queries, keys, and values.
Cost: (for all 3 matrices).
(2) Attention score computation (QKᵀ):
Queries:
Keys:
Score matrix:
Cost:
(3) Softmax normalization:
For each row of the score matrix:
Where: = raw score for position
= total sequence length
Cost:
(4) Weighted sum with values (AV):
Attention weights applied to values
:
Cost:
(5) Output projection:
Final linear layer to mix heads back to .
Cost:
Total Complexity:
Putting it all together:
Since , the dominant term is:
Overall, the cost of Multi-Head Attention (MHA)is in the same order as single-head, because per-head dims scale as .
Leave a Reply