gpt4 book ai didi

python - 如何在 tensorflow 中保存文本分类模型?

转载 作者:行者123 更新时间:2023-11-28 19:02:59 26 4
gpt4 key购买 nike

阅读 tensorflow documentation对于文本分类,我在下面放了一个脚本,用于训练文本分类(正/负)模型。我不确定一件事。我怎样才能保存模型以便以后重用?另外,如何测试我拥有的输入测试集?

import tensorflow as tf
import tensorflow_hub as hub
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import re
import seaborn as sns



# Load all files from a directory in a DataFrame.
def load_directory_data(directory):
data = {}
data["sentence"] = []
data["sentiment"] = []
for file_path in os.listdir(directory):
with tf.gfile.GFile(os.path.join(directory, file_path), "r") as f:
data["sentence"].append(f.read())
data["sentiment"].append(re.match("\d+_(\d+)\.txt", file_path).group(1))
return pd.DataFrame.from_dict(data)

# Merge positive and negative examples, add a polarity column and shuffle.
def load_dataset(directory):
pos_df = load_directory_data(os.path.join(directory, "pos"))
neg_df = load_directory_data(os.path.join(directory, "neg"))
pos_df["polarity"] = 1
neg_df["polarity"] = 0
return pd.concat([pos_df, neg_df]).sample(frac=1).reset_index(drop=True)

# Download and process the dataset files.
def download_and_load_datasets(force_download=False):
dataset = tf.keras.utils.get_file(
fname="aclImdb.tar.gz",
origin="http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz",
extract=True)

train_df = load_dataset(os.path.join(os.path.dirname(dataset),
"aclImdb", "train"))
test_df = load_dataset(os.path.join(os.path.dirname(dataset),
"aclImdb", "test"))

return train_df, test_df

# Reduce logging output.
tf.logging.set_verbosity(tf.logging.ERROR)

train_df, test_df = download_and_load_datasets()
train_df.head()


# Training input on the whole training set with no limit on training epochs.
train_input_fn = tf.estimator.inputs.pandas_input_fn(
train_df, train_df["polarity"], num_epochs=None, shuffle=True)

# Prediction on the whole training set.
predict_train_input_fn = tf.estimator.inputs.pandas_input_fn(
train_df, train_df["polarity"], shuffle=False)
# Prediction on the test set.
predict_test_input_fn = tf.estimator.inputs.pandas_input_fn(
test_df, test_df["polarity"], shuffle=False)


embedded_text_feature_column = hub.text_embedding_column(
key="sentence",
module_spec="https://tfhub.dev/google/nnlm-en-dim128/1")


estimator = tf.estimator.DNNClassifier(
hidden_units=[500, 100],
feature_columns=[embedded_text_feature_column],
n_classes=2,
optimizer=tf.train.AdagradOptimizer(learning_rate=0.003))

# Training for 1,000 steps means 128,000 training examples with the default
# batch size. This is roughly equivalent to 5 epochs since the training dataset
# contains 25,000 examples.
estimator.train(input_fn=train_input_fn, steps=1000);

train_eval_result = estimator.evaluate(input_fn=predict_train_input_fn)
test_eval_result = estimator.evaluate(input_fn=predict_test_input_fn)

print "Training set accuracy: {accuracy}".format(**train_eval_result)
print "Test set accuracy: {accuracy}".format(**test_eval_result)

目前,如果我运行上面的脚本,它会重新训练完整的模型。我想重用该模型并将其输出为我拥有的一些示例文本。我怎么能这样做?

我尝试了以下方法来保存:

sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.save(sess, 'test-model')

但这会引发错误,提示 Value Error: No variables to save

最佳答案

您只需将 model_dir 参数传递给 Estimator 实例和 tf.estimator.RunConfig 实例,即可对保存/加载的 Estimator 模型进行训练和预测传递给预制估算器的 config 参数(因为关于 Tensorflow 1.4——仍然适用于 Tensorflow 1.12):

        model_path = '/path/to/model'
run_config = tf.estimator.RunConfig(model_dir=model_path,
tf_random_seed=72, #Default=None
save_summary_steps=100,
# save_checkpoints_steps=_USE_DEFAULT, #Default=1000
# save_checkpoints_secs=_USE_DEFAULT, #Default=60
session_config=None,
keep_checkpoint_max=12, #Default=5
keep_checkpoint_every_n_hours=10000,
log_step_count_steps=100,
train_distribute=None,
device_fn=None,
protocol=None,
eval_distribute=None,
experimental_distribute=None)
classifier = tf.estimator.DNNLinearCombinedClassifier(
config=run_config,
model_dir=model_path,
...
)

然后您将能够调用 classifier.train()classifier.predict(),跳过 classifier.train 重新运行脚本()调用,再次调用classifier.predict()得到相同的结果。

这适用于使用 hub.text_embedding_column 特征列,并且在使用 categorical_column_with_identityembedding_column 特征列时手动保存/恢复 VocabularyProcessor 词典。

关于python - 如何在 tensorflow 中保存文本分类模型?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50475348/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com