为了账号安全,请及时绑定邮箱和手机立即绑定

在 Tensorflow 中 - 是否可以将特定的卷积滤波器锁定在一个层中,或者将它们完全删除?

在 Tensorflow 中 - 是否可以将特定的卷积滤波器锁定在一个层中,或者将它们完全删除?

HUH函数 2022-09-27 16:44:32
在 Tensorflow 中使用迁移学习时,我知道可以通过执行以下操作来锁定层,使其无法进行进一步的训练:for layer in pre_trained_model.layers:     layer.trainable = False是否可以将特定滤镜锁定在图层中?如 - 如果整个图层包含 64 个过滤器,是否可以:只锁定其中一些,似乎包含合理的过滤器并重新训练那些没有的过滤器?或从图层中删除看起来不合理的过滤器,并在没有它们的情况下重新训练?(例如,查看重新训练的过滤器是否会发生很大变化)
查看完整描述

1 回答

?
千万里不及你

TA贡献1784条经验 获得超9个赞

一种可能的解决方案是实现自定义层,该层将卷积拆分为单独的卷积,并将每个通道(具有一个输出通道的卷积)设置为 或 设置为 。例如:number of filterstrainablenot trainable


import tensorflow as tf

import numpy as np


class Conv2DExtended(tf.keras.layers.Layer):

    def __init__(self, filters, kernel_size, **kwargs):

        self.filters = filters

        self.conv_layers = [tf.keras.layers.Conv2D(1, kernel_size, **kwargs) for _ in range(filters)]

        super().__init__()


    def build(self, input_shape):

        _ = [l.build(input_shape) for l in self.conv_layers]

        super().build(input_shape)


    def set_trainable(self, channels):

        """Sets trainable channels."""

        for i in channels:

            self.conv_layers[i].trainable = True


    def set_non_trainable(self, channels):

        """Sets not trainable channels."""

        for i in channels:

            self.conv_layers[i].trainable = False


    def call(self, inputs):

        results = [l(inputs) for l in self.conv_layers]

        return tf.concat(results, -1)

和用法示例:


inputs = tf.keras.layers.Input((28, 28, 1))

conv = Conv2DExtended(filters=4, kernel_size=(3, 3))

conv.set_non_trainable([1, 2]) # only channels 0 and 3 are trainable

res = conv(inputs)

res = tf.keras.layers.Flatten()(res)

res = tf.keras.layers.Dense(1, activation=tf.nn.sigmoid)(res)


model = tf.keras.models.Model(inputs, res)

model.compile(optimizer=tf.keras.optimizers.SGD(),

              loss='binary_crossentropy',

              metrics=['accuracy'])

model.fit(np.random.normal(0, 1, (10, 28, 28, 1)),

          np.random.randint(0, 2, (10)),

          batch_size=2,

          epochs=5)


查看完整回答
反对 回复 2022-09-27
  • 1 回答
  • 0 关注
  • 90 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信