gpt4 book ai didi

json - 使用 JSON 文件中的值初始化多个类似乎会创建重复项

转载 作者:行者123 更新时间:2023-12-04 10:48:43 24 4
gpt4 key购买 nike

我有一个名为 Dense 的主类,它有许多变量,我有另一个名为 Sequential 的类,它将这些层堆叠在一个数组中,并允许它们之间轻松通信。

在我的程序中的某个时刻,我最终将每个单独的 Dense 类的所有变量保存在一个 JSON 文件中以备后用。它看起来有点像这样:


{
"layer: 0":{
.
.
.
variables
.
.
.
},

"layer: 1":{
.
.
.
variables
.
.
.
}
}

然后,当我将这些 dicts 加载回我的图层时,当我稍后使用它们时,Sequential 类中最后一个图层的值似乎会复制到其他每个图层。

这是我在重新创建代码方面所能做的最好的事情

from abc import ABCMeta, abstractmethod
import os
import json

class Base_Layer(metaclass=ABCMeta):
@abstractmethod
def __init__(self, name:str, **kwargs) -> None:
self.name = name


@abstractmethod
def get_summary(self) -> dict:
return summary


@abstractmethod
def load(self, layer_data:dict) -> None:
pass





class Dense(Base_Layer):
def __init__(self, layer_shape:tuple):
super().__init__("Dense")

self.layer_shape = layer_shape


def get_summary(self):
summary = {
"name":self.name,
"layer_shape":self.layer_shape
}
return summary


def load(self, layer_data):
self.name = layer_data["name"]
self.layer_shape = tuple(layer_data["layer_shape"])





class Sequential:
def __init__(self, model=[]):
self.model = model


def save(self, file_path):

summaries = []
for layer in self.model:
summaries.append(layer.get_summary())


saved_model = {}
for i in range(len(summaries)):
saved_model["layer: %s" % i] = {}
for key in summaries[i]:
saved_model["layer: %s" % i][key] = summaries[i][key]

with open(file_path, "w+") as json_file:
json_file.write(json.dumps(saved_model, indent=2))


def load(self, file_path):
layers = {
"Dense":Dense((0, 0)),
}

# Try to open the file at file_path.
try:
with open(file_path, "r") as json_file:
model_layers = json.loads(json_file.read())

for i in range(len(model_layers)):
layer_data = model_layers["layer: %s" % i]
self.model.append(layers[layer_data["name"]])
self.model[-1].load(layer_data)

# Gets called if the program can't find the file_path.
except Exception as e:
raise FileNotFoundError("Can't find file path %s. Try saving the model or enter a correct file path." % file_path)



seq_model = Sequential([
Dense((128, 784)),
Dense((128, 128)),
Dense((128, 128)),
Dense((32, 128)),
Dense((10, 32)),
])

file_path = os.path.dirname(os.path.realpath(__file__)) + "/recreate_error_test.json"
seq_model.save(file_path)
seq_model.load(file_path)




for layer in seq_model.model:
print(layer.layer_shape)


这实际上输出:
(128, 784)
(128, 128)
(128, 128)
(32, 128)
(10, 32)
(10, 32)
(10, 32)
(10, 32)
(10, 32)
(10, 32)

奇怪的是为什么它打印两次,前 7 个值是正确的,接下来的 7 个显示了我在实际代码中面临的问题。

对这个问题的任何帮助将不胜感激,谢谢。

最佳答案

问题是重用 layer[layer_data["name"]] ,这与添加到列表中的引用相同。 Dense.load() 不会重新创建一个新的 Dense 对象,它会填充它。

这里对加载函数进行了一些重写以查看元素创建。

import copy

(...)



def load(self, file_path):
layers = {
"Dense":Dense((0, 0)),
}

# Try to open the file at file_path.
try:
with open(file_path, "r") as json_file:
model_layers = json.loads(json_file.read())

for i in range(len(model_layers)):
layer_data = model_layers["layer: %s" % i]
print("%s %s" %( i,layer_data["name"]))
# OK could be new_layer=Dense((0, 0)) since it recreate a new object
new_layer=copy.copy(layers[layer_data["name"]])
# NOK same reference is reused
# new_layer=layers[layer_data["name"]]
new_layer.load(layer_data)
self.model.append(new_layer)

# Gets called if the program can't find the file_path.
except Exception as e:
raise FileNotFoundError("Can't find file path %s. Try saving the model or enter a correct file path." % file_path)

这是使用类名而不需要复制的另一种解决方案:
def load(self, file_path):
layerclass = {
"Dense":Dense,
}

# Try to open the file at file_path.
try:
with open(file_path, "r") as json_file:
model_layers = json.loads(json_file.read())

for i in range(len(model_layers)):
layer_data = model_layers["layer: %s" % i]
print("%s %s" %( i,layer_data["name"]))
# requires that all layerclass have a __init__ constructor with layer_shape tuple.
new_layer=layerclass[layer_data["name"]](layer_shape=(0,0))
new_layer.load(layer_data)
self.model.append(new_layer)

# Gets called if the program can't find the file_path.
except Exception as e:
raise FileNotFoundError("Can't find file path %s. Try saving the model or enter a correct file path." % file_path)

关于json - 使用 JSON 文件中的值初始化多个类似乎会创建重复项,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59583808/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com