SupervisedLearningStrategy

Trait SupervisedLearningStrategy 

pub trait SupervisedLearningStrategy<LC>{
    // Required method
    fn fit(
        &self,
        training_components: TrainingComponents<LC>,
        learner: Learner<LC>,
        dataloader_train: Arc<dyn DataLoader<<LC as LearningComponentsTypes>::Backend, <<LC as LearningComponentsTypes>::TrainingModel as TrainStep>::Input>>,
        dataloader_valid: Arc<dyn DataLoader<<<LC as LearningComponentsTypes>::Backend as AutodiffBackend>::InnerBackend, <<LC as LearningComponentsTypes>::InferenceModel as InferenceStep>::Input>>,
        starting_epoch: usize,
    ) -> (<LC as LearningComponentsTypes>::TrainingModel, AsyncProcessorTraining<FullEventProcessorTraining<<<LC as LearningComponentsTypes>::TrainingModel as TrainStep>::Output, <<LC as LearningComponentsTypes>::InferenceModel as InferenceStep>::Output>>);

    // Provided method
    fn train(
        &self,
        learner: Learner<LC>,
        dataloader_train: Arc<dyn DataLoader<<LC as LearningComponentsTypes>::Backend, <<LC as LearningComponentsTypes>::TrainingModel as TrainStep>::Input>>,
        dataloader_valid: Arc<dyn DataLoader<<<LC as LearningComponentsTypes>::Backend as AutodiffBackend>::InnerBackend, <<LC as LearningComponentsTypes>::InferenceModel as InferenceStep>::Input>>,
        training_components: TrainingComponents<LC>,
    ) -> LearningResult<<LC as LearningComponentsTypes>::InferenceModel> { ... }
}
Expand description

Provides the fit function for any learning strategy

Required Methods§

fn fit( &self, training_components: TrainingComponents<LC>, learner: Learner<LC>, dataloader_train: Arc<dyn DataLoader<<LC as LearningComponentsTypes>::Backend, <<LC as LearningComponentsTypes>::TrainingModel as TrainStep>::Input>>, dataloader_valid: Arc<dyn DataLoader<<<LC as LearningComponentsTypes>::Backend as AutodiffBackend>::InnerBackend, <<LC as LearningComponentsTypes>::InferenceModel as InferenceStep>::Input>>, starting_epoch: usize, ) -> (<LC as LearningComponentsTypes>::TrainingModel, AsyncProcessorTraining<FullEventProcessorTraining<<<LC as LearningComponentsTypes>::TrainingModel as TrainStep>::Output, <<LC as LearningComponentsTypes>::InferenceModel as InferenceStep>::Output>>)

Training loop for this strategy

Provided Methods§

fn train( &self, learner: Learner<LC>, dataloader_train: Arc<dyn DataLoader<<LC as LearningComponentsTypes>::Backend, <<LC as LearningComponentsTypes>::TrainingModel as TrainStep>::Input>>, dataloader_valid: Arc<dyn DataLoader<<<LC as LearningComponentsTypes>::Backend as AutodiffBackend>::InnerBackend, <<LC as LearningComponentsTypes>::InferenceModel as InferenceStep>::Input>>, training_components: TrainingComponents<LC>, ) -> LearningResult<<LC as LearningComponentsTypes>::InferenceModel>

Train the learner’s model with this strategy.

Implementors§