Function attention
pub fn attention<B>(
query: Tensor<B, 4>,
key: Tensor<B, 4>,
value: Tensor<B, 4>,
mask: Option<Tensor<B, 4, Bool>>,
attn_bias: Option<Tensor<B, 4>>,
options: AttentionModuleOptions,
) -> Tensor<B, 4>where
B: Backend,Expand description
Computes scaled dot-product attention: softmax(QKᵗ * scale) · V,
where scale defaults to 1/sqrt(head_dim) (configurable via options.scale).
Optionally applies masking, additive bias, causal masking, and softcap.
§Arguments
query: Query tensor of shape[batch_size, num_heads, seq_len_q, head_dim]key: Key tensor of shape[batch_size, num_heads, seq_len_k, head_dim]value: Value tensor of shape[batch_size, num_heads, seq_len_k, val_dim]mask: Optional boolean mask of shape[batch_size, num_heads, seq_len_q, seq_len_k], wheretrueindicates positions to mask (i.e. set to -inf before softmax).attn_bias: Optional float tensor of shape[batch_size, num_heads, seq_len_q, seq_len_k]added to the attention scores before softmax (e.g. ALiBi, relative position biases).options: Additional attention options (custom scale, softcap, causal masking).
§Returns
A tensor of shape [batch_size, num_heads, seq_len_q, val_dim]
representing the attended context per head.
§Note
This implementation does not support dropout and is intended for inference or use cases where dropout is not needed.