SupervisedTraining

Struct SupervisedTraining 

pub struct SupervisedTraining<LC>{ /* private fields */ }
Expand description

Structure to configure and launch supervised learning trainings.

Implementations§

§

impl<B, LR, M, O> SupervisedTraining<LearningComponentsMarker<B, LR, M, O>>
where B: AutodiffBackend, LR: LrScheduler + 'static, M: TrainStep + AutodiffModule<B> + Display + 'static, <M as AutodiffModule<B>>::InnerModule: InferenceStep, O: Optimizer<M, B> + 'static,

pub fn new( directory: impl AsRef<Path>, dataloader_train: Arc<dyn DataLoader<B, <M as TrainStep>::Input>>, dataloader_valid: Arc<dyn DataLoader<<B as AutodiffBackend>::InnerBackend, <<M as AutodiffModule<B>>::InnerModule as InferenceStep>::Input>>, ) -> SupervisedTraining<LearningComponentsMarker<B, LR, M, O>>

Creates a new runner for a supervised training.

§Arguments
  • directory - The directory to save the checkpoints.
  • dataloader_train - The dataloader for the training split.
  • dataloader_valid - The dataloader for the validation split.
§

impl<LC> SupervisedTraining<LC>

pub fn with_training_strategy( self, training_strategy: TrainingStrategy<LC>, ) -> SupervisedTraining<LC>

Replace the default training strategy (SingleDeviceTrainingStrategy) with the provided ones.

§Arguments
  • training_strategy - The training strategy.

pub fn with_metric_logger<ML>(self, logger: ML) -> SupervisedTraining<LC>
where ML: MetricLogger + 'static,

Replace the default metric loggers with the provided ones.

§Arguments
  • logger - The training logger.

pub fn with_checkpointing_strategy<CS>( self, strategy: CS, ) -> SupervisedTraining<LC>
where CS: CheckpointingStrategy + 'static,

Update the checkpointing_strategy.

pub fn renderer<MR>(self, renderer: MR) -> SupervisedTraining<LC>
where MR: MetricsRenderer + 'static,

Replace the default CLI renderer with a custom one.

§Arguments
  • renderer - The custom renderer.

pub fn metrics<Me>(self, metrics: Me) -> SupervisedTraining<LC>
where Me: MetricRegistration<LC>,

Register all metrics as numeric for the training and validation set.

pub fn metrics_text<Me>(self, metrics: Me) -> SupervisedTraining<LC>
where Me: TextMetricRegistration<LC>,

Register all metrics as numeric for the training and validation set.

pub fn metric_train<Me>(self, metric: Me) -> SupervisedTraining<LC>
where Me: Metric + 'static, <<<LC as LearningComponentsTypes>::TrainingModel as TrainStep>::Output as ItemLazy>::ItemSync: Adaptor<<Me as Metric>::Input>,

Register a training metric.

pub fn metric_valid<Me>(self, metric: Me) -> SupervisedTraining<LC>

Register a validation metric.

pub fn grads_accumulation(self, accumulation: usize) -> SupervisedTraining<LC>

Enable gradients accumulation.

§Notes

When you enable gradients accumulation, the gradients object used by the optimizer will be the sum of all gradients generated by each backward pass. It might be a good idea to reduce the learning to compensate.

The effect is similar to increasing the batch size and the learning rate by the accumulation amount.

pub fn metric_train_numeric<Me>(self, metric: Me) -> SupervisedTraining<LC>
where Me: Metric + Numeric + 'static, <<<LC as LearningComponentsTypes>::TrainingModel as TrainStep>::Output as ItemLazy>::ItemSync: Adaptor<<Me as Metric>::Input>,

Register a numeric training metric.

pub fn metric_valid_numeric<Me>(self, metric: Me) -> SupervisedTraining<LC>

Register a numeric validation metric.

pub fn num_epochs(self, num_epochs: usize) -> SupervisedTraining<LC>

The number of epochs the training should last.

pub fn checkpoint(self, checkpoint: usize) -> SupervisedTraining<LC>

The epoch from which the training must resume.

pub fn interrupter(&self) -> Interrupter

Provides a handle that can be used to interrupt training.

pub fn with_interrupter( self, interrupter: Interrupter, ) -> SupervisedTraining<LC>

Override the handle for stopping training with an externally provided handle

pub fn early_stopping<Strategy>( self, strategy: Strategy, ) -> SupervisedTraining<LC>
where Strategy: EarlyStoppingStrategy + Clone + Send + Sync + 'static,

Register an early stopping strategy to stop the training when the conditions are meet.

pub fn with_application_logger( self, logger: Option<Box<dyn ApplicationLoggerInstaller>>, ) -> SupervisedTraining<LC>

By default, Rust logs are captured and written into experiment.log. If disabled, standard Rust log handling will apply.

pub fn with_file_checkpointer<FR>(self, recorder: FR) -> SupervisedTraining<LC>

Register a checkpointer that will save the optimizer, the model and the scheduler to different files.

pub fn summary(self) -> SupervisedTraining<LC>

Enable the training summary report.

The summary will be displayed after .fit(), when the renderer is dropped.

§

impl<LC> SupervisedTraining<LC>
where LC: LearningComponentsTypes + Send + 'static,

pub fn launch( self, learner: Learner<LC>, ) -> LearningResult<<LC as LearningComponentsTypes>::InferenceModel>

Launch this training with the given Learner.

Auto Trait Implementations§

Blanket Implementations§

§

impl<T> Adaptor<()> for T

§

fn adapt(&self)

Adapt the type to be passed to a metric.
Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
§

impl<T> Downcast<T> for T

§

fn downcast(&self) -> &T

Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

§

impl<T> Instrument for T

§

fn instrument(self, span: Span) -> Instrumented<Self>

Instruments this type with the provided [Span], returning an Instrumented wrapper. Read more
§

fn in_current_span(self) -> Instrumented<Self>

Instruments this type with the current Span, returning an Instrumented wrapper. Read more
Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

§

impl<T> IntoComptime for T

§

fn comptime(self) -> Self

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
§

impl<T> Pointable for T

§

const ALIGN: usize

The alignment of pointer.
§

type Init = T

The type for initializers.
§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
§

impl<T> Upcast<T> for T

§

fn upcast(&self) -> Option<&T>

§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

§

fn vzip(self) -> V

§

impl<T> WithSubscriber for T

§

fn with_subscriber<S>(self, subscriber: S) -> WithDispatch<Self>
where S: Into<Dispatch>,

Attaches the provided Subscriber to this type, returning a [WithDispatch] wrapper. Read more
§

fn with_current_subscriber(self) -> WithDispatch<Self>

Attaches the current default Subscriber to this type, returning a [WithDispatch] wrapper. Read more