attention

Function attention 

pub fn attention<B>(
    query: Tensor<B, 4>,
    key: Tensor<B, 4>,
    value: Tensor<B, 4>,
    mask: Option<Tensor<B, 4, Bool>>,
    attn_bias: Option<Tensor<B, 4>>,
    options: AttentionModuleOptions,
) -> Tensor<B, 4>
where B: Backend,
Expand description

Computes scaled dot-product attention: softmax(QKᵗ * scale) · V, where scale defaults to 1/sqrt(head_dim) (configurable via options.scale). Optionally applies masking, additive bias, causal masking, and softcap.

§Arguments

  • query: Query tensor of shape [batch_size, num_heads, seq_len_q, head_dim]
  • key: Key tensor of shape [batch_size, num_heads, seq_len_k, head_dim]
  • value: Value tensor of shape [batch_size, num_heads, seq_len_k, val_dim]
  • mask: Optional boolean mask of shape [batch_size, num_heads, seq_len_q, seq_len_k], where true indicates positions to mask (i.e. set to -inf before softmax).
  • attn_bias: Optional float tensor of shape [batch_size, num_heads, seq_len_q, seq_len_k] added to the attention scores before softmax (e.g. ALiBi, relative position biases).
  • options: Additional attention options (custom scale, softcap, causal masking).

§Returns

A tensor of shape [batch_size, num_heads, seq_len_q, val_dim] representing the attended context per head.

§Note

This implementation does not support dropout and is intended for inference or use cases where dropout is not needed.