Struct burn::train::LearnerBuilder

pub struct LearnerBuilder<B, T, V, M, O, S>
where T: Send + 'static, V: Send + 'static, B: AutodiffBackend, M: AutodiffModule<B>, O: Optimizer<M, B>, S: LrScheduler,
{ /* private fields */ }
Expand description

Struct to configure and create a learner.

Implementations§

§

impl<B, T, V, M, O, S> LearnerBuilder<B, T, V, M, O, S>
where B: AutodiffBackend, T: Send + 'static, V: Send + 'static, M: AutodiffModule<B> + Display + 'static, O: Optimizer<M, B>, S: LrScheduler,

pub fn new(directory: impl AsRef<Path>) -> LearnerBuilder<B, T, V, M, O, S>

Creates a new learner builder.

§Arguments
  • directory - The directory to save the checkpoints.

pub fn metric_loggers<MT, MV>( self, logger_train: MT, logger_valid: MV, ) -> LearnerBuilder<B, T, V, M, O, S>
where MT: MetricLogger + 'static, MV: MetricLogger + 'static,

Replace the default metric loggers with the provided ones.

§Arguments
  • logger_train - The training logger.
  • logger_valid - The validation logger.

pub fn with_checkpointing_strategy<CS>( self, strategy: CS, ) -> LearnerBuilder<B, T, V, M, O, S>
where CS: CheckpointingStrategy + 'static,

Update the checkpointing_strategy.

pub fn renderer<MR>(self, renderer: MR) -> LearnerBuilder<B, T, V, M, O, S>
where MR: MetricsRenderer + 'static,

Replace the default CLI renderer with a custom one.

§Arguments
  • renderer - The custom renderer.

pub fn metric_train<Me>(self, metric: Me) -> LearnerBuilder<B, T, V, M, O, S>
where Me: Metric + 'static, T: Adaptor<<Me as Metric>::Input>,

Register a training metric.

pub fn metric_valid<Me>(self, metric: Me) -> LearnerBuilder<B, T, V, M, O, S>
where Me: Metric + 'static, V: Adaptor<<Me as Metric>::Input>,

Register a validation metric.

pub fn grads_accumulation( self, accumulation: usize, ) -> LearnerBuilder<B, T, V, M, O, S>

Enable gradients accumulation.

§Notes

When you enable gradients accumulation, the gradients object used by the optimizer will be the sum of all gradients generated by each backward pass. It might be a good idea to reduce the learning to compensate.

The effect is similar to increasing the batch size and the learning rate by the accumulation amount.

pub fn metric_train_numeric<Me>( self, metric: Me, ) -> LearnerBuilder<B, T, V, M, O, S>
where Me: Metric + Numeric + 'static, T: Adaptor<<Me as Metric>::Input>,

Register a numeric training metric.

pub fn metric_valid_numeric<Me>( self, metric: Me, ) -> LearnerBuilder<B, T, V, M, O, S>
where Me: Metric + Numeric + 'static, V: Adaptor<<Me as Metric>::Input>,

Register a numeric validation metric.

pub fn num_epochs(self, num_epochs: usize) -> LearnerBuilder<B, T, V, M, O, S>

The number of epochs the training should last.

pub fn devices( self, devices: Vec<<B as Backend>::Device>, ) -> LearnerBuilder<B, T, V, M, O, S>

Run the training loop on multiple devices.

pub fn checkpoint(self, checkpoint: usize) -> LearnerBuilder<B, T, V, M, O, S>

The epoch from which the training must resume.

pub fn interrupter(&self) -> TrainingInterrupter

Provides a handle that can be used to interrupt training.

pub fn early_stopping<Strategy>( self, strategy: Strategy, ) -> LearnerBuilder<B, T, V, M, O, S>
where Strategy: EarlyStoppingStrategy + 'static,

Register an early stopping strategy to stop the training when the conditions are meet.

pub fn with_application_logger( self, logger: Option<Box<dyn ApplicationLoggerInstaller>>, ) -> LearnerBuilder<B, T, V, M, O, S>

