为了账号安全,请及时绑定邮箱和手机立即绑定

在转换过程中从 tensorflow 对象中提取 numpy 值

在转换过程中从 tensorflow 对象中提取 numpy 值

海绵宝宝撒 2023-03-30 10:29:27
我正在尝试使用 tensorflow 获取词嵌入,并且我已经使用我的语料库创建了相邻的工作列表。我的词汇表中唯一单词的数量为 8000,相邻单词列表的数量约为 160 万单词列表示例照片由于数据非常大,我试图将单词列表分批写入 TFRecords 文件。def save_tfrecords_wordlist(toprocess_word_lists, path ):        writer = tf.io.TFRecordWriter(path)    for word_list in toprocess_word_lists:        features=tf.train.Features(            feature={        'word_list_X': tf.train.Feature( bytes_list=tf.train.BytesList(value=[word_list[0].encode('utf-8')] )),        'word_list_Y': tf.train.Feature( bytes_list=tf.train.BytesList(value=[word_list[1].encode('utf-8') ]))                }            )        example = tf.train.Example(features = features)        writer.write(example.SerializeToString())    writer.close()定义批次batches = [0,250000,500000,750000,1000000,1250000,1500000,1641790]for i in range(len(batches) - 1 ):    batches_start = batches[i]    batches_end = batches[i + 1]    print( str(batches_start) + " -- " + str(batches_end ))    toprocess_word_lists = word_lists[batches_start:batches_end]    save_tfrecords_wordlist( toprocess_word_lists, path +"/TFRecords/data_" + str(i) +".tfrecords")##############################def _parse_function(example_proto):  features = {"word_list_X": tf.io.FixedLenFeature((), tf.string),          "word_list_Y": tf.io.FixedLenFeature((), tf.string)}  parsed_features = tf.io.parse_single_example(example_proto, features)
查看完整描述

1 回答

?
沧海一幻觉

TA贡献1824条经验 获得超5个赞

似乎您无法从映射函数(1、2)内部调用 .numpy() 函数,尽管我能够使用来自(doc)的 py_function 进行管理。

在下面的示例中,我已将我解析的数据集映射到一个函数,该函数将我的图像转换np.uint8以便使用 matplotlib绘制它们。

records_path = data_directory+'TFRecords'+'/data_0.tfrecord'

# Create a dataset

dataset = tf.data.TFRecordDataset(filenames=records_path)

# Map our dataset to the parsing function 

parsed_dataset = dataset.map(parsing_fn)

converted_dataset = parsed_dataset.map(lambda image,label:

                                       tf.py_function(func=converting_function,

                                                      inp=[image,label],

                                                      Tout=[np.uint8,tf.int64]))


# Gets the iterator

iterator = tf.compat.v1.data.make_one_shot_iterator(converted_dataset) 


for i in range(5):

    image,label = iterator.get_next()

    plt.imshow(image)

    plt.show()

    print('label: ', label)

输出:

//img1.sycdn.imooc.com/6424f4bc0001ce7c05110204.jpg

解析函数:

def parsing_fn(serialized):

    # Define a dict with the data-names and types we expect to

    # find in the TFRecords file.

    features = \

        {

            'image': tf.io.FixedLenFeature([], tf.string),

            'label': tf.io.FixedLenFeature([], tf.int64)            

        }


    # Parse the serialized data so we get a dict with our data.

    parsed_example = tf.io.parse_single_example(serialized=serialized,

                                             features=features)

    # Get the image as raw bytes.

    image_raw = parsed_example['image']


    # Decode the raw bytes so it becomes a tensor with type.

    image = tf.io.decode_jpeg(image_raw)

    

    # Get the label associated with the image.

    label = parsed_example['label']

    

    # The image and label are now correct TensorFlow types.

    return image, label

更新:实际上并没有签出,但 tf.shape() 似乎也是一个有前途的选择。



查看完整回答
反对 回复 2023-03-30
  • 1 回答
  • 0 关注
  • 100 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信