为了账号安全,请及时绑定邮箱和手机立即绑定

keras 中的 CTC 损失实现

keras 中的 CTC 损失实现

繁花如伊 2024-01-27 16:27:08
我正在尝试使用 keras 为我的简化神经网络实现 CTC 损失:  def ctc_lambda_func(args):    y_pred, y_train, input_length, label_length = args     return K.ctc_batch_cost(y_train, y_pred, input_length, label_length)x_train = x_train.reshape(x_train.shape[0],20, 10).astype('float32')input_data = layers.Input(shape=(20,10,))x=layers.Convolution1D(filters=256, kernel_size=3,  padding="same", strides=1, use_bias=False ,activation= 'relu')(input_data)x=layers.BatchNormalization()(x)x=layers.Dropout(0.2)(x)x=layers.Bidirectional (LSTM(units=200 , return_sequences=True)) (x)x=layers.BatchNormalization()(x)x=layers.Dropout(0.2)(x)y_pred=outputs = layers.Dense(5, activation='softmax')(x)fun = Model(input_data, y_pred)# fun.summary()label_length=np.zeros((3800,1))input_length=np.zeros((3800,1))for i in range (3799):    label_length[i,0]=4    input_length[i,0]=5   y_train = np.array(y_train)x_train = np.array(x_train)input_length = np.array(input_length)label_length = np.array(label_length)   loss_out = Lambda(ctc_lambda_func, output_shape=(1,), name='ctc')([y_pred, y_train, input_length, label_length])model =keras.models.Model(inputs=[input_data, y_train, input_length, label_length], outputs=loss_out)model.compile(loss={'ctc': lambda y_train, y_pred: y_pred}, optimizer = 'adam')model.fit(x=[x_train, y_train, input_length, label_length],  epochs=10, batch_size=100)我们有 y_true (或 y_train) ,尺寸为(3800,4),因此我将 label_length=4 和 input_length=5 (+1 表示空白)我面临这个错误:ValueError: Input tensors to a Model must come from `tf.keras.Input`. Received: [[0. 1. 0. 0.] [0. 1. 0. 0.] [0. 1. 0. 0.] ... [1. 0. 0. 0.] [1. 0. 0. 0.] [1. 0. 0. 0.]] (missing previous layer metadata).y_true 是这样的: [[0. 1. 0. 0.] [0. 1. 0. 0.] ... [1. 0. 0. 0.] [1. 0. 0. 0.] [1. 0. 0. 0.]]我的问题是什么?
查看完整描述

1 回答

?
千万里不及你

TA贡献1784条经验 获得超9个赞

你误解了长度。它不是标签类别的数量,而是序列的实际长度。CTC只能用于目标符号数量小于输入状态数量的情况。从技术上讲,输入和输出的数量是相同的,但有些输出是空白的。(这通常发生在语音识别中,其中有大量的输入信号窗口,而输出中的音素相对较少。)

假设您必须填充输入和输出才能将它们批量化:

  • input_length对于批次中的每个项目,应包含实际有效的输入数量,即不填充;

  • label_length应包含模型应为批次中的每个项目生成多少个非空白标签。


查看完整回答
反对 回复 2024-01-27
  • 1 回答
  • 0 关注
  • 135 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信