Example 2: Brain Parcellation using PyTorch
Introduction
In this example, you are using a pre-trained PyTorch deep learning model (HighRes3DNet) to perform a full brain parcellation. HighRes3DNet is a 3D residual network presented by Li et al. in On the Compactness, Efficiency, and Representation of 3D Convolutional Networks: Brain Parcellation as a Pretext Task.
Steps to do
Add a LocalImage
module to your workspace and select the file MRI_Head.dcm. For PyTorch it is necessary to resample the data to a defined size. Add a Resample3D
module to the LocalImage
and open the panel. Change Keep Constant to Voxel Size and define Image Size as 176, 217, 160.
The coordinates in PyTorch are also a little different than in MeVisLab, therefore you have to rotate the image. Add an OrthoSwapFlip
module and connect it to the Resample3D
module. Change View to Other and set Orientation to YXZ. Also check Flip horizontal, Flip vertical and Flip depth. Apply your changes.
You can use the Output Inspector to see the changes on the images after applying the resample and a swap or flip.
Add an OrthoView2D
module to your network and save the *.mlab file.
Integrate PyTorch and scripting
For integrating PyTorch and Python scripting, we need a PythonImage
module. Add it to your workspace. Right-click
on the PythonImage
module and select [
Grouping
→
Add to new Group...
]. Right-click
your new group and select [
Grouping
→
Add to new Group...
]. Name your new local macro DemoAI, select a directory for your project and leave all settings as default.
Our new module does not provide an input or output.
.Adding an interface to the local macro
Right-click
the local macro and select [
Related Files
→
DemoAI.script
]. MATE opens showing the *.script file of our module. Add an input Field of type Image, an output Field using the internalName of the output of our PythonImage
and a Trigger to start the segmentation.
You should also already add a Python file in the Commands section.
DemoAI.script
Interface {
Inputs {
Field inputImage { type = Image }
}
Outputs {
Field outImage { internalName = PythonImage.output0 }
}
Parameters {
Field start { type = Trigger }
}
}
Commands {
source = $(LOCAL)/DemoAI.py
}
In MATE, right-click the Project Workspace and add a new file DemoAI.py to your project. The workspace now contains an empty Python file.
.Change to MeVisLab IDE, right-click the local macro and select [ Reload Definition ]. Your new input and output interface are now available and you can connect images to your module.
Extend your network
We want to show the segmentation results as an overlay on the original image. Add a SoView2DOverlayMPR
module and connect it to your DemoAI
macro. Connect the output of the SoView2DOverlayMPR
to a SoGroup
. We also need a lookup table for the colors to be used for the overlay. We already prepared a *.xml file you can simply use. Download the lut.xml file and save it in your current working directory of the project.
Add a LoadBase
module and connect it to a SoMLLUT
module. The SoMLLUT
needs to be connected to the SoGroup
so that it is applied to our segmentation results.
Inspect the output of the LoadBase
module in the Output Inspector to see if the lookup table has been loaded correctly.
Write Python script
You can now execute the pre-trained PyTorch network on your image. Right-click the local macro and select [ Related Files → DemoAI.script ]. The Python function is supposed to be called whenever the Trigger is touched.
Add the following code to your Commands section:
DemoAI.script
Commands {
source = $(LOCAL)/DemoAI.py
FieldListener start { command = onStart }
}
The FieldListener always calls the Python function onStart when the Trigger start is touched. We now need to implement the Python function. Right-click the command onStart and select [ Create Python Function 'onStart' ].
The Python file opens automatically and the function is created.
DemoAI.py
import torch
def onStart():
# Step 1: Get input image
image = ctx.field("inputImage").image()
imageArray = image.getTile((0, 0, 0, 0, 0, 0), image.imageExtent())
inputImage = imageArray[0,0,0,:,:,:].astype("float")
# Step 2: Normalize input image
values = inputImage[inputImage > inputImage.mean()]
inputImage = (inputImage - values.mean()) / values.std()
# Step 3: Convert into torch tensor of size: [Batch, Channel, z, y, x]
inputTensor = torch.Tensor(inputImage[None, None, :, :, :])
# Step 4: Load and prepare AI model
device = torch.device("cpu")
model = torch.hub.load("fepegar/highresnet", "highres3dnet", pretrained=True, trust_repo=True)
model.to(device).eval()
output = model(inputTensor.to(device))
brainParcellationMap = output.argmax(dim=1, keepdim=True).cpu()[0]
print('...done.')
# Step 6: Set output image to module
interface = ctx.module("PythonImage").call("getInterface")
interface.setImage(brainParcellationMap.numpy(), voxelToWorldMatrix=image.voxelToWorldMatrix())
The function does the following:
- Get the input image of the module
PythonImage
- Normalize the input image
- Convert the image into a torch tensor of size: [Batch, Channel, z, y, x]
- Load and prepare AI model
- Set output image to module output
Execute the segmentation
Change alpha value of your SoView2DOverlayMPR
to have a better visualization of the results.
Change to MeVisLab IDE and select your module DemoAI
. In Module Inspector click Trigger for start and wait a little until you can see the results.
Without adding a SubImage
the segmentation results should look like this:
Summary
- Pre-trained PyTorch networks can be used directly in MeVisLab via
PythonImage
module