gpt4 book ai didi

python - 具有两个分类变量的 Matplotlib 点图

转载 作者:行者123 更新时间:2023-12-01 00:54:08 26 4
gpt4 key购买 nike

我想生成一种特定类型的可视化,包含一个相当简单的 dot plot但有一点不同:两个轴都是分类变量(即序数或非数值)。这不但没有让事情变得更容易,反而让事情变得更加复杂。

为了说明这个问题,我将使用一个小型示例数据集,该数据集是对 seaborn.load_dataset("tips") 的修改,并定义如下:

import pandas
from six import StringIO
df = """total_bill | tip | sex | smoker | day | time | size
16.99 | 1.01 | Male | No | Mon | Dinner | 2
10.34 | 1.66 | Male | No | Sun | Dinner | 3
21.01 | 3.50 | Male | No | Sun | Dinner | 3
23.68 | 3.31 | Male | No | Sun | Dinner | 2
24.59 | 3.61 | Female | No | Sun | Dinner | 4
25.29 | 4.71 | Female | No | Mon | Lunch | 4
8.77 | 2.00 | Female | No | Tue | Lunch | 2
26.88 | 3.12 | Male | No | Wed | Lunch | 4
15.04 | 3.96 | Male | No | Sat | Lunch | 2
14.78 | 3.23 | Male | No | Sun | Lunch | 2"""
df = pandas.read_csv(StringIO(df.replace(' ','')), sep="|", header=0)

生成图表的第一种方法是尝试调用 seaborn,如下所示:

import seaborn
axes = seaborn.pointplot(x="time", y="sex", data=df)

此操作失败:

ValueError: Neither the `x` nor `y` variable appears to be numeric.

等效的 seaborn.stripplotseaborn.swarmplot 调用也是如此。但是,如果其中一个变量是分类变量而另一个变量是数值变量,则它确实有效。确实 seaborn.pointplot(x="total_bill", y="sex", data=df) 有效,但不是我想要的。

我还尝试了这样的散点图:

axes = seaborn.scatterplot(x="time", y="sex", size="day", data=df,
x_jitter=True, y_jitter=True)

这会产生以下图表,该图表不包含任何抖动,并且所有点都重叠,因此毫无用处:

SeabornScatterPlot

你知道有什么优雅的方法或库可以解决我的问题吗?

我开始自己写一些东西,我将在下面包含它,但这种实现不是最理想的,并且受到可以在同一点重叠的点的数量的限制(目前,如果超过 4 个点重叠,它就会失败)。

# Modules #
import seaborn, pandas, matplotlib
from six import StringIO

################################################################################
def amount_to_offets(amount):
"""A function that takes an amount of overlapping points (e.g. 3)
and returns a list of offsets (jittered) coordinates for each of the
points.

It follows the logic that two points are displayed side by side:

2 -> * *

Three points are organized in a triangle

3 -> *
* *

Four points are sorted into a square, and so on.

4 -> * *
* *
"""
assert isinstance(amount, int)
solutions = {
1: [( 0.0, 0.0)],
2: [(-0.5, 0.0), ( 0.5, 0.0)],
3: [(-0.5, -0.5), ( 0.0, 0.5), ( 0.5, -0.5)],
4: [(-0.5, -0.5), ( 0.5, 0.5), ( 0.5, -0.5), (-0.5, 0.5)],
}
return solutions[amount]

################################################################################
class JitterDotplot(object):

def __init__(self, data, x_col='time', y_col='sex', z_col='tip'):
self.data = data
self.x_col = x_col
self.y_col = y_col
self.z_col = z_col

def plot(self, **kwargs):
# Load data #
self.df = self.data.copy()

# Assign numerical values to the categorical data #
# So that ['Dinner', 'Lunch'] becomes [0, 1] etc. #
self.x_values = self.df[self.x_col].unique()
self.y_values = self.df[self.y_col].unique()
self.x_mapping = dict(zip(self.x_values, range(len(self.x_values))))
self.y_mapping = dict(zip(self.y_values, range(len(self.y_values))))
self.df = self.df.replace({self.x_col: self.x_mapping, self.y_col: self.y_mapping})

# Offset points that are overlapping in the same location #
# So that (2.0, 3.0) becomes (2.05, 2.95) for instance #
cols = [self.x_col, self.y_col]
scaling_factor = 0.05
for values, df_view in self.df.groupby(cols):
offsets = amount_to_offets(len(df_view))
offsets = pandas.DataFrame(offsets, index=df_view.index, columns=cols)
offsets *= scaling_factor
self.df.loc[offsets.index, cols] += offsets

# Plot a standard scatter plot #
g = seaborn.scatterplot(x=self.x_col, y=self.y_col, size=self.z_col, data=self.df, **kwargs)

# Force integer ticks on the x and y axes #
locator = matplotlib.ticker.MaxNLocator(integer=True)
g.xaxis.set_major_locator(locator)
g.yaxis.set_major_locator(locator)
g.grid(False)

# Expand the axis limits for x and y #
margin = 0.4
xmin, xmax, ymin, ymax = g.get_xlim() + g.get_ylim()
g.set_xlim(xmin-margin, xmax+margin)
g.set_ylim(ymin-margin, ymax+margin)

# Replace ticks with the original categorical names #
g.set_xticklabels([''] + list(self.x_mapping.keys()))
g.set_yticklabels([''] + list(self.y_mapping.keys()))

# Return for display in notebooks for instance #
return g

################################################################################
# Graph #
graph = JitterDotplot(data=df)
axes = graph.plot()
axes.figure.savefig('jitter_dotplot.png')

JitterDotPlot

最佳答案

您可以首先将时间性别转换为分类类型并稍微调整一下:

df.sex = pd.Categorical(df.sex)
df.time = pd.Categorical(df.time)

axes = sns.scatterplot(x=df.time.cat.codes+np.random.uniform(-0.1,0.1, len(df)),
y=df.sex.cat.codes+np.random.uniform(-0.1,0.1, len(df)),
size=df.tip)

输出:

enter image description here

有了这个想法,您可以将上述代码中的偏移量(np.random)修改为相应的距离。例如:

# grouping
groups = df.groupby(['time', 'sex'])

# compute the number of samples per group
num_samples = groups.tip.transform('size')

# enumerate the samples within a group
sample_ranks = df.groupby(['time']).cumcount() * (2*np.pi) / num_samples

# compute the offset
x_offsets = np.where(num_samples.eq(1), 0, np.cos(df.sample_rank) * 0.03)
y_offsets = np.where(num_samples.eq(1), 0, np.sin(df.sample_rank) * 0.03)

# plot
axes = sns.scatterplot(x=df.time.cat.codes + x_offsets,
y=df.sex.cat.codes + y_offsets,
size=df.tip)

输出:

enter image description here

关于python - 具有两个分类变量的 Matplotlib 点图,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56347325/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com