AutoTrain - Part 2: Recipes

model.pyURL copied

from auto_train.common import Task
from auto_train.classification.custom_functions import *
from auto_train.classification.timmy import create_timm_model

class ClassificationModel(Task):
    def __init__(self, config, inference_only=True):
        super().__init__(config)
        config = self.config

        self.model = create_timm_model(
            config.architecture.backbone.model,
            n_out=config.project.meta.num_classes)
        if inference_only:
            self.dls = self.get_dataloaders()
            self.learn = Learner(
                self.dls, self.model,
                splitter=default_split,
                metrics=[accuracy])

    def get_dataloaders(self):
        training_dir = str(P(self.config.training.dir).expanduser())
        if not os.path.exists(training_dir):
        print(f'downloading data to {training_dir}...')
        self.download_data()

        dls = self.config.training.data.load_datablocks(self.config)
        return dls
> Base Task can be found in source code at Base Task can be found in source code at AutoTrain/auto_train/common/base.py
This is used for repetitive tasks like parsing config and downloading dataset if it doesn't exist







> Creates a model using timm library


> Optionally loads the dataloaders if model is going to be trained














self.model is responsible for loading the right architecture. timm is an excellent library that can create a pretrained model just by using a string as input. We load the data using fastai's datablocks api.