为了账号安全,请及时绑定邮箱和手机立即绑定

如何使 from_tensor_slice 的嵌套结构通过 tf.py_func

如何使 from_tensor_slice 的嵌套结构通过 tf.py_func

红颜莎娜 2021-12-26 10:24:24
我正在尝试py_func通过使用Dataset.map()来创建我的输入管道来将 .h5 解析器函数与包装器进行映射。我想传递两个参数:filename和window_size在地图函数中。以下代码有调用顺序:Dataset.map--> _pyfn_wrapper-->parse_h5缺点是使用 map() 函数时 _pyfn_wrapper 只能接受一个参数,因为不能压缩from_tensor_slices2 种类型的数据:字符串然后是 intdef helper(window_size, batch_size, ncores=mp.cpu_count()):    flist = []    for dirpath, _, fnames in os.walk('./'):        for fname in fnames:           flist.append(os.path.abspath(os.path.join(dirpath, fname)))    f_len = len(flist)    # init list of files    batch = tf.data.Dataset.from_tensor_slices((tf.constant(flist)))  #fixme: how to zip one list of string and a list of int    batch = batch.map_fn(_pyfn_wrapper, num_parallel_calls=ncores)  #fixme: how to map two args    batch = batch.shuffle(batch_size).batch(batch_size, drop_remainder=True).prefetch(ncores + 6)    # construct iterator    it = batch.make_initializable_iterator()    iter_init_op = it.initializer    # get next img and label    X_it, y_it = it.get_next()    inputs = {'img': X_it, 'label': y_it, 'iterator_init_op': iter_init_op}    return inputs, f_lendef _pyfn_wrapper(filename):  #fixme: args    # filename, window_size = args  #fixme: try to separate args    window_size = 100    return tf.py_func(parse_h5,  #wrapped pythonic function                      [filename, window_size],                      [tf.float32, tf.float32]  #[input, output] dtype                      )def parse_h5(name, window_size):    with h5py.File(name.decode('utf-8'), 'r') as f:        X = f['X'][:].reshape(window_size, window_size, 1)        y = f['y'][:].reshape(window_size, window_size, 1)        return X, y
查看完整描述

1 回答

?
一只名叫tom的猫

TA贡献1906条经验 获得超3个赞

使用嵌套结构Datasets作为@Sharky 的注释是解决方案之一。应该在最后一个函数中解压缩这个嵌套的 args , parse_h5而不是_pyfn_wrapper为了避免错误:


类型错误:张量对象仅在启用急切执行时才可迭代。要迭代此张量,请使用 tf.map_fn。


还应该解码参数,因为通过 tf.py_func() args 传递被转换为二进制文字。


代码修改如下:


def helper(...):

     ...

     flist.append((os.path.abspath(os.path.join(dirpath, fname)), str(window_size)))

     ...

def _pyfn_wrapper(args):

    return tf.py_func(parse_h5,  #wrapped pythonic function

                      [args],

                      [tf.float32, tf.float32]  #output dtype

                      )


def parse_h5(args):

    name, window_size = args  #only unzip the args here

    window_size = int(window_size.decode('utf-8'))  #and decode for converting bin to int

    with h5py.File(name, 'r') as f:

        ...


查看完整回答
反对 回复 2021-12-26
  • 1 回答
  • 0 关注
  • 150 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信