Python Interface

Overview

The Python interface offers a higher degree of flexibility in model training. This section details the process of training a model to localise Vim in the left thalamus. It emphasises the ability to use custom training labels which can be sourced from manual annotations or high-quality data segmentations.

Train a Model Using Custom Labels

  1. Setup & Import Required Modules

import numpy as np
from localise.load import load_data, load_features, ShuffledDataloader
from localise.train import train
from localise.predict import apply_model
  1. Gather Subject Lists

# Obtain a list of training subjects
with open('train_subjs.txt', 'r') as f:
    train_list = [line.strip() for line in f]

# Obtain a list of test subjects
with open('test_subjs.txt', 'r') as f:
    test_list = [line.strip() for line in f]
  1. Configuration of Data Paths

# Paths and filenames for seed mask, labels, tract-density, and atlas
mask_name = 'roi/left/tha_small.nii.gz'
label_name = 'high-quality-labels/left/labels.nii.gz'
target_path = 'tracts/left'
target_list = '/path/to/target_list.txt'
atlas = 'roi/left/atlas.nii.gz'
output_fname = 'tracts/left/data.npy'
  1. Data Loading

# Load the training data
train_data = load_data(subject=train_list, mask_name=mask_name,
                       target_path=target_path, target_list=target_list,
                       atlas=atlas, label_name=label_name,
                       output_fname=output_fname)

# Use ShuffledDataloader to shuffle the order of training subjects for each epoch
train_dataloader = ShuffledDataloader(train_data)

# Load the test data
test_data = load_data(subject=test_list, mask_name=mask_name,
                      target_path=target_path, target_list=target_list,
                      atlas=atlas, label_name=label_name,
                      output_fname=output_fname)

test_dataloader = ShuffledDataloader(test_data)
  1. Training the Model

# Define the path to save the trained model
model_save_path = 'your_trained_model.pth'

# Train the model and store in variable 'm'
m = train(train_dataloader, test_dataloader, model_save_path=model_save_path)
  1. Making Predictions and Saving Results

# Make predictions using the trained model
predictions = apply_model(test_data, m)

# Save the predictions in Nifti format
for prediction, subject in zip(predictions, test_list):
    save_nifti_4D(prediction, os.path.join(subject, mask_name), os.path.join(subject, 'predictions'))