为了账号安全,请及时绑定邮箱和手机立即绑定

tf.data.Dataset - 为什么缓存示例时数据管道的性能没有提高?

tf.data.Dataset - 为什么缓存示例时数据管道的性能没有提高?

慕村9548890 2022-07-26 16:29:10
我目前正在尝试了解有关使用 tf.data 构建高效预处理管道的更多信息。根据本教程,缓存数据时应该对性能产生不可忽视的影响。我将我的数据管道简化为一个非常简单的示例来验证这种效果。import osimport tensorflow as tfclass ExperimentalDS:    def __init__(self, hr_img_path, cache, repeat, shuffle_buffer_size=4096):        self.hr_img_path = hr_img_path        self.ids = os.listdir(self.hr_img_path)        self.train_list = self.ids        train_list_ds = tf.data.Dataset.list_files([f"{hr_img_path}/{fname}" for fname in self.train_list])        train_hr_ds = train_list_ds.map(self.load_img)        train_hr_ds = train_hr_ds.shuffle(shuffle_buffer_size)        self.train_ds = train_hr_ds        # should probably call shuffle again after caching        if cache: self.train_ds.cache()        self.train_ds = train_hr_ds.repeat(repeat)    def get_train_ds(self, batch_size=8):        return self.train_ds.batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)    def load_img(self, fpath):        img = tf.io.read_file(fpath)        img = tf.image.decode_png(img)        img = tf.image.convert_image_dtype(img, tf.float32)        return img管道基本上只是从文件夹中读取文件名,从这些文件名加载图像,对图像进行随机播放,然后根据提供的参数缓存它们或不缓存它们。为了评估性能,我主要从前面提到的教程中复制了基准测试功能。def benchmark_dataset(ds, num_steps):    start = time.perf_counter()    it = iter(ds)    for i in range(num_steps):        batch = next(it)        if i % 100 == 0:            print(".", end="")    print()    end = time.perf_counter()    duration = end - start    return durationif __name__ == "__main__":    num_steps = 1000    batch_size = 8    durations_no_cache = []    durations_cached = []    for i in range(num_steps):        ds = ExperimentalDS("./test_data/benchmark/16", cache=False, repeat=-1)        ds_train = ds.get_train_ds(batch_size=batch_size)        durations_no_cache.append(benchmark_dataset(ds_train, num_steps))我正在加载一个非常简单的图像数据集,其中包含 16 个图像,每个图像的尺寸为 128x128(因此它应该很容易放入内存中)。我无限期地重复这个数据集并迭代它 1000 个批次(批次大小为 8),使用缓存和不缓存记录运行时,然后在 1000 次运行中平均这些结果。由于这些是相当多的运行,我认为不应该有太大的差异。如果重要的话,基准测试是在 GPU 上运行的。
查看完整描述

1 回答

?
明月笑刀无情

TA贡献1828条经验 获得超4个赞

这个语句什么都不做:

if cache: self.train_ds.cache()

它应该是:

if cache: train_hr_ds = train_hr_ds.cache()

与其他数据集转换一样,cache返回新数据集而不是修改现有数据集。


查看完整回答
反对 回复 2022-07-26
  • 1 回答
  • 0 关注
  • 71 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信