PyTorch image data visualization
how to explore and visualise 3 inputs from the training data
You can visualize your input images using libraries like Matplotlib or OpenCV.
Here’s an example of how you can visualize 3 input images from your training data using PyTorch and Matplotlib:
import matplotlib.pyplot as plt
import torchvision
# function to unnormalize and convert tensor to numpy for visualization
def imshow(inp):
inp = inp.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = np.clip(inp, 0, 1)
plt.imshow(inp)
plt.show()
# Get a batch of training data
inputs, classes = next(iter(train_loader))
# Make a grid from batch
out = torchvision.utils.make_grid(inputs[:3])
imshow(out)
In this script:
-
imshow
is a helper function that unnormalizes an image and converts it from a Tensor to a numpy array so it can be displayed by Matplotlib. -
inputs, classes = next(iter(train_loader))
gets a batch of training data.inputs
is a batch of input images, andclasses
are the labels for those images. -
out = torchvision.utils.make_grid(inputs[:3])
takes the first three images from the batch and combines them into a single image that can be displayed. -
imshow(out)
displays the grid of images.
This script assumes that your images have been normalized using the mean [0.485, 0.456, 0.406]
and standard deviation [0.229, 0.224, 0.225]
, and that they are loaded using a DataLoader. Adjust it as necessary to match your setup.
Note that this code will work if your images are color (RGB) images. If your images are grayscale or if they have a different color format, you might need to modify the code accordingly.