The following might be helpful:
##########################
### VISUALIZATION
##########################
n_images = 15
image_width = 28
fig, axes = plt.subplots(nrows=2, ncols=n_images,
sharex=True, sharey=True, figsize=(20, 2.5))
orig_images = x_batch_train[:n_images].numpy()
decoded_images = reconstructed[:n_images].numpy()
for i in range(n_images):
for ax, img in zip(axes, [orig_images, decoded_images]):
curr_img = img[i]
ax[i].imshow(curr_img.reshape((image_width, image_width)), cmap='binary')
Courtesy: Sebastian Raschka
python/ops/gen_array_ops.py
'.