2 回答
TA贡献1829条经验 获得超7个赞
也许你可以试试这一行而不是这两行:
label = tf.argmax(tf.cast(parts[-2] == CLASS_NAMES, tf.int32))
你会得到类似的东西[0, 1, 0](标签的索引在CLASS_NAMES)。
功能和可重现的例子:
import tensorflow as tf
import numpy as np
from string import ascii_lowercase as letters
CLASS_NAMES = [b'class_1', b'class_2', b'class_3']
files = ['\\'.join([np.random.choice(CLASS_NAMES).decode(),
''.join(np.random.choice(list(letters), 5)) + '.jpg'])
for i in range(10)]
ds = tf.data.Dataset.from_tensor_slices(files)
这是我生成的假文件:
['class_3\\jrxog.jpg',
'class_1\\slfiq.jpg',
'class_2\\svldd.jpg',
'class_2\\avrgt.jpg',
'class_3\\wqwuv.jpg']
现在实现这个:
def get_label(file_path):
parts = tf.strings.split(file_path, '\\')
return file_path, tf.argmax(tf.cast(parts[-2] == CLASS_NAMES, tf.int32))
ds = ds.map(get_label)
next(iter(ds))
(<tf.Tensor: shape=(), dtype=string, numpy=b'class_1\\bbqrx.jpg'>,
<tf.Tensor: shape=(), dtype=int64, numpy=0>)
添加回答
举报