gpt4 book ai didi

用于 beta 分发的 Python scipy 重载 _stats 函数

转载 作者:行者123 更新时间:2023-12-01 03:06:26 35 4
gpt4 key购买 nike

我需要为我的 beta 发行版重载 _stats 函数。这是我当前的代码:

from scipy.stats import beta
import scipy.stats as st

class CustomBeta(st.rv_continuous):
def _stats(self, a, b):
# will add own code here
mn = a * 1.0 / (a + b)
var = (a * b * 1.0) / (a + b + 1.0) / (a + b) ** 2.0
g1 = 2.0 * (b - a) * sqrt((1.0 + a + b) / (a * b)) / (2 + a + b)
g2 = 6.0 * (a ** 3 + a ** 2 * (1 - 2 * b) + b ** 2 * (1 + b) - 2 * a * b * (2 + b))
g2 /= a * b * (a + b + 2) * (a + b + 3)
return mn, var, g1, g2

dist = beta(4, 6)
print dist.rvs() # works fine


dist = CustomBeta(4, 6)
print dist.rvs() # crashes

从我的自定义对象获取_rvs()给我一个很长的堆栈跟踪和一个错误

运行时错误:超出最大递归深度

最佳答案

这与重载_stats无关。同样的行为只是由

class CustomBeta(st.rv_continuous):
pass

dist = CustomBeta(4, 6)
print(dist.rvs()) # crashes

documentation of rv_continuous指出

New random variables can be defined by subclassing the rv_continuous class and re-defining at least the _pdf or the _cdf method.

您需要至少提供其中一种方法来计算概率密度函数 (pdf) 或累积概率密度函数 (cdf)。

此外,

[rv_continuous] cannot be used directly as a distribution.

使用方法如下:

class CustomBetaGen(st.rv_continuous):
...

CustomBeta = CustomBetaGen(name='CustomBeta')

dist = CustomBeta(4, 6)

最后,如果您不提供 _rvs 方法,rvs.() 似乎无法在 Beta 发行版中正常工作。

将所有内容放在一起并从 beta 发行版中窃取适当的方法:

from scipy.stats import beta
import scipy.stats as st
import numpy as np

class CustomBetaGen(st.rv_continuous):
def _cdf(self, x, a, b):
return beta.cdf(x, a, b)
def _pdf(self, x, a, b):
return beta.pdf(x, a, b)
def _rvs(self, a, b):
return beta.rvs(a, b)
def _stats(self, a, b):
# will add own code here
mn = a * 1.0 / (a + b)
var = (a * b * 1.0) / (a + b + 1.0) / (a + b) ** 2.0
g1 = 2.0 * (b - a) * np.sqrt((1.0 + a + b) / (a * b)) / (2 + a + b)
g2 = 6.0 * (a ** 3 + a ** 2 * (1 - 2 * b) + b ** 2 * (1 + b) - 2 * a * b * (2 + b))
g2 /= a * b * (a + b + 2) * (a + b + 3)
return mn, var, g1, g2

CustomBeta = CustomBetaGen(name='CustomBeta')

dist = beta(4, 6)
print(dist.rvs()) # works fine
print(dist.stats()) # (array(0.4), array(0.021818181818181816))

dist = CustomBeta(4, 6)
print(dist.rvs()) # works fine
print(dist.stats()) # (array(0.4), array(0.021818181818181816))

关于用于 beta 分发的 Python scipy 重载 _stats 函数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43367874/

35 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com