为了账号安全,请及时绑定邮箱和手机立即绑定

Tensorflow - 从张量中提取字符串

Tensorflow - 从张量中提取字符串

凤凰求蛊 2022-07-26 20:53:06
我正在尝试遵循本教程的“使用 tf.data 加载”部分。在本教程中,他们可以只使用字符串张量,但是,我需要提取文件名的字符串表示,因为我需要从字典中查找额外的数据。我似乎无法提取张量的字符串部分。我很确定.name张量的属性应该返回字符串,但我不断收到一条错误消息,说KeyError: 'strided_slice_1:0'不知何故,切片做了一些奇怪的事情?我正在使用以下方法加载数据集:dataset_list = tf.data.Dataset.list_files(str(DATASET_DIR / "data/*"))然后使用以下方法处理它:def process(t):    return dataset.process_image_path(t, param_data, param_min_max)dataset_labeled = dataset_list.map(    process,     num_parallel_calls=AUTOTUNE)whereparam_data和param_min_max是我加载的两个字典,其中包含构建标签所需的额外数据。这些是我用来处理数据张量的三个函数(来自我的dataset.py):def process_image_path(image_path, param_data_file, param_max_min_file):    label = path_to_label(image_path, param_data_file, param_max_min_file)    img = tf.io.read_file(image_path)    img = decode_img(img)    return (img, label)def decode_img(img):    """Converts an image to a 3D uint8 tensor"""    img = tf.image.decode_jpeg(img, channels=3)    img = tf.image.convert_image_dtype(img, tf.float32)    return imgdef path_to_label(image_path, param_data_file, param_max_min_file):    """Returns the NORMALIZED label (set of parameter values) of an image."""    parts = tf.strings.split(image_path, os.path.sep)    filename = parts[-1]  # Extract filename with extension    filename = tf.strings.split(filename, ".")[0].name  # Extract filename    param_data = param_data_file[filename]  # ERROR! .name above doesn't seem to return just the filename    P = len(param_max_min_file)    label = np.zeros(P)    i = 0    while i < P:        param = param_max_min_file[i]        umin = param["user_min"]        umax = param["user_max"]        sub_index = param["sub_index"]        identifier = param["identifier"]        node = param["node_name"]        value = param_data[node][identifier]        label[i] = _normalize(value[sub_index])        i += 1    return label我已经验证filename = tf.strings.split(filename, ".")[0]inpath_to_label()确实返回了正确的张量,但我需要它作为字符串。整个事情也证明很难调试,因为调试时我无法访问属性(我收到错误消息AttributeError: Tensor.name is meaningless when eager execution is enabled.)。
查看完整描述

1 回答

?
慕斯709654

TA贡献1840条经验 获得超5个赞

该name字段是张量本身的名称,而不是张量的内容。


要进行常规 python 字典查找,请将解析函数包装在tf.py_func.


import tensorflow as tf

tf.enable_eager_execution()


d = {"a": 1, "b": 3, "c": 10}

dataset = tf.data.Dataset.from_tensor_slices(["a", "b", "c"])


def parse(s):

  return s, d[s]

dataset = dataset.map(lambda s: tf.py_func(parse, [s], (tf.string, tf.int64)))


for element in dataset:

  print(element[1].numpy()) # prints 1, 3, 10


查看完整回答
反对 回复 2022-07-26
  • 1 回答
  • 0 关注
  • 140 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信