为了账号安全,请及时绑定邮箱和手机立即绑定

如何在 Tensorflow 中编写 LabelEncoder?

如何在 Tensorflow 中编写 LabelEncoder?

扬帆大鱼 2023-02-22 17:04:09
我正在尝试将 Google Storage 上的目录解析为字符串,但我不断收到错误。我想找到每个文件的目录并将目录名称的数字编码作为数据集返回。这在使用 LabelEncoder 的 sklearn 中是微不足道的,但我在 Tensorflow 中做这件事时遇到了麻烦。CLASS_NAMES = [b'class_1', b'class_2', b'class_3']labeler = tfds.features.ClassLabel(names=CLASS_NAMES)def parse_filenames(filename):    label = tf.strings.split(tf.expand_dims(filename, axis=-1), sep='/')    label = label.values[-2]    # Problem is in the two lines below    position_feature = tf.feature_column.categorical_column_with_vocabulary_list('label_names', CLASS_NAMES)    label = tf.io.parse_example(label, features=position_feature)    return labelfolder = b'gs://<bucket>/train/*/*.jpg'filenames_dataset = tf.data.Dataset.list_files(folder)label_dataset = filenames_dataset.map(parse_filenames)next(iter(label_dataset))我得到一个错误ValueError: dictionary update sequence element #0 has length 16; 2 is required如果我删除“# Problem is here”注释下的两行,它工作正常,除了它返回一个字符串而不是一个整数。我已经尝试过其他非张量流选项,例如 <list_name>.index(label),但那些当然会失败,因为一切都是张量而不是字符串。还有另一种方法吗?
查看完整描述

2 回答

?
千巷猫影

TA贡献1829条经验 获得超7个赞

也许你可以试试这一行而不是这两行:


label = tf.argmax(tf.cast(parts[-2] == CLASS_NAMES, tf.int32))

你会得到类似的东西[0, 1, 0](标签的索引在CLASS_NAMES)。


功能和可重现的例子:


import tensorflow as tf

import numpy as np

from string import ascii_lowercase as letters


CLASS_NAMES = [b'class_1', b'class_2', b'class_3']


files = ['\\'.join([np.random.choice(CLASS_NAMES).decode(),

                    ''.join(np.random.choice(list(letters), 5)) + '.jpg']) 

         for i in range(10)]


ds = tf.data.Dataset.from_tensor_slices(files)

这是我生成的假文件:


['class_3\\jrxog.jpg',

 'class_1\\slfiq.jpg',

 'class_2\\svldd.jpg',

 'class_2\\avrgt.jpg',

 'class_3\\wqwuv.jpg']

现在实现这个:


def get_label(file_path):

    parts = tf.strings.split(file_path, '\\')

    return file_path, tf.argmax(tf.cast(parts[-2] == CLASS_NAMES, tf.int32))


ds = ds.map(get_label)


next(iter(ds))

(<tf.Tensor: shape=(), dtype=string, numpy=b'class_1\\bbqrx.jpg'>,

 <tf.Tensor: shape=(), dtype=int64, numpy=0>)


查看完整回答
反对 回复 2023-02-22
?
慕森卡

TA贡献1806条经验 获得超8个赞

我使用了 sklearn 的标签编码器。这是你可能需要腌制的东西,这样你可以在以后反向转换你的结果。我在这方面还是新手,所以我不确定这对你有多好



查看完整回答
反对 回复 2023-02-22
  • 2 回答
  • 0 关注
  • 124 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信