Matplotlib: add color bar to non-displayable

I have a series of lines representing a change to a variable; each with a unique color. For this reason, I want to add a color bar next to the plot. The required output is shown below.

The problem is what plot

is a non-transferable object i.e. the colored bar must be added manually. I find my current solution (below) suboptimal as it includes sizing options that I'm not interested in controlling. I would prefer a similar solution as for the displayable (example below the current solution).

Desired exit

Desired output

Current solution

import numpy             as np
import matplotlib        as mpl
import matplotlib.pyplot as plt

x    = np.linspace(0, 5, 100)
N    = 20
cmap = plt.get_cmap('jet',N)

fig  = plt.figure(figsize=(8,6))
ax1  = fig.add_axes([0.10,0.10,0.70,0.85])

for i,n in enumerate(np.linspace(0,2,N)):
    y = np.sin(x)*x**n
    ax1.plot(x,y,c=cmap(i))

plt.xlabel('x')
plt.ylabel('y')

ax2  = fig.add_axes([0.85,0.10,0.05,0.85])
norm = mpl.colors.Normalize(vmin=0,vmax=2)
cb1  = mpl.colorbar.ColorbarBase(ax2,cmap=cmap,norm=norm,orientation='vertical')

plt.show()

      

Desired solution

(obviously replacing imshow

)

fig,ax = plt.subplots()
cax    = ax.imshow(..)
cbar   = fig.colorbar(cax,aspect=10)
plt.show()

      

+3


source to share


1 answer


You can define your own ScalarMappable and use it as if it were present in the plot.
(Note that I changed the colors of numo f to 21 to have nice distances 0.1

)

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

x = np.linspace(0, 5, 100)
N = 21
cmap = plt.get_cmap('jet',N)

fig = plt.figure(figsize=(8,6))
ax1 = fig.add_axes([0.10,0.10,0.70,0.85])

for i,n in enumerate(np.linspace(0,2,N)):
    y = np.sin(x)*x**n
    ax1.plot(x,y,c=cmap(i))

plt.xlabel('x')
plt.ylabel('y')

norm = mpl.colors.Normalize(vmin=0,vmax=2)
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])
plt.colorbar(sm, ticks=np.linspace(0,2,N), 
             boundaries=np.arange(-0.05,2.1,.1))


plt.show()

      



enter image description here

+5


source







All Articles