Struct SupervisedTraining
pub struct SupervisedTraining<LC>where
LC: LearningComponentsTypes,{ /* private fields */ }Expand description
Structure to configure and launch supervised learning trainings.
Implementations§
§impl<B, LR, M, O> SupervisedTraining<LearningComponentsMarker<B, LR, M, O>>where
B: AutodiffBackend,
LR: LrScheduler + 'static,
M: TrainStep + AutodiffModule<B> + Display + 'static,
<M as AutodiffModule<B>>::InnerModule: InferenceStep,
O: Optimizer<M, B> + 'static,
impl<B, LR, M, O> SupervisedTraining<LearningComponentsMarker<B, LR, M, O>>where
B: AutodiffBackend,
LR: LrScheduler + 'static,
M: TrainStep + AutodiffModule<B> + Display + 'static,
<M as AutodiffModule<B>>::InnerModule: InferenceStep,
O: Optimizer<M, B> + 'static,
pub fn new(
directory: impl AsRef<Path>,
dataloader_train: Arc<dyn DataLoader<B, <M as TrainStep>::Input>>,
dataloader_valid: Arc<dyn DataLoader<<B as AutodiffBackend>::InnerBackend, <<M as AutodiffModule<B>>::InnerModule as InferenceStep>::Input>>,
) -> SupervisedTraining<LearningComponentsMarker<B, LR, M, O>>
pub fn new( directory: impl AsRef<Path>, dataloader_train: Arc<dyn DataLoader<B, <M as TrainStep>::Input>>, dataloader_valid: Arc<dyn DataLoader<<B as AutodiffBackend>::InnerBackend, <<M as AutodiffModule<B>>::InnerModule as InferenceStep>::Input>>, ) -> SupervisedTraining<LearningComponentsMarker<B, LR, M, O>>
Creates a new runner for a supervised training.
§Arguments
directory- The directory to save the checkpoints.dataloader_train- The dataloader for the training split.dataloader_valid- The dataloader for the validation split.
§impl<LC> SupervisedTraining<LC>where
LC: LearningComponentsTypes,
impl<LC> SupervisedTraining<LC>where
LC: LearningComponentsTypes,
pub fn with_training_strategy(
self,
training_strategy: TrainingStrategy<LC>,
) -> SupervisedTraining<LC>
pub fn with_training_strategy( self, training_strategy: TrainingStrategy<LC>, ) -> SupervisedTraining<LC>
Replace the default training strategy (SingleDeviceTrainingStrategy) with the provided ones.
§Arguments
training_strategy- The training strategy.
pub fn with_metric_logger<ML>(self, logger: ML) -> SupervisedTraining<LC>where
ML: MetricLogger + 'static,
pub fn with_metric_logger<ML>(self, logger: ML) -> SupervisedTraining<LC>where
ML: MetricLogger + 'static,
pub fn with_checkpointing_strategy<CS>(
self,
strategy: CS,
) -> SupervisedTraining<LC>where
CS: CheckpointingStrategy + 'static,
pub fn with_checkpointing_strategy<CS>(
self,
strategy: CS,
) -> SupervisedTraining<LC>where
CS: CheckpointingStrategy + 'static,
Update the checkpointing_strategy.
pub fn renderer<MR>(self, renderer: MR) -> SupervisedTraining<LC>where
MR: MetricsRenderer + 'static,
pub fn renderer<MR>(self, renderer: MR) -> SupervisedTraining<LC>where
MR: MetricsRenderer + 'static,
pub fn metrics<Me>(self, metrics: Me) -> SupervisedTraining<LC>where
Me: MetricRegistration<LC>,
pub fn metrics<Me>(self, metrics: Me) -> SupervisedTraining<LC>where
Me: MetricRegistration<LC>,
Register all metrics as numeric for the training and validation set.
pub fn metrics_text<Me>(self, metrics: Me) -> SupervisedTraining<LC>where
Me: TextMetricRegistration<LC>,
pub fn metrics_text<Me>(self, metrics: Me) -> SupervisedTraining<LC>where
Me: TextMetricRegistration<LC>,
Register all metrics as numeric for the training and validation set.
pub fn metric_train<Me>(self, metric: Me) -> SupervisedTraining<LC>where
Me: Metric + 'static,
<<<LC as LearningComponentsTypes>::TrainingModel as TrainStep>::Output as ItemLazy>::ItemSync: Adaptor<<Me as Metric>::Input>,
pub fn metric_train<Me>(self, metric: Me) -> SupervisedTraining<LC>where
Me: Metric + 'static,
<<<LC as LearningComponentsTypes>::TrainingModel as TrainStep>::Output as ItemLazy>::ItemSync: Adaptor<<Me as Metric>::Input>,
Register a training metric.
pub fn metric_valid<Me>(self, metric: Me) -> SupervisedTraining<LC>where
Me: Metric + 'static,
<<<LC as LearningComponentsTypes>::InferenceModel as InferenceStep>::Output as ItemLazy>::ItemSync: Adaptor<<Me as Metric>::Input>,
pub fn metric_valid<Me>(self, metric: Me) -> SupervisedTraining<LC>where
Me: Metric + 'static,
<<<LC as LearningComponentsTypes>::InferenceModel as InferenceStep>::Output as ItemLazy>::ItemSync: Adaptor<<Me as Metric>::Input>,
Register a validation metric.
pub fn grads_accumulation(self, accumulation: usize) -> SupervisedTraining<LC>
pub fn grads_accumulation(self, accumulation: usize) -> SupervisedTraining<LC>
Enable gradients accumulation.
§Notes
When you enable gradients accumulation, the gradients object used by the optimizer will be the sum of all gradients generated by each backward pass. It might be a good idea to reduce the learning to compensate.
The effect is similar to increasing the batch size and the learning rate by the accumulation
amount.
pub fn metric_train_numeric<Me>(self, metric: Me) -> SupervisedTraining<LC>
pub fn metric_train_numeric<Me>(self, metric: Me) -> SupervisedTraining<LC>
pub fn metric_valid_numeric<Me>(self, metric: Me) -> SupervisedTraining<LC>where
Me: Metric + Numeric + 'static,
<<<LC as LearningComponentsTypes>::InferenceModel as InferenceStep>::Output as ItemLazy>::ItemSync: Adaptor<<Me as Metric>::Input>,
pub fn metric_valid_numeric<Me>(self, metric: Me) -> SupervisedTraining<LC>where
Me: Metric + Numeric + 'static,
<<<LC as LearningComponentsTypes>::InferenceModel as InferenceStep>::Output as ItemLazy>::ItemSync: Adaptor<<Me as Metric>::Input>,
pub fn num_epochs(self, num_epochs: usize) -> SupervisedTraining<LC>
pub fn num_epochs(self, num_epochs: usize) -> SupervisedTraining<LC>
The number of epochs the training should last.
pub fn checkpoint(self, checkpoint: usize) -> SupervisedTraining<LC>
pub fn checkpoint(self, checkpoint: usize) -> SupervisedTraining<LC>
The epoch from which the training must resume.
pub fn interrupter(&self) -> Interrupter
pub fn interrupter(&self) -> Interrupter
Provides a handle that can be used to interrupt training.
pub fn with_interrupter(
self,
interrupter: Interrupter,
) -> SupervisedTraining<LC>
pub fn with_interrupter( self, interrupter: Interrupter, ) -> SupervisedTraining<LC>
Override the handle for stopping training with an externally provided handle
pub fn early_stopping<Strategy>(
self,
strategy: Strategy,
) -> SupervisedTraining<LC>
pub fn early_stopping<Strategy>( self, strategy: Strategy, ) -> SupervisedTraining<LC>
Register an early stopping strategy to stop the training when the conditions are meet.
pub fn with_application_logger(
self,
logger: Option<Box<dyn ApplicationLoggerInstaller>>,
) -> SupervisedTraining<LC>
pub fn with_application_logger( self, logger: Option<Box<dyn ApplicationLoggerInstaller>>, ) -> SupervisedTraining<LC>
By default, Rust logs are captured and written into
experiment.log. If disabled, standard Rust log handling
will apply.
pub fn with_file_checkpointer<FR>(self, recorder: FR) -> SupervisedTraining<LC>where
FR: FileRecorder<<LC as LearningComponentsTypes>::Backend> + 'static + FileRecorder<<<LC as LearningComponentsTypes>::Backend as AutodiffBackend>::InnerBackend>,
pub fn with_file_checkpointer<FR>(self, recorder: FR) -> SupervisedTraining<LC>where
FR: FileRecorder<<LC as LearningComponentsTypes>::Backend> + 'static + FileRecorder<<<LC as LearningComponentsTypes>::Backend as AutodiffBackend>::InnerBackend>,
pub fn summary(self) -> SupervisedTraining<LC>
pub fn summary(self) -> SupervisedTraining<LC>
Enable the training summary report.
The summary will be displayed after .fit(), when the renderer is dropped.
§impl<LC> SupervisedTraining<LC>where
LC: LearningComponentsTypes + Send + 'static,
impl<LC> SupervisedTraining<LC>where
LC: LearningComponentsTypes + Send + 'static,
pub fn launch(
self,
learner: Learner<LC>,
) -> LearningResult<<LC as LearningComponentsTypes>::InferenceModel>
pub fn launch( self, learner: Learner<LC>, ) -> LearningResult<<LC as LearningComponentsTypes>::InferenceModel>
Launch this training with the given Learner.
Auto Trait Implementations§
impl<LC> Freeze for SupervisedTraining<LC>
impl<LC> !RefUnwindSafe for SupervisedTraining<LC>
impl<LC> !Send for SupervisedTraining<LC>
impl<LC> !Sync for SupervisedTraining<LC>
impl<LC> Unpin for SupervisedTraining<LC>
impl<LC> !UnwindSafe for SupervisedTraining<LC>
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
§impl<T> Instrument for T
impl<T> Instrument for T
§fn instrument(self, span: Span) -> Instrumented<Self>
fn instrument(self, span: Span) -> Instrumented<Self>
§fn in_current_span(self) -> Instrumented<Self>
fn in_current_span(self) -> Instrumented<Self>
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read more