How to plot MNIST using Matplotlib?

Root
2 min readMar 19, 2021

--

As I mentioned in the last post, We cannot install PyTorch MNIST smoothly.

If you stuck in installing the MNIST, check https://kazma-s-1306.medium.com/in-2021-as-it-turned-out-pytorch-mnist-cannot-be-installed-like-before-8b1083f80086.

1, Plot multiple images

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
# obtain a batch of training images
dataiter = iter(train_loader)
images, labels = dataiter.next()
images = images.numpy()
fig = plt.figure(figsize=(20, 5))
for i in np.arange(20):
rows = 2
cols = 20/2
ax = fig.add_subplot(rows, cols, i+1, xticks=[], yticks=[])
ax.imshow(np.squeeze(images[i]), cmap='gray')
ax.set_title(str(labels[i].item()))

2, Plot an individual image

#obtain a batch of training images
idx = 0
# plot the sample
fig = plt.figure
plt.imshow(image[idx].reshape(28, 28), cmap='gray')
#remove ticks
plt.tick_params(bottom=False,
left=False,
right=False,
top=False,
labelbottom=False,
labelleft=False,
labelright=False,
labeltop=False)
#plot a label
plt.title("digit: {}".format(labels[0]))
#remove ticks
#plt.axis("off") # this one works as well
plt.show()

If you’d like to get the image black and white in the other way, cmap = ‘gray_r’.

3, Plot an individual image with annotations

img = images[idx].squeeze()fig = plt.figure(figsize = (12,12)) 
ax = fig.add_subplot(111)
ax.imshow(img, cmap='gray')
ax.axis('off')
width, height = img.shape
thresh = img.max()/2
for x in range(width):
for y in range(height):
val = round(img[x][y],2) if img[x][y] !=0 else 0
ax.annotate(str(val), xy=(y,x),
horizontalalignment='center',
verticalalignment='center',
color='white' if img[x][y]<thresh else 'black')

--

--

Root
Root

Written by Root

0 Followers

Writer, programmer. Here to show what’s inside my head and to be understood.

No responses yet