Struct burn::train::LearnerBuilder
pub struct LearnerBuilder<B, T, V, M, O, S>where
T: Send + 'static,
V: Send + 'static,
B: AutodiffBackend,
M: AutodiffModule<B>,
O: Optimizer<M, B>,
S: LrScheduler,{ /* private fields */ }
Expand description
Struct to configure and create a learner.
Implementations§
§impl<B, T, V, M, O, S> LearnerBuilder<B, T, V, M, O, S>where
B: AutodiffBackend,
T: Send + 'static,
V: Send + 'static,
M: AutodiffModule<B> + Display + 'static,
O: Optimizer<M, B>,
S: LrScheduler,
impl<B, T, V, M, O, S> LearnerBuilder<B, T, V, M, O, S>where
B: AutodiffBackend,
T: Send + 'static,
V: Send + 'static,
M: AutodiffModule<B> + Display + 'static,
O: Optimizer<M, B>,
S: LrScheduler,
pub fn new(directory: impl AsRef<Path>) -> LearnerBuilder<B, T, V, M, O, S>
pub fn new(directory: impl AsRef<Path>) -> LearnerBuilder<B, T, V, M, O, S>
pub fn metric_loggers<MT, MV>(
self,
logger_train: MT,
logger_valid: MV,
) -> LearnerBuilder<B, T, V, M, O, S>where
MT: MetricLogger + 'static,
MV: MetricLogger + 'static,
pub fn metric_loggers<MT, MV>(
self,
logger_train: MT,
logger_valid: MV,
) -> LearnerBuilder<B, T, V, M, O, S>where
MT: MetricLogger + 'static,
MV: MetricLogger + 'static,
Replace the default metric loggers with the provided ones.
§Arguments
logger_train
- The training logger.logger_valid
- The validation logger.
pub fn with_checkpointing_strategy<CS>(
self,
strategy: CS,
) -> LearnerBuilder<B, T, V, M, O, S>where
CS: CheckpointingStrategy + 'static,
pub fn with_checkpointing_strategy<CS>(
self,
strategy: CS,
) -> LearnerBuilder<B, T, V, M, O, S>where
CS: CheckpointingStrategy + 'static,
Update the checkpointing_strategy.
pub fn renderer<MR>(self, renderer: MR) -> LearnerBuilder<B, T, V, M, O, S>where
MR: MetricsRenderer + 'static,
pub fn renderer<MR>(self, renderer: MR) -> LearnerBuilder<B, T, V, M, O, S>where
MR: MetricsRenderer + 'static,
pub fn metric_train<Me>(self, metric: Me) -> LearnerBuilder<B, T, V, M, O, S>
pub fn metric_train<Me>(self, metric: Me) -> LearnerBuilder<B, T, V, M, O, S>
Register a training metric.
pub fn metric_valid<Me>(self, metric: Me) -> LearnerBuilder<B, T, V, M, O, S>
pub fn metric_valid<Me>(self, metric: Me) -> LearnerBuilder<B, T, V, M, O, S>
Register a validation metric.
pub fn grads_accumulation(
self,
accumulation: usize,
) -> LearnerBuilder<B, T, V, M, O, S>
pub fn grads_accumulation( self, accumulation: usize, ) -> LearnerBuilder<B, T, V, M, O, S>
Enable gradients accumulation.
§Notes
When you enable gradients accumulation, the gradients object used by the optimizer will be the sum of all gradients generated by each backward pass. It might be a good idea to reduce the learning to compensate.
The effect is similar to increasing the batch size
and the learning rate
by the accumulation
amount.
pub fn metric_train_numeric<Me>(
self,
metric: Me,
) -> LearnerBuilder<B, T, V, M, O, S>
pub fn metric_train_numeric<Me>( self, metric: Me, ) -> LearnerBuilder<B, T, V, M, O, S>
pub fn metric_valid_numeric<Me>(
self,
metric: Me,
) -> LearnerBuilder<B, T, V, M, O, S>
pub fn metric_valid_numeric<Me>( self, metric: Me, ) -> LearnerBuilder<B, T, V, M, O, S>
pub fn num_epochs(self, num_epochs: usize) -> LearnerBuilder<B, T, V, M, O, S>
pub fn num_epochs(self, num_epochs: usize) -> LearnerBuilder<B, T, V, M, O, S>
The number of epochs the training should last.
pub fn devices(
self,
devices: Vec<<B as Backend>::Device>,
) -> LearnerBuilder<B, T, V, M, O, S>
pub fn devices( self, devices: Vec<<B as Backend>::Device>, ) -> LearnerBuilder<B, T, V, M, O, S>
Run the training loop on multiple devices.
pub fn checkpoint(self, checkpoint: usize) -> LearnerBuilder<B, T, V, M, O, S>
pub fn checkpoint(self, checkpoint: usize) -> LearnerBuilder<B, T, V, M, O, S>
The epoch from which the training must resume.
pub fn interrupter(&self) -> TrainingInterrupter
pub fn interrupter(&self) -> TrainingInterrupter
Provides a handle that can be used to interrupt training.
pub fn early_stopping<Strategy>(
self,
strategy: Strategy,
) -> LearnerBuilder<B, T, V, M, O, S>where
Strategy: EarlyStoppingStrategy + 'static,
pub fn early_stopping<Strategy>(
self,
strategy: Strategy,
) -> LearnerBuilder<B, T, V, M, O, S>where
Strategy: EarlyStoppingStrategy + 'static,
Register an early stopping strategy to stop the training when the conditions are meet.
pub fn with_application_logger(
self,
logger: Option<Box<dyn ApplicationLoggerInstaller>>,
) -> LearnerBuilder<B, T, V, M, O, S>
pub fn with_application_logger( self, logger: Option<Box<dyn ApplicationLoggerInstaller>>, ) -> LearnerBuilder<B, T, V, M, O, S>
By default, Rust logs are captured and written into
experiment.log
. If disabled, standard Rust log handling
will apply.
pub fn with_file_checkpointer<FR>(
self,
recorder: FR,
) -> LearnerBuilder<B, T, V, M, O, S>where
FR: FileRecorder<B> + 'static + FileRecorder<<B as AutodiffBackend>::InnerBackend>,
<O as Optimizer<M, B>>::Record: 'static,
<M as Module<B>>::Record: 'static,
<S as LrScheduler>::Record<B>: 'static,
pub fn with_file_checkpointer<FR>(
self,
recorder: FR,
) -> LearnerBuilder<B, T, V, M, O, S>where
FR: FileRecorder<B> + 'static + FileRecorder<<B as AutodiffBackend>::InnerBackend>,
<O as Optimizer<M, B>>::Record: 'static,
<M as Module<B>>::Record: 'static,
<S as LrScheduler>::Record<B>: 'static,
pub fn summary(self) -> LearnerBuilder<B, T, V, M, O, S>
pub fn summary(self) -> LearnerBuilder<B, T, V, M, O, S>
Enable the training summary report.
The summary will be displayed at the end of .fit()
.
pub fn build(
self,
model: M,
optim: O,
lr_scheduler: S,
) -> Learner<LearnerComponentsMarker<B, S, M, O, AsyncCheckpointer<<M as Module<B>>::Record, B>, AsyncCheckpointer<<O as Optimizer<M, B>>::Record, B>, AsyncCheckpointer<<S as LrScheduler>::Record<B>, B>, FullEventProcessor<T, V>, Box<dyn CheckpointingStrategy>>>
pub fn build( self, model: M, optim: O, lr_scheduler: S, ) -> Learner<LearnerComponentsMarker<B, S, M, O, AsyncCheckpointer<<M as Module<B>>::Record, B>, AsyncCheckpointer<<O as Optimizer<M, B>>::Record, B>, AsyncCheckpointer<<S as LrScheduler>::Record<B>, B>, FullEventProcessor<T, V>, Box<dyn CheckpointingStrategy>>>
Create the learner from a model and an optimizer. The learning rate scheduler can also be a simple learning rate.
Auto Trait Implementations§
impl<B, T, V, M, O, S> Freeze for LearnerBuilder<B, T, V, M, O, S>
impl<B, T, V, M, O, S> !RefUnwindSafe for LearnerBuilder<B, T, V, M, O, S>
impl<B, T, V, M, O, S> !Send for LearnerBuilder<B, T, V, M, O, S>
impl<B, T, V, M, O, S> !Sync for LearnerBuilder<B, T, V, M, O, S>
impl<B, T, V, M, O, S> Unpin for LearnerBuilder<B, T, V, M, O, S>
impl<B, T, V, M, O, S> !UnwindSafe for LearnerBuilder<B, T, V, M, O, S>
Blanket Implementations§
source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
§impl<T> Instrument for T
impl<T> Instrument for T
§fn instrument(self, span: Span) -> Instrumented<Self>
fn instrument(self, span: Span) -> Instrumented<Self>
§fn in_current_span(self) -> Instrumented<Self>
fn in_current_span(self) -> Instrumented<Self>
source§impl<T> IntoEither for T
impl<T> IntoEither for T
source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self
into a Left
variant of Either<Self, Self>
if into_left
is true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read moresource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self
into a Left
variant of Either<Self, Self>
if into_left(&self)
returns true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read more