ctc_loss

Function ctc_loss 

pub fn ctc_loss<B>(
    log_probs: Tensor<B, 3>,
    targets: Tensor<B, 2, Int>,
    input_lengths: Tensor<B, 1, Int>,
    target_lengths: Tensor<B, 1, Int>,
    blank: usize,
) -> Tensor<B, 1>
where B: Backend,
Expand description

Computes the CTC loss.

§Arguments

  • log_probs - Log-probabilities of shape [T, N, C]
  • targets - Target label indices of shape [N, S]
  • input_lengths - Actual input sequence lengths per batch element [N]
  • target_lengths - Actual target lengths per batch element [N]
  • blank - Index of the blank label

§Returns

Per-sample loss of shape [N]