您当前的位置:首页 > IT编程 > 深度学习
| C语言 | Java | VB | VC | python | Android | TensorFlow | C++ | oracle | 学术与代码 | cnn卷积神经网络 | gnn | 图像修复 | Keras | 数据集 | Neo4j | 自然语言处理 | 深度学习 | 医学CAD | 医学影像 | 超参数 | pointnet | pytorch |

自学教程:Generative Image Inpainting with Contextual Attention

51自学网 2021-01-06 14:42:53
  深度学习
这篇教程Generative Image Inpainting with Contextual Attention写得很实用,希望能帮到您。

Generative Image Inpainting with Contextual Attention

今天介绍CVPR 2018Generative Image Inpainting with Contextual Attention

paper: https://arxiv.org/abs/1801.07892, demo http://jiahuiyu.com/deepfill

github:https://github.com/JiahuiYu/generative_inpainting

先看效果:

上述是作者修复的结果,我自己训练后修复的如下:

这里生成了两个不同情况的图,因为使用了两个不同的pre-train Model 

下面介绍如何使用:

  1. Requirements:
    • Install python3.
    • Install tensorflow (tested on Release 1.3.0, 1.4.0, 1.5.0, 1.6.0, 1.7.0).
    • Install tensorflow toolkit neuralgym (run pip install git+https://github.com/JiahuiYu/neuralgym).
  2. Training:
    • Prepare training images filelist and shuffle it (example).
    • Modify inpaint.yml to set DATA_FLIST, LOG_DIR, IMG_SHAPES and other parameters.
    • Run python3 train.py.

这里重点介绍如何准备自己的训练集,直接写了个python脚本自动处理即可。gen_flist.py自动将源数据集划分为训练集和验证集。并生成项目需要的格式。


 
  1.  
    # 将原数据集分为training ,validation by gavin
  2.  
    import os
  3.  
    import random
  4.  
     
  5.  
    import argparse
  6.  
     
  7.  
    #划分验证集训练集
  8.  
    _NUM_TEST = 20000
  9.  
     
  10.  
    parser = argparse.ArgumentParser()
  11.  
    parser.add_argument('--folder_path', default='/home/gavin/Dataset/celeba', type=str,
  12.  
    help='The folder path')
  13.  
    parser.add_argument('--train_filename', default='./data/celeba/train_shuffled.flist', type=str,
  14.  
    help='The train filename.')
  15.  
    parser.add_argument('--validation_filename', default='./data/celeba/validation_static_view.flist', type=str,
  16.  
    help='The validation filename.')
  17.  
     
  18.  
     
  19.  
    def _get_filenames(dataset_dir):
  20.  
    photo_filenames = []
  21.  
    image_list = os.listdir(dataset_dir)
  22.  
    photo_filenames = [os.path.join(dataset_dir, _) for _ in image_list]
  23.  
    return photo_filenames
  24.  
     
  25.  
     
  26.  
    if __name__ == "__main__":
  27.  
     
  28.  
    args = parser.parse_args()
  29.  
     
  30.  
    data_dir = args.folder_path
  31.  
     
  32.  
    # get all file names
  33.  
    photo_filenames = _get_filenames(data_dir)
  34.  
    print("size of celeba is %d" % (len(photo_filenames)))
  35.  
     
  36.  
    # 切分数据为测试训练集
  37.  
    random.seed(0)
  38.  
    random.shuffle(photo_filenames)
  39.  
    training_file_names = photo_filenames[_NUM_TEST:]
  40.  
    validation_file_names = photo_filenames[:_NUM_TEST]
  41.  
     
  42.  
    print("training file size:",len(training_file_names))
  43.  
    print("validation file size:", len(validation_file_names))
  44.  
     
  45.  
    # make output file if not existed
  46.  
    if not os.path.exists(args.train_filename):
  47.  
    os.mknod(args.train_filename)
  48.  
     
  49.  
    if not os.path.exists(args.validation_filename):
  50.  
    os.mknod(args.validation_filename)
  51.  
     
  52.  
    # write to file
  53.  
    fo = open(args.train_filename, "w")
  54.  
    fo.write("\n".join(training_file_names))
  55.  
    fo.close()
  56.  
     
  57.  
    fo = open(args.validation_filename, "w")
  58.  
    fo.write("\n".join(validation_file_names))
  59.  
    fo.close()
  60.  
     
  61.  
    # print process
  62.  
    print("Written file is: ", args.train_filename)
  63.  
     
  64.  
     
  65.  
     

最终生成的格式如下图:

  1.  
  2.  
  3. Resume training:
    • Modify MODEL_RESTORE flag in inpaint.yml. E.g., MODEL_RESTORE: 20180115220926508503_places2_model.
    • Run python3 train.py.
  4. Testing:
    • Run python test.py --image examples/input.png --mask examples/mask.png --output examples/output.png --checkpoint model_logs/your_model_dir.

 

大概就是以上操作,后面贴上我实际训练和测试的脚本。

配置文件

其中inpaint.yml中要注意的是,在恢复训练模型的时候,MODEL_RESTORE的值:

多GPU模式训练

如果使用多个GPU训练,需要改三处地方,分别是inpaint.yml中两处,如下


 
  1.  
    # training
  2.  
    NUM_GPUS: 2
  3.  
    GPU_ID: [01] # -1 indicate select any available one, otherwise select gpu ID, e.g. [0,1,3]

分别指定将gpu使用的个数及各自的id,第三处,也是最重要而且特别容易忽略的,在train.py中修改这里


 
  1.  
    # train generator with primary trainer ,MultiGPUTrainer. for multi gpu,and add num_gpus=config.NUM_GPUS
  2.  
     
  3.  
    trainer = ng.train.Trainer(
  4.  
    optimizer=g_optimizer,
  5.  
    var_list=g_vars,
  6.  
    max_iters=config.MAX_ITERS,
  7.  
    graph_def=multigpu_graph_def,
  8.  
    grads_summary=config.GRADS_SUMMARY,
  9.  
    gradient_processor=gradient_processor,
  10.  
    graph_def_kwargs={
  11.  
    'model': model, 'data': data, 'config': config, 'loss_type': 'g'},
  12.  
    spe=config.TRAIN_SPE,
  13.  
    log_dir=log_prefix,
  14.  
    )
  15.  
    '''
  16.  
     
  17.  
    trainer = ng.train.MultiGPUTrainer(
  18.  
    optimizer=g_optimizer,
  19.  
    var_list=g_vars,
  20.  
    max_iters=config.MAX_ITERS,
  21.  
    graph_def=multigpu_graph_def,
  22.  
    grads_summary=config.GRADS_SUMMARY,
  23.  
    gradient_processor=gradient_processor,
  24.  
    graph_def_kwargs={
  25.  
    'model': model, 'data': data, 'config': config, 'loss_type': 'g'},
  26.  
    spe=config.TRAIN_SPE,
  27.  
    log_dir=log_prefix,
  28.  
    num_gpus = config.NUM_GPUS,
  29.  
    )
  30.  
    '''

即有两种调用方式,一种单GPU跑,一种多GPU模式,而多GPU模式下需要加上参数

num_gpus = config.NUM_GPUS,

 

 

脚本:


 
  1.  
    # training
  2.  
    python3 train.py
  3.  
     
  4.  
    # Resume training:
  5.  
    Modify MODEL_RESTORE flag in inpaint.yml. E.g., MODEL_RESTORE: 20180115220926508503_places2_model.
  6.  
    Run python3 train.py.
  7.  
     
  8.  
    #Testing:
  9.  
     
  10.  
    python3 test.py --image examples/input.png --mask examples/mask.png --output examples/output.png --checkpoint model_logs/your_model_dir.
  11.  
     
  12.  
    python3 test.py --image examples/celeba/celebahr_patches_164787_input.png --mask examples/center_mask_256.png
  13.  
    --output examples/output_celeba.png --checkpoint_dir model_logs/celebA_model/snap-60000
  14.  
     
  15.  
     
  16.  
    # for any other image,you can generate mask and masked image first ,then predict
  17.  
     
  18.  
    1. python3 generate_mask.py --img ./examples/celeba/000035.jpg --HEIGHT 64 --WIDTH 64
  19.  
     
  20.  
    2. python3 test.py --image ./data/mask_img/masked/000035.jpg --mask ./data/mask_img/mask/000035.jpg \
  21.  
    --output examples/output_000035.png --checkpoint_dir model_logs/celebA_model/snap-90000

测试

实际测试过程中,对于任一张图,需要输入mask,和input,这里需要我们自己生成,为了便于随机生成mask,我写了如下代码,可以随机生成规则及不规则的mask


 
  1.  
    '''
  2.  
    利用opencv随机给图像生成带mask区域的图
  3.  
    author:gavin
  4.  
    '''
  5.  
     
  6.  
    # import itertools
  7.  
    # import matplotlib
  8.  
    # import matplotlib.pyplot as plt
  9.  
    from copy import deepcopy
  10.  
    from random import randint
  11.  
    import numpy as np
  12.  
    import cv2
  13.  
    import os
  14.  
    import sys
  15.  
    import tensorflow as tf
  16.  
     
  17.  
    import argparse
  18.  
     
  19.  
    parser = argparse.ArgumentParser()
  20.  
    parser.add_argument('--img', default='./examples/celeba/000042.jpg', type=str,
  21.  
    help='The input img for single image ')
  22.  
     
  23.  
    parser.add_argument('--input_dirimg', default='./data/mask_img/src_img/', type=str,
  24.  
    help='The input folder path for multi-images')
  25.  
    parser.add_argument('--output_dirmask', default='./data/mask_img/mask/', type=str,
  26.  
    help='The output file path of mask.')
  27.  
    parser.add_argument('--output_dirmasked', default='./data/mask_img/masked/', type=str,
  28.  
    help='The output file path of masked.')
  29.  
    parser.add_argument('--MAX_MASK_NUMS', default='16', type=int,
  30.  
    help='max numbers of masks')
  31.  
     
  32.  
    parser.add_argument('--MAX_DELTA_HEIGHT', default='32', type=int,
  33.  
    help='max height of delta')
  34.  
    parser.add_argument('--MAX_DELTA_WIDTH', default='32', type=int,
  35.  
    help='max width of delta')
  36.  
     
  37.  
    parser.add_argument('--HEIGHT', default='128', type=int,
  38.  
    help='max height of delta')
  39.  
    parser.add_argument('--WIDTH', default='128', type=int,
  40.  
    help='max width of delta')
  41.  
     
  42.  
    parser.add_argument('--IMG_SHAPES', type=eval, default=(256, 256, 3))
  43.  
     
  44.  
     
  45.  
    # 随机生成不规则掩膜
  46.  
    def random_mask(height, width, config,channels=3):
  47.  
    """Generates a random irregular mask with lines, circles and elipses"""
  48.  
    img = np.zeros((height, width, channels), np.uint8)
  49.  
     
  50.  
    # Set size scale
  51.  
    size = int((width + height) * 0.02)
  52.  
    if width < 64 or height < 64:
  53.  
    raise Exception("Width and Height of mask must be at least 64!")
  54.  
     
  55.  
    # Draw random lines
  56.  
    for _ in range(randint(1, config.MAX_MASK_NUMS)):
  57.  
    x1, x2 = randint(1, width), randint(1, width)
  58.  
    y1, y2 = randint(1, height), randint(1, height)
  59.  
    thickness = randint(3, size)
  60.  
    cv2.line(img, (x1, y1), (x2, y2), (1, 1, 1), thickness)
  61.  
     
  62.  
     
  63.  
    # Draw random circles
  64.  
    for _ in range(randint(1, config.MAX_MASK_NUMS)):
  65.  
    x1, y1 = randint(1, width), randint(1, height)
  66.  
    radius = randint(3, size)
  67.  
    cv2.circle(img, (x1, y1), radius, (1, 1, 1), -1)
  68.  
     
  69.  
    # Draw random ellipses
  70.  
    for _ in range(randint(1, config.MAX_MASK_NUMS)):
  71.  
    x1, y1 = randint(1, width), randint(1, height)
  72.  
    s1, s2 = randint(1, width), randint(1, height)
  73.  
    a1, a2, a3 = randint(3, 180), randint(3, 180), randint(3, 180)
  74.  
    thickness = randint(3, size)
  75.  
    cv2.ellipse(img, (x1, y1), (s1, s2), a1, a2, a3, (1, 1, 1), thickness)
  76.  
     
  77.  
    return 1 - img
  78.  
     
  79.  
     
  80.  
    '''
  81.  
    # this for test
  82.  
    # %matplotlib inline ==> plt.show()
  83.  
    # Plot the results
  84.  
    _, axes = plt.subplots(5, 5, figsize=(20, 20))
  85.  
    axes = list(itertools.chain.from_iterable(axes))
  86.  
     
  87.  
    for i in range(len(axes)):
  88.  
    # Generate image
  89.  
    img = random_mask(500, 500)
  90.  
     
  91.  
    # Plot image on axis
  92.  
    axes[i].imshow(img * 255)
  93.  
     
  94.  
    plt.show()
  95.  
     
  96.  
    '''
  97.  
     
  98.  
     
  99.  
    def random_bbox(config):
  100.  
    """Generate a random tlhw with configuration.
  101.  
     
  102.  
    Args:
  103.  
    config: Config should have configuration including IMG_SHAPES,
  104.  
    VERTICAL_MARGIN, HEIGHT, HORIZONTAL_MARGIN, WIDTH.
  105.  
     
  106.  
    Returns:
  107.  
    tuple: (top, left, height, width)
  108.  
     
  109.  
    """
  110.  
    img_shape = config.IMG_SHAPES
  111.  
    img_height = img_shape[0]
  112.  
    img_width = img_shape[1]
  113.  
    maxt = img_height - config.HEIGHT
  114.  
    maxl = img_width - config.WIDTH
  115.  
    t = tf.random_uniform(
  116.  
    [], minval=0, maxval=maxt, dtype=tf.int32)
  117.  
    l = tf.random_uniform(
  118.  
    [], minval=0, maxval=maxl, dtype=tf.int32)
  119.  
    h = tf.constant(config.HEIGHT)
  120.  
    w = tf.constant(config.WIDTH)
  121.  
    return (t, l, h, w)
  122.  
     
  123.  
    def bbox2mask(bbox, config, name='mask'):
  124.  
    """Generate mask tensor from bbox.
  125.  
     
  126.  
    Args:
  127.  
    bbox: configuration tuple, (top, left, height, width)
  128.  
    config: Config should have configuration including IMG_SHAPES,
  129.  
    MAX_DELTA_HEIGHT, MAX_DELTA_WIDTH.
  130.  
     
  131.  
    Returns:
  132.  
    tf.Tensor: output with shape [1, H, W, 1]
  133.  
     
  134.  
    """
  135.  
    def npmask(bbox, height, width, delta_h, delta_w):
  136.  
    mask = np.zeros((1, height, width, 1), np.float32)
  137.  
    h = np.random.randint(delta_h//2+1)
  138.  
    w = np.random.randint(delta_w//2+1)
  139.  
    mask[:, bbox[0]+h:bbox[0]+bbox[2]-h,
  140.  
    bbox[1]+w:bbox[1]+bbox[3]-w, :] = 1.
  141.  
    return mask
  142.  
    with tf.variable_scope(name), tf.device('/cpu:0'):
  143.  
    img_shape = config.IMG_SHAPES
  144.  
    height = img_shape[0]
  145.  
    width = img_shape[1]
  146.  
    mask = tf.py_func(
  147.  
    npmask,
  148.  
    [bbox, height, width,
  149.  
    config.MAX_DELTA_HEIGHT, config.MAX_DELTA_WIDTH],
  150.  
    tf.float32, stateful=False)
  151.  
    mask.set_shape([1] + [height, width] + [1])
  152.  
    return mask
  153.  
     
  154.  
    # 对于矩形mask随机生成
  155.  
    def random_mask_rect(img_path,config,bsave=True):
  156.  
     
  157.  
    # Load image
  158.  
    img_data = cv2.imread(img_path)
  159.  
    #img_data = cv2.cvtColor(img_data, cv2.COLOR_BGR2RGB)
  160.  
     
  161.  
     
  162.  
    '''
  163.  
    # generate mask, 1 represents masked point
  164.  
    bbox = random_bbox(config)
  165.  
    mask = bbox2mask(bbox, config, name='mask_c')
  166.  
    img_pos = img_data / 127.5 - 1.
  167.  
    masked_img = img_pos * (1. - mask)
  168.  
    '''
  169.  
     
  170.  
    # 创建矩形区域,填充白色255
  171.  
    img_shape = config.IMG_SHAPES
  172.  
    img_height = img_shape[0]
  173.  
    img_width = img_shape[1]
  174.  
     
  175.  
    image = cv2.resize(img_data, (img_width, img_height))
  176.  
    rectangle = np.zeros(image.shape[0:2], dtype=np.uint8)
  177.  
     
  178.  
    maxt = img_height - config.HEIGHT
  179.  
    maxl = img_width - config.WIDTH
  180.  
     
  181.  
    h = config.HEIGHT
  182.  
    w = config.WIDTH
  183.  
     
  184.  
    x = randint(0, maxt - 1)
  185.  
    y = randint(0, maxl - 1)
  186.  
     
  187.  
    mask = cv2.rectangle(rectangle,(x, y), (x+w, y+h) , 255, -1) # 修改这里 (78, 30), (98, 46)
  188.  
     
  189.  
    masked_img = deepcopy(image)
  190.  
    masked_img[mask == 255] = 255
  191.  
     
  192.  
     
  193.  
    print("shape of mask:",mask.shape)
  194.  
    print("shape of masked_img:",masked_img.shape)
  195.  
     
  196.  
    if bsave:
  197.  
    save_name_mask = os.path.join(config.output_dirmask, img_path.split('/')[-1])
  198.  
    cv2.imwrite(save_name_mask,mask)
  199.  
     
  200.  
    save_name_masked = os.path.join(config.output_dirmasked, img_path.split('/')[-1])
  201.  
    cv2.imwrite(save_name_masked, masked_img)
  202.  
     
  203.  
    return masked_img,mask
  204.  
     
  205.  
     
  206.  
    def get_path(config):
  207.  
    if not os.path.exists(config.input_dirimg):
  208.  
    os.mkdir(config.input_dirimg)
  209.  
    if not os.path.exists(config.output_dirmask):
  210.  
    os.mkdir(config.output_dirmask)
  211.  
    if not os.path.exists(config.output_dirmasked):
  212.  
    os.mkdir(config.output_dirmasked)
  213.  
     
  214.  
     
  215.  
     
  216.  
    # 给单个图像生成带mask区域的图
  217.  
    def load_mask(img_path,config,bsave=False):
  218.  
     
  219.  
    # Load image
  220.  
    img = cv2.imread(img_path)
  221.  
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  222.  
    shape = img.shape
  223.  
    print("Shape of image is: ",shape)
  224.  
    # Load mask
  225.  
    mask = random_mask(shape[0], shape[1],config)
  226.  
     
  227.  
    # Image + mask
  228.  
    masked_img = deepcopy(img)
  229.  
    masked_img[mask == 0] = 255
  230.  
     
  231.  
    mask = mask * 255
  232.  
     
  233.  
    if bsave:
  234.  
    save_name_mask = os.path.join(config.output_dirmask, img_path.split('/')[-1])
  235.  
    cv2.imwrite(save_name_mask,mask)
  236.  
     
  237.  
    save_name_masked = os.path.join(config.output_dirmasked, img_path.split('/')[-1])
  238.  
    cv2.imwrite(save_name_masked, masked_img)
  239.  
     
  240.  
    return masked_img,mask
  241.  
     
  242.  
     
  243.  
     
  244.  
    # 批量生成带mask区域的图像
  245.  
    def img2maskedImg(dataset_dir):
  246.  
    files = []
  247.  
    image_list = os.listdir(dataset_dir)
  248.  
    files = [os.path.join(dataset_dir, _) for _ in image_list]
  249.  
    length = len(files)
  250.  
    for index,jpg in enumerate(files):
  251.  
    try:
  252.  
    sys.stdout.write('\r>>Converting image %d/%d ' % (index,length))
  253.  
    sys.stdout.flush()
  254.  
    load_mask(jpg,config,True)
  255.  
    # 将已经转换的图片移动到指定位置
  256.  
    #shutil.move(png, output_dirHR)
  257.  
    except IOError as e:
  258.  
    print('could not read:',jpg)
  259.  
    print('error:',e)
  260.  
    print('skip it\n')
  261.  
     
  262.  
    sys.stdout.write('Convert Over!\n')
  263.  
    sys.stdout.flush()
  264.  
     
  265.  
    # python3 generate_mask.py --img ./examples/celeba/000042.jpg --HEIGHT 64 --WIDTH 64
  266.  
     
  267.  
    if __name__ == '__main__':
  268.  
    config = parser.parse_args()
  269.  
    get_path(config)
  270.  
    # 单张图像生成mask
  271.  
    #img = './data/test.jpg'
  272.  
    #masked_img,mask = load_mask(img,config,True)
  273.  
     
  274.  
    # 批量图像处理==>圆形,椭圆,直线
  275.  
    #img2maskedImg(config.input_dirimg)
  276.  
     
  277.  
    # 矩形特殊处理 处理同样shape的图片(256,256,3) fix me
  278.  
    #img = './examples/celeba/000042.jpg'
  279.  
    img = config.img
  280.  
    masked_img, mask = random_mask_rect(img,config)
  281.  
     
  282.  
    '''
  283.  
    # Show side by side
  284.  
    _, axes = plt.subplots(1, 3, figsize=(20, 5))
  285.  
    axes[0].imshow(img)
  286.  
    axes[1].imshow(mask*255)
  287.  
    axes[2].imshow(masked_img)
  288.  
    plt.show()
  289.  
    '''
  290.  
     
  291.  
     

 

效果:

mask,masked,output

 


什么是掩膜(mask)图像mask制作实例
联合模型适应小数据集的深度学习模型:Union-net: A deep neural network model adapted to small data sets
51自学网,即我要自学网,自学EXCEL、自学PS、自学CAD、自学C语言、自学css3实例,是一个通过网络自主学习工作技能的自学平台,网友喜欢的软件自学网站。
京ICP备13026421号-1