专栏名称: 数据STUDIO
点击领取《Python学习手册》,后台回复「福利」获取。『数据STUDIO』专注于数据科学原创文章分享,内容以 Python 为核心语言,涵盖机器学习、数据分析、可视化、MySQL等领域干货知识总结及实战项目。
目录
相关文章推荐
51好读  ›  专栏  ›  数据STUDIO

Matplotlib 丑图到期刊图表改造指南

数据STUDIO  · 公众号  ·  · 2025-03-27 10:30

正文

请到「今天看啥」查看全文



每个使用过 Matplotlib 的人都知道默认图表看起来有多丑。在本系列文章中,我将分享一些技巧,让你的可视化脱颖而出并反映您的个人风格。

我们将从广泛使用的简单折线图开始。主要亮点是添加渐变填充到图表下方 — 这项任务并不完全简单。

那么,我们深入了解这一转变的所有关键步骤!

首先进行所有必要的导入。

import pandas as pd
import numpy as np
import matplotlib.dates as mdates
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from matplotlib import rcParams
from matplotlib.path import Path
from matplotlib.patches import PathPatch

np.random.seed(38)

现在我们需要为可视化生成样本数据。我们将创建类似于股票价格的东西。

dates = pd.date_range(start='2024-02-01', periods=100, freq='D')
initial_rate = 75
drift = 0.003
volatility = 0.1
returns = np.random.normal(drift, volatility, len(dates))
rates = initial_rate * np.cumprod(1 + returns)

x, y = dates, rates

检查一下它在默认的 Matplotlib 设置下是什么样子。

fix, ax = plt.subplots(figsize=(8, 4))
ax.plot(dates, rates)
ax.xaxis.set_major_locator(mdates.DayLocator(interval=30))
plt.show()

不是很迷人吧?但我们会逐渐让它看起来更好。

  • 设置标题
  • 设置常规图表参数——大小和字体
  • 将 Y 刻度放在右侧
  • 更改主线颜色、样式和宽度
# General parameters
fig, ax = plt.subplots(figsize=(10, 6))
plt.title("Daily visitors", fontsize=18, color="black")
rcParams['font.family'] = 'DejaVu Sans'
rcParams['font.size'] = 14

# Axis Y to the right
ax.yaxis.tick_right()
ax.yaxis.set_label_position("right")

# Plotting main line
ax.plot(dates, rates, color='#268358', linewidth=2)

好了,现在看起来干净一些了。

现在,我们想在背景中添加简约的网格,删除边框以获得更整洁的外观,并从 Y 轴上删除刻度。

# Grid
ax.grid(color="gray", linestyle=(0, (10, 10)), linewidth=0.5, alpha=0.6)
ax.tick_params(axis="x", colors="black")
ax.tick_params(axis="y", left=False, labelleft=False) 

# Borders
ax.spines["top"].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines["bottom"].set_color("black")
ax.spines['left'].set_color('white')
ax.spines['left'].set_linewidth(1)

# Remove ticks from axis Y
ax.tick_params(axis='y', length=0)

现在我们在 X 轴上的第一个刻度附近添加一个美学细节 - 年份。我们还使刻度标签的字体颜色更浅。

