【Python】Visualize torch.tensor or numpy.ndarray
In this short guide, you’ll see how to Visualize torch.tensor or numpy.ndarray
how to Visualize torch.tensor or numpy.ndarray
Visualize torch.tensor or numpy.ndarray
# from via import via; via(x)
import numpy as np
import torch
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
def via(arr, save_txt:bool = True, size:tuple = (20,20),
out:str = 'array_out.txt', Normalize:bool = False):
dim = arr.ndim
if isinstance(arr, np.ndarray):
# (#Images, #Chennels, #Row, #Column)
if dim == 4:
arr = arr.transpose(3,2,0,1)
if dim == 3:
arr = arr.transpose(2,0,1)
if isinstance(arr, torch.Tensor):
arr = arr.numpy()
fig = plt.figure(figsize=size)
if save_txt:
with open(out, 'w') as outfile:
outfile.write('# Array shape: {0}\n'.format(arr.shape))
if dim == 1 or dim == 2:
np.savetxt(outfile, arr, fmt='%-7.3f')
elif dim == 3:
for i, arr2d in enumerate(arr):
outfile.write('# {0}-th channel\n'.format(i))
np.savetxt(outfile, arr2d, fmt='%-7.3f')
elif dim == 4:
for j, arr3d in enumerate(arr):
outfile.write('\n\n# {0}-th Image\n'.format(j))
for i, arr2d in enumerate(arr3d):
outfile.write('# {0}-th channel\n'.format(i))
np.savetxt(outfile, arr2d, fmt='%-7.3f')
print("Out of dimension!")
if Normalize:
arr -= np.min(arr)
arr /= max(np.max(arr),10e-7)
if dim == 1 or dim == 2:
if dim==1: arr = arr.reshape((1,-1))
fig.suptitle('Array shape: {0}\n'.format(arr.shape), fontsize=30)
plt.imshow(arr, cmap='jet')
elif dim == 3:
x_n = int(np.ceil(np.sqrt(arr.shape[0])))
fig.suptitle('Array shape: {0}\n'.format(arr.shape), fontsize=30)
for i, arr2d in enumerate(arr):
ax = fig.add_subplot(x_n,x_n,i+1)
im = ax.imshow(arr2d, cmap='jet')
elif dim == 4:
img_n = arr.shape[0]
x_n = int(np.ceil(np.sqrt(arr.shape[1])))
outer = gridspec.GridSpec(img_n, 1)
fig.suptitle('Array shape: {0}\n'.format(arr.shape), fontsize=30)
for j, arr3d in enumerate(arr):
inner = gridspec.GridSpecFromSubplotSpec(x_n, x_n, subplot_spec=outer[j],wspace=0.1,hspace=0.3)
for i, arr2d in enumerate(arr3d):
ax = plt.subplot(inner[i])
im = ax.imshow(arr2d, cmap='jet')
ax.set_title('{0}-Image {1}-channel'.format(j,i))
fig.suptitle('Array shape: {0}\n'.format(arr.shape), fontsize=30)
print("Out of dimension!")
arr = torch.rand(2,28,35)
via(arr, size=(20,20))
How to use
from pythonfile import Visualarr; Visualarr(x)