为了账号安全,请及时绑定邮箱和手机立即绑定

numpy 数组输入到 tensorflow/keras 神经网络的 dtype 有什么关系吗?

numpy 数组输入到 tensorflow/keras 神经网络的 dtype 有什么关系吗?

慕沐林林 2023-05-23 10:39:57
如果我采用 tensorflow.keras 模型并调用model.fit(x, y)(numpy 数组在哪里x),numpy 数组是y什么重要吗?dtype我最好只是尽可能dtype小(例如int8二进制数据),还是这会给 tensorflow/keras 额外的工作以将其转换为浮点数?
查看完整描述

1 回答

?
慕田峪4524236

TA贡献1875条经验 获得超5个赞

您应该将输入转换为np.float32,这是 Keras 的默认数据类型。查一下:


import tensorflow as tf

tf.keras.backend.floatx()

'float32'

如果你给 Keras 输入 in np.float64,它会抱怨:


import tensorflow as tf

from tensorflow.keras.layers import Dense 

from tensorflow.keras import Model

from sklearn.datasets import load_iris

iris, target = load_iris(return_X_y=True)


X = iris[:, :3]

y = iris[:, 3]


ds = tf.data.Dataset.from_tensor_slices((X, y)).shuffle(25).batch(8)


class MyModel(Model):

  def __init__(self):

    super(MyModel, self).__init__()

    self.d0 = Dense(16, activation='relu')

    self.d1 = Dense(32, activation='relu')

    self.d2 = Dense(1, activation='linear')


  def call(self, x):

    x = self.d0(x)

    x = self.d1(x)

    x = self.d2(x)

    return x


model = MyModel()


_ = model(X)

警告:tensorflow:Layer my_model 正在将输入张量从 dtype float64 转换为层的 dtype float32,这是 TensorFlow 2 中的新行为。该层具有 dtype float32,因为它的 dtype 默认为 floatx。如果你打算在 float32 中运行这个层,你可以安全地忽略这个警告。如果有疑问,如果您将 TensorFlow 1.X 模型移植到 TensorFlow 2,则此警告可能只是一个问题。要将所有层更改为默认 dtype float64,请调用tf.keras.backend.set_floatx('float64'). 要仅更改这一层,请将 dtype='float64' 传递给层构造函数。如果您是该层的作者,则可以通过将 autocast=False 传递给基础层构造函数来禁用自动转换。


查看完整回答
反对 回复 2023-05-23
  • 1 回答
  • 0 关注
  • 90 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信