Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

can't plot multi-row subplots #8117

Closed
ericdf opened this issue Aug 26, 2014 · 3 comments
Closed

can't plot multi-row subplots #8117

ericdf opened this issue Aug 26, 2014 · 3 comments

Comments

@ericdf
Copy link

ericdf commented Aug 26, 2014

This code works fine to produce a single row of 3 subplots

g = np.random.choice([1,2,3], 10)
s = np.random.normal(size=10)
s2 =np.random.normal(size=10)
df = pd.DataFrame([g, s, s2]).T
df.columns = ['key', 's1', 's2']
gb = df.groupby('key')

fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(12, 4))
i = 0
for key, df2 in gb:
    df2.plot(ax=axes[i], x='s1', y='s2', title=key)
    i = i + 1

But if I try to add a second row (nrows=2), it blows up with an Attribute Error

AttributeError: 'numpy.ndarray' object has no attribute 'get_figure'

fig, axes = plt.subplots(nrows=2, ncols=3)
fig.tight_layout() # Or equivalently,  "plt.tight_layout()"
i = 1
for key, df2 in gb:
    df2.plot(ax=axes[i])
    i = i + 1
@TomAugspurger
Copy link
Contributor

That's because axes is now a 2-d array of matplotlib axes.

In [35]: fig, axes = plt.subplots(nrows=2, ncols=3)

In [36]: axes
Out[36]: 
array([[<matplotlib.axes._subplots.AxesSubplot object at 0x10c094ac8>,
        <matplotlib.axes._subplots.AxesSubplot object at 0x10ee40b70>,
        <matplotlib.axes._subplots.AxesSubplot object at 0x10c0ac240>],
       [<matplotlib.axes._subplots.AxesSubplot object at 0x10edf8a90>,
        <matplotlib.axes._subplots.AxesSubplot object at 0x10ec27630>,
        <matplotlib.axes._subplots.AxesSubplot object at 0x10edc9128>]], dtype=object)

In [37]: axes.shape
Out[37]: (2, 3)

Try something like

In [38]: for i, (key, df2) in enumerate(gb):
    df2.plot(ax=axes[0][i])

@TomAugspurger
Copy link
Contributor

If you want to wrap around to the second row you'll have to do something like axes[i // 3][i % 3]
(I expanded your example to have 6 groups)

In [67]: df
Out[67]: 
    key        s1        s2
0     3 -1.452043 -0.119374
1     1  0.603860 -1.635034
2     3  0.964165 -0.043124
3     2  0.459628 -0.538155
4     3  0.398761 -0.195261
5     1  0.085750 -0.116766
6     2 -0.397419 -0.140660
7     3 -0.053209  1.547755
8     1 -0.634555 -0.509077
9     3  0.138808  0.608165
10    6 -1.452043 -0.119374
11    4  0.603860 -1.635034
12    6  0.964165 -0.043124
13    5  0.459628 -0.538155
14    6  0.398761 -0.195261
15    4  0.085750 -0.116766
16    5 -0.397419 -0.140660
17    6 -0.053209  1.547755
18    4 -0.634555 -0.509077
19    6  0.138808  0.608165

In [63]: for i, (key, df2) in enumerate(gb):
    df2.plot(ax=axes[i // 3][i % 3])

ex

@ericdf
Copy link
Author

ericdf commented Aug 27, 2014

Thank you so much, Tom. This now works perfectly. Sincerely appreciate your helpful reply.

@ericdf ericdf closed this as completed Aug 27, 2014
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants