为了账号安全,请及时绑定邮箱和手机立即绑定

使用数据集API从两个不同数据源获取数据时了解张量流行为

使用数据集API从两个不同数据源获取数据时了解张量流行为

FFIVE 2021-04-09 06:45:23
我正在尝试dataset使用张量流从两个不同的来源获取数据。我写了下面的代码:首先,我尝试了以下方法:import tensorflow as tfimport numpy as npiters = []def return_data1():    d1 = tf.data.Dataset.range(1, 2000)    iter1 = d1.make_initializable_iterator()    iters.append(iter1)    data1 = iter1.get_next()    return data1def return_data2():    d2 = tf.data.Dataset.range(2000, 4000)    iter2 = d2.make_initializable_iterator()    iters.append(iter2)    data2 = iter2.get_next()    return data2test = tf.placeholder(dtype=tf.bool)data = tf.cond(test, lambda: return_data1(), lambda: return_data2())iter1 = iters[0]iter2 = iters[1]init_op = tf.global_variables_initializer()with tf.Session() as sess:    sess.run(init_op)    sess.run([iter1.initializer, iter2.initializer])    for i in range(2000):        if i < 1000:            print(sess.run(data, feed_dict={test: True}), "..")        else:            print(sess.run(data, feed_dict={test: False}), "--")我得到了以下错误:ValueError: Operation 'cond/MakeIterator' has been marked as not fetchable.1-我想知道为什么我会出现这种现象。
查看完整描述

1 回答

?
ABOUTYOU

TA贡献1812条经验 获得超5个赞

这是如何分别从多个数据集中获取数据的方法。但是我想知道其他有关张量流行为以及为什么应data2 = iter2.get_next()在方法中定义的问题的答案。


import tensorflow as tf

import numpy as np


d1 = tf.data.Dataset.range(1, 1000)

iter1 = d1.make_initializable_iterator()


d2 = tf.data.Dataset.range(1000, 2000)

iter2 = d2.make_initializable_iterator()


d3 = tf.data.Dataset.range(2000, 3000)

iter3 = d3.make_initializable_iterator()


d4 = tf.data.Dataset.range(3000, 4000)

iter4 = d4.make_initializable_iterator()


def return_data1_2():

    data1 = iter1.get_next()

    data2 = iter2.get_next()

    return data1, data2


def return_data2_3():

    data2 = iter2.get_next()

    data3 = iter3.get_next()

    return data2, data3


def return_data3_4():

    data3 = iter3.get_next()

    data4 = iter4.get_next()

    return data3, data4


def return_data4_1():

    data4 = iter4.get_next()

    data1 = iter1.get_next()

    return data4, data1


index1 = tf.placeholder(dtype=tf.int32)

index2 = tf.placeholder(dtype=tf.int32)


data = tf.case(pred_fn_pairs=[

    (tf.logical_and(tf.equal(index1, 1), tf.equal(index2, 2)), lambda: return_data1_2()), 

    (tf.logical_and(tf.equal(index1, 2), tf.equal(index2, 3)), lambda: return_data2_3()),

    (tf.logical_and(tf.equal(index1, 3), tf.equal(index2, 4)), lambda: return_data3_4()),

    (tf.logical_and(tf.equal(index1, 4), tf.equal(index2, 1)), lambda: return_data4_1())], exclusive=False)


init_op = tf.global_variables_initializer()


with tf.Session() as sess:

    sess.run(init_op)

    sess.run([iter1.initializer, iter2.initializer, iter3.initializer, iter4.initializer])



    for i in range(2000):

        try:

            if i < 500:

                print(sess.run(data, feed_dict={index1: 1, index2: 2}), "1-2")

            elif i < 1000:

                print(sess.run(data, feed_dict={index1: 2, index2: 3}), "2-3")

            elif i < 1500:

                print(sess.run(data, feed_dict={index1: 3, index2: 4}), "3-4")

            elif i < 2000:

                print(sess.run(data, feed_dict={index1: 4, index2: 1}), "4-1")

        except tf.errors.OutOfRangeError as error:

            print("error")


查看完整回答
反对 回复 2021-04-13
  • 1 回答
  • 0 关注
  • 135 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信