attention

Function attention 

pub fn attention<B>(
    query: Tensor<B, 4>,
    key: Tensor<B, 4>,
    value: Tensor<B, 4>,
    mask: Option<Tensor<B, 4, Bool>>,
) -> Tensor<B, 4>
where B: Backend,
Expand description

Computes scaled dot-product attention: softmax(QKᵗ / √d) · V, optionally applying a 4D mask to the attention scores.

§Arguments

  • query: Query tensor of shape [batch_size, num_heads, seq_len_q, head_dim]
  • key: Key tensor of shape [batch_size, num_heads, seq_len_k, head_dim]
  • value: Value tensor of shape [batch_size, num_heads, seq_len_k, head_dim]
  • mask: Optional boolean mask of shape [batch_size, num_heads, seq_len_q, seq_len_k], where true indicates positions to mask (i.e. set to -∞ before softmax).

§Returns

A tensor of shape [batch_size, num_heads, seq_len_q, head_dim] representing the attended context per head.

§Note

This implementation does not support dropout and is intended for inference or use cases where dropout is not needed.