- html - 出于某种原因,IE8 对我的 Sass 文件中继承的 html5 CSS 不友好?
- JMeter 在响应断言中使用 span 标签的问题
- html - 在 :hover and :active? 上具有不同效果的 CSS 动画
- html - 相对于居中的 html 内容固定的 CSS 重复背景?
TensorFlow 2.0(GPU 版)、CUDA 10.0、NVIDIA GeForce RTX 2060、Windows 10 (1903)
我在做文本分类。我使用字符嵌入和 LSTM 将每个单词编码为一个向量。这是我的模型:
import tensorflow as tf
from tensorflow.keras.models import Sequential, Model, load_model, model_from_json
from tensorflow.keras.layers import Input, Embedding, LSTM, TimeDistributed, Multiply, Masking, Concatenate, Dense
from tensorflow.keras.utils import to_categorical
import numpy as np
import re
.
target_word_length = 15
target_sentence_length = 100
mask_embedding = 27
def NetEncoder(string):
d = dict(zip('abcdefghijklmnopqrstuvwxyz', range(1,27)))
text_enc = []
for w in re.findall('[a-z]+', string.lower()):
word_enc = []
for c in w:
word_enc.append(d[c])
if len(word_enc) < target_word_length:
for i in range(target_word_length - len(word_enc)):
word_enc.append(0)
text_enc.append(word_enc)
const = []
for i in range(target_word_length):
const.append(0)
const[0] = mask_embedding
sentence_length = len(text_enc)
if sentence_length < target_sentence_length:
for i in range(target_sentence_length - sentence_length):
text_enc.append(const)
mask = []
for i in range(target_sentence_length):
if i < sentence_length:
mask.append([1])
else:
mask.append([0])
return [np.array(text_enc, dtype=np.float32), np.array(mask, dtype=np.float32)]
.
emb = Sequential([
Embedding(28, 3, mask_zero=True, input_shape=(target_word_length,)),
LSTM(16, return_sequences=False)
])
input_1 = Input(shape=(target_sentence_length, target_word_length))
mask_1 = Input(shape=(target_sentence_length, 1))
net_1 = TimeDistributed(emb)(input_1)
net_1 = Multiply()([net_1, mask_1])
net_1 = Masking(mask_value=np.full(16, 0.))(net_1)
net_1 = LSTM(64, return_sequences=False)(net_1)
input_2 = Input(shape=(target_sentence_length, target_word_length))
mask_2 = Input(shape=(target_sentence_length, 1))
net_2 = TimeDistributed(emb)(input_2)
net_2 = Multiply()([net_2, mask_2])
net_2 = Masking(mask_value=np.full(16, 0.))(net_2)
net_2 = LSTM(64, return_sequences=False)(net_2)
cat = Concatenate()([net_1, net_2])
clf = Dense(32, activation='relu')(cat)
clf = Dense(2, activation='softmax')(clf)
model = Model(inputs=[input_1, mask_1, input_2, mask_2], outputs=clf)
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
[input_1, mask_1] = NetEncoder('Hello Max')
[input_2, mask_2] = NetEncoder('Hello Max')
Y = to_categorical([0], num_classes=2)
model.fit([[input_1], [mask_1], [input_2], [mask_2]], Y, epochs=10)
训练时间:每个样本 23 毫秒
Train on 1 samplesEpoch 1/104/1 [==============================================] - 16s 4s/sample - loss: 0.6840 - accuracy: 0.7500Epoch 2/104/1 [==============================================] - 0s 23ms/sample - loss: 0.6618 - accuracy: 1.0000Epoch 3/104/1 [==============================================] - 0s 23ms/sample - loss: 0.6335 - accuracy: 1.0000
Prediction time: 63 ms per sample
%%timemodel.predict([[input_1], [mask_1], [input_2], [mask_2]])
Wall time: 63 ms
array([[1.0000000e+00, 2.6085422e-08]], dtype=float32)
Save and load model (#1)
model.save('model.h5', save_format='tf', include_optimizer=False)
model_1 = load_model('model.h5')
预测时间更短:37 与 63 毫秒。但输出略有不同。
%%timemodel_1.predict([[input_1], [mask_1], [input_2], [mask_2]])
Wall time: 37.4 ms
array([[1.000000e+00, 2.608547e-08]], dtype=float32)
model_1.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model_1.fit([[input_1], [mask_1], [input_2], [mask_2]], Y, epochs=10)
训练速度慢 6 倍:每个样本 139 毫秒对 23 毫秒!
Train on 1 samplesEpoch 1/104/1 [==============================================] - 4s 1s/sample - loss: 0.0000e+00 - accuracy: 1.0000Epoch 2/104/1 [==============================================] - 1s 133ms/sample - loss: 0.0000e+00 - accuracy: 1.0000Epoch 3/104/1 [==============================================] - 1s 139ms/sample - loss: 0.0000e+00 - accuracy: 1.0000
Save and load model (#2)
model.save_weights('model_V2.h5')
model_json = model.to_json()
with open('model_V2.json', "w") as json_file:
json_file.write(model_json)
json_file.close()
json_file = open("model_V2.json", 'r')
loaded_model_json = json_file.read()
json_file.close()
model_2 = model_from_json(loaded_model_json)
#model_2.load_weights("model_V2.h5")
model_2.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model_2.fit([[input_1], [mask_1], [input_2], [mask_2]], Y, epochs=10)
训练时间:每个样本 149 与 23 毫秒!
Train on 1 samplesEpoch 1/104/1 [==============================================] - 4s 1s/sample - loss: 0.6805 - accuracy: 0.7500Epoch 2/104/1 [==============================================] - 1s 145ms/sample - loss: 0.6526 - accuracy: 1.0000Epoch 3/104/1 [==============================================] - 1s 149ms/sample - loss: 0.6186 - accuracy: 1.0000
问题是什么?
最佳答案
使用 mask 时会发生这种情况。我不知道原因,但是,当我在嵌入层中设置 mask_zero = True
时,我的训练也很慢。
关于python - Keras:加载模型的训练速度慢了 6 倍,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58220535/
我的应用程序中有以下查询: SELECT a.*, f.* FROM flights_database f JOIN airports a ON f.airport = a.airportNameCl
我们在使用 MySQL(以及 MariaDB)时遇到了一个奇怪的问题。一个简单的数据库,有 2 个表(InnoDB 引擎),都包含(以及其他一些)3 或 4 个带有 XML 数据的文本列。大小为 1-
我在 MySQL 上的执行路径上遇到问题,导致查询缓慢且不一致。这是一个全新的现象。我们还有其他具有完全相同(好吧,尽可能接近)设置的表,这很好,但出于某种原因,现在创建新表会遇到这个缓慢/不一致的问
我使用 Eclipse Marketplace 的下载速度始终非常慢(现在从 http://download.eclipse.org 开始,下载速度为 3 MB/s,下载速度为 25 kB/s),这使
我正在开发一个 Qt Creator 项目,其中包含大量头文件(点云库、Boost 等)。例如。 Boost 有大约 9000 个头文件。现在看来,包含的数量确实减慢了 IDE。代码完成很慢,大约。
我在一个项目中使用 document.elementFromPoint,它看起来很慢。 100,000 次迭代需要 7051 毫秒。 document.getElementsByTagName("*"
我有一个 tableView ,每行有四个图表,大约 20 行。当我尝试滚动表格时,我将删除现有图表并为每一行构建新图表。 此操作使 TableView 的滚动非常慢。任何使滚动速度更快以及加载新图表
我有一个如下所示的数据框: date,time,metric_x 2016-02-27,00:00:28.0000000,31 2016-02-27,00:01:19.0000000,40 2016-
TLDR:我的微调器瞬间显示了错误的颜色。 我的微调器有问题。每当我运行应用程序时,如果 Activity 没有缓存在内存中,它有时会滞后。在我可以将其设置为正确的颜色之前,文本是默认颜色(如黑色)。
我在使用 SELECT COUNT(*) 对大型表进行 SQLite 时遇到性能问题。 由于我还没有收到可用的答案并且我做了一些进一步的测试,所以我编辑了我的问题以纳入我的新发现。 我有 2 个表:
当音频因加载数据不足(速度慢)而暂停时,我可以使用什么事件? 就像: $audio.on('suspendToLoading',function(){ alert('loading...');
这是我的 MATLAB 程序的分析模拟运行结果。我需要运行此模拟数十万次(约 100,000 次)。 因此我需要一种更快的方法来读取 Excel 文件。 规范:Excel 文件由 10000x2 个单
每当与数据透视表交互时,Excel 都非常慢,这让我感到非常困难。添加/删除字段、更改过滤器或切片器,所有这些都需要 Excel 卡住几分钟才能响应。 看来生成的 MDX 效率极低。我可以理解他们必须
我正在使用 Entity Framework 来检索大型数据集。 数据集有parent/child关系,我需要和parent同时带回child信息。 我发现 EF 最初发送一个查询以获取父对象列表,然
我有一个使用 gridview 的应用程序,它非常慢。 添加 Trace=true 后对于页面,我追踪了时间花费的地方:在 GridView 上调用 BindData() 时。 GridView连接到
我编写了一个小代码来使用 QtCreator 测试 QGraphicsView 的功能。 代码非常简单,只是创建了一个继承自 QGraphicsView 的类,上面有一个 QGraphicsScene
后期以补充作品的形式自动加入成员(member)。数据库速度较慢。有没有办法加快这个速度?用户无所谓..除了自动补码之外如何停止写?(自动补码;城市输入。成员(member)表格位于。) 注册.php
我有一个文件 (insert.sql),其中有 250k 行,没有键,没有索引: INSERT `project_383`.`entity_metrics_build_1` VALUES ('d402
我最近开发了一个应用程序(java 8、spring-boot、hibernate、maven),它通过 REST API 公开数据库。我遇到的问题是数据库调用很慢(3000 毫秒以上),只是为了获取
我正在尝试在 Canvas 上使用旋转,我现在有了它,因此每个对象都有自己的旋转。如果没有它们旋转,我可以在一台非常低端的计算机上在屏幕上显示大约 400 个对象,在一台正常库存的计算机上显示近 20
我是一名优秀的程序员,十分优秀!