《利用Python进行数据分析》读书笔记。
第8章:绘图和可视化。
%pylab inline
import pandas as pd
from pandas import Series, DataFrame
#fig 有一些重要的选项,特别是figsize,规定了图片尺寸
fig = plt.figure(figsize=(12,3))
ax1 = fig.add_subplot(2, 2, 1)
ax2 = fig.add_subplot(2, 2, 2)
ax3 = fig.add_subplot(2, 2, 3)
# 默认在最后一个 subplot 上绘制
from numpy.random import randn
plt.plot(randn(50).cumsum(), 'k--')
#fig.add_subplot 返回的对象是AxesSubplot对象
_ = ax1.hist(np.random.randn(100),bins = 20,color = 'k',alpha = 0.3)
ax2.scatter(np.arange(30),np.arange(30) + 3 * np.random.randn(30))
Out[2]:
plt.close('all')
# 直接创建并返回 subplot 对象的Numpy数组
fig,axes = plt.subplots(2,3)
axes[0][0].hist(np.random.randn(100),bins = 20,color = 'k',alpha = 0.3)
Out[3]:
print(fig)
print(axes)
可以指定多个 subplot 共用坐标轴。
subplots 的属性:
- nrows: 行数
- ncols: 列数
- sharex: 使用相同的 x 刻度
- sharey: 使用相同的 y 刻度
- ...
调整subplot周围的间距¶
subplots_adjust(left = None,bottom = None,right = None,top = None,wspace = None,hspace = None)
# wspace和space用于控制宽度和高度的百分比,可以用做subplot之间的间距
fig,ax = plt.subplots(2,2,sharex = True,sharey = True)
for i in range(2):
for j in range(2):
ax[i,j].hist(np.random.randn(500),bins = 50,color = 'k',alpha = 0.5)
plt.subplots_adjust(wspace = 0,hspace = 0.2)
# plt.show()
# matplotlib不会检查标签的重叠
颜色、标记和线型¶
# 绿色虚线
plt.figure()
plt.plot(randn(30).cumsum(), 'g--')
Out[6]:
# 等价于
plt.plot(randn(30).cumsum(), linestyle='--', color='g')
Out[7]:
# 加上标记点
plt.figure()
plt.plot(randn(30).cumsum(), 'ko--')
# 等价于
# plt.plot(randn(30).cumsum(), color='k', linestyle='--', marker='o')
Out[8]:
# drawstyle 指定插值方式
data = randn(30).cumsum()
plt.plot(data, 'k--', label='Default')
plt.plot(data, 'k-', drawstyle='steps-post', label='steps-post')
plt.legend(loc='best')
Out[9]:
刻度、标签和图例¶
fig = plt.figure(); ax = fig.add_subplot(1, 1, 1)
ax.plot(randn(1000).cumsum())
# x 刻度
ticks = ax.set_xticks([0, 250, 500, 750, 1000])
# x 刻度标签
labels = ax.set_xticklabels(['one', 'two', 'three', 'four', 'five'],
rotation=30, fontsize='small')
# 图表标题
ax.set_title('My first matplotlib plot')
# x 标题
ax.set_xlabel('Stages')
Out[10]:
# 增加图例
fig = plt.figure(); ax = fig.add_subplot(1, 1, 1)
ax.plot(randn(1000).cumsum(), 'k', label='one')
ax.plot(randn(1000).cumsum(), 'k--', label='two')
ax.plot(randn(1000).cumsum(), 'k.', label='three')
ax.legend(loc='best')
Out[11]:
注解及在 subplot 上绘图¶
from datetime import datetime
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
data = pd.read_csv('data/ch08/spx.csv', index_col=0, parse_dates=True)
spx = data['SPX']
spx.plot(ax=ax, style='k-')
crisis_data = [
(datetime(2007, 10, 11), 'Peak of bull market'),
(datetime(2008, 3, 12), 'Bear Stearns Fails'),
(datetime(2008, 9, 15), 'Lehman Bankruptcy')
]
# 逐个添加注解
for date, label in crisis_data:
ax.annotate(label, xy=(date, spx.asof(date) + 50),
xytext=(date, spx.asof(date) + 200),
arrowprops=dict(facecolor='black'),
horizontalalignment='left', verticalalignment='top')
# Zoom in on 2007-2010
ax.set_xlim(['1/1/2007', '1/1/2011'])
ax.set_ylim([600, 1800])
ax.set_title('Important dates in 2008-2009 financial crisis')
Out[12]:
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
rect = plt.Rectangle((0.2, 0.75), 0.4, 0.15, color='k', alpha=0.3)
circ = plt.Circle((0.7, 0.2), 0.15, color='b', alpha=0.3)
pgon = plt.Polygon([[0.15, 0.15], [0.35, 0.4], [0.2, 0.6]],
color='g', alpha=0.5)
ax.add_patch(rect)
ax.add_patch(circ)
ax.add_patch(pgon)
Out[13]:
保存到文件¶
fig
fig.savefig('figpath.svg')
fig.savefig('figpath.png', dpi=400, bbox_inches='tight')
from io import BytesIO
buffer = BytesIO()
plt.savefig(buffer)
plot_data = buffer.getvalue()
# 全局设置
plt.rc('figure', figsize=(10, 10))
s = Series(np.random.randn(10).cumsum(), index=np.arange(0, 100, 10))
s.plot()
Out[18]:
df = DataFrame(np.random.randn(10, 4).cumsum(0),
columns=['A', 'B', 'C', 'D'],
index=np.arange(0, 100, 10))
df.plot()
Out[19]:
柱状图¶
fig, axes = plt.subplots(2, 1)
data = Series(np.random.rand(16), index=list('abcdefghijklmnop'))
data.plot(kind='bar', ax=axes[0], color='k', alpha=0.7)
data.plot(kind='barh', ax=axes[1], color='k', alpha=0.7)
Out[20]:
df = DataFrame(np.random.rand(6, 4),
index=['one', 'two', 'three', 'four', 'five', 'six'],
columns=pd.Index(['A', 'B', 'C', 'D'], name='Genus'))
df
df.plot(kind='bar')
Out[21]:
plt.figure()
Out[22]:
df.plot(kind='barh', stacked=True, alpha=0.5)
Out[23]:
tips = pd.read_csv('data/ch08/tips.csv')
party_counts = pd.crosstab(tips.day, tips.size)
party_counts
# Not many 1- and 6-person parties
party_counts = party_counts.ix[:, 2:5]
# Normalize to sum to 1
# party_pcts = party_counts.div(party_counts.sum(1).astype(float), axis=0)
# party_pcts
# party_pcts.plot(kind='bar', stacked=True)
直方图和密度图¶
tips['tip_pct'] = tips['tip'] / tips['total_bill']
tips['tip_pct'].hist(bins=50)
Out[26]:
tips['tip_pct'].plot(kind='kde')
Out[27]:
comp1 = np.random.normal(0, 1, size=200) # N(0, 1)
comp2 = np.random.normal(10, 2, size=200) # N(10, 4)
values = Series(np.concatenate([comp1, comp2]))
values.hist(bins=100, alpha=0.3, color='k', normed=True)
values.plot(kind='kde', style='k--')
Out[28]:
散点图¶
macro = pd.read_csv('data/ch08/macrodata.csv')
data = macro[['cpi', 'm1', 'tbilrate', 'unemp']]
trans_data = np.log(data).diff().dropna()
trans_data[-5:]
Out[29]:
plt.scatter(trans_data['m1'], trans_data['unemp'])
plt.title('Changes in log %s vs. log %s' % ('m1', 'unemp'))
pd.scatter_matrix(trans_data, diagonal='kde', alpha=0.3)
Out[30]:
data = pd.read_csv('data/ch08/Haiti.csv')
data.info()
data[['INCIDENT DATE', 'LATITUDE', 'LONGITUDE']][:10]
Out[32]:
data['CATEGORY'][:6]
Out[33]:
data.describe()
Out[34]:
data = data[(data.LATITUDE > 18) & (data.LATITUDE < 20) &
(data.LONGITUDE > -75) & (data.LONGITUDE < -70)
& data.CATEGORY.notnull()]
def to_cat_list(catstr):
stripped = (x.strip() for x in catstr.split(','))
return [x for x in stripped if x]
def get_all_categories(cat_series):
cat_sets = (set(to_cat_list(x)) for x in cat_series)
return sorted(set.union(*cat_sets))
def get_english(cat):
code, names = cat.split('.')
if '|' in names:
names = names.split(' | ')[1]
return code, names.strip()
get_english('2. Urgences logistiques | Vital Lines')
Out[37]:
all_cats = get_all_categories(data.CATEGORY)
# Generator expression
english_mapping = dict(get_english(x) for x in all_cats)
english_mapping['2a']
english_mapping['6c']
Out[38]:
def get_code(seq):
return [x.split('.')[0] for x in seq if x]
all_codes = get_code(all_cats)
code_index = pd.Index(np.unique(all_codes))
dummy_frame = DataFrame(np.zeros((len(data), len(code_index))),
index=data.index, columns=code_index)
dummy_frame.ix[:, :6].info()
for row, cat in zip(data.index, data.CATEGORY):
codes = get_code(to_cat_list(cat))
dummy_frame.ix[row, codes] = 1
data = data.join(dummy_frame.add_prefix('category_'))
data.ix[:, 10:15].info()
from mpl_toolkits.basemap import Basemap
import matplotlib.pyplot as plt
def basic_haiti_map(ax=None, lllat=17.25, urlat=20.25,
lllon=-75, urlon=-71):
# create polar stereographic Basemap instance.
m = Basemap(ax=ax, projection='stere',
lon_0=(urlon + lllon) / 2,
lat_0=(urlat + lllat) / 2,
llcrnrlat=lllat, urcrnrlat=urlat,
llcrnrlon=lllon, urcrnrlon=urlon,
resolution='f')
# draw coastlines, state and country boundaries, edge of map.
m.drawcoastlines()
m.drawstates()
m.drawcountries()
return m
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(12, 10))
fig.subplots_adjust(hspace=0.05, wspace=0.05)
to_plot = ['2a', '1', '3c', '7a']
lllat=17.25; urlat=20.25; lllon=-75; urlon=-71
for code, ax in zip(to_plot, axes.flat):
m = basic_haiti_map(ax, lllat=lllat, urlat=urlat,
lllon=lllon, urlon=urlon)
cat_data = data[data['category_%s' % code] == 1]
# compute map proj coordinates.
x, y = m(cat_data.LONGITUDE.values, cat_data.LATITUDE.values)
m.plot(x, y, 'k.', alpha=0.5)
ax.set_title('%s: %s' % (code, english_mapping[code]))
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(12, 10))
fig.subplots_adjust(hspace=0.05, wspace=0.05)
to_plot = ['2a', '1', '3c', '7a']
lllat=17.25; urlat=20.25; lllon=-75; urlon=-71
def make_plot():
for i, code in enumerate(to_plot):
cat_data = data[data['category_%s' % code] == 1]
lons, lats = cat_data.LONGITUDE, cat_data.LATITUDE
ax = axes.flat[i]
m = basic_haiti_map(ax, lllat=lllat, urlat=urlat,
lllon=lllon, urlon=urlon)
# compute map proj coordinates.
x, y = m(lons.values, lats.values)
m.plot(x, y, 'k.', alpha=0.5)
ax.set_title('%s: %s' % (code, english_mapping[code]))
make_plot()
shapefile_path = 'data/ch08/PortAuPrince_Roads/PortAuPrince_Roads'
m.readshapefile(shapefile_path, 'roads')
Out[47]:
Python 图形化工具生态系统¶
Chaco
能生成可交互的图像
mayavi
3D图像工具包,可以平移、旋转、缩放等。