【Python】Visualize torch.tensor or numpy.ndarray

In this short guide, you’ll see how to Visualize torch.tensor or numpy.ndarray

how to Visualize torch.tensor or numpy.ndarray

Visualize torch.tensor or numpy.ndarray

# from via import via; via(x)
import numpy as np
import torch
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

def via(arr, save_txt:bool = True, size:tuple = (20,20), 
                out:str = 'array_out.txt', Normalize:bool = False):
    dim = arr.ndim 
    if isinstance(arr, np.ndarray):
        # (#Images, #Chennels, #Row, #Column)
        if dim == 4:
            arr = arr.transpose(3,2,0,1)
        if dim == 3:
            arr = arr.transpose(2,0,1)
    if isinstance(arr, torch.Tensor):
        arr = arr.numpy()
    fig = plt.figure(figsize=size)

    if save_txt:
        with open(out, 'w') as outfile:    
            outfile.write('# Array shape: {0}\n'.format(arr.shape))
            
            if dim == 1 or dim == 2:
                np.savetxt(outfile, arr, fmt='%-7.3f')

            elif dim == 3:
                for i, arr2d in enumerate(arr):
                    outfile.write('# {0}-th channel\n'.format(i))
                    np.savetxt(outfile, arr2d, fmt='%-7.3f')

            elif dim == 4:
                for j, arr3d in enumerate(arr):
                    outfile.write('\n\n# {0}-th Image\n'.format(j))
                    for i, arr2d in enumerate(arr3d):
                        outfile.write('# {0}-th channel\n'.format(i))
                        np.savetxt(outfile, arr2d, fmt='%-7.3f')
            else:
                print("Out of dimension!")

    
    if Normalize:
        arr -= np.min(arr)
        arr /= max(np.max(arr),10e-7)
    if dim == 1 or dim == 2:
        if dim==1: arr = arr.reshape((1,-1))
        fig.suptitle('Array shape: {0}\n'.format(arr.shape), fontsize=30)
        plt.imshow(arr, cmap='jet')
        plt.colorbar()
        plt.savefig('array_out.png')

    elif dim == 3:
        x_n = int(np.ceil(np.sqrt(arr.shape[0])))
        fig.suptitle('Array shape: {0}\n'.format(arr.shape), fontsize=30)
        for i, arr2d in enumerate(arr):
            ax = fig.add_subplot(x_n,x_n,i+1)
            im = ax.imshow(arr2d, cmap='jet')
            plt.colorbar(im)
            ax.set_title('{0}-channel'.format(i))
        fig.savefig('array_out.png')

    elif dim == 4:
        img_n = arr.shape[0]
        x_n = int(np.ceil(np.sqrt(arr.shape[1])))
        outer = gridspec.GridSpec(img_n, 1)
        fig.suptitle('Array shape: {0}\n'.format(arr.shape), fontsize=30)
        for j, arr3d in enumerate(arr):
            inner = gridspec.GridSpecFromSubplotSpec(x_n, x_n, subplot_spec=outer[j],wspace=0.1,hspace=0.3)
            for i, arr2d in enumerate(arr3d):
                ax = plt.subplot(inner[i])
                im = ax.imshow(arr2d, cmap='jet')
                plt.colorbar(im)
                ax.set_title('{0}-Image {1}-channel'.format(j,i))
        
        fig.suptitle('Array shape: {0}\n'.format(arr.shape), fontsize=30)
        fig.savefig('array_out.png')

    else:
        print("Out of dimension!")

    

arr = torch.rand(2,28,35)
via(arr, size=(20,20))

How to use

from pythonfile import Visualarr; Visualarr(x)


© All rights reserved By Junha Song.