- ubuntu12.04环境下使用kvm ioctl接口实现最简单的虚拟机
- Ubuntu 通过无线网络安装Ubuntu Server启动系统后连接无线网络的方法
- 在Ubuntu上搭建网桥的方法
- ubuntu 虚拟机上网方式及相关配置详解
CFSDN坚持开源创造价值,我们致力于搭建一个资源共享平台,让每一个IT人在这里找到属于你的精彩世界.
这篇CFSDN的博客文章解决pytorch rnn 变长输入序列的问题由作者收集整理,如果你对这篇文章有兴趣,记得点赞哟.
输入数据是长度不固定的序列数据,主要讲解两个部分 。
1、Data.DataLoader的collate_fn用法,以及按batch进行padding数据 。
2、pack_padded_sequence和pad_packed_sequence来处理变长序列 。
Dataloader的collate_fn参数,定义数据处理和合并成batch的方式.
由于pack_padded_sequence用到的tensor必须按照长度从大到小排过序的,所以在Collate_fn中,需要完成两件事,一是把当前batch的样本按照当前batch最大长度进行padding,二是将padding后的数据从大到小进行排序.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
|
def
pad_tensor(vec, pad):
"""
args:
vec - tensor to pad
pad - the size to pad to
return:
a new tensor padded to 'pad'
"""
return
torch.cat([vec, torch.zeros(pad
-
len
(vec), dtype
=
torch.
float
)], dim
=
0
).data.numpy()
class
Collate:
"""
a variant of callate_fn that pads according to the longest sequence in
a batch of sequences
"""
def
__init__(
self
):
pass
def
_collate(
self
, batch):
"""
args:
batch - list of (tensor, label)
reutrn:
xs - a tensor of all examples in 'batch' before padding like:
'''
[tensor([1,2,3,4]),
tensor([1,2]),
tensor([1,2,3,4,5])]
'''
ys - a LongTensor of all labels in batch like:
'''
[1,0,1]
'''
"""
xs
=
[torch.FloatTensor(v[
0
])
for
v
in
batch]
ys
=
torch.LongTensor([v[
1
]
for
v
in
batch])
# 获得每个样本的序列长度
seq_lengths
=
torch.LongTensor([v
for
v
in
map
(
len
, xs)])
max_len
=
max
([
len
(v)
for
v
in
xs])
# 每个样本都padding到当前batch的最大长度
xs
=
torch.FloatTensor([pad_tensor(v, max_len)
for
v
in
xs])
# 把xs和ys按照序列长度从大到小排序
seq_lengths, perm_idx
=
seq_lengths.sort(
0
, descending
=
True
)
xs
=
xs[perm_idx]
ys
=
ys[perm_idx]
return
xs, seq_lengths, ys
def
__call__(
self
, batch):
return
self
._collate(batch)
|
定义完collate类以后,在DataLoader中直接使用 。
1
|
train_data
=
Data.DataLoader(dataset
=
train_dataset, batch_size
=
32
, num_workers
=
0
, collate_fn
=
Collate())
|
pack_padded_sequence将一个填充过的变长序列压紧。输入参数包括 。
input(Variable)- 被填充过后的变长序列组成的batch data 。
lengths (list[int]) - 变长序列的原始序列长度 。
batch_first (bool,optional) - 如果是True,input的形状应该是(batch_size,seq_len,input_size) 。
返回值:一个PackedSequence对象,可以直接作为rnn,lstm,gru的传入数据.
用法:
1
2
3
|
from
torch.nn.utils.rnn
import
pack_padded_sequence, pad_packed_sequence
# x是填充过后的batch数据,seq_lengths是每个样本的序列长度
packed_input
=
pack_padded_sequence(x, seq_lengths, batch_first
=
True
)
|
定义了一个单向的LSTM模型,因为处理的是变长序列,forward函数传入的值是一个PackedSequence对象,返回值也是一个PackedSequence对象 。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
|
class
Model(nn.Module):
def
__init__(
self
, in_size, hid_size, n_layer, drop
=
0.1
, bi
=
False
):
super
(Model,
self
).__init__()
self
.lstm
=
nn.LSTM(input_size
=
in_size,
hidden_size
=
hid_size,
num_layers
=
n_layer,
batch_first
=
True
,
dropout
=
drop,
bidirectional
=
bi)
# 分类类别数目为2
self
.fc
=
nn.Linear(in_features
=
hid_size, out_features
=
2
)
def
forward(
self
, x):
'''
:param x: 变长序列时,x是一个PackedSequence对象
:return: PackedSequence对象
'''
# lstm_out: tensor of shape (batch, seq_len, num_directions * hidden_size)
lstm_out, _
=
self
.lstm(x)
return
lstm_out
model
=
Model()
lstm_out
=
model(packed_input)
|
这个操作和pack_padded_sequence()是相反的,把压紧的序列再填充回来。因为前面提到的LSTM模型传入和返回的都是PackedSequence对象,所以我们如果想要把返回的PackedSequence对象转换回Tensor,就需要用到pad_packed_sequence函数.
参数说明:
sequence (PackedSequence) – 将要被填充的 batch 。
batch_first (bool, optional) – 如果为True,返回的数据的形状为(batch_size,seq_len,input_size) 。
返回值: 一个tuple,包含被填充后的序列,和batch中序列的长度列表.
用法:
1
2
|
# 此处lstm_out是一个PackedSequence对象
output, _
=
pad_packed_sequence(lstm_out)
|
返回的output是一个形状为(batch_size,seq_len,input_size)的tensor.
1、pytorch在自定义dataset时,可以在DataLoader的collate_fn参数中定义对数据的变换,操作以及合成batch的方式.
2、处理变长rnn问题时,通过pack_padded_sequence()将填充的batch数据转换成PackedSequence对象,直接传入rnn模型中。通过pad_packed_sequence()来将rnn模型输出的PackedSequence对象转换回相应的Tensor.
补充:pytorch实现不定长输入的RNN / LSTM / GRU 。
As we all know,RNN循环神经网络(及其改进模型LSTM、GRU)可以处理序列的顺序信息,如人类自然语言。但是在实际场景中,我们常常向模型输入一个批次(batch)的数据,这个批次中的每个序列往往不是等长的.
pytorch提供的模型(nn.RNN,nn.LSTM,nn.GRU)是支持可变长序列的处理的,但条件是传入的数据必须按序列长度排序。本文针对以下两种场景提出解决方法.
1、每个样本只有一个序列:(seq,label),其中seq是一个长度不定的序列。则使用pytorch训练时,我们将按列把一个批次的数据输入网络,seq这一列的形状就是(batch_size, seq_len),经过编码层(如word2vec)之后的形状是(batch_size, seq_len, emb_size).
2、情况1的拓展:每个样本有两个(或多个)序列,如(seq1, seq2, label)。这种样本形式在问答系统、推荐系统多见.
定义ImprovedRnn类。与nn.RNN,nn.LSTM,nn.GRU相比,除了此两点【①forward函数多一个参数lengths表示每个seq的长度】【②初始化函数(__init__)第一个参数module必须指定三者之一】外,使用方法完全相同.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
|
import
torch
from
torch
import
nn
class
ImprovedRnn(nn.Module):
def
__init__(
self
, module,
*
args,
*
*
kwargs):
assert
module
in
(nn.RNN, nn.LSTM, nn.GRU)
super
().__init__()
self
.module
=
module(
*
args,
*
*
kwargs)
def
forward(
self
,
input
, lengths):
# input shape(batch_size, seq_len, input_size)
if
not
hasattr
(
self
,
'_flattened'
):
self
.module.flatten_parameters()
setattr
(
self
,
'_flattened'
,
True
)
max_len
=
input
.shape[
1
]
# enforce_sorted=False则自动按lengths排序,并且返回值package.unsorted_indices可用于恢复原顺序
package
=
nn.utils.rnn.pack_padded_sequence(
input
, lengths.cpu(), batch_first
=
self
.module.batch_first, enforce_sorted
=
False
)
result, hidden
=
self
.module(package)
# total_length参数一般不需要,因为lengths列表中一般含最大值。但分布式训练时是将一个batch切分了,故一定要有!
result, lens
=
nn.utils.rnn.pad_packed_sequence(result, batch_first
=
self
.module.batch_first, total_length
=
max_len)
return
result[package.unsorted_indices], hidden
# output shape(batch_size, seq_len, rnn_hidden_size)
|
使用示例:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
|
class
TestNet(nn.Module):
def
__init__(
self
, word_emb, gru_in, gru_out):
super
().__init__()
self
.encode
=
nn.Embedding.from_pretrained(torch.Tensor(word_emb))
self
.rnn
=
ImprovedRnn(nn.RNN, input_size
=
gru_in, hidden_size
=
gru_out,
batch_first
=
True
, bidirectional
=
True
)
def
forward(
self
, seq1, seq1_lengths, seq2, seq2_lengths):
seq1_emb
=
self
.encode(seq1)
seq2_emb
=
self
.encode(seq2)
rnn1, hn
=
self
.rnn(seq1_emb, seq1_lengths)
rnn2, hn
=
self
.rnn(seq2_emb, seq2_lengths)
"""
此处略去rnn1和rnn2的后续计算,当前网络最后计算结果记为prediction
"""
return
prediction
|
以上为个人经验,希望能给大家一个参考,也希望大家多多支持我.
原文链接:https://blog.csdn.net/u011550545/article/details/89529977 。
最后此篇关于解决pytorch rnn 变长输入序列的问题的文章就讲到这里了,如果你想了解更多关于解决pytorch rnn 变长输入序列的问题的内容请搜索CFSDN的文章或继续浏览相关文章,希望大家以后支持我的博客! 。
关闭。这个问题是off-topic .它目前不接受答案。 想要改进这个问题? Update the question所以它是on-topic用于堆栈溢出。 关闭 12 年前。 Improve thi
我有一个动态网格,其中的数据功能需要正常工作,这样我才能逐步复制网格中的数据。假设在第 5 行中,我输入 10,则从第 6 行开始的后续行应从 11 开始读取,依此类推。 如果我转到空白的第一行并输入
我有一个关于我的按钮消失的问题 我已经把一个图像作为我的按钮 用这个函数动画 function example_animate(px) { $('#cont
我有一个具有 Facebook 连接和经典用户名/密码登录的网站。目前,如果用户单击 facebook_connect 按钮,系统即可运行。但是,我想将现有帐户链接到 facebook,因为用户可以选
我有一个正在为 iOS 开发的应用程序,该应用程序执行以下操作 加载和设置注释并启动核心定位和缩放到位置。 map 上有很多注释,从数据加载不会花很长时间,但将它们实际渲染到 map 上需要一段时间。
我被推荐使用 Heroku for Ruby on Rails 托管,到目前为止,我认为我真的会喜欢它。只是想知道是否有人可以帮助我找出问题所在。 我按照那里的说明在该网站上创建应用程序,创建并提交
我看过很多关于 SSL 错误的帖子和信息,我自己也偶然发现了一个。 我正在尝试使用 GlobalSign CA BE 证书通过 Android WebView 访问网页,但出现了不可信错误。 对于大多
我想开始使用 OpenGL 3+ 和 4,但我在使用 Glew 时遇到了问题。我试图将 glew32.lib 包含在附加依赖项中,并且我已将库和 .dll 移动到主文件夹中,因此不应该有任何路径问题。
我已经盯着这两个下载页面的源代码看了一段时间,但我似乎找不到问题。 我有两个下载页面,一个 javascript 可以工作,一个没有。 工作:http://justupload.it/v/lfd7不是
我一直在使用 jQuery,只是尝试在单击链接时替换文本字段以及隐藏/显示内容项。它似乎在 IE 中工作得很好,但我似乎无法让它在 FF 中工作。 我的 jQuery: $(function() {
我正在尝试为 NDK 编译套接字库,但出现以下两个错误: error: 'close' was not declared in this scope 和 error: 'min' is not a m
我正在使用 Selenium 浏览器自动化框架测试网站。在测试过程中,我切换到特定的框架,我们将其称为“frame_1”。后来,我在 Select 类中使用了 deselectAll() 方法。不久之
我正在尝试通过 Python 创建到 Heroku PostgreSQL 数据库的连接。我将 Windows10 与 Python 3.6.8 和 PostgreSQL 9.6 一起使用。 我从“ht
我有一个包含 2 列的数据框,我想根据两列之间的比较创建第三列。 所以逻辑是:第 1 列 val = 3,第 2 列 val = 4,因此新列值什么都没有 第 1 列 val = 3,第 2 列 va
我想知道如何调试 iphone 5 中的 css 问题。 我尝试使用 firelite 插件。但是从纵向旋转到横向时,火石占据了整个屏幕。 有没有其他方法可以调试 iphone 5 中的 css 问题
所以我有点难以理解为什么这不起作用。我正在尝试替换我正在处理的示例站点上的类别复选框。我试图让它做以下事情:未选中时以一种方式出现,悬停时以另一种方式出现(选中或未选中)选中时以第三种方式出现(而不是
Javascript CSS 问题: 我正在使用一个文本框来写入一个 div。我使用以下 javascript 获取文本框来执行此操作: function process_input(){
你好,我很难理解 P、NP 和多项式时间缩减的主题。我试过在网上搜索它并问过我的一些 friend ,但我没有得到任何好的答案。 我想问一个关于这个话题的一般性问题: 设 A,B 为 P 中的语言(或
你好,我一直在研究 https://leetcode.com/problems/2-keys-keyboard/并想到了这个动态规划问题。 您从空白页上的“A”开始,完成后得到一个数字 n,页面上应该
我正在使用 Cocoapods 和 KIF 在 Xcode 服务器上运行持续集成。我已经成功地为一个项目设置了它来报告每次提交。我现在正在使用第二个项目并收到错误: Bot Issue: warnin
我是一名优秀的程序员,十分优秀!