Asked  7 Months ago    Answers:  5   Viewed   35 times

I am a little confused about how this code works:

fig, axes = plt.subplots(nrows=2, ncols=2)
plt.show()

How does the fig, axes work in this case? What does it do?

Also why wouldn't this work to do the same thing:

fig = plt.figure()
axes = fig.subplots(nrows=2, ncols=2)

 Answers

29

There are several ways to do it. The subplots method creates the figure along with the subplots that are then stored in the ax array. For example:

import matplotlib.pyplot as plt

x = range(10)
y = range(10)

fig, ax = plt.subplots(nrows=2, ncols=2)

for row in ax:
    for col in row:
        col.plot(x, y)

plt.show()

enter image description here

However, something like this will also work, it's not so "clean" though since you are creating a figure with subplots and then add on top of them:

fig = plt.figure()

plt.subplot(2, 2, 1)
plt.plot(x, y)

plt.subplot(2, 2, 2)
plt.plot(x, y)

plt.subplot(2, 2, 3)
plt.plot(x, y)

plt.subplot(2, 2, 4)
plt.plot(x, y)

plt.show()

enter image description here

Tuesday, June 1, 2021
 
kwichz
answered 7 Months ago
96

Just place the colorbar in its own axis and use subplots_adjust to make room for it.

As a quick example:

import numpy as np
import matplotlib.pyplot as plt

fig, axes = plt.subplots(nrows=2, ncols=2)
for ax in axes.flat:
    im = ax.imshow(np.random.random((10,10)), vmin=0, vmax=1)

fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
fig.colorbar(im, cax=cbar_ax)

plt.show()

enter image description here

Note that the color range will be set by the last image plotted (that gave rise to im) even if the range of values is set by vmin and vmax. If another plot has, for example, a higher max value, points with higher values than the max of im will show in uniform color.

Tuesday, June 1, 2021
 
Sujith
answered 7 Months ago
67

First, you're using calls to plt when you have Axes objects as your disposal. That road leads to pain. Second, imshow sets the aspect ratio of the axes scales to 1. That's why the axes are so narrow. Knowing all that, your example becomes:

import numpy as np
import matplotlib.pyplot as plt

data = np.random.rand(10,4)

#creating a wide figure with 2 subplots in 1 row
fig, axes = plt.subplots(1, 2, figsize=(9,3))  

for ax in axes.flatten():  # flatten in case you have a second row at some point
    img = ax.imshow(data, interpolation='nearest')
    ax.set_aspect('auto')

plt.colorbar(img)

On my system, that looks like this: enter image description here

Saturday, July 31, 2021
 
ajjumma
answered 4 Months ago
27

A correct combination of legendgroup and showlegend should do the trick. With the setup below, all 2017 traces are assigned to the same legendgroup="2017". And all 2017 traces except the first have showlegend=False. And of course the same goes for the 2018 traces. Give it a try!

Plot

enter image description here

Complete code

from plotly.subplots import make_subplots
import plotly.graph_objects as go
from plotly import offline

fig = make_subplots(rows=3, cols=1)

fig.add_trace(go.Scatter(x=[3, 4, 5], y=[1000, 1100, 1200],
                         name="2017", legendgroup="2017",
                         line=dict(color='blue')),
              row=1, col=1)

fig.add_trace(go.Scatter(x=[2, 3, 4], y=[1200, 1100, 1000],
                         name="2018",legendgroup="2018",
                         line=dict(color='red')),
              row=1, col=1)


fig.add_trace(go.Scatter(x=[2, 3, 4], y=[100, 110, 120],
                         name="2017", legendgroup="2017",
                         line=dict(color='blue'),
                         showlegend=False),
              row=2, col=1)

fig.append_trace(go.Scatter(x=[2, 3, 4], y=[120, 110, 100],
                            name="2018", legendgroup="2018",
                            line=dict(color='red'),
                            showlegend=False),
                 row=2, col=1)

fig.append_trace(go.Scatter(x=[0, 1, 2], y=[10, 11, 12],
                            name="2017", legendgroup="2017",
                            line=dict(color='blue'),
                            showlegend=False),
                 row=3, col=1)

fig.append_trace(go.Scatter(x=[0, 1, 2], y=[12, 11, 10],
                            name="2018", legendgroup="2018",
                            line=dict(color='red'),
                            showlegend=False),
                 row=3, col=1)

fig.update_layout(height=600, width=600, title_text="Stacked Subplots")
#offline.plot(fig,filename="subplots.html")
fig.show()
Saturday, August 7, 2021
 
muffe
answered 4 Months ago
15

The key here is to assign your traces to the subplot through row and col in fig.add_trace(). And you don't have to use from plotly.offline import iplot for the latest plotly updates.

Plot:

enter image description here

Code:

# imports
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import pandas as pd
import numpy as np

# data
df = pd.DataFrame({'Index': {0: 1.0,
                              1: 2.0,
                              2: 3.0,
                              3: 4.0,
                              4: 5.0,
                              5: 6.0,
                              6: 7.0,
                              7: 8.0,
                              8: 9.0,
                              9: 10.0},
                             'A': {0: 15.0,
                              1: 6.0,
                              2: 5.0,
                              3: 4.0,
                              4: 3.0,
                              5: 2.0,
                              6: 1.0,
                              7: 0.5,
                              8: 0.3,
                              9: 0.1},
                             'B': {0: 1.0,
                              1: 4.0,
                              2: 2.0,
                              3: 5.0,
                              4: 4.0,
                              5: 6.0,
                              6: 7.0,
                              7: 2.0,
                              8: 8.0,
                              9: 1.0},
                             'C': {0: 12.0,
                              1: 6.0,
                              2: 5.0,
                              3: 4.0,
                              4: 3.0,
                              5: 2.0,
                              6: 1.0,
                              7: 0.5,
                              8: 0.2,
                              9: 0.1}})
# set up plotly figure
fig = make_subplots(1,2)

# add first bar trace at row = 1, col = 1
fig.add_trace(go.Bar(x=df['Index'], y=df['A'],
                     name='A',
                     marker_color = 'green',
                     opacity=0.4,
                     marker_line_color='rgb(8,48,107)',
                     marker_line_width=2),
              row = 1, col = 1)

# add first scatter trace at row = 1, col = 1
fig.add_trace(go.Scatter(x=df['Index'], y=df['B'], line=dict(color='red'), name='B'),
              row = 1, col = 1)

# add first bar trace at row = 1, col = 2
fig.add_trace(go.Bar(x=df['Index'], y=df['C'],
                     name='C',
                     marker_color = 'green',
                     opacity=0.4,
                     marker_line_color='rgb(8,48,107)',
                    marker_line_width=2),
              row = 1, col = 2)

fig.show()
Wednesday, September 22, 2021
 
anas
answered 3 Months ago
Only authorized users can answer the question. Please sign in first, or register a free account.
Not the answer you're looking for? Browse other questions tagged :  
Share