AutoTrain - Part 2: Recipes

validate.pyURL copied

Once a model is trained the same architecture can be loaded along with the trained weights using the config file.

import ...

task = ClassificationModel(config='configs/mnist.ini', inference_only=True)
learn, config, model = task.learn, task.config, task.model
weights_path = config.training.scheme.output_path
load_torch_model_weights_to(model, weights_path)

def infer(img_path):
    logger.info(f'received {img_path} for classification')
    pred, _, cnf = learn.predict(img_path)
    prediction = {
        'prediction': pred,
        'confidence': f'{max(cnf)*100:.2f}%'
    }
    if 'postprocess' in dir(config.testing):
        return config.testing.postprocess(prediction)
    else:
        return prediction


> Load the same task as loaded in training, this time inference will be True so no dataloaders will be loaded


> Since the weights were already saved during training, load them up here









> Post process if it exists else return raw prediction