AutoTrain - Part 2: Recipes
Once the model and learner are created, loading it for training is easy
import ...
def train_model(config: P):
task = ClassificationModel(config)
learn = task.learn
model = task.model
config = task.config
training_scheme = config.training.scheme
try:
load_torch_model_weights_to(model, config.training.scheme.resume_training_from)
except:
logger.info('Training from scratch!')
try:
lr = config.training.scheme.learning_rate
except:
lr = find_best_learning_rate(task)
logger.info(f"Using learning rate: {lr}")
learn.fine_tune(
training_scheme.epochs, lr,
freeze_epochs=training_scheme.freeze_epochs
)
makedir(parent(training_scheme.output_path))
save_torch_model_weights_from(model, training_scheme.output_path)
> Load the task and its components
> Resume training if weights are found
> Find best learning rate if lr is not given in config
> Train the model with pretrained weights
> Save the model to designated directory