# Add year to the first date on the axis
def custom_date_formatter(t, pos, dates, x_interval):
    date = dates[pos*x_interval]
    if pos == 0:
        return date.strftime('%d %b \'%y')  
    else:
        return date.strftime('
%d %b')  
ax.xaxis.set_major_formatter(ticker.FuncFormatter((lambda x, pos: custom_date_formatter(x, pos, dates=dates, x_interval=x_interval))))

# Ticks label color
[t.set_color('
#808079') for t in ax.yaxis.get_ticklabels()]
[t.set_color('#808079'for t in ax.xaxis.get_ticklabels()]

现在我们即将进入最棘手的时刻——如何在曲线下创建渐变。实际上 Matplotlib 中没有这样的选项,但我们可以模拟它,创建渐变图像,然后用图表裁剪它。

# Gradient
numeric_x = np.array([i for i in range(len(x))])
numeric_x_patch = np.append(numeric_x, max(numeric_x))
numeric_x_patch = np.append(numeric_x_patch[0], numeric_x_patch)
y_patch = np.append(y, 0)
y_patch = np.append(0, y_patch)

path = Path(np.array([numeric_x_patch, y_patch]).transpose())
patch = PathPatch(path, facecolor='none')
plt.gca().add_patch(patch)

ax.imshow(numeric_x.reshape(len(numeric_x), 1),  interpolation="bicubic",
                cmap=plt.cm.Greens, 
                origin='lower',
                alpha=0.3,
                extent=[min(numeric_x), max(numeric_x), min(y_patch), max(y_patch) * 1.2], 
                aspect="auto", clip_path=patch, clip_on=True)

现在它看起来干净又漂亮。我们只需要使用任何编辑器(我更喜欢 Google Slides)添加一些细节 — 标题、圆角边框和一些数字指示器。

重现可视化的完整代码如下:

import pandas as pd
import numpy as np
import matplotlib.dates as mdates
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from matplotlib import rcParams
from matplotlib.path import Path
from matplotlib.patches import PathPatch

np.random.seed(38)

# Data generation
dates = pd.date_range(start='2024-02-01', periods=100, freq='D')
initial_rate = 75
drift = 0.003
volatility = 0.1
returns = np.random.normal(drift, volatility, len(dates))
rates = initial_rate * np.cumprod(1 + returns)

x, y = dates, rates

# General parameters
fig, ax = plt.subplots(figsize=(106))
plt.title("Daily visitors", fontsize=18, color="black")
rcParams['font.family'] = 'DejaVu Sans'
rcParams['font.size'] = 14

# Axis Y to the right
ax.yaxis.tick_right()
ax.yaxis.set_label_position("right")

# Axis
x_interval = 21
ax.xaxis.set_major_formatter(mdates.DateFormatter("%d %b"))
ax.xaxis.set_major_locator(mdates.DayLocator(interval=x_interval))

ax.yaxis.set_major_locator(ticker.MultipleLocator(50))

# Grid
ax.grid(color="gray", linestyle=(0, (1010)), linewidth=0.5, alpha=0.6)
ax.tick_params(axis="x", colors="black")
ax.tick_params(axis="y", left=False, labelleft=False

# Borders
ax.spines["top"].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines["bottom"].set_color("black")
ax.spines['left'].set_color('white')
ax.spines['left'].set_linewidth(1)

# Remove ticks from axis Y
ax.tick_params(axis='y', length=0)

# Add year to the first date on the axis
def custom_date_formatter(t, pos, dates, x_interval):
    date = dates[pos*x_interval]
    if pos == 0:
        return date.strftime('%d %b \'%y')  
    else:
        return date.strftime('%d %b')  
ax.xaxis.set_major_formatter(ticker.FuncFormatter((lambda x, pos: custom_date_formatter(x, pos, dates=dates, x_interval=x_interval))))

# Ticks label color
[t.set_color('#808079'for t in ax.yaxis.get_ticklabels()]
[t.set_color('#808079'for t in ax.xaxis.get_ticklabels()]


# Gradient
numeric_x = np.array([i for i in range(len(x))])
numeric_x_patch = np.append(numeric_x, max(numeric_x))
numeric_x_patch = np.append(numeric_x_patch[0], numeric_x_patch)
y_patch = np.append(y, 0)
y_patch = np.append(0, y_patch)

path = Path(np.array([numeric_x_patch, y_patch]).transpose())
patch = PathPatch(path, facecolor='none')
plt.gca().add_patch(patch)

ax.imshow(numeric_x.reshape(len(numeric_x), 1),  interpolation="bicubic",
                cmap=plt.cm.Greens, 
                origin='lower',
                alpha=0.3,
                extent=[min(numeric_x), max(numeric_x), min(y_patch), max(y_patch) * 1.2], 
                aspect="auto", clip_path=patch, clip_on=True)

# Plotting main line
y_chart = y_patch
y_chart[0] = y_chart[1]
y_chart[-1] = y_chart[-2]
ax.plot(numeric_x_patch, y_chart, color='#268358', linewidth=2)

# fix a grey line of imshow
ax.plot([max(numeric_x_patch), max(numeric_x_patch)], [0, max(y)], color='white', linewidth=2)

plt.savefig('high_quality_plot.png', dpi=300, bbox_inches='tight')
plt.show()

🏴‍☠️宝藏级🏴‍☠️ 原创公众号『 数据STUDIO 』内容超级硬核。公众号以Python为核心语言,垂直于数据科学领域,包括 可戳 👉 Python MySQL 数据分析 数据可视化 机器学习与数据挖掘 爬虫 等,从入门到进阶!

长按👇关注- 数据STUDIO -设为星标,干货速递







请到「今天看啥」查看全文