AutoTrain - Part 2: Recipes

train.pyURL copied

Once the model and learner are created, loading it for training is easy

import ...

def train_model(config: P):
    task = ClassificationModel(config)
    learn = task.learn
    model = task.model
    config = task.config
    training_scheme = config.training.scheme

    try:
        load_torch_model_weights_to(model, config.training.scheme.resume_training_from)
    except:
        logger.info('Training from scratch!')

    try:
        lr = config.training.scheme.learning_rate
    except:
        lr = find_best_learning_rate(task)
        logger.info(f"Using learning rate: {lr}")
    learn.fine_tune(
        training_scheme.epochs, lr, 
        freeze_epochs=training_scheme.freeze_epochs
    )

    makedir(parent(training_scheme.output_path))
    save_torch_model_weights_from(model, training_scheme.output_path)



> Load the task and its components






> Resume training if weights are found






> Find best learning rate if lr is not given in config

> Train the model with pretrained weights





> Save the model to designated directory