我试图在此脚本中使用类,该脚本对目录“ test_images”中的多个图像执行图像分类。我以前没有使用过很多类,因此我对如何正确应用这种情况有些困惑。错误是:TypeError: __init__() missing 1 required positional argument: 'sess'。任何帮助将不胜感激!下面是代码:def image_recognition_algorithm():def load_graph(model_file): graph = tf.Graph() graph_def = tf.GraphDef() with open(model_file, "rb") as f: graph_def.ParseFromString(f.read()) with graph.as_default(): tf.import_graph_def(graph_def) return graphdef read_tensor_from_image_file(file_name, input_height=299, input_width=299, input_mean=0, input_std=255): input_name = "file_reader" output_name = "normalized" file_reader = tf.read_file(file_name, input_name) image_reader = tf.image.decode_jpeg(file_reader, channels = 3, name='jpeg_reader') float_caster = tf.cast(image_reader, tf.float32) dims_expander = tf.expand_dims(float_caster, 0); resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width]) normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std]) sess = tf.Session() result = sess.run(normalized) return resultdef load_labels(label_file): label = [] proto_as_ascii_lines = tf.gfile.GFile(label_file).readlines() for l in proto_as_ascii_lines: label.append(l.rstrip()) return labelclass initiate_session():def __init__(self, sess): self.sess = sess graph = load_graph(model_file) input_name = "import/" + input_layer output_name = "import/" + output_layer input_operation = graph.get_operation_by_name(input_name); output_operation = graph.get_operation_by_name(output_name);
1 回答
杨__羊羊
TA贡献1943条经验 获得超7个赞
您的initiate_session.__init__()
方法有两个参数,self
它们作为对自身的引用自动传递sess
,而您需要传递。在initiate_session
此处实例化时:
if __name__ == '__main__': initiate_session().main()
您需要传递一个sess
参数。
但是,在您的情况下,我认为您实际上想要删除方法的sess
参数__init__()
,因为self.sess
稍后要在构造函数中将分配给,这里:
self.sess = tf.Session(graph=graph, config = config)
卸下sess
参数__init__()
和行
self.sess = sess
应该可以解决您的问题。
添加回答
举报
0/150
提交
取消