gpt4 book ai didi

python - K.ctc_batch_cost() 中 input_length 的意思是什么

转载 作者:太空宇宙 更新时间:2023-11-04 06:40:19 24 4
gpt4 key购买 nike

我下载了一个使用Keras的ocr代码,它应用了CRNN网络并使用CTC loss作为损失函数。然而,我对 CTC 损失真的很陌生,只是在使用 K.ctc_batch_cost() 时遇到了麻烦,尤其是 input_length 的含义。在keras的文档中,

Arguments of tf.keras.backend.ctc_batch_cost( y_true, y_pred, input_length, label_length )

  1. y_true:包含真值标签的张量(样本,max_string_length)。
  2. y_pred:包含预测或 softmax 输出的张量(samples、time_steps、num_categories)。
  3. input_length:包含 y_pred 中每个批处理项目的序列长度的张量 (samples, 1)。
  4. label_length:张量(样本,1)包含 y_true 中每个批处理项目的序列长度。

    但是,我的问题是input_length 是什么意思?那是 LSTM 输出的维度吗?

最佳答案

一个示例的 CTC 损失是在二维数组 (T,C) 上计算的。 C 必须等于字符数 + 1(空白字符)。 C包含字符在某个时间戳的概率分布。 T 将是时间戳的数量。

T 的长度应为 2* max_string_length。所有可能的长度为 T 的 y_true 编码将用于负对数损失计算。

通常是上一层输出的shape。

关于python - K.ctc_batch_cost() 中 input_length 的意思是什么,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55160939/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com