为了账号安全,请及时绑定邮箱和手机立即绑定

Tensorflow 对象检测 - 将检测到的对象转换为图像

Tensorflow 对象检测 - 将检测到的对象转换为图像

Smart猫小萌 2021-06-07 03:34:48
我训练了一个 ssd_mobilenet_v1 模型来检测静态灰度图像中的小物体。现在我想确定诸如物体的水平角度之类的东西。如何将对象“提取”为图像或图像阵列以进行进一步的几何研究?这是我从 Github 上的 Tensorflow 对象检测 API 更改的 object_detection_tutorial.ipynb 文件的版本(原始版本可以在这里找到:https : //github.com/tensorflow/models/tree/master/research/object_detection)代码:进口mport numpy as npimport osimport six.moves.urllib as urllibimport sysimport tarfileimport tensorflow as tfimport zipfilefrom collections import defaultdictfrom io import StringIOfrom matplotlib import pyplot as pltfrom PIL import Image# This is needed since the notebook is stored in the object_detection folder.sys.path.append("..")from object_detection.utils import ops as utils_ops对象检测导入from utils import label_map_utilfrom utils import visualization_utils as vis_util变量# What model to download.MODEL_NAME = 'shard_graph_ssd'# Path to frozen detection graph. This is the actual model that is used for the object detection.PATH_TO_FROZEN_GRAPH = MODEL_NAME + '/frozen_inference_graph.pb'# List of the strings that is used to add correct label for each box.PATH_TO_LABELS = os.path.join('data', 'label_map.pbtxt')NUM_CLASSES = 1将(冻结的)Tensorflow 模型加载到内存中。detection_graph = tf.Graph()with detection_graph.as_default():  od_graph_def = tf.GraphDef()  with tf.gfile.GFile(PATH_TO_FROZEN_GRAPH, 'rb') as fid:    serialized_graph = fid.read()    od_graph_def.ParseFromString(serialized_graph)    tf.import_graph_def(od_graph_def, name='')加载标签图label_map = label_map_util.load_labelmap(PATH_TO_LABELS)categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)category_index = label_map_util.create_category_index(categories)
查看完整描述

1 回答

?
斯蒂芬大帝

TA贡献1827条经验 获得超8个赞

我用以下功能解决了这个问题:


i 是一个用于循环的变量,基本上是当前图像的数量


def crop_objects(image, image_np, output_dict, i):

    global ymin, ymax, xmin, xmax

    width, height = image.size


    #Coordinates of detected objects

    ymin = int(output_dict['detection_boxes'][0][0]*height)

    xmin = int(output_dict['detection_boxes'][0][1]*width)

    ymax = int(output_dict['detection_boxes'][0][2]*height)

    xmax = int(output_dict['detection_boxes'][0][3]*width)

    crop_img = image_np[ymin:ymax, xmin:xmax]


    # 1. Only crop objects that are detected with an accuracy above 50%, 

    # images 

    # with objects below 50% will be filled with zeros (black image)

    # This is something I need in my program

    # 2. Only crop the object with the highest score (Object Zero)

    if output_dict['detection_scores'][0] < 0.5:

        crop_img.fill(0)


    #Save cropped object into image

    cv2.imwrite('Images/Step_2/' + str(i) + '.png', crop_img)

    return ymin, ymax, xmin, xmax

这些是它工作所必需的:


image = Image.open(image_path)

image_np = load_image_into_numpy_array(image)


def load_image_into_numpy_array(image):

    #Für Bilderkennung benötigte Funktion

    last_axis = -1

    dim_to_repeat = 2

    repeats = 3

    grscale_img_3dims = np.expand_dims(image, last_axis)

    training_image = np.repeat(grscale_img_3dims, repeats, dim_to_repeat).astype('uint8')

    assert len(training_image.shape) == 3

    assert training_image.shape[-1] == 3

    return training_image

这可能比仅裁剪对象所需的代码更多。


查看完整回答
反对 回复 2021-06-09
  • 1 回答
  • 0 关注
  • 200 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
微信客服

购课补贴
联系客服咨询优惠详情

帮助反馈 APP下载

慕课网APP
您的移动学习伙伴

公众号

扫描二维码
关注慕课网微信公众号