gpt4 book ai didi

python - 在机器学习中使用三种不同的标签

转载 作者:行者123 更新时间:2023-11-30 09:40:57 24 4
gpt4 key购买 nike

我是机器学习领域的新手。我正在审查用于区分电子邮件中的垃圾邮件或非正常邮件值的代码。当我为另一个数据集设置代码时遇到问题。因此,我的数据集不仅仅包含火腿或垃圾邮件值。我有 2 个不同的分类值(年龄和性别)。当我尝试在下面的代码块中使用 2 个分类值时,出现错误,解压值太多。我怎样才能体现我的全部值(value)观?

from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(messages_bow, import_data['age'], import_data['gender'], test_size = 0.20, random_state = 0)

完整代码:

import numpy as np
import pandas
import nltk
from nltk.corpus import stopwords
import string

# Import Data.
import_data = pandas.read_csv('/root/Desktop/%20/%100.csv' , encoding='cp1252')

# To See Columns Headers.
print(import_data.columns)

# To Remove Duplications.
import_data.drop_duplicates(inplace = True)

# To Find Data Size.
print(import_data.shape)


#Tokenization (a list of tokens), will be used as the analyzer
#1.Punctuations are [!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~]
#2.Stop words in natural language processing, are useless words (data).
def process_text(text):
'''
What will be covered:
1. Remove punctuation
2. Remove stopwords
3. Return list of clean text words
'''

#1
nopunc = [char for char in text if char not in string.punctuation]
nopunc = ''.join(nopunc)

#2
clean_words = [word for word in nopunc.split() if word.lower() not in stopwords.words('english')]

#3
return clean_words

#Show the Tokenization (a list of tokens )
print(import_data['text'].head().apply(process_text))

# Convert the text into a matrix of token counts.
from sklearn.feature_extraction.text import CountVectorizer
messages_bow = CountVectorizer(analyzer=process_text).fit_transform(import_data['text'])

#Split data into 80% training & 20% testing data sets

from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(messages_bow, import_data['gender'], import_data['frequency'], test_size = 0.20, random_state = 0)

#Get the shape of messages_bow
print(messages_bow.shape)

最佳答案

train_test_split 将传递给它的每个参数拆分为训练集和测试集。由于您要分割三种不同类型的数据,因此需要 6 个变量:

X_train, X_test, age_train, age_test, gender_train, gender_test = train_test_split(messages_bow, import_data['age'], import_data['gender'], test_size=0.20, random_state=0)

关于python - 在机器学习中使用三种不同的标签,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58742524/

24 4 0