marsvision.pipeline.Model module¶
-
class
marsvision.pipeline.Model.
Model
(model, model_type: str = 'pytorch', **kwargs)¶ Bases:
object
-
PYTORCH
= 'pytorch'¶
-
SKLEARN
= 'sklearn'¶
-
__init__
(model, model_type: str = 'pytorch', **kwargs)¶ Model class that serves as an abstract wrapper for either an sklearn or pytorch model.
Parameters:
model: Either an sklearn machine learning model, or a pytorch neural network. Can be a path to a file or a model object.
model_type (str): String identifier for the type of model. Determines how the model will be trained in this class.
**kwargs: training_images (numpy.ndarray): Batch of images to train on training_labels: Class labels for the training images dataset_root_directory: The root directory of the Deep Mars dataset to train on.
-
cross_validate
(n_folds: int = 10, scoring: list = ['accuracy', 'precision', 'recall', 'roc_auc'], **kwargs)¶ Run cross validation on the model with its training data and labels. Return the results.
Parameters:
scoring (list): List of sklearn cross validation scoring identifiers. Default: [“accuracy”, “precision”, “recall”, “roc_auc”]. Assumes binary classification. Valid IDs are SKLearn classification identifiers. https://scikit-learn.org/stable/modules/model_evaluation.html n_folds (int): Number of cross validation folds. Default 10.
-
cross_validate_binary_metrics
(n_folds: int = 5, ax=None)¶ Run cross validation on a binary classification problem on the basic pipeline, and return the results as a dictionary.
This is mainly a helper function for the cross_validate_plot function, which cross validates and plotes ROC curves for each fold.
This method returns the domain over which the plot is constructed as well as the tpr and fpr values, alongside standard binary classification measures for each fold: precision, recall, accuracy, auc.
This method assumes that there are only two labels in the training label member of this class.
Parameters:
n_dolfds (int): Number of folds. ax: Matplotlib axis on which to show the plot.
-
cross_validate_plot
(title: str = 'Binary Cross Validation Results', n_folds: int = 2)¶ Run cross validation on a binary classification problem, and make a matplotlib plot of the results.
—
Parameters
Title(str): Title of the figure. n_folds(int): Number of folds.
-
load_model
(input_path: str, model_type: str)¶ Loads a model into this object from a pickle file, into the self.model member.
Parameters: out_path(str): The input location of the file to be read. model_type(str): The model type. Either “sklearn” or “pytorch”.
-
predict
(image_list: numpy.ndarray)¶ Run inference using self.model on a list of images using the currently instantiated model.
This model can either be an sklearn model or a pytorch model.
Returns a list of inferences.
—
- Parameters:
image_list (List[np.ndarray]): Batch of images to run inference on with this model.
-
save_model
(out_path: str = 'model.p')¶ Saves a pickle file containing this object’s model.
Parameters: out_path (str): The output location for the file.
-
set_extracted_features
()¶ Run feature extraction training images defined in self.training_images.
For more details on feature extraction, see the FeatureExtractor module.
-
set_training_data
(training_images: numpy.ndarray, training_labels: List[str])¶ Setter for training image data.
—
Parameters:
training_images (self): List of images to train the model on. Numpy is expected to be as follows: (image count, height, image width, channels) training_labels (self): Labels associated with training images. Should be a list parallel to the list of training images.
-
train_model
()¶ Trains a classifier using this object’s configuration, as specified in the constructor.
Either an SKLearn or Pytorch model will be trained on this object’s data. The SKLearn model will be trained on extracted image features as specified in the FeatureExtractor module. The Pytorch model will be trained by running a CNN on the image data.
Parameters:
root_dir (str): Root directory of the Deep Mars dataset.
-
train_model_pytorchcnn
()¶ This is an internal helper function which handles the training of a pytorch CNN model.
The various hyperparameters for CNN training, such as learning rate and number of epochs, can be found in the package’s config file.
Parameters:
root_dir (str): Path to the Deep Mars dataset.
-
write_cv_results
(output_path: str = 'cv_test_results.txt')¶ Save cross validation results to a file. Shows results for individual folds, and the mean result for all folds, for all user specified classification metrics.
Parameters: output_path (str): Path and file to write the results to.
-