Function attention
pub fn attention<B>(
query: Tensor<B, 4>,
key: Tensor<B, 4>,
value: Tensor<B, 4>,
mask: Option<Tensor<B, 4, Bool>>,
) -> Tensor<B, 4>where
B: Backend,Expand description
Computes scaled dot-product attention: softmax(QKᵗ / √d) · V, optionally applying a 4D mask to the attention scores.
§Arguments
query: Query tensor of shape[batch_size, num_heads, seq_len_q, head_dim]key: Key tensor of shape[batch_size, num_heads, seq_len_k, head_dim]value: Value tensor of shape[batch_size, num_heads, seq_len_k, head_dim]mask: Optional boolean mask of shape[batch_size, num_heads, seq_len_q, seq_len_k], wheretrueindicates positions to mask (i.e. set to -∞ before softmax).
§Returns
A tensor of shape [batch_size, num_heads, seq_len_q, head_dim]
representing the attended context per head.
§Note
This implementation does not support dropout and is intended for inference or use cases where dropout is not needed.