AutoTrain - Part 2: Recipes
model.py URL copied
from auto_train.common import Task
from auto_train.classification.custom_functions import *
from auto_train.classification.timmy import create_timm_model
class ClassificationModel(Task):
def __init__(self, config, inference_only=True):
super().__init__(config)
config = self.config
self.model = create_timm_model(
config.architecture.backbone.model,
n_out=config.project.meta.num_classes)
if inference_only:
self.dls = self.get_dataloaders()
self.learn = Learner(
self.dls, self.model,
splitter=default_split,
metrics=[accuracy])
def get_dataloaders(self):
training_dir = str(P(self.config.training.dir).expanduser())
if not os.path.exists(training_dir):
print(f'downloading data to {training_dir}...')
self.download_data()
dls = self.config.training.data.load_datablocks(self.config)
return dls
> Base Task can be found in source code at Base Task can be found in source code at AutoTrain/auto_train/common/base.py
This is used for repetitive tasks like parsing config and downloading dataset if it doesn't exist
> Creates a model using timm library
> Optionally loads the dataloaders if model is going to be trained
self.model is responsible for loading the right architecture. timm is an excellent library that can create a pretrained model just by using a string as input. We load the data using fastai's datablocks api.