gpt4 book ai didi

python - 序列化 Numpy 数组的意外行为

转载 作者:太空宇宙 更新时间:2023-11-03 11:25:21 25 4
gpt4 key购买 nike

代码

假设我有:

import numpy
import pickle


class Test():
def __init__(self):
self.base = numpy.zeros(6)
self.view = self.base[-3:]

def __len__(self):
return len(self.view)

def update(self):
self.view[0] += 1

def add(self):
self.view = self.base[-len(self.view) - 1:]
self.view[0] = 1

def __repr__(self):
return str(self.view)


def serialize_data():
data = Test()
return pickle.dumps(data)

请注意,Test 类只是一个包含 NumPy 数组 baseview 的类。此 view 只是 base 中最后 N 元素的一部分(N == 3 在初始化时)。

Test 有一个方法 update()1 添加到 View 位置 0 的值, 和一个方法 add() 修改 View 大小 (N = N + 1) 并将位置 0 的值设置为 1

函数 serialize_data 只是创建一个 Test() 实例,然后使用 pickle 返回序列化对象。

行为

如果我创建一个局部变量并更新它两次并添加它一次,一切都按预期工作:

# Local variable
test = Test()
print(test) # [ 0. 0. 0.]

test.update()
test.update()
print(test) # [ 2. 0. 0.]

test.add()
print(test) # [ 1. 2. 0. 0.]

现在,如果我从序列化数据中创建一个局部变量,那么在执行 add2 之后(在调用 update 两次之后设置) 似乎丢失了:

# Serialized variable
data = pickle.loads(serialize_data())
print(data) # [ 0. 0. 0.]

data.update()
data.update()
print(data) # [ 2. 0. 0.]

data.add()
print(data) # [ 1. 0. 0. 0.] <---- This should be [ 1. 2. 0. 0. ] !!!

问题

为什么会发生这种情况,我该如何避免这种行为?

最佳答案

问题是,在 pickling/unpickling 之后, View 不再是基础 View ,而是拥有自己的数据副本。 See here ,不幸的是,没有关于如何防止这种情况的答案。

可以通过定义 __getstate__ and __setstate__ 来克服特定问题。在 unpickling 后重新定义 View 的类的方法。

除了 View 之外,还需要跟踪 View 所查看的基础部分。我选择使用切片对象,但还有其他方法。没有必要对 View 本身进行 pickle,因为它将在 unpickling 时从切片中重建。

class Test():
def __init__(self):
self.base = numpy.zeros(6)
self.slice = slice(-3, self.base.size)
self.view = self.base[self.slice]

def __len__(self):
return len(self.view)

def update(self):
self.view[0] += 1

def add(self):
self.slice = slice(-len(self.view) - 1, self.base.size)
self.view = self.base[self.slice]
self.view[0] = 1

def __getstate__(self):
return {'base': self.base, 'slice': self.slice}

def __setstate__(self, state):
self.base = state['base']
self.slice = state['slice']
self.view = self.base[self.slice]

def __repr__(self):
return str(self.view)

关于python - 序列化 Numpy 数组的意外行为,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/35150386/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com