- ubuntu12.04环境下使用kvm ioctl接口实现最简单的虚拟机
- Ubuntu 通过无线网络安装Ubuntu Server启动系统后连接无线网络的方法
- 在Ubuntu上搭建网桥的方法
- ubuntu 虚拟机上网方式及相关配置详解
CFSDN坚持开源创造价值,我们致力于搭建一个资源共享平台,让每一个IT人在这里找到属于你的精彩世界.
这篇CFSDN的博客文章pytorch实现用CNN和LSTM对文本进行分类方式由作者收集整理,如果你对这篇文章有兴趣,记得点赞哟.
model.py:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
|
#!/usr/bin/python
# -*- coding: utf-8 -*-
import
torch
from
torch
import
nn
import
numpy as np
from
torch.autograd
import
Variable
import
torch.nn.functional as F
class
TextRNN(nn.Module):
"""文本分类,RNN模型"""
def
__init__(
self
):
super
(TextRNN,
self
).__init__()
# 三个待输入的数据
self
.embedding
=
nn.Embedding(
5000
,
64
)
# 进行词嵌入
# self.rnn = nn.LSTM(input_size=64, hidden_size=128, num_layers=2, bidirectional=True)
self
.rnn
=
nn.GRU(input_size
=
64
, hidden_size
=
128
, num_layers
=
2
, bidirectional
=
True
)
self
.f1
=
nn.Sequential(nn.Linear(
256
,
128
),
nn.Dropout(
0.8
),
nn.ReLU())
self
.f2
=
nn.Sequential(nn.Linear(
128
,
10
),
nn.Softmax())
def
forward(
self
, x):
x
=
self
.embedding(x)
x,_
=
self
.rnn(x)
x
=
F.dropout(x,p
=
0.8
)
x
=
self
.f1(x[:,
-
1
,:])
return
self
.f2(x)
class
TextCNN(nn.Module):
def
__init__(
self
):
super
(TextCNN,
self
).__init__()
self
.embedding
=
nn.Embedding(
5000
,
64
)
self
.conv
=
nn.Conv1d(
64
,
256
,
5
)
self
.f1
=
nn.Sequential(nn.Linear(
256
*
596
,
128
),
nn.ReLU())
self
.f2
=
nn.Sequential(nn.Linear(
128
,
10
),
nn.Softmax())
def
forward(
self
, x):
x
=
self
.embedding(x)
x
=
x.detach().numpy()
x
=
np.transpose(x,[
0
,
2
,
1
])
x
=
torch.Tensor(x)
x
=
Variable(x)
x
=
self
.conv(x)
x
=
x.view(
-
1
,
256
*
596
)
x
=
self
.f1(x)
return
self
.f2(x)
|
train.py:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
|
# coding: utf-8
from
__future__
import
print_function
import
torch
from
torch
import
nn
from
torch
import
optim
from
torch.autograd
import
Variable
import
os
import
numpy as np
from
model
import
TextRNN,TextCNN
from
cnews_loader
import
read_vocab, read_category, batch_iter, process_file, build_vocab
base_dir
=
'cnews'
train_dir
=
os.path.join(base_dir,
'cnews.train.txt'
)
test_dir
=
os.path.join(base_dir,
'cnews.test.txt'
)
val_dir
=
os.path.join(base_dir,
'cnews.val.txt'
)
vocab_dir
=
os.path.join(base_dir,
'cnews.vocab.txt'
)
def
train():
x_train, y_train
=
process_file(train_dir, word_to_id, cat_to_id,
600
)
#获取训练数据每个字的id和对应标签的oe-hot形式
x_val, y_val
=
process_file(val_dir, word_to_id, cat_to_id,
600
)
#使用LSTM或者CNN
model
=
TextRNN()
# model = TextCNN()
#选择损失函数
Loss
=
nn.MultiLabelSoftMarginLoss()
# Loss = nn.BCELoss()
# Loss = nn.MSELoss()
optimizer
=
optim.Adam(model.parameters(),lr
=
0.001
)
best_val_acc
=
0
for
epoch
in
range
(
1000
):
batch_train
=
batch_iter(x_train, y_train,
100
)
for
x_batch, y_batch
in
batch_train:
x
=
np.array(x_batch)
y
=
np.array(y_batch)
x
=
torch.LongTensor(x)
y
=
torch.Tensor(y)
# y = torch.LongTensor(y)
x
=
Variable(x)
y
=
Variable(y)
out
=
model(x)
loss
=
Loss(out,y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
accracy
=
np.mean((torch.argmax(out,
1
)
=
=
torch.argmax(y,
1
)).numpy())
#对模型进行验证
if
(epoch
+
1
)
%
20
=
=
0
:
batch_val
=
batch_iter(x_val, y_val,
100
)
for
x_batch, y_batch
in
batch_train:
x
=
np.array(x_batch)
y
=
np.array(y_batch)
x
=
torch.LongTensor(x)
y
=
torch.Tensor(y)
# y = torch.LongTensor(y)
x
=
Variable(x)
y
=
Variable(y)
out
=
model(x)
loss
=
Loss(out, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
accracy
=
np.mean((torch.argmax(out,
1
)
=
=
torch.argmax(y,
1
)).numpy())
if
accracy > best_val_acc:
torch.save(model.state_dict(),
'model_params.pkl'
)
best_val_acc
=
accracy
print
(accracy)
if
__name__
=
=
'__main__'
:
#获取文本的类别及其对应id的字典
categories, cat_to_id
=
read_category()
#获取训练文本中所有出现过的字及其所对应的id
words, word_to_id
=
read_vocab(vocab_dir)
#获取字数
vocab_size
=
len
(words)
train()
|
test.py:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
|
# coding: utf-8
from
__future__
import
print_function
import
os
import
tensorflow.contrib.keras as kr
import
torch
from
torch
import
nn
from
cnews_loader
import
read_category, read_vocab
from
model
import
TextRNN
from
torch.autograd
import
Variable
import
numpy as np
try
:
bool
(
type
(
unicode
))
except
NameError:
unicode
=
str
base_dir
=
'cnews'
vocab_dir
=
os.path.join(base_dir,
'cnews.vocab.txt'
)
class
TextCNN(nn.Module):
def
__init__(
self
):
super
(TextCNN,
self
).__init__()
self
.embedding
=
nn.Embedding(
5000
,
64
)
self
.conv
=
nn.Conv1d(
64
,
256
,
5
)
self
.f1
=
nn.Sequential(nn.Linear(
152576
,
128
),
nn.ReLU())
self
.f2
=
nn.Sequential(nn.Linear(
128
,
10
),
nn.Softmax())
def
forward(
self
, x):
x
=
self
.embedding(x)
x
=
x.detach().numpy()
x
=
np.transpose(x,[
0
,
2
,
1
])
x
=
torch.Tensor(x)
x
=
Variable(x)
x
=
self
.conv(x)
x
=
x.view(
-
1
,
152576
)
x
=
self
.f1(x)
return
self
.f2(x)
class
CnnModel:
def
__init__(
self
):
self
.categories,
self
.cat_to_id
=
read_category()
self
.words,
self
.word_to_id
=
read_vocab(vocab_dir)
self
.model
=
TextCNN()
self
.model.load_state_dict(torch.load(
'model_params.pkl'
))
def
predict(
self
, message):
# 支持不论在python2还是python3下训练的模型都可以在2或者3的环境下运行
content
=
unicode
(message)
data
=
[
self
.word_to_id[x]
for
x
in
content
if
x
in
self
.word_to_id]
data
=
kr.preprocessing.sequence.pad_sequences([data],
600
)
data
=
torch.LongTensor(data)
y_pred_cls
=
self
.model(data)
class_index
=
torch.argmax(y_pred_cls[
0
]).item()
return
self
.categories[class_index]
class
RnnModel:
def
__init__(
self
):
self
.categories,
self
.cat_to_id
=
read_category()
self
.words,
self
.word_to_id
=
read_vocab(vocab_dir)
self
.model
=
TextRNN()
self
.model.load_state_dict(torch.load(
'model_rnn_params.pkl'
))
def
predict(
self
, message):
# 支持不论在python2还是python3下训练的模型都可以在2或者3的环境下运行
content
=
unicode
(message)
data
=
[
self
.word_to_id[x]
for
x
in
content
if
x
in
self
.word_to_id]
data
=
kr.preprocessing.sequence.pad_sequences([data],
600
)
data
=
torch.LongTensor(data)
y_pred_cls
=
self
.model(data)
class_index
=
torch.argmax(y_pred_cls[
0
]).item()
return
self
.categories[class_index]
if
__name__
=
=
'__main__'
:
model
=
CnnModel()
# model = RnnModel()
test_demo
=
[
'湖人助教力助科比恢复手感 他也是阿泰的精神导师新浪体育讯记者戴高乐报道 上赛季,科比的右手食指遭遇重创,他的投篮手感也因此大受影响。不过很快科比就调整了自己的投篮手型,并通过这一方式让自己的投篮命中率回升。而在这科比背后,有一位特别助教对科比帮助很大,他就是查克·珀森。珀森上赛季担任湖人的特别助教,除了帮助科比调整投篮手型之外,他的另一个重要任务就是担任阿泰的精神导师。来到湖人队之后,阿泰收敛起了暴躁的脾气,成为湖人夺冠路上不可或缺的一员,珀森的“心灵按摩”功不可没。经历了上赛季的成功之后,珀森本赛季被“升职”成为湖人队的全职助教,每场比赛,他都会坐在球场边,帮助禅师杰克逊一起指挥湖人球员在场上拼杀。对于珀森的工作,禅师非常欣赏,“查克非常善于分析问题,”菲尔·杰克逊说,“他总是在寻找问题的答案,同时也在找造成这一问题的原因,这是我们都非常乐于看到的。我会在平时把防守中出现的一些问题交给他,然后他会通过组织球员练习找到解决的办法。他在球员时代曾是一名很好的外线投手,不过现在他与内线球员的配合也相当不错。'
,
'弗老大被裁美国媒体看热闹“特权”在中国像蠢蛋弗老大要走了。虽然他只在首钢男篮效力了13天,而且表现毫无亮点,大大地让球迷和俱乐部失望了,但就像中国人常说的“好聚好散”,队友还是友好地与他告别,俱乐部与他和平分手,球迷还请他留下了在北京的最后一次签名。相比之下,弗老大的同胞美国人却没那么“宽容”。他们嘲讽这位NBA前巨星的英雄迟暮,批评他在CBA的业余表现,还惊讶于中国人的“大方”。今天,北京首钢俱乐部将与弗朗西斯继续商讨解约一事。从昨日的进展来看,双方可以做到“买卖不成人意在”,但回到美国后,恐怕等待弗朗西斯的就没有这么轻松的环境了。进展@北京昨日与队友告别 最后一次为球迷签名弗朗西斯在13天里为首钢队打了4场比赛,3场的得分为0,只有一场得了2分。昨天是他来到北京的第14天,虽然他与首钢还未正式解约,但双方都明白“缘分已尽”。下午,弗朗西斯来到首钢俱乐部与队友们告别。弗朗西斯走到队友身边,依次与他们握手拥抱。“你们都对我很好,安排的条件也很好,我很喜欢这支球队,想融入你们,但我现在真的很不适应。希望你们'
]
for
i
in
test_demo:
print
(i,
":"
,model.predict(i))
|
cnews_loader.py:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
|
# coding: utf-8
import
sys
from
collections
import
Counter
import
numpy as np
import
tensorflow.contrib.keras as kr
if
sys.version_info[
0
] >
2
:
is_py3
=
True
else
:
reload
(sys)
sys.setdefaultencoding(
"utf-8"
)
is_py3
=
False
def
native_word(word, encoding
=
'utf-8'
):
"""如果在python2下面使用python3训练的模型,可考虑调用此函数转化一下字符编码"""
if
not
is_py3:
return
word.encode(encoding)
else
:
return
word
def
native_content(content):
if
not
is_py3:
return
content.decode(
'utf-8'
)
else
:
return
content
def
open_file(filename, mode
=
'r'
):
"""
常用文件操作,可在python2和python3间切换.
mode: 'r' or 'w' for read or write
"""
if
is_py3:
return
open
(filename, mode, encoding
=
'utf-8'
, errors
=
'ignore'
)
else
:
return
open
(filename, mode)
def
read_file(filename):
"""读取文件数据"""
contents, labels
=
[], []
with open_file(filename) as f:
for
line
in
f:
try
:
label, content
=
line.strip().split(
'\t'
)
if
content:
contents.append(
list
(native_content(content)))
labels.append(native_content(label))
except
:
pass
return
contents, labels
def
build_vocab(train_dir, vocab_dir, vocab_size
=
5000
):
"""根据训练集构建词汇表,存储"""
data_train, _
=
read_file(train_dir)
all_data
=
[]
for
content
in
data_train:
all_data.extend(content)
counter
=
Counter(all_data)
count_pairs
=
counter.most_common(vocab_size
-
1
)
words, _
=
list
(
zip
(
*
count_pairs))
# 添加一个 <PAD> 来将所有文本pad为同一长度
words
=
[
'<PAD>'
]
+
list
(words)
open_file(vocab_dir, mode
=
'w'
).write(
'\n'
.join(words)
+
'\n'
)
def
read_vocab(vocab_dir):
"""读取词汇表"""
# words = open_file(vocab_dir).read().strip().split('\n')
with open_file(vocab_dir) as fp:
# 如果是py2 则每个值都转化为unicode
words
=
[native_content(_.strip())
for
_
in
fp.readlines()]
word_to_id
=
dict
(
zip
(words,
range
(
len
(words))))
return
words, word_to_id
def
read_category():
"""读取分类目录,固定"""
categories
=
[
'体育'
,
'财经'
,
'房产'
,
'家居'
,
'教育'
,
'科技'
,
'时尚'
,
'时政'
,
'游戏'
,
'娱乐'
]
categories
=
[native_content(x)
for
x
in
categories]
cat_to_id
=
dict
(
zip
(categories,
range
(
len
(categories))))
return
categories, cat_to_id
def
to_words(content, words):
"""将id表示的内容转换为文字"""
return
''.join(words[x]
for
x
in
content)
def
process_file(filename, word_to_id, cat_to_id, max_length
=
600
):
"""将文件转换为id表示"""
contents, labels
=
read_file(filename)
#读取训练数据的每一句话及其所对应的类别
data_id, label_id
=
[], []
for
i
in
range
(
len
(contents)):
data_id.append([word_to_id[x]
for
x
in
contents[i]
if
x
in
word_to_id])
#将每句话id化
label_id.append(cat_to_id[labels[i]])
#每句话对应的类别的id
#
# # 使用keras提供的pad_sequences来将文本pad为固定长度
x_pad
=
kr.preprocessing.sequence.pad_sequences(data_id, max_length)
y_pad
=
kr.utils.to_categorical(label_id, num_classes
=
len
(cat_to_id))
# 将标签转换为one-hot表示
#
return
x_pad, y_pad
def
batch_iter(x, y, batch_size
=
64
):
"""生成批次数据"""
data_len
=
len
(x)
num_batch
=
int
((data_len
-
1
)
/
batch_size)
+
1
indices
=
np.random.permutation(np.arange(data_len))
x_shuffle
=
x[indices]
y_shuffle
=
y[indices]
for
i
in
range
(num_batch):
start_id
=
i
*
batch_size
end_id
=
min
((i
+
1
)
*
batch_size, data_len)
yield
x_shuffle[start_id:end_id], y_shuffle[start_id:end_id]
|
以上这篇pytorch实现用CNN和LSTM对文本进行分类方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我.
原文链接:https://blog.csdn.net/weixin_38241876/article/details/90606639 。
最后此篇关于pytorch实现用CNN和LSTM对文本进行分类方式的文章就讲到这里了,如果你想了解更多关于pytorch实现用CNN和LSTM对文本进行分类方式的内容请搜索CFSDN的文章或继续浏览相关文章,希望大家以后支持我的博客! 。
我喜欢 smartcase,也喜欢 * 和 # 搜索命令。但我更希望 * 和 # 搜索命令区分大小写,而/和 ?搜索命令遵循 smartcase 启发式。 是否有隐藏在某个地方我还没有找到的设置?我宁
关闭。这个问题是off-topic .它目前不接受答案。 想改进这个问题? Update the question所以它是on-topic对于堆栈溢出。 10年前关闭。 Improve this qu
从以下网站,我找到了执行java AD身份验证的代码。 http://java2db.com/jndi-ldap-programming/solution-to-sslhandshakeexcepti
似乎 melt 会使用 id 列和堆叠的测量变量 reshape 您的数据框,然后通过转换让您执行聚合。 ddply,从 plyr 包看起来非常相似..你给它一个数据框,几个用于分组的列变量和一个聚合
我的问题是关于 memcached。 Facebook 使用 memcached 作为其结构化数据的缓存,以减少用户的延迟。他们在 Linux 上使用 UDP 优化了 memcached 的性能。 h
在 Camel route ,我正在使用 exec 组件通过 grep 进行 curl ,但使用 ${HOSTNAME} 的 grep 无法正常工作,下面是我的 Camel 路线。请在这方面寻求帮助。
我正在尝试执行相当复杂的查询,在其中我可以排除与特定条件集匹配的项目。这是一个 super 简化的模型来解释我的困境: class Thing(models.Model) user = mod
我正在尝试执行相当复杂的查询,我可以在其中排除符合特定条件集的项目。这里有一个 super 简化的模型来解释我的困境: class Thing(models.Model) user = mod
我发现了很多嵌入/内容项目的旧方法,并且我遵循了在这里找到的最新方法(我假设):https://blog.angular-university.io/angular-ng-content/ 我正在尝试
我正在寻找如何使用 fastify-nextjs 启动 fastify-cli 的建议 我曾尝试将代码简单地添加到建议的位置,但它不起作用。 'use strict' const path = req
我正在尝试将振幅 js 与 React 和 Gatsby 集成。做 gatsby developer 时一切看起来都不错,因为它发生在浏览器中,但是当我尝试 gatsby build 时,我收到以下错
我试图避免过度执行空值检查,但同时我想在需要使代码健壮的时候进行空值检查。但有时我觉得它开始变得如此防御,因为我没有实现 API。然后我避免了一些空检查,但是当我开始单元测试时,它开始总是等待运行时异
尝试进行包含一些 NOT 的 Kibana 搜索,但获得包含 NOT 的结果,因此猜测我的语法不正确: "chocolate" AND "milk" AND NOT "cow" AND NOT "tr
我正在使用开源代码共享包在 iOS 中进行 facebook 集成,但收到错误“FT_Load_Glyph failed: glyph 65535: error 6”。我在另一台 mac 机器上尝试了
我正在尝试估计一个标准的 tobit 模型,该模型被审查为零。 变量是 因变量 : 幸福 自变量 : 城市(芝加哥,纽约), 性别(男,女), 就业(0=失业,1=就业), 工作类型(失业,蓝色,白色
我有一个像这样的项目布局 样本/ 一种/ 源/ 主要的/ java / java 资源/ .jpg 乙/ 源/ 主要的/ java / B.java 资源/ B.jpg 构建.gradle 设置.gr
如何循环遍历数组中的多个属性以及如何使用map函数将数组中的多个属性显示到网页 import React, { Component } from 'react'; import './App.css'
我有一个 JavaScript 函数,它进行 AJAX 调用以返回一些数据,该调用是在选择列表更改事件上触发的。 我尝试了多种方法来在等待时显示加载程序,因为它当前暂停了选择列表,从客户的 Angul
可能以前问过,但找不到。 我正在用以下形式写很多语句: if (bar.getFoo() != null) { this.foo = bar.getFoo(); } 我想到了三元运算符,但我认
我有一个表单,在将其发送到 PHP 之前我正在执行一些验证 JavaScript,验证后的 JavaScript 函数会发布用户在 中输入的文本。页面底部的标签;然而,此消息显示短暂,然后消失...
我是一名优秀的程序员,十分优秀!