Struct CTCLoss
pub struct CTCLoss { /* private fields */ }Expand description
Computes the Connectionist Temporal Classification (CTC) loss.
Calculates the loss between a continuous (unsegmented) time series and a target sequence. CTC sums over the probability of all possible alignments of the input to the target, producing a loss value that is differentiable with respect to each input node.
The input to this loss is expected to be log-probabilities (e.g,, via log_softmax),
not raw logits.
§References
§Example
ⓘ
use burn::tensor::{Tensor, Int};
use burn::tensor::activation::log_softmax;
use burn::nn::loss::{CTCLossConfig, CTCLoss};
let device = Default::default();
// Initialize CTC Loss with default configuration
let ctc_loss = CTCLossConfig::new().init();
// Initialize CTC Loss with custom configuration
let ctc_loss = CTCLossConfig::new()
.with_blank(1)
.with_zero_infinity(true)
.init();
// Prepare inputs (Logits shape: [Time, Batch, Class])
// In your actual code, the logits would be the output of your model
let logits = Tensor::<B, 3>::ones([10, 2, 5], &device);
let log_probs = log_softmax(logits, 2);
// Targets shape: [Batch, Max_Target_Len]
// Note: Targets should not contain the blank index (1).
let targets = Tensor::<B, 2, Int>::from_data([[0, 2], [3, 4]], &device);
// Lengths shape: [Batch]
let input_lengths = Tensor::<B, 1, Int>::from_data([10, 8], &device);
let target_lengths = Tensor::<B, 1, Int>::from_data([2, 2], &device);
// Compute loss
let loss = ctc_loss.forward(log_probs, targets, input_lengths, target_lengths);Implementations§
§impl CTCLoss
impl CTCLoss
pub fn forward<B>(
&self,
log_probs: Tensor<B, 3>,
targets: Tensor<B, 2, Int>,
input_lengths: Tensor<B, 1, Int>,
target_lengths: Tensor<B, 1, Int>,
) -> Tensor<B, 1>where
B: Backend,
pub fn forward<B>(
&self,
log_probs: Tensor<B, 3>,
targets: Tensor<B, 2, Int>,
input_lengths: Tensor<B, 1, Int>,
target_lengths: Tensor<B, 1, Int>,
) -> Tensor<B, 1>where
B: Backend,
Computes the CTC loss for the input log-probabilities and targets with no reduction applied.
§Arguments
log_probs: The log-probabilities of the outputs (e.g., fromlog_softmax).targets: A 2D tensor containing the target class indices. These indices should not include the blank index used in CTC loss. The targets are padded to the length of the longest sequence.input_lengths: A 1D tensor containing the actual length of the input sequence for each batch. This allows retrieving the actual sequence of log-probabilities fromlog_probsif the batch contains sequences of varying lengths.target_lengths: A 1D tensor containing the actual length of the target sequence for each target sequence intargets.
§Returns
- A 1D tensor of shape
[batch_size]containing the loss for each sample.
§Shapes
log_probs:[time_steps, batch_size, num_classes]wherenum_classesincludes blank.targets:[batch_size, max_target_length]input_lengths:[batch_size]target_lengths:[batch_size]
pub fn forward_with_reduction<B>(
&self,
log_probs: Tensor<B, 3>,
targets: Tensor<B, 2, Int>,
input_lengths: Tensor<B, 1, Int>,
target_lengths: Tensor<B, 1, Int>,
reduction: Reduction,
) -> Tensor<B, 1>where
B: Backend,
pub fn forward_with_reduction<B>(
&self,
log_probs: Tensor<B, 3>,
targets: Tensor<B, 2, Int>,
input_lengths: Tensor<B, 1, Int>,
target_lengths: Tensor<B, 1, Int>,
reduction: Reduction,
) -> Tensor<B, 1>where
B: Backend,
Computes the CTC loss for the input log-probabilities and targets with reduction.
§Arguments
log_probs: The log-probabilities of the outputs (e.g., fromlog_softmax).targets: A 2D tensor containing the target class indices. These indices should not include the blank index used in CTC loss. The targets are padded to the length of the longest sequence.input_lengths: A 1D tensor containing the actual length of the input sequence for each batch. This allows retrieving the actual sequence of log-probabilities fromlog_probsif the batch contains sequences of varying lengths.target_lengths: A 1D tensor containing the actual length of the target sequence for each target sequence intargets.reduction: The reduction stratey to apply to the loss tensor containing the CTC loss values for each sample (e.g., mean, sum). For the mean reduction strategy, the output losses will be divided by the target lengths and then the mean over the batch is taken. This follows PyTorch’s behavior.
§Returns
- A 1D tensor of shape
[1]containing the reduced loss value.
§Shapes
log_probs:[time_steps, batch_size, num_classes]wherenum_classesincludes blank.targets:[batch_size, max_target_length]input_lengths:[batch_size]target_lengths:[batch_size]
§Panics
- If
reductionis not one ofReduction::Auto,Reduction::Mean, andReduction::Sum. - If
blankindex is greater than or equal tonum_classes. - If the batch dimension of
log_probs,targets,input_lengths, andtarget_lengthsdo not match.
Trait Implementations§
§impl<B> AutodiffModule<B> for CTCLosswhere
B: AutodiffBackend,
impl<B> AutodiffModule<B> for CTCLosswhere
B: AutodiffBackend,
§type InnerModule = CTCLoss
type InnerModule = CTCLoss
Inner module without auto-differentiation.
§fn valid(&self) -> <CTCLoss as AutodiffModule<B>>::InnerModule
fn valid(&self) -> <CTCLoss as AutodiffModule<B>>::InnerModule
Returns the same module, but on the inner backend without auto-differentiation.
§fn from_inner(module: <CTCLoss as AutodiffModule<B>>::InnerModule) -> CTCLoss
fn from_inner(module: <CTCLoss as AutodiffModule<B>>::InnerModule) -> CTCLoss
Wraps an inner module back into an auto-diff module.
§impl<B> Module<B> for CTCLosswhere
B: Backend,
impl<B> Module<B> for CTCLosswhere
B: Backend,
§type Record = EmptyRecord
type Record = EmptyRecord
Type to save and load the module.
§fn visit<V>(&self, _visitor: &mut V)where
V: ModuleVisitor<B>,
fn visit<V>(&self, _visitor: &mut V)where
V: ModuleVisitor<B>,
Visit each tensor parameter in the module with a visitor.
§fn map<M>(self, _mapper: &mut M) -> CTCLosswhere
M: ModuleMapper<B>,
fn map<M>(self, _mapper: &mut M) -> CTCLosswhere
M: ModuleMapper<B>,
Map each tensor parameter in the module with a mapper.
§fn load_record(self, _record: <CTCLoss as Module<B>>::Record) -> CTCLoss
fn load_record(self, _record: <CTCLoss as Module<B>>::Record) -> CTCLoss
Load the module state from a record.
§fn into_record(self) -> <CTCLoss as Module<B>>::Record
fn into_record(self) -> <CTCLoss as Module<B>>::Record
Convert the module into a record containing the state.
§fn to_device(self, _: &<B as BackendTypes>::Device) -> CTCLoss
fn to_device(self, _: &<B as BackendTypes>::Device) -> CTCLoss
Move the module and all of its sub-modules to the given device. Read more
§fn fork(self, _: &<B as BackendTypes>::Device) -> CTCLoss
fn fork(self, _: &<B as BackendTypes>::Device) -> CTCLoss
Fork the module and all of its sub-modules to the given device. Read more
§fn collect_devices(
&self,
devices: Vec<<B as BackendTypes>::Device>,
) -> Vec<<B as BackendTypes>::Device>
fn collect_devices( &self, devices: Vec<<B as BackendTypes>::Device>, ) -> Vec<<B as BackendTypes>::Device>
Return all the devices found in the underneath module tree added to the given vector
without duplicates.
§fn devices(&self) -> Vec<<B as BackendTypes>::Device>
fn devices(&self) -> Vec<<B as BackendTypes>::Device>
Return all the devices found in the underneath module tree without duplicates.
§fn train<AB>(self) -> Self::TrainModulewhere
AB: AutodiffBackend<InnerBackend = B>,
Self: HasAutodiffModule<AB>,
fn train<AB>(self) -> Self::TrainModulewhere
AB: AutodiffBackend<InnerBackend = B>,
Self: HasAutodiffModule<AB>,
Move the module and all of its sub-modules to the autodiff backend. Read more
§fn num_params(&self) -> usize
fn num_params(&self) -> usize
Get the number of parameters the module has, including all of its sub-modules.
§fn save_file<FR, PB>(
self,
file_path: PB,
recorder: &FR,
) -> Result<(), RecorderError>
fn save_file<FR, PB>( self, file_path: PB, recorder: &FR, ) -> Result<(), RecorderError>
Save the module to a file using the provided file recorder. Read more
§fn load_file<FR, PB>(
self,
file_path: PB,
recorder: &FR,
device: &<B as BackendTypes>::Device,
) -> Result<Self, RecorderError>
fn load_file<FR, PB>( self, file_path: PB, recorder: &FR, device: &<B as BackendTypes>::Device, ) -> Result<Self, RecorderError>
Load the module from a file using the provided file recorder. Read more
§fn quantize_weights(self, quantizer: &mut Quantizer) -> Self
fn quantize_weights(self, quantizer: &mut Quantizer) -> Self
Quantize the weights of the module.
§impl ModuleDisplay for CTCLoss
impl ModuleDisplay for CTCLoss
§fn format(&self, passed_settings: DisplaySettings) -> String
fn format(&self, passed_settings: DisplaySettings) -> String
Formats the module with provided display settings. Read more
§fn custom_settings(&self) -> Option<DisplaySettings>
fn custom_settings(&self) -> Option<DisplaySettings>
Custom display settings for the module. Read more
Auto Trait Implementations§
impl Freeze for CTCLoss
impl RefUnwindSafe for CTCLoss
impl Send for CTCLoss
impl Sync for CTCLoss
impl Unpin for CTCLoss
impl UnwindSafe for CTCLoss
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more
§impl<C> CloneExpand for Cwhere
C: Clone,
impl<C> CloneExpand for Cwhere
C: Clone,
fn __expand_clone_method(&self, _scope: &mut Scope) -> C
Source§impl<T> CloneToUninit for Twhere
T: Clone,
impl<T> CloneToUninit for Twhere
T: Clone,
§impl<T> Instrument for T
impl<T> Instrument for T
§fn instrument(self, span: Span) -> Instrumented<Self>
fn instrument(self, span: Span) -> Instrumented<Self>
§fn in_current_span(self) -> Instrumented<Self>
fn in_current_span(self) -> Instrumented<Self>
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
Converts
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
Converts
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read more§impl<T> Pointable for T
impl<T> Pointable for T
§impl<T> ToCompactString for Twhere
T: Display,
impl<T> ToCompactString for Twhere
T: Display,
§fn try_to_compact_string(&self) -> Result<CompactString, ToCompactStringError>
fn try_to_compact_string(&self) -> Result<CompactString, ToCompactStringError>
Fallible version of [
ToCompactString::to_compact_string()] Read more§fn to_compact_string(&self) -> CompactString
fn to_compact_string(&self) -> CompactString
Converts the given value to a [
CompactString]. Read more