AutoTrain - Part 1: Ingredients

Ingredients for ExperimentURL copied

With such a functionality we are able to store information about everything related to an experiment in exactly one .ini file and one .py file.

We can store information under pertinent headings with human readable names for each variable. Drawing a cooking analogy, we are going to call this file as ingredients.

Here is how the ingredients for training MNIST would look like

[project]
version = 0.0.1
name = mnist-classification
root = /home/me/projects/${project.name}
[project.data]
source = https://files.fast.ai/data/examples/mnist_sample.tgz
root = ${root}/data/
[project.meta]
task = classification
n_classes = 10

[model]
version= 0.0.0
[model.architecture]
backbone = resnet18
output_classes = ${project.n_classes}

[training]
[training.load_dataset]
@load_dataset = load_mnist_from_disk
source = ${projeect.data_root}
[training.preprocess]
@preprocess = divide_by_255
[training.scheme]
epochs = 10
batch_size = 128
early_stop = True
[training.model]
save_dir = ${project.root}/models/
save_name = ${save_dir}/${model.version}/model.pth
save_all_checkpoints = False

[testing]
debug = False
[testing.preprocess]
@preprocess = resize_and_divide_by_255
mnist.ini
Not all variables must be used in a project. For example `project.version` will not be part of any training code, but nonetheless will be useful for readability

Here's the python file that contains the custom functions to support above ingredients.

from torch_snippets.registry import registry
registry.create('load_dataset')
registry.create('preprocess')

@registry.load_datset.register('load_mnist_from_disk')
def wrapper(source):
    def loader():
        return X, y
    return loader

@registry.preprocess.register('divide_by_255')
def wrapper():
    def preprocess(input):
        return input/255.
    return preprocess

@registry.preprocess.register('resize_and_divide_by_255')
def wrapper():
    from torch_snippets import resize
    def preprocess(input):
        input = resize(input, (28,28))
        return input/255.
    return preprocess