By default, Rust logs are captured and written into experiment.log. If disabled, standard Rust log handling will apply.

pub fn with_file_checkpointer<FR>( self, recorder: FR, ) -> LearnerBuilder<B, T, V, M, O, S>
where FR: FileRecorder<B> + 'static + FileRecorder<<B as AutodiffBackend>::InnerBackend>, <O as Optimizer<M, B>>::Record: 'static, <M as Module<B>>::Record: 'static, <S as LrScheduler>::Record<B>: 'static,

Register a checkpointer that will save the optimizer, the model and the scheduler to different files.

pub fn summary(self) -> LearnerBuilder<B, T, V, M, O, S>

Enable the training summary report.

The summary will be displayed at the end of .fit().

pub fn build( self, model: M, optim: O, lr_scheduler: S, ) -> Learner<LearnerComponentsMarker<B, S, M, O, AsyncCheckpointer<<M as Module<B>>::Record, B>, AsyncCheckpointer<<O as Optimizer<M, B>>::Record, B>, AsyncCheckpointer<<S as LrScheduler>::Record<B>, B>, FullEventProcessor<T, V>, Box<dyn CheckpointingStrategy>>>
where <M as Module<B>>::Record: 'static, <O as Optimizer<M, B>>::Record: 'static, <S as LrScheduler>::Record<B>: 'static,

Create the learner from a model and an optimizer. The learning rate scheduler can also be a simple learning rate.

Auto Trait Implementations§

§

impl<B, T, V, M, O, S> Freeze for LearnerBuilder<B, T, V, M, O, S>

§

impl<B, T, V, M, O, S> !RefUnwindSafe for LearnerBuilder<B, T, V, M, O, S>

§

impl<B, T, V, M, O, S> !Send for LearnerBuilder<B, T, V, M, O, S>

§

impl<B, T, V, M, O, S> !Sync for LearnerBuilder<B, T, V, M, O, S>

§

impl<B, T, V, M, O, S> Unpin for LearnerBuilder<B, T, V, M, O, S>
where <B as Backend>::Device: Unpin,

§

impl<B, T, V, M, O, S> !UnwindSafe for LearnerBuilder<B, T, V, M, O, S>

Blanket Implementations§

§

impl<T> Adaptor<()> for T

§

fn adapt(&self)

Adapt the type to be passed to a metric.
source§

impl<T> Any for T
where T: 'static + ?Sized,

source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
source§

impl<T> Borrow<T> for T
where T: ?Sized,

source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
§

impl<T> Downcast<T> for T

§

fn downcast(&self) -> &T

source§

impl<T> From<T> for T

source§

fn from(t: T) -> T

Returns the argument unchanged.

§

impl<T> Instrument for T

§

fn instrument(self, span: Span) -> Instrumented<Self>

Instruments this type with the provided [Span], returning an Instrumented wrapper. Read more
§

fn in_current_span(self) -> Instrumented<Self>

Instruments this type with the current Span, returning an Instrumented wrapper. Read more
source§

impl<T, U> Into<U> for T
where U: From<T>,

source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

source§

impl<T> IntoEither for T

source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
§

impl<T> Pointable for T

§

const ALIGN: usize = _

The alignment of pointer.
§

type Init = T

The type for initializers.
§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
source§

impl<T> Same for T

source§

type Output = T

Should always be Self
source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

source§

type Error = Infallible

The type returned in the event of a conversion error.
source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
§

impl<T> Upcast<T> for T

§

fn upcast(&self) -> Option<&T>

§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

§

fn vzip(self) -> V

§

impl<T> WithSubscriber for T

§

fn with_subscriber<S>(self, subscriber: S) -> WithDispatch<Self>
where S: Into<Dispatch>,

Attaches the provided Subscriber to this type, returning a [WithDispatch] wrapper. Read more
§

fn with_current_subscriber(self) -> WithDispatch<Self>

Attaches the current default Subscriber to this type, returning a [WithDispatch] wrapper. Read more
§

impl<T> ErasedDestructor for T
where T: 'static,