gpt4 book ai didi

python-3.x - 在神经网络中使用一个热编码向量馈送标签

转载 作者:行者123 更新时间:2023-11-30 09:08:10 24 4
gpt4 key购买 nike

我正在尝试创建一个分类神经网络(NN)我得到的数据集有 169307 行。我的输出标签是 [0,1,2]我对它们进行了热编码,但我无法使用神经网络对它们进行操作。我遇到值错误。我想我在 reshape 我的“目标”专栏时犯了错误。我已经转换成列表l这是我的解决方案的完整代码。

# coding: utf-8

# In[349]:

import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from sklearn import datasets
from sklearn import metrics
from sklearn import model_selection
from sklearn import preprocessing
from sklearn.preprocessing import OneHotEncoder


# In[382]:

df =pd.read_csv("train_data.csv")
num_labels = 3


# In[392]:

import numpy as np
nb_classes = 3
targets = np.array([0,1,2]).reshape(-1)
one_hot_targets = np.eye(nb_classes)[targets]
one_hot_targets


# In[420]:

target = df["target"]
feat=df.drop(['target','connection_id'],axis=1)
target[10]
l=[]
l=target.values.tolist()
l=np.array(l)
l[9]


# In[410]:




# In[394]:

logs_path="Server_attack"


# In[395]:

#Hyperparameters
batch_size=100
learning_rate=0.5
training_epochs=10


# In[396]:

X=tf.placeholder(tf.float32,[None,41])
Y_=tf.placeholder(tf.float32,[None,3])
lr=tf.placeholder(tf.float32)
XX=tf.reshape(X,[41,-1])



# In[397]:

#5Layer Neural Network
L=200
M=100
N=60
O=30


# In[398]:

#Weights and Biases

W1=tf.Variable(tf.truncated_normal([41,L],stddev=0.1))
B1=tf.Variable(tf.ones([L]))
W2=tf.Variable(tf.truncated_normal([L,M],stddev=0.1))
B2=tf.Variable(tf.ones([M]))
W3=tf.Variable(tf.truncated_normal([M,N],stddev=0.1))
B3=tf.Variable(tf.ones([N]))
W4=tf.Variable(tf.truncated_normal([N,O],stddev=0.1))
B4=tf.Variable(tf.ones([O]))
W5=tf.Variable(tf.truncated_normal([O,3],stddev=0.1))
B5=tf.Variable(tf.ones([3]))



# In[399]:

Y1=tf.nn.relu(tf.matmul(XX,W1)+B1)
Y2=tf.nn.relu(tf.matmul(Y1,W2)+B2)
Y3=tf.nn.relu(tf.matmul(Y2,W3)+B3)
Y4=tf.nn.relu(tf.matmul(Y3,W4)+B4)
Ylogits=tf.nn.relu(tf.matmul(Y4,W5)+B5)
Y=tf.nn.softmax(Ylogits)


# In[400]:

cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=Ylogits,labels=Y_)
cross_entropy = tf.reduce_mean(cross_entropy)


# In[401]:

correct_prediction=tf.equal(tf.argmax(Y,1),tf.argmax(Y_,1))
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))


# In[402]:

train_step=tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy)


# In[403]:

#TensorBoard Parameters
tf.summary.scalar("cost",cross_entropy)
tf.summary.scalar("accuracy",accuracy)
summary_op=tf.summary.merge_all()


# In[404]:

init = tf.global_variables_initializer()
sess=tf.Session()
sess.run(init)


# In[417]:

with tf.Session() as sess:
sess.run(init)
writer = tf.summary.FileWriter(logs_path,graph=tf.get_default_graph())
for epoch in range(training_epochs):
batch_count=int(len(feature)/batch_size)
for i in range(batch_count):


#batch_x,batch_y=feature.iloc[i, :].values.tolist(),target[i]
batch_x = np.expand_dims(np.array(feature.iloc[i, :].values.tolist()), axis=0)
batch_y = np.expand_dims(l, axis=0)


# batch_y = np.reshape(batch_y,(1, 3))



_,summary = sess.run([train_step,summary_op],
{X:batch_x,Y:batch_y,learning_rate:0.001}
)

writer.add_summary(summary, epoch * batch_count + i)
print("Epoch: ", epoch)

错误:

ValueError: Cannot feed value of shape (1, 169307) for Tensor 'Softmax_16:0', which has shape '(41, 3)'

请指导我

最佳答案

您实际上没有进行转换。您仅创建了一个 3x3 单位矩阵 one_hot_targets,但从未使用过它。因此,batch_ydf["target"] 的数组:

target = df["target"]
l = target.values.tolist()
l = np.array(l)
...
batch_y = np.expand_dims(l, axis=0) # Has shape `(1, 169307)`!

您的 batch_x 似乎也不正确,但 feature 未在代码段中定义,因此我无法说出它到底是什么。

[更新]如何进行one-hot编码:

from sklearn.preprocessing import OneHotEncoder

# Categorical target: 0, 1 or 2. The value is just an example
target = np.array([1, 2, 2, 1, 0, 2, 1, 1, 0, 2, 1])

target = target.reshape([-1, 1]) # add one extra dimension
encoder = OneHotEncoder(sparse=False)
encoder.fit(target)
encoded = encoder.transform(target) # now it's one-hot: [N, 3]

关于python-3.x - 在神经网络中使用一个热编码向量馈送标签,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47068683/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com