您当前的位置:首页 > IT编程 > 医学CAD
| C语言 | Java | VB | VC | python | Android | TensorFlow | C++ | oracle | 学术与代码 | cnn卷积神经网络 | gnn | 图像修复 | Keras | 数据集 | Neo4j | 自然语言处理 | 深度学习 | 医学CAD | 医学影像 | 超参数 | pointnet | pytorch |

自学教程:深度学习检测疟疾

51自学网 2020-12-09 11:20:47
  医学CAD
这篇教程深度学习检测疟疾写得很实用,希望能帮到您。

通过深度学习检测疟疾

虽然我既不是医生,也不是医疗保健研究人员,而且我的素质还差得很远,但我对将AI应用于医疗保健研究感兴趣。 我在本文中的目的是展示AI和开源解决方案如何帮助疟疾检测和减少体力劳动。

Python and TensorFlow

Python和TensorFlow:构建开源深度学习解决方案的绝佳组合

得益于Python的强大功能和TensorFlow等深度学习框架,我们可以构建健壮,可扩展且有效的深度学习解决方案。 由于这些工具是免费和开源的,因此我们可以构建成本效益高,易于被任何人采用和使用的解决方案。 让我们开始吧!

项目动机

疟疾是一种致命的,传染性的,由蚊子传播的疾病,由疟原虫的寄生虫引起,这种寄生虫是由被感染的雌性按蚊的叮咬传播的。 造成疟疾的寄生虫有五种,但恶性疟原虫间日 疟原虫有两种类型。

Malaria heat map

该地图显示,疟疾在全球范围内普遍存在,尤其是在热带地区,但该疾病的性质和致命性是该项目的主要动机。

世界卫生组织(WHO)的疟疾事实表明,全球近一半的人口处于疟疾风险之中,每年有2亿多例疟疾病例,约40万人死于疟疾。 这是使疟疾检测和诊断快速,轻松和有效的动力。

疟疾检测方法

有几种方法可用于疟疾的检测和诊断。 我们的项目所基于的论文,Rajaraman等人的“ 预训练的卷积神经网络作为特征提取器,以改善薄血涂片图像中的疟疾寄生虫检测 ”,介绍了一些方法,包括聚合酶链React(PCR) )和快速诊断测试(RDT)。 这两种测试通常用于无法提供高质量显微镜服务的地方。

卡洛斯·阿里扎(Carlos Ariza)的文章“ 疟疾英雄:用于更快地进行疟疾诊断的网络应用程序 ”(我在Adrian Rosebrock的“ 使用Keras进行深度学习和医学图像分析 ”中了解到)表示,标准的疟疾诊断通常基于涂血工作流程。 我感谢这些出色的资源的作者,使我对疟疾的流行,诊断和治疗有了更多的了解。

Blood smear workflow for Malaria detection

疟疾检测的血液涂片工作流程

根据WHO协议,诊断通常涉及以100倍放大倍率对血液涂片进行深入检查。 受过训练的人员手动计算出5,000个细胞中有多少个红细胞含有寄生虫。 正如Rajaraman等人所述,以上引用的论文解释了:

浓血涂片有助于检测寄生虫的存在,而薄血涂片有助于识别引起感染的寄生虫的种类(疾病控制与预防中心,2012)。 诊断的准确性在很大程度上取决于人类的专业知识,并且可能因观察者之间的差异以及疾病流行/资源受限地区的大型诊断所施加的责任而受到不利影响(Mitiku,Mengistu和Gelaw,2003年)。 使用替代技术,例如聚合酶链React(PCR)和快速诊断测试(RDT); 然而,PCR分析的性能有限(Hommelsheim等,2014),而在疾病流行地区,RDTs的成本效益较低(Hawkes,Katsuva和Masumbuko,2009)。

因此,疟疾检测可以受益于使用深度学习的自动化。

深度学习以检测疟疾

血涂片的手动诊断是一个密集的手动过程,需要专业知识来对寄生和未感染的细胞进行分类和计数。 此过程可能无法很好地扩展,尤其是在找不到合适的专业知识的地区。 在利用最新的图像处理和分析技术提取手工设计的特征并建立基于机器学习的分类模型方面取得了一些进展。 但是,由于手工设计的功能会花费大量时间,因此这些模型无法扩展,因为有更多数据可用于训练。

深度学习模型,或更具体地说是卷积神经网络(CNN),已被证明在多种计算机视觉任务中非常有效。 (如果您想了解有关CNN的其他背景知识,我建议阅读CS231n用于视觉识别的卷积神经网络 。)简而言之,CNN模型中的关键层包括卷积和池化层,如下图所示。

A typical CNN architecture

典型的CNN架构

卷积层从数据中学习空间分层模式,这些数据也是平移不变的,因此它们能够学习图像的不同方面。 例如,第一卷积层将学习较小的局部模式,例如边缘和拐角,第二卷积层将基于第一层的特征学习较大的模式,依此类推。 这使CNN可以自动进行要素工程设计并学习有效地概括新数据点的有效要素。 池化层有助于降低采样率和减小尺寸。

因此,CNN有助于自动化和可扩展的特征工程。 此外,在模型末尾插入密集层使我们能够执行诸如图像分类之类的任务。 使用像CNN之类的深度学习模型进行自动疟疾检测可能非常有效,便宜且可扩展,尤其是在迁移学习和预训练模型问世的情况下,即使在数据较少等约束下,该方法也能很好地工作。

Rajaraman等人的论文利用数据集上的六个预训练模型在检测疟疾和未感染样本方面获得了95.9%的惊人准确性。 我们的重点是从头开始尝试一些简单的CNN模型以及使用转移学习的一些预训练模型,以查看我们可以在同一数据集上获得的结果。 我们将使用包括Python和TensorFlow在内的开源工具和框架来构建我们的模型。

数据集

我们用于分析的数据来自利斯特山国家生物医学通信中心(LHNCBC)的研究人员,该中心是国家医学图书馆(NLM)的一部分,他们已经仔细收集并注释了公开的健康和感染血液涂片图像的数据集 。 这些研究人员开发了一种用于疟疾检测的移动应用程序,应用程序可在连接到常规光学显微镜的标准Android智能手机上运行。 他们使用了孟加拉国吉大港医学院医院收集并拍照的150 株恶性疟原虫感染的吉姆萨染色薄载玻片和50名健康患者的照片。 智能手机的内置相机会为每个微观视野获取幻灯片的图像。 图像由泰国曼谷Mahidol-Oxford热带医学研究室的专业幻灯片阅读器手动注释。

让我们简要地检查一下数据集的结构。 首先,我将安装一些基本的依赖项(基于所使用的操作系统)。

Installing dependencies

我在云上使用带有GPU的基于Debian的系统,因此可以更快地运行模型。 要查看目录结构,我们必须使用sudo apt install tr​​ee安装树依赖项(如果没有)。

Installing the tree dependency

我们有两个包含感染细胞和健康细胞图像的文件夹。 通过输入以下内容,我们可以获得有关图像总数的更多详细信息:


 
  1.  
     
  2.  
     
  3.  
     
  4.  
    import
  5.  
    os
  6.  
     
  7.  
     
  8.  
    import
  9.  
    glob
  10.  
     
  11.  
     
  12.  
     
  13.  
    base_dir
  14.  
    =
  15.  
    os .
  16.  
    path .
  17.  
    join
  18.  
    (
  19.  
    './cell_images'
  20.  
    )
  21.  
     
  22.  
    infected_dir
  23.  
    =
  24.  
    os .
  25.  
    path .
  26.  
    join
  27.  
    ( base_dir
  28.  
    ,
  29.  
    'Parasitized'
  30.  
    )
  31.  
     
  32.  
    healthy_dir
  33.  
    =
  34.  
    os .
  35.  
    path .
  36.  
    join
  37.  
    ( base_dir
  38.  
    ,
  39.  
    'Uninfected'
  40.  
    )
  41.  
     
  42.  
     
  43.  
     
  44.  
    infected_files
  45.  
    =
  46.  
    glob .
  47.  
    glob
  48.  
    ( infected_dir+
  49.  
    '/*.png'
  50.  
    )
  51.  
     
  52.  
    healthy_files
  53.  
    =
  54.  
    glob .
  55.  
    glob
  56.  
    ( healthy_dir+
  57.  
    '/*.png'
  58.  
    )
  59.  
     
  60.  
     
  61.  
    len
  62.  
    ( infected_files
  63.  
    )
  64.  
    ,
  65.  
    len
  66.  
    ( healthy_files
  67.  
    )
  68.  
     
  69.  
     
  70.  
     
  71.  
     
  72.  
    # Output
  73.  
     
  74.  
     
  75.  
    (
  76.  
    13779
  77.  
    ,
  78.  
    13779
  79.  
    )
  80.  
     
  81.  
     

看起来我们有一个平衡的数据集,其中包含13779个疟疾和13779个非疟疾(未感染)细胞图像。 让我们以此为基础构建一个数据框架,当我们开始构建数据集时将使用它。


 
  1.  
     
  2.  
     
  3.  
     
  4.  
    import numpy
  5.  
    as np
  6.  
     
  7.  
     
  8.  
    import pandas
  9.  
    as pd
  10.  
     
  11.  
     
  12.  
     
  13.  
    np.
  14.  
    random .
  15.  
    seed
  16.  
    (
  17.  
    42
  18.  
    )
  19.  
     
  20.  
     
  21.  
     
  22.  
    files_df
  23.  
    = pd.
  24.  
    DataFrame
  25.  
    (
  26.  
    {
  27.  
     
  28.  
       
  29.  
    'filename' : infected_files + healthy_files
  30.  
    ,
  31.  
     
  32.  
       
  33.  
    'label' :
  34.  
    [
  35.  
    'malaria'
  36.  
    ] *
  37.  
    len
  38.  
    ( infected_files
  39.  
    ) +
  40.  
    [
  41.  
    'healthy'
  42.  
    ] *
  43.  
    len
  44.  
    ( healthy_files
  45.  
    )
  46.  
     
  47.  
     
  48.  
    }
  49.  
    ) .
  50.  
    sample
  51.  
    ( frac
  52.  
    =
  53.  
    1
  54.  
    , random_state
  55.  
    =
  56.  
    42
  57.  
    ) .
  58.  
    reset_index
  59.  
    ( drop
  60.  
    =
  61.  
    True
  62.  
    )
  63.  
     
  64.  
     
  65.  
     
  66.  
    files_df.
  67.  
    head
  68.  
    (
  69.  
    )
  70.  
     
  71.  
     
Datasets

建立和探索图像数据集

要构建深度学习模型,我们需要训练数据,但我们还需要在看不见的数据上测试模型的性能。 我们将分别使用60:10:30的比例划分训练,验证和测试数据集。 我们将在训练过程中利用训练和验证数据集,并在测试数据集上检查模型的性能。


 
  1.  
     
  2.  
     
  3.  
     
  4.  
    from sklearn.
  5.  
    model_selection
  6.  
    import train_test_split
  7.  
     
  8.  
     
  9.  
    from
  10.  
    collections
  11.  
    import Counter
  12.  
     
  13.  
     
  14.  
     
  15.  
    train_files
  16.  
    , test_files
  17.  
    , train_labels
  18.  
    , test_labels
  19.  
    = train_test_split
  20.  
    ( files_df
  21.  
    [
  22.  
    'filename'
  23.  
    ] .
  24.  
    values
  25.  
    ,
  26.  
     
  27.  
                                                                          files_df
  28.  
    [
  29.  
    'label'
  30.  
    ] .
  31.  
    values
  32.  
    ,
  33.  
     
  34.  
                                                                          test_size
  35.  
    =
  36.  
    0.3
  37.  
    , random_state
  38.  
    =
  39.  
    42
  40.  
    )
  41.  
     
  42.  
    train_files
  43.  
    , val_files
  44.  
    , train_labels
  45.  
    , val_labels
  46.  
    = train_test_split
  47.  
    ( train_files
  48.  
    ,
  49.  
     
  50.  
                                                                        train_labels
  51.  
    ,
  52.  
     
  53.  
                                                                        test_size
  54.  
    =
  55.  
    0.1
  56.  
    , random_state
  57.  
    =
  58.  
    42
  59.  
    )
  60.  
     
  61.  
     
  62.  
     
  63.  
     
  64.  
    print
  65.  
    ( train_files.
  66.  
    shape
  67.  
    , val_files.
  68.  
    shape
  69.  
    , test_files.
  70.  
    shape
  71.  
    )
  72.  
     
  73.  
     
  74.  
    print
  75.  
    (
  76.  
    'Train:'
  77.  
    , Counter
  78.  
    ( train_labels
  79.  
    )
  80.  
    ,
  81.  
    ' \n Val:'
  82.  
    , Counter
  83.  
    ( val_labels
  84.  
    )
  85.  
    ,
  86.  
    ' \n Test:'
  87.  
    , Counter
  88.  
    ( test_labels
  89.  
    )
  90.  
    )
  91.  
     
  92.  
     
  93.  
     
  94.  
     
  95.  
    # Output
  96.  
     
  97.  
     
  98.  
    (
  99.  
    17361
  100.  
    ,
  101.  
    )
  102.  
    (
  103.  
    1929
  104.  
    ,
  105.  
    )
  106.  
    (
  107.  
    8268
  108.  
    ,
  109.  
    )
  110.  
     
  111.  
    Train: Counter
  112.  
    (
  113.  
    {
  114.  
    'healthy' :
  115.  
    8734
  116.  
    ,
  117.  
    'malaria' :
  118.  
    8627
  119.  
    }
  120.  
    )
  121.  
     
  122.  
    Val: Counter
  123.  
    (
  124.  
    {
  125.  
    'healthy' :
  126.  
    970
  127.  
    ,
  128.  
    'malaria' :
  129.  
    959
  130.  
    }
  131.  
    )
  132.  
     
  133.  
    Test: Counter
  134.  
    (
  135.  
    {
  136.  
    'malaria' :
  137.  
    4193
  138.  
    ,
  139.  
    'healthy' :
  140.  
    4075
  141.  
    }
  142.  
    )
  143.  
     
  144.  
     

由于血液涂片和细胞图像会根据人,测试方法和照片的方向而变化,因此图像的尺寸将不相等。 让我们获取训练数据集的一些摘要统计信息,以确定最佳图像尺寸(请记住,我们根本不接触测试数据集!)。


 
  1.  
     
  2.  
     
  3.  
     
  4.  
    import cv2
  5.  
     
  6.  
     
  7.  
    from concurrent
  8.  
    import futures
  9.  
     
  10.  
     
  11.  
    import
  12.  
    threading
  13.  
     
  14.  
     
  15.  
     
  16.  
     
  17.  
    def get_img_shape_parallel
  18.  
    ( idx
  19.  
    , img
  20.  
    , total_imgs
  21.  
    ) :
  22.  
     
  23.  
       
  24.  
    if idx %
  25.  
    5000
  26.  
    ==
  27.  
    0
  28.  
    or idx
  29.  
    ==
  30.  
    ( total_imgs -
  31.  
    1
  32.  
    ) :
  33.  
     
  34.  
           
  35.  
    print
  36.  
    (
  37.  
    '{}: working on img num: {}' .
  38.  
    format
  39.  
    (
  40.  
    threading .
  41.  
    current_thread
  42.  
    (
  43.  
    ) .
  44.  
    name
  45.  
    ,
  46.  
     
  47.  
                                                      idx
  48.  
    )
  49.  
    )
  50.  
     
  51.  
       
  52.  
    return cv2.
  53.  
    imread
  54.  
    ( img
  55.  
    ) .
  56.  
    shape
  57.  
     
  58.  
     
  59.  
     
  60.  
    ex
  61.  
    = futures.
  62.  
    ThreadPoolExecutor
  63.  
    ( max_workers
  64.  
    =
  65.  
    None
  66.  
    )
  67.  
     
  68.  
    data_inp
  69.  
    =
  70.  
    [
  71.  
    ( idx
  72.  
    , img
  73.  
    ,
  74.  
    len
  75.  
    ( train_files
  76.  
    )
  77.  
    )
  78.  
    for idx
  79.  
    , img
  80.  
    in
  81.  
    enumerate
  82.  
    ( train_files
  83.  
    )
  84.  
    ]
  85.  
     
  86.  
     
  87.  
    print
  88.  
    (
  89.  
    'Starting Img shape computation:'
  90.  
    )
  91.  
     
  92.  
    train_img_dims_map
  93.  
    = ex.
  94.  
    map
  95.  
    ( get_img_shape_parallel
  96.  
    ,
  97.  
     
  98.  
                               
  99.  
    [ record
  100.  
    [
  101.  
    0
  102.  
    ]
  103.  
    for record
  104.  
    in data_inp
  105.  
    ]
  106.  
    ,
  107.  
     
  108.  
                               
  109.  
    [ record
  110.  
    [
  111.  
    1
  112.  
    ]
  113.  
    for record
  114.  
    in data_inp
  115.  
    ]
  116.  
    ,
  117.  
     
  118.  
                               
  119.  
    [ record
  120.  
    [
  121.  
    2
  122.  
    ]
  123.  
    for record
  124.  
    in data_inp
  125.  
    ]
  126.  
    )
  127.  
     
  128.  
    train_img_dims
  129.  
    =
  130.  
    list
  131.  
    ( train_img_dims_map
  132.  
    )
  133.  
     
  134.  
     
  135.  
    print
  136.  
    (
  137.  
    'Min Dimensions:'
  138.  
    , np.
  139.  
    min
  140.  
    ( train_img_dims
  141.  
    , axis
  142.  
    =
  143.  
    0
  144.  
    )
  145.  
    )
  146.  
     
  147.  
     
  148.  
    print
  149.  
    (
  150.  
    'Avg Dimensions:'
  151.  
    , np.
  152.  
    mean
  153.  
    ( train_img_dims
  154.  
    , axis
  155.  
    =
  156.  
    0
  157.  
    )
  158.  
    )
  159.  
     
  160.  
     
  161.  
    print
  162.  
    (
  163.  
    'Median Dimensions:'
  164.  
    , np.
  165.  
    median
  166.  
    ( train_img_dims
  167.  
    , axis
  168.  
    =
  169.  
    0
  170.  
    )
  171.  
    )
  172.  
     
  173.  
     
  174.  
    print
  175.  
    (
  176.  
    'Max Dimensions:'
  177.  
    , np.
  178.  
    max
  179.  
    ( train_img_dims
  180.  
    , axis
  181.  
    =
  182.  
    0
  183.  
    )
  184.  
    )
  185.  
     
  186.  
     
  187.  
     
  188.  
     
  189.  
     
  190.  
     
  191.  
    # Output
  192.  
     
  193.  
    Starting Img shape computation:
  194.  
     
  195.  
    ThreadPoolExecutor-
  196.  
    0 _0: working on img num:
  197.  
    0
  198.  
     
  199.  
    ThreadPoolExecutor-
  200.  
    0 _17: working on img num:
  201.  
    5000
  202.  
     
  203.  
    ThreadPoolExecutor-
  204.  
    0 _15: working on img num:
  205.  
    10000
  206.  
     
  207.  
    ThreadPoolExecutor-
  208.  
    0 _1: working on img num:
  209.  
    15000
  210.  
     
  211.  
    ThreadPoolExecutor-
  212.  
    0 _7: working on img num:
  213.  
    17360
  214.  
     
  215.  
    Min Dimensions:
  216.  
    [
  217.  
    46
  218.  
    46  
  219.  
    3
  220.  
    ]
  221.  
     
  222.  
    Avg Dimensions:
  223.  
    [
  224.  
    132.77311215
  225.  
    132.45757733  
  226.  
    3 .
  227.  
    ]
  228.  
     
  229.  
    Median Dimensions:
  230.  
    [
  231.  
    130 .
  232.  
    130 .  
  233.  
    3 .
  234.  
    ]
  235.  
     
  236.  
    Max Dimensions:
  237.  
    [
  238.  
    385
  239.  
    394  
  240.  
    3
  241.  
    ]
  242.  
     
  243.  
     

我们应用并行处理来加快图像读取操作,并基于摘要统计信息,将每个图像的大小调整为125x125像素。 让我们加载所有图像并将它们调整为这些固定尺寸。


 
  1.  
     
  2.  
     
  3.  
     
  4.  
    IMG_DIMS
  5.  
    =
  6.  
    (
  7.  
    125
  8.  
    ,
  9.  
    125
  10.  
    )
  11.  
     
  12.  
     
  13.  
     
  14.  
     
  15.  
    def get_img_data_parallel
  16.  
    ( idx
  17.  
    , img
  18.  
    , total_imgs
  19.  
    ) :
  20.  
     
  21.  
       
  22.  
    if idx %
  23.  
    5000
  24.  
    ==
  25.  
    0
  26.  
    or idx
  27.  
    ==
  28.  
    ( total_imgs -
  29.  
    1
  30.  
    ) :
  31.  
     
  32.  
           
  33.  
    print
  34.  
    (
  35.  
    '{}: working on img num: {}' .
  36.  
    format
  37.  
    (
  38.  
    threading .
  39.  
    current_thread
  40.  
    (
  41.  
    ) .
  42.  
    name
  43.  
    ,
  44.  
     
  45.  
                                                      idx
  46.  
    )
  47.  
    )
  48.  
     
  49.  
        img
  50.  
    = cv2.
  51.  
    imread
  52.  
    ( img
  53.  
    )
  54.  
     
  55.  
        img
  56.  
    = cv2.
  57.  
    resize
  58.  
    ( img
  59.  
    , dsize
  60.  
    = IMG_DIMS
  61.  
    ,
  62.  
     
  63.  
                         interpolation
  64.  
    = cv2.
  65.  
    INTER_CUBIC
  66.  
    )
  67.  
     
  68.  
        img
  69.  
    = np.
  70.  
    array
  71.  
    ( img
  72.  
    , dtype
  73.  
    = np.
  74.  
    float32
  75.  
    )
  76.  
     
  77.  
       
  78.  
    return img
  79.  
     
  80.  
     
  81.  
     
  82.  
    ex
  83.  
    = futures.
  84.  
    ThreadPoolExecutor
  85.  
    ( max_workers
  86.  
    =
  87.  
    None
  88.  
    )
  89.  
     
  90.  
    train_data_inp
  91.  
    =
  92.  
    [
  93.  
    ( idx
  94.  
    , img
  95.  
    ,
  96.  
    len
  97.  
    ( train_files
  98.  
    )
  99.  
    )
  100.  
    for idx
  101.  
    , img
  102.  
    in
  103.  
    enumerate
  104.  
    ( train_files
  105.  
    )
  106.  
    ]
  107.  
     
  108.  
    val_data_inp
  109.  
    =
  110.  
    [
  111.  
    ( idx
  112.  
    , img
  113.  
    ,
  114.  
    len
  115.  
    ( val_files
  116.  
    )
  117.  
    )
  118.  
    for idx
  119.  
    , img
  120.  
    in
  121.  
    enumerate
  122.  
    ( val_files
  123.  
    )
  124.  
    ]
  125.  
     
  126.  
    test_data_inp
  127.  
    =
  128.  
    [
  129.  
    ( idx
  130.  
    , img
  131.  
    ,
  132.  
    len
  133.  
    ( test_files
  134.  
    )
  135.  
    )
  136.  
    for idx
  137.  
    , img
  138.  
    in
  139.  
    enumerate
  140.  
    ( test_files
  141.  
    )
  142.  
    ]
  143.  
     
  144.  
     
  145.  
     
  146.  
     
  147.  
    print
  148.  
    (
  149.  
    'Loading Train Images:'
  150.  
    )
  151.  
     
  152.  
    train_data_map
  153.  
    = ex.
  154.  
    map
  155.  
    ( get_img_data_parallel
  156.  
    ,
  157.  
     
  158.  
                           
  159.  
    [ record
  160.  
    [
  161.  
    0
  162.  
    ]
  163.  
    for record
  164.  
    in train_data_inp
  165.  
    ]
  166.  
    ,
  167.  
     
  168.  
                           
  169.  
    [ record
  170.  
    [
  171.  
    1
  172.  
    ]
  173.  
    for record
  174.  
    in train_data_inp
  175.  
    ]
  176.  
    ,
  177.  
     
  178.  
                           
  179.  
    [ record
  180.  
    [
  181.  
    2
  182.  
    ]
  183.  
    for record
  184.  
    in train_data_inp
  185.  
    ]
  186.  
    )
  187.  
     
  188.  
    train_data
  189.  
    = np.
  190.  
    array
  191.  
    (
  192.  
    list
  193.  
    ( train_data_map
  194.  
    )
  195.  
    )
  196.  
     
  197.  
     
  198.  
     
  199.  
     
  200.  
    print
  201.  
    (
  202.  
    ' \n Loading Validation Images:'
  203.  
    )
  204.  
     
  205.  
    val_data_map
  206.  
    = ex.
  207.  
    map
  208.  
    ( get_img_data_parallel
  209.  
    ,
  210.  
     
  211.  
                           
  212.  
    [ record
  213.  
    [
  214.  
    0
  215.  
    ]
  216.  
    for record
  217.  
    in val_data_inp
  218.  
    ]
  219.  
    ,
  220.  
     
  221.  
                           
  222.  
    [ record
  223.  
    [
  224.  
    1
  225.  
    ]
  226.  
    for record
  227.  
    in val_data_inp
  228.  
    ]
  229.  
    ,
  230.  
     
  231.  
                           
  232.  
    [ record
  233.  
    [
  234.  
    2
  235.  
    ]
  236.  
    for record
  237.  
    in val_data_inp
  238.  
    ]
  239.  
    )
  240.  
     
  241.  
    val_data
  242.  
    = np.
  243.  
    array
  244.  
    (
  245.  
    list
  246.  
    ( val_data_map
  247.  
    )
  248.  
    )
  249.  
     
  250.  
     
  251.  
     
  252.  
     
  253.  
    print
  254.  
    (
  255.  
    ' \n Loading Test Images:'
  256.  
    )
  257.  
     
  258.  
    test_data_map
  259.  
    = ex.
  260.  
    map
  261.  
    ( get_img_data_parallel
  262.  
    ,
  263.  
     
  264.  
                           
  265.  
    [ record
  266.  
    [
  267.  
    0
  268.  
    ]
  269.  
    for record
  270.  
    in test_data_inp
  271.  
    ]
  272.  
    ,
  273.  
     
  274.  
                           
  275.  
    [ record
  276.  
    [
  277.  
    1
  278.  
    ]
  279.  
    for record
  280.  
    in test_data_inp
  281.  
    ]
  282.  
    ,
  283.  
     
  284.  
                           
  285.  
    [ record
  286.  
    [
  287.  
    2
  288.  
    ]
  289.  
    for record
  290.  
    in test_data_inp
  291.  
    ]
  292.  
    )
  293.  
     
  294.  
    test_data
  295.  
    = np.
  296.  
    array
  297.  
    (
  298.  
    list
  299.  
    ( test_data_map
  300.  
    )
  301.  
    )
  302.  
     
  303.  
     
  304.  
     
  305.  
    train_data.
  306.  
    shape
  307.  
    , val_data.
  308.  
    shape
  309.  
    , test_data.
  310.  
    shape  
  311.  
     
  312.  
     
  313.  
     
  314.  
     
  315.  
     
  316.  
     
  317.  
    # Output
  318.  
     
  319.  
    Loading Train Images:
  320.  
     
  321.  
    ThreadPoolExecutor-
  322.  
    1 _0: working on img num:
  323.  
    0
  324.  
     
  325.  
    ThreadPoolExecutor-
  326.  
    1 _12: working on img num:
  327.  
    5000
  328.  
     
  329.  
    ThreadPoolExecutor-
  330.  
    1 _6: working on img num:
  331.  
    10000
  332.  
     
  333.  
    ThreadPoolExecutor-
  334.  
    1 _10: working on img num:
  335.  
    15000
  336.  
     
  337.  
    ThreadPoolExecutor-
  338.  
    1 _3: working on img num:
  339.  
    17360
  340.  
     
  341.  
     
  342.  
     
  343.  
    Loading Validation Images:
  344.  
     
  345.  
    ThreadPoolExecutor-
  346.  
    1 _13: working on img num:
  347.  
    0
  348.  
     
  349.  
    ThreadPoolExecutor-
  350.  
    1 _18: working on img num:
  351.  
    1928
  352.  
     
  353.  
     
  354.  
     
  355.  
    Loading Test Images:
  356.  
     
  357.  
    ThreadPoolExecutor-
  358.  
    1 _5: working on img num:
  359.  
    0
  360.  
     
  361.  
    ThreadPoolExecutor-
  362.  
    1 _19: working on img num:
  363.  
    5000
  364.  
     
  365.  
    ThreadPoolExecutor-
  366.  
    1 _8: working on img num:
  367.  
    8267
  368.  
     
  369.  
     
  370.  
    (
  371.  
    (
  372.  
    17361
  373.  
    ,
  374.  
    125
  375.  
    ,
  376.  
    125
  377.  
    ,
  378.  
    3
  379.  
    )
  380.  
    ,
  381.  
    (
  382.  
    1929
  383.  
    ,
  384.  
    125
  385.  
    ,
  386.  
    125
  387.  
    ,
  388.  
    3
  389.  
    )
  390.  
    ,
  391.  
    (
  392.  
    8268
  393.  
    ,
  394.  
    125
  395.  
    ,
  396.  
    125
  397.  
    ,
  398.  
    3
  399.  
    )
  400.  
    )
  401.  
     
  402.  
     

我们再次利用并行处理来加快与图像加载和调整大小有关的计算。 最后,我们得到所需尺寸的图像张量,如前面的输出所示。 现在,我们可以查看一些样本细胞图像,以了解数据的外观。


 
  1.  
     
  2.  
     
  3.  
     
  4.  
    import matplotlib.
  5.  
    pyplot
  6.  
    as plt
  7.  
     
  8.  
    %matplotlib inline
  9.  
     
  10.  
     
  11.  
     
  12.  
    plt.
  13.  
    figure
  14.  
    (
  15.  
    1
  16.  
    , figsize
  17.  
    =
  18.  
    (
  19.  
    8
  20.  
    ,
  21.  
    8
  22.  
    )
  23.  
    )
  24.  
     
  25.  
    n
  26.  
    =
  27.  
    0
  28.  
     
  29.  
     
  30.  
    for i
  31.  
    in
  32.  
    range
  33.  
    (
  34.  
    16
  35.  
    ) :
  36.  
     
  37.  
        n +
  38.  
    =
  39.  
    1
  40.  
     
  41.  
        r
  42.  
    = np.
  43.  
    random .
  44.  
    randint
  45.  
    (
  46.  
    0
  47.  
    , train_data.
  48.  
    shape
  49.  
    [
  50.  
    0
  51.  
    ]
  52.  
    ,
  53.  
    1
  54.  
    )
  55.  
     
  56.  
        plt.
  57.  
    subplot
  58.  
    (
  59.  
    4
  60.  
    ,
  61.  
    4
  62.  
    , n
  63.  
    )
  64.  
     
  65.  
        plt.
  66.  
    subplots_adjust
  67.  
    ( hspace
  68.  
    =
  69.  
    0.5
  70.  
    , wspace
  71.  
    =
  72.  
    0.5
  73.  
    )
  74.  
     
  75.  
        plt.
  76.  
    imshow
  77.  
    ( train_data
  78.  
    [ r
  79.  
    [
  80.  
    0
  81.  
    ]
  82.  
    ] /
  83.  
    255 .
  84.  
    )
  85.  
     
  86.  
        plt.
  87.  
    title
  88.  
    (
  89.  
    '{}' .
  90.  
    format
  91.  
    ( train_labels
  92.  
    [ r
  93.  
    [
  94.  
    0
  95.  
    ]
  96.  
    ]
  97.  
    )
  98.  
    )
  99.  
     
  100.  
        plt.
  101.  
    xticks
  102.  
    (
  103.  
    [
  104.  
    ]
  105.  
    )
  106.  
    , plt.
  107.  
    yticks
  108.  
    (
  109.  
    [
  110.  
    ]
  111.  
    )
  112.  
     
  113.  
     
Malaria cell samples

根据这些样本图像,我们可以看到疟疾图像与健康细胞图像之间的细微差别。 我们将使我们的深度学习模型在模型训练期间尝试学习这些模式。

在开始训练模型之前,我们必须设置一些基本的配置设置。


 
  1.  
     
  2.  
     
  3.  
     
  4.  
    BATCH_SIZE
  5.  
    =
  6.  
    64
  7.  
     
  8.  
    NUM_CLASSES
  9.  
    =
  10.  
    2
  11.  
     
  12.  
    EPOCHS
  13.  
    =
  14.  
    25
  15.  
     
  16.  
    INPUT_SHAPE
  17.  
    =
  18.  
    (
  19.  
    125
  20.  
    ,
  21.  
    125
  22.  
    ,
  23.  
    3
  24.  
    )
  25.  
     
  26.  
     
  27.  
     
  28.  
    train_imgs_scaled
  29.  
    = train_data /
  30.  
    255 .
  31.  
     
  32.  
     
  33.  
    val_imgs_scaled
  34.  
    = val_data /
  35.  
    255 .
  36.  
     
  37.  
     
  38.  
     
  39.  
     
  40.  
    # encode text category labels
  41.  
     
  42.  
     
  43.  
    from sklearn.
  44.  
    preprocessing
  45.  
    import LabelEncoder
  46.  
     
  47.  
     
  48.  
     
  49.  
    le
  50.  
    = LabelEncoder
  51.  
    (
  52.  
    )
  53.  
     
  54.  
    le.
  55.  
    fit
  56.  
    ( train_labels
  57.  
    )
  58.  
     
  59.  
    train_labels_enc
  60.  
    = le.
  61.  
    transform
  62.  
    ( train_labels
  63.  
    )
  64.  
     
  65.  
    val_labels_enc
  66.  
    = le.
  67.  
    transform
  68.  
    ( val_labels
  69.  
    )
  70.  
     
  71.  
     
  72.  
     
  73.  
     
  74.  
    print
  75.  
    ( train_labels
  76.  
    [ :
  77.  
    6
  78.  
    ]
  79.  
    , train_labels_enc
  80.  
    [ :
  81.  
    6
  82.  
    ]
  83.  
    )
  84.  
     
  85.  
     
  86.  
     
  87.  
     
  88.  
     
  89.  
     
  90.  
    # Output
  91.  
     
  92.  
     
  93.  
    [
  94.  
    'malaria'
  95.  
    'malaria'
  96.  
    'malaria'
  97.  
    'healthy'
  98.  
    'healthy'
  99.  
    'malaria'
  100.  
    ]
  101.  
    [
  102.  
    1
  103.  
    1
  104.  
    1
  105.  
    0
  106.  
    0
  107.  
    1
  108.  
    ]
  109.  
     
  110.  
     

我们固定图像尺寸,批处理大小和历元,并对分类类标签进行编码。 TensorFlow 2.0的Alpha版本已于2019年3月发布,此练习是尝试它的完美借口。


 
  1.  
     
  2.  
     
  3.  
     
  4.  
    import tensorflow
  5.  
    as tf
  6.  
     
  7.  
     
  8.  
     
  9.  
     
  10.  
    # Load the TensorBoard notebook extension (optional)
  11.  
     
  12.  
    %load_ext tensorboard.
  13.  
    notebook
  14.  
     
  15.  
     
  16.  
     
  17.  
    tf.
  18.  
    random .
  19.  
    set_seed
  20.  
    (
  21.  
    42
  22.  
    )
  23.  
     
  24.  
    tf.__version__
  25.  
     
  26.  
     
  27.  
     
  28.  
     
  29.  
    # Output
  30.  
     
  31.  
     
  32.  
    '2.0.0-alpha0'
  33.  
     
  34.  
     

深度学习模型训练

在模型训练阶段,我们将构建三个深度学习模型,并用我们的训练数据对其进行训练,并使用验证数据比较它们的性能。 然后,我们将保存这些模型,并在模型评估阶段稍后使用它们。

模型1:从无到有的CNN

我们的第一个疟疾检测模型将从头开始构建和训练基本的CNN。 首先,让我们定义模型架构。


 
  1.  
     
  2.  
     
  3.  
     
  4.  
    inp
  5.  
    = tf.
  6.  
    keras .
  7.  
    layers .
  8.  
    Input
  9.  
    ( shape
  10.  
    = INPUT_SHAPE
  11.  
    )
  12.  
     
  13.  
     
  14.  
     
  15.  
    conv1
  16.  
    = tf.
  17.  
    keras .
  18.  
    layers .
  19.  
    Conv2D
  20.  
    (
  21.  
    32
  22.  
    , kernel_size
  23.  
    =
  24.  
    (
  25.  
    3
  26.  
    ,
  27.  
    3
  28.  
    )
  29.  
    ,
  30.  
     
  31.  
                                   activation
  32.  
    =
  33.  
    'relu'
  34.  
    , padding
  35.  
    =
  36.  
    'same'
  37.  
    )
  38.  
    ( inp
  39.  
    )
  40.  
     
  41.  
    pool1
  42.  
    = tf.
  43.  
    keras .
  44.  
    layers .
  45.  
    MaxPooling2D
  46.  
    ( pool_size
  47.  
    =
  48.  
    (
  49.  
    2
  50.  
    ,
  51.  
    2
  52.  
    )
  53.  
    )
  54.  
    ( conv1
  55.  
    )
  56.  
     
  57.  
    conv2
  58.  
    = tf.
  59.  
    keras .
  60.  
    layers .
  61.  
    Conv2D
  62.  
    (
  63.  
    64
  64.  
    , kernel_size
  65.  
    =
  66.  
    (
  67.  
    3
  68.  
    ,
  69.  
    3
  70.  
    )
  71.  
    ,
  72.  
     
  73.  
                                   activation
  74.  
    =
  75.  
    'relu'
  76.  
    , padding
  77.  
    =
  78.  
    'same'
  79.  
    )
  80.  
    ( pool1
  81.  
    )
  82.  
     
  83.  
    pool2
  84.  
    = tf.
  85.  
    keras .
  86.  
    layers .
  87.  
    MaxPooling2D
  88.  
    ( pool_size
  89.  
    =
  90.  
    (
  91.  
    2
  92.  
    ,
  93.  
    2
  94.  
    )
  95.  
    )
  96.  
    ( conv2
  97.  
    )
  98.  
     
  99.  
    conv3
  100.  
    = tf.
  101.  
    keras .
  102.  
    layers .
  103.  
    Conv2D
  104.  
    (
  105.  
    128
  106.  
    , kernel_size
  107.  
    =
  108.  
    (
  109.  
    3
  110.  
    ,
  111.  
    3
  112.  
    )
  113.  
    ,
  114.  
     
  115.  
                                   activation
  116.  
    =
  117.  
    'relu'
  118.  
    , padding
  119.  
    =
  120.  
    'same'
  121.  
    )
  122.  
    ( pool2
  123.  
    )
  124.  
     
  125.  
    pool3
  126.  
    = tf.
  127.  
    keras .
  128.  
    layers .
  129.  
    MaxPooling2D
  130.  
    ( pool_size
  131.  
    =
  132.  
    (
  133.  
    2
  134.  
    ,
  135.  
    2
  136.  
    )
  137.  
    )
  138.  
    ( conv3
  139.  
    )
  140.  
     
  141.  
     
  142.  
     
  143.  
    flat
  144.  
    = tf.
  145.  
    keras .
  146.  
    layers .
  147.  
    Flatten
  148.  
    (
  149.  
    )
  150.  
    ( pool3
  151.  
    )
  152.  
     
  153.  
     
  154.  
     
  155.  
    hidden1
  156.  
    = tf.
  157.  
    keras .
  158.  
    layers .
  159.  
    Dense
  160.  
    (
  161.  
    512
  162.  
    , activation
  163.  
    =
  164.  
    'relu'
  165.  
    )
  166.  
    ( flat
  167.  
    )
  168.  
     
  169.  
    drop1
  170.  
    = tf.
  171.  
    keras .
  172.  
    layers .
  173.  
    Dropout
  174.  
    ( rate
  175.  
    =
  176.  
    0.3
  177.  
    )
  178.  
    ( hidden1
  179.  
    )
  180.  
     
  181.  
    hidden2
  182.  
    = tf.
  183.  
    keras .
  184.  
    layers .
  185.  
    Dense
  186.  
    (
  187.  
    512
  188.  
    , activation
  189.  
    =
  190.  
    'relu'
  191.  
    )
  192.  
    ( drop1
  193.  
    )
  194.  
     
  195.  
    drop2
  196.  
    = tf.
  197.  
    keras .
  198.  
    layers .
  199.  
    Dropout
  200.  
    ( rate
  201.  
    =
  202.  
    0.3
  203.  
    )
  204.  
    ( hidden2
  205.  
    )
  206.  
     
  207.  
     
  208.  
     
  209.  
    out
  210.  
    = tf.
  211.  
    keras .
  212.  
    layers .
  213.  
    Dense
  214.  
    (
  215.  
    1
  216.  
    , activation
  217.  
    =
  218.  
    'sigmoid'
  219.  
    )
  220.  
    ( drop2
  221.  
    )
  222.  
     
  223.  
     
  224.  
     
  225.  
    model
  226.  
    = tf.
  227.  
    keras .
  228.  
    Model
  229.  
    ( inputs
  230.  
    = inp
  231.  
    , outputs
  232.  
    = out
  233.  
    )
  234.  
     
  235.  
    model.
  236.  
    compile
  237.  
    ( optimizer
  238.  
    =
  239.  
    'adam'
  240.  
    ,
  241.  
     
  242.  
                    loss
  243.  
    =
  244.  
    'binary_crossentropy'
  245.  
    ,
  246.  
     
  247.  
                    metrics
  248.  
    =
  249.  
    [
  250.  
    'accuracy'
  251.  
    ]
  252.  
    )
  253.  
     
  254.  
    model.
  255.  
    summary
  256.  
    (
  257.  
    )
  258.  
     
  259.  
     
  260.  
     
  261.  
     
  262.  
     
  263.  
     
  264.  
    # Output
  265.  
     
  266.  
    Model:
  267.  
    "model"
  268.  
     
  269.  
    _________________________________________________________________
  270.  
     
  271.  
    Layer
  272.  
    (
  273.  
    type
  274.  
    )                 Output Shape              Param
  275.  
    #  
  276.  
     
  277.  
     
  278.  
    =================================================================
  279.  
     
  280.  
    input_1
  281.  
    ( InputLayer
  282.  
    )        
  283.  
    [
  284.  
    (
  285.  
    None
  286.  
    ,
  287.  
    125
  288.  
    ,
  289.  
    125
  290.  
    ,
  291.  
    3
  292.  
    )
  293.  
    ]    
  294.  
    0        
  295.  
     
  296.  
    _________________________________________________________________
  297.  
     
  298.  
    conv2d
  299.  
    ( Conv2D
  300.  
    )              
  301.  
    (
  302.  
    None
  303.  
    ,
  304.  
    125
  305.  
    ,
  306.  
    125
  307.  
    ,
  308.  
    32
  309.  
    )      
  310.  
    896      
  311.  
     
  312.  
    _________________________________________________________________
  313.  
     
  314.  
    max_pooling2d
  315.  
    ( MaxPooling2D
  316.  
    )
  317.  
    (
  318.  
    None
  319.  
    ,
  320.  
    62
  321.  
    ,
  322.  
    62
  323.  
    ,
  324.  
    32
  325.  
    )        
  326.  
    0        
  327.  
     
  328.  
    _________________________________________________________________
  329.  
     
  330.  
    conv2d_1
  331.  
    ( Conv2D
  332.  
    )            
  333.  
    (
  334.  
    None
  335.  
    ,
  336.  
    62
  337.  
    ,
  338.  
    62
  339.  
    ,
  340.  
    64
  341.  
    )        
  342.  
    18496    
  343.  
     
  344.  
    _________________________________________________________________
  345.  
     
  346.  
    ...
  347.  
     
  348.  
    ...
  349.  
     
  350.  
    _________________________________________________________________
  351.  
     
  352.  
    dense_1
  353.  
    ( Dense
  354.  
    )              
  355.  
    (
  356.  
    None
  357.  
    ,
  358.  
    512
  359.  
    )              
  360.  
    262656    
  361.  
     
  362.  
    _________________________________________________________________
  363.  
     
  364.  
    dropout_1
  365.  
    ( Dropout
  366.  
    )          
  367.  
    (
  368.  
    None
  369.  
    ,
  370.  
    512
  371.  
    )              
  372.  
    0        
  373.  
     
  374.  
    _________________________________________________________________
  375.  
     
  376.  
    dense_2
  377.  
    ( Dense
  378.  
    )              
  379.  
    (
  380.  
    None
  381.  
    ,
  382.  
    1
  383.  
    )                
  384.  
    513      
  385.  
     
  386.  
     
  387.  
    =================================================================
  388.  
     
  389.  
    Total params:
  390.  
    15
  391.  
    ,
  392.  
    102
  393.  
    ,
  394.  
    529
  395.  
     
  396.  
    Trainable params:
  397.  
    15
  398.  
    ,
  399.  
    102
  400.  
    ,
  401.  
    529
  402.  
     
  403.  
    Non-trainable params:
  404.  
    0
  405.  
     
  406.  
    _________________________________________________________________
  407.  
     
  408.  
     

基于此代码中的体系结构,我们的CNN模型具有三个卷积和池化层,然后是两个密集层,以及用于规范化的缺失。 让我们训练模型。


 
  1.  
     
  2.  
     
  3.  
     
  4.  
    import
  5.  
    datetime
  6.  
     
  7.  
     
  8.  
     
  9.  
    logdir
  10.  
    =
  11.  
    os .
  12.  
    path .
  13.  
    join
  14.  
    (
  15.  
    '/home/dipanzan_sarkar/projects/tensorboard_logs'
  16.  
    ,
  17.  
     
  18.  
                         
  19.  
    datetime .
  20.  
    datetime .
  21.  
    now
  22.  
    (
  23.  
    ) .
  24.  
    strftime
  25.  
    (
  26.  
    "%Y%m%d-%H%M%S"
  27.  
    )
  28.  
    )
  29.  
     
  30.  
    tensorboard_callback
  31.  
    = tf.
  32.  
    keras .
  33.  
    callbacks .
  34.  
    TensorBoard
  35.  
    ( logdir
  36.  
    , histogram_freq
  37.  
    =
  38.  
    1
  39.  
    )
  40.  
     
  41.  
    reduce_lr
  42.  
    = tf.
  43.  
    keras .
  44.  
    callbacks .
  45.  
    ReduceLROnPlateau
  46.  
    ( monitor
  47.  
    =
  48.  
    'val_loss'
  49.  
    , factor
  50.  
    =
  51.  
    0.5
  52.  
    ,
  53.  
     
  54.  
                                  patience
  55.  
    =
  56.  
    2
  57.  
    , min_lr
  58.  
    =
  59.  
    0.000001
  60.  
    )
  61.  
     
  62.  
    callbacks
  63.  
    =
  64.  
    [ reduce_lr
  65.  
    , tensorboard_callback
  66.  
    ]
  67.  
     
  68.  
     
  69.  
     
  70.  
    history
  71.  
    = model.
  72.  
    fit
  73.  
    ( x
  74.  
    = train_imgs_scaled
  75.  
    , y
  76.  
    = train_labels_enc
  77.  
    ,
  78.  
     
  79.  
                        batch_size
  80.  
    = BATCH_SIZE
  81.  
    ,
  82.  
     
  83.  
                        epochs
  84.  
    = EPOCHS
  85.  
    ,
  86.  
     
  87.  
                        validation_data
  88.  
    =
  89.  
    ( val_imgs_scaled
  90.  
    , val_labels_enc
  91.  
    )
  92.  
    ,
  93.  
     
  94.  
                        callbacks
  95.  
    = callbacks
  96.  
    ,
  97.  
     
  98.  
                        verbose
  99.  
    =
  100.  
    1
  101.  
    )
  102.  
     
  103.  
                       
  104.  
     
  105.  
     
  106.  
     
  107.  
     
  108.  
    # Output
  109.  
     
  110.  
    Train on
  111.  
    17361 samples
  112.  
    , validate on
  113.  
    1929 samples
  114.  
     
  115.  
    Epoch
  116.  
    1 /
  117.  
    25
  118.  
     
  119.  
     
  120.  
    17361 /
  121.  
    17361
  122.  
    [
  123.  
    ====
  124.  
    ] - 32s 2ms/sample - loss:
  125.  
    0.4373 - accuracy:
  126.  
    0.7814 - val_loss:
  127.  
    0.1834 - val_accuracy:
  128.  
    0.9393
  129.  
     
  130.  
    Epoch
  131.  
    2 /
  132.  
    25
  133.  
     
  134.  
     
  135.  
    17361 /
  136.  
    17361
  137.  
    [
  138.  
    ====
  139.  
    ] - 30s 2ms/sample - loss:
  140.  
    0.1725 - accuracy:
  141.  
    0.9434 - val_loss:
  142.  
    0.1567 - val_accuracy:
  143.  
    0.9513
  144.  
     
  145.  
    ...
  146.  
     
  147.  
    ...
  148.  
     
  149.  
     
  150.  
    Epoch
  151.  
    24 /
  152.  
    25
  153.  
     
  154.  
     
  155.  
    17361 /
  156.  
    17361
  157.  
    [
  158.  
    ====
  159.  
    ] - 30s 2ms/sample - loss:
  160.  
    0.0036 - accuracy:
  161.  
    0.9993 - val_loss:
  162.  
    0.3693 - val_accuracy:
  163.  
    0.9565
  164.  
     
  165.  
    Epoch
  166.  
    25 /
  167.  
    25
  168.  
     
  169.  
     
  170.  
    17361 /
  171.  
    17361
  172.  
    [
  173.  
    ====
  174.  
    ] - 30s 2ms/sample - loss:
  175.  
    0.0034 - accuracy:
  176.  
    0.9994 - val_loss:
  177.  
    0.3699 - val_accuracy:
  178.  
    0.9559
  179.  
     
  180.  
                       
  181.  
     
  182.  
         
  183.  
     
  184.  
     

尽管我们的模型看起来有些过拟合(基于对我们的训练准确度(99.9%)的了解),但我们获得了95.6%的验证准确度,这是相当不错的。 通过绘制训练和验证的准确性以及损失曲线,我们可以对此有一个清晰的认识。


 
  1.  
     
  2.  
     
  3.  
     
  4.  
    f
  5.  
    ,
  6.  
    ( ax1
  7.  
    , ax2
  8.  
    )
  9.  
    = plt.
  10.  
    subplots
  11.  
    (
  12.  
    1
  13.  
    ,
  14.  
    2
  15.  
    , figsize
  16.  
    =
  17.  
    (
  18.  
    12
  19.  
    ,
  20.  
    4
  21.  
    )
  22.  
    )
  23.  
     
  24.  
    t
  25.  
    = f.
  26.  
    suptitle
  27.  
    (
  28.  
    'Basic CNN Performance'
  29.  
    , fontsize
  30.  
    =
  31.  
    12
  32.  
    )
  33.  
     
  34.  
    f.
  35.  
    subplots_adjust
  36.  
    ( top
  37.  
    =
  38.  
    0.85
  39.  
    , wspace
  40.  
    =
  41.  
    0.3
  42.  
    )
  43.  
     
  44.  
     
  45.  
     
  46.  
    max_epoch
  47.  
    =
  48.  
    len
  49.  
    ( history.
  50.  
    history
  51.  
    [
  52.  
    'accuracy'
  53.  
    ]
  54.  
    ) +
  55.  
    1
  56.  
     
  57.  
    epoch_list
  58.  
    =
  59.  
    list
  60.  
    (
  61.  
    range
  62.  
    (
  63.  
    1
  64.  
    , max_epoch
  65.  
    )
  66.  
    )
  67.  
     
  68.  
    ax1.
  69.  
    plot
  70.  
    ( epoch_list
  71.  
    , history.
  72.  
    history
  73.  
    [
  74.  
    'accuracy'
  75.  
    ]
  76.  
    , label
  77.  
    =
  78.  
    'Train Accuracy'
  79.  
    )
  80.  
     
  81.  
    ax1.
  82.  
    plot
  83.  
    ( epoch_list
  84.  
    , history.
  85.  
    history
  86.  
    [
  87.  
    'val_accuracy'
  88.  
    ]
  89.  
    , label
  90.  
    =
  91.  
    'Validation Accuracy'
  92.  
    )
  93.  
     
  94.  
    ax1.
  95.  
    set_xticks
  96.  
    ( np.
  97.  
    arange
  98.  
    (
  99.  
    1
  100.  
    , max_epoch
  101.  
    ,
  102.  
    5
  103.  
    )
  104.  
    )
  105.  
     
  106.  
    ax1.
  107.  
    set_ylabel
  108.  
    (
  109.  
    'Accuracy Value'
  110.  
    )
  111.  
     
  112.  
    ax1.
  113.  
    set_xlabel
  114.  
    (
  115.  
    'Epoch'
  116.  
    )
  117.  
     
  118.  
    ax1.
  119.  
    set_title
  120.  
    (
  121.  
    'Accuracy'
  122.  
    )
  123.  
     
  124.  
    l1
  125.  
    = ax1.
  126.  
    legend
  127.  
    ( loc
  128.  
    =
  129.  
    "best"
  130.  
    )
  131.  
     
  132.  
     
  133.  
     
  134.  
    ax2.
  135.  
    plot
  136.  
    ( epoch_list
  137.  
    , history.
  138.  
    history
  139.  
    [
  140.  
    'loss'
  141.  
    ]
  142.  
    , label
  143.  
    =
  144.  
    'Train Loss'
  145.  
    )
  146.  
     
  147.  
    ax2.
  148.  
    plot
  149.  
    ( epoch_list
  150.  
    , history.
  151.  
    history
  152.  
    [
  153.  
    'val_loss'
  154.  
    ]
  155.  
    , label
  156.  
    =
  157.  
    'Validation Loss'
  158.  
    )
  159.  
     
  160.  
    ax2.
  161.  
    set_xticks
  162.  
    ( np.
  163.  
    arange
  164.  
    (
  165.  
    1
  166.  
    , max_epoch
  167.  
    ,
  168.  
    5
  169.  
    )
  170.  
    )
  171.  
     
  172.  
    ax2.
  173.  
    set_ylabel
  174.  
    (
  175.  
    'Loss Value'
  176.  
    )
  177.  
     
  178.  
    ax2.
  179.  
    set_xlabel
  180.  
    (
  181.  
    'Epoch'
  182.  
    )
  183.  
     
  184.  
    ax2.
  185.  
    set_title
  186.  
    (
  187.  
    'Loss'
  188.  
    )
  189.  
     
  190.  
    l2
  191.  
    = ax2.
  192.  
    legend
  193.  
    ( loc
  194.  
    =
  195.  
    "best"
  196.  
    )
  197.  
     
  198.  
     
Learning curves for basic CNN

基本CNN的学习曲线

在第五个时期之后,我们可以看到总体上似乎一切都没有改善。 让我们保存此模型以供将来评估。

 model. save ( 'basic_cnn.h5' ) 

深度转移学习

就像人类具有跨任务传递知识的固有能力一样,传递学习使我们能够利用先前学习的任务中的知识并将其应用于更新的相关任务,甚至在机器学习或深度学习的情况下。 如果您有兴趣深入学习迁移学习,可以阅读我的文章“ 使用深度学习中的实际应用进行迁移学习的全面动手指南 ”和我的《 Python的动手迁移 学习 》一书。

Ideas for deep transfer learning

我们希望在本练习中探讨的想法是:

我们是否可以利用预先训练的深度学习模型(在像ImageNet这样的大型数据集上进行训练)来解决疟疾检测问题,方法是在问题背景下应用和转移其知识?

我们将采用两种最受欢迎​​的策略进行深度迁移学习。

  • 预训练模型作为特征提取器
  • 预先训练的模型具有微调

我们将使用由牛津大学视觉几何小组(VGG)开发的预先训练的VGG-19深度学习模型进行实验。 像VGG-19这样的预训练模型是在包含许多不同图像类别的巨大数据集( ImageNet )上进行训练的。 因此,该模型应该已经学习了功能强大的层次结构,这些特征对于CNN模型所学习的特征而言是空间,旋转和平移不变的。 因此,该模型已经学习了超过一百万个图像的特征的良好表示形式,可以充当适用于计算机视觉问题(例如疟疾检测)的新图像的良好特征提取器。 在释放转移学习的力量解决我们的问题之前,让我们讨论一下VGG-19模型体系结构。

了解VGG-19模型

VGG-19模型是建立在ImageNet数据库上的19层(卷积和完全连接)深度学习网络,其目的是进行图像识别和分类。 该模型由Karen Simonyan和Andrew Zisserman建立,并在他们的论文“ 用于大规模图像识别的非常深的卷积网络 ”中进行了描述。 VGG-19模型的体系结构为:

VGG-19 Model Architecture

您可以看到,我们总共有16个卷积层,使用3x3卷积滤波器,以及用于下采样的最大池化层,以及每层中4,096个单元的两个完全连接的隐藏层,然后是1,000个单元的密集层,其中每个单元代表一个ImageNet数据库中的图像类别。 我们不需要最后三层,因为我们将使用自己的完全连接的密集层来预测疟疾。 我们更关注前五个块,因此我们可以利用VGG模型作为有效的特征提取器。

我们将冻结五个卷积块,以确保在每个时期之后不更新其权重,从而将其中一个模型用作简单的特征提取器。 对于最后一个模型,我们将对VGG模型进行微调,在该模型中,我们将解冻最后两个块(第4块和第5块),以便在我们训练的每个纪元(每批数据)中更新它们的权重我们自己的模型。

模型2:经过预训练的模型作为特征提取器

为了构建此模型,我们将利用TensorFlow加载VGG-19模型并冻结卷积块,以便将它们用作图像特征提取器。 我们将在最后插入自己的密集层以执行分类任务。


 
  1.  
     
  2.  
     
  3.  
     
  4.  
    vgg
  5.  
    = tf.
  6.  
    keras .
  7.  
    applications .
  8.  
    vgg19 .
  9.  
    VGG19
  10.  
    ( include_top
  11.  
    =
  12.  
    False
  13.  
    , weights
  14.  
    =
  15.  
    'imagenet'
  16.  
    ,
  17.  
     
  18.  
                                            input_shape
  19.  
    = INPUT_SHAPE
  20.  
    )
  21.  
     
  22.  
    vgg.
  23.  
    trainable
  24.  
    =
  25.  
    False
  26.  
     
  27.  
     
  28.  
    # Freeze the layers
  29.  
     
  30.  
     
  31.  
    for layer
  32.  
    in vgg.
  33.  
    layers :
  34.  
     
  35.  
        layer.
  36.  
    trainable
  37.  
    =
  38.  
    False
  39.  
     
  40.  
       
  41.  
     
  42.  
    base_vgg
  43.  
    = vgg
  44.  
     
  45.  
    base_out
  46.  
    = base_vgg.
  47.  
    output
  48.  
     
  49.  
    pool_out
  50.  
    = tf.
  51.  
    keras .
  52.  
    layers .
  53.  
    Flatten
  54.  
    (
  55.  
    )
  56.  
    ( base_out
  57.  
    )
  58.  
     
  59.  
    hidden1
  60.  
    = tf.
  61.  
    keras .
  62.  
    layers .
  63.  
    Dense
  64.  
    (
  65.  
    512
  66.  
    , activation
  67.  
    =
  68.  
    'relu'
  69.  
    )
  70.  
    ( pool_out
  71.  
    )
  72.  
     
  73.  
    drop1
  74.  
    = tf.
  75.  
    keras .
  76.  
    layers .
  77.  
    Dropout
  78.  
    ( rate
  79.  
    =
  80.  
    0.3
  81.  
    )
  82.  
    ( hidden1
  83.  
    )
  84.  
     
  85.  
    hidden2
  86.  
    = tf.
  87.  
    keras .
  88.  
    layers .
  89.  
    Dense
  90.  
    (
  91.  
    512
  92.  
    , activation
  93.  
    =
  94.  
    'relu'
  95.  
    )
  96.  
    ( drop1
  97.  
    )
  98.  
     
  99.  
    drop2
  100.  
    = tf.
  101.  
    keras .
  102.  
    layers .
  103.  
    Dropout
  104.  
    ( rate
  105.  
    =
  106.  
    0.3
  107.  
    )
  108.  
    ( hidden2
  109.  
    )
  110.  
     
  111.  
     
  112.  
     
  113.  
    out
  114.  
    = tf.
  115.  
    keras .
  116.  
    layers .
  117.  
    Dense
  118.  
    (
  119.  
    1
  120.  
    , activation
  121.  
    =
  122.  
    'sigmoid'
  123.  
    )
  124.  
    ( drop2
  125.  
    )
  126.  
     
  127.  
     
  128.  
     
  129.  
    model
  130.  
    = tf.
  131.  
    keras .
  132.  
    Model
  133.  
    ( inputs
  134.  
    = base_vgg.
  135.  
    input
  136.  
    , outputs
  137.  
    = out
  138.  
    )
  139.  
     
  140.  
    model.
  141.  
    compile
  142.  
    ( optimizer
  143.  
    = tf.
  144.  
    keras .
  145.  
    optimizers .
  146.  
    RMSprop
  147.  
    ( lr
  148.  
    =
  149.  
    1e-4
  150.  
    )
  151.  
    ,
  152.  
     
  153.  
                    loss
  154.  
    =
  155.  
    'binary_crossentropy'
  156.  
    ,
  157.  
     
  158.  
                    metrics
  159.  
    =
  160.  
    [
  161.  
    'accuracy'
  162.  
    ]
  163.  
    )
  164.  
     
  165.  
    model.
  166.  
    summary
  167.  
    (
  168.  
    )
  169.  
     
  170.  
     
  171.  
     
  172.  
     
  173.  
     
  174.  
     
  175.  
    # Output
  176.  
     
  177.  
    Model:
  178.  
    "model_1"
  179.  
     
  180.  
    _________________________________________________________________
  181.  
     
  182.  
    Layer
  183.  
    (
  184.  
    type
  185.  
    )                 Output Shape              Param
  186.  
    #  
  187.  
     
  188.  
     
  189.  
    =================================================================
  190.  
     
  191.  
    input_2
  192.  
    ( InputLayer
  193.  
    )        
  194.  
    [
  195.  
    (
  196.  
    None
  197.  
    ,
  198.  
    125
  199.  
    ,
  200.  
    125
  201.  
    ,
  202.  
    3
  203.  
    )
  204.  
    ]    
  205.  
    0        
  206.  
     
  207.  
    _________________________________________________________________
  208.  
     
  209.  
    block1_conv1
  210.  
    ( Conv2D
  211.  
    )        
  212.  
    (
  213.  
    None
  214.  
    ,
  215.  
    125
  216.  
    ,
  217.  
    125
  218.  
    ,
  219.  
    64
  220.  
    )      
  221.  
    1792      
  222.  
     
  223.  
    _________________________________________________________________
  224.  
     
  225.  
    block1_conv2
  226.  
    ( Conv2D
  227.  
    )        
  228.  
    (
  229.  
    None
  230.  
    ,
  231.  
    125
  232.  
    ,
  233.  
    125
  234.  
    ,
  235.  
    64
  236.  
    )      
  237.  
    36928    
  238.  
     
  239.  
    _________________________________________________________________
  240.  
     
  241.  
    ...
  242.  
     
  243.  
    ...
  244.  
     
  245.  
    _________________________________________________________________
  246.  
     
  247.  
    block5_pool
  248.  
    ( MaxPooling2D
  249.  
    )  
  250.  
    (
  251.  
    None
  252.  
    ,
  253.  
    3
  254.  
    ,
  255.  
    3
  256.  
    ,
  257.  
    512
  258.  
    )        
  259.  
    0        
  260.  
     
  261.  
    _________________________________________________________________
  262.  
     
  263.  
    flatten_1
  264.  
    ( Flatten
  265.  
    )          
  266.  
    (
  267.  
    None
  268.  
    ,
  269.  
    4608
  270.  
    )              
  271.  
    0        
  272.  
     
  273.  
    _________________________________________________________________
  274.  
     
  275.  
    dense_3
  276.  
    ( Dense
  277.  
    )              
  278.  
    (
  279.  
    None
  280.  
    ,
  281.  
    512
  282.  
    )              
  283.  
    2359808  
  284.  
     
  285.  
    _________________________________________________________________
  286.  
     
  287.  
    dropout_2
  288.  
    ( Dropout
  289.  
    )          
  290.  
    (
  291.  
    None
  292.  
    ,
  293.  
    512
  294.  
    )              
  295.  
    0        
  296.  
     
  297.  
    _________________________________________________________________
  298.  
     
  299.  
    dense_4
  300.  
    ( Dense
  301.  
    )              
  302.  
    (
  303.  
    None
  304.  
    ,
  305.  
    512
  306.  
    )              
  307.  
    262656    
  308.  
     
  309.  
    _________________________________________________________________
  310.  
     
  311.  
    dropout_3
  312.  
    ( Dropout
  313.  
    )          
  314.  
    (
  315.  
    None
  316.  
    ,
  317.  
    512
  318.  
    )              
  319.  
    0        
  320.  
     
  321.  
    _________________________________________________________________
  322.  
     
  323.  
    dense_5
  324.  
    ( Dense
  325.  
    )              
  326.  
    (
  327.  
    None
  328.  
    ,
  329.  
    1
  330.  
    )                
  331.  
    513      
  332.  
     
  333.  
     
  334.  
    =================================================================
  335.  
     
  336.  
    Total params:
  337.  
    22
  338.  
    ,
  339.  
    647
  340.  
    ,
  341.  
    361
  342.  
     
  343.  
    Trainable params:
  344.  
    2
  345.  
    ,
  346.  
    622
  347.  
    ,
  348.  
    977
  349.  
     
  350.  
    Non-trainable params:
  351.  
    20
  352.  
    ,
  353.  
    024
  354.  
    ,
  355.  
    384
  356.  
     
  357.  
    _________________________________________________________________
  358.  
     
  359.  
     

从此输出中可以明显看出,我们的模型中有很多层,我们将仅将VGG-19模型的冻结层用作特征提取器。 您可以使用以下代码来验证模型中确实可以训练的层数以及网络中存在的总层数。


 
  1.  
     
  2.  
     
  3.  
     
  4.  
    print
  5.  
    (
  6.  
    "Total Layers:"
  7.  
    ,
  8.  
    len
  9.  
    ( model.
  10.  
    layers
  11.  
    )
  12.  
    )
  13.  
     
  14.  
     
  15.  
    print
  16.  
    (
  17.  
    "Total trainable layers:"
  18.  
    ,
  19.  
     
  20.  
         
  21.  
    sum
  22.  
    (
  23.  
    [
  24.  
    1
  25.  
    for l
  26.  
    in model.
  27.  
    layers
  28.  
    if l.
  29.  
    trainable
  30.  
    ]
  31.  
    )
  32.  
    )
  33.  
     
  34.  
     
  35.  
     
  36.  
     
  37.  
    # Output
  38.  
     
  39.  
    Total Layers:
  40.  
    28
  41.  
     
  42.  
    Total trainable layers:
  43.  
    6
  44.  
     
  45.  
     

现在,我们将使用与先前模型中使用的配置和回调相似的配置和回调来训练模型。 请参阅我的GitHub存储库以获取训练模型的完整代码。 我们观察到以下图表,它们显示了模型的准确性和损失。

Learning curves for frozen pre-trained CNN

冻结的预训练CNN的学习曲线

这表明我们的模型并没有像基本的CNN模型那样过度拟合,但是性能略低于基本的CNN模型。 让我们保存此模型以供将来评估。

 model. save ( 'vgg_frozen.h5' ) 

模型3:具有图像增强功能的微调预训练模型

在最终模型中,我们将微调预训练的VGG-19模型的最后两个块中各层的权重。 我们还将介绍图像增强的概念。 图像增强背后的想法恰如其名。 我们从训练数据集中加载现有图像,并对它们应用一些图像变换操作,例如旋转,剪切,平移,缩放等,以生成现有图像的新版本。 由于这些随机转换,我们不会每次都获得相同的图像。 我们将在tf.keras中利用一个称为ImageDataGenerator的出色实用程序,该实用程序可以帮助构建图像增强器。


 
  1.  
     
  2.  
     
  3.  
     
  4.  
    train_datagen
  5.  
    = tf.
  6.  
    keras .
  7.  
    preprocessing .
  8.  
    image .
  9.  
    ImageDataGenerator
  10.  
    ( rescale
  11.  
    =
  12.  
    1 ./
  13.  
    255
  14.  
    ,
  15.  
     
  16.  
                                                                    zoom_range
  17.  
    =
  18.  
    0.05
  19.  
    ,
  20.  
     
  21.  
                                                                    rotation_range
  22.  
    =
  23.  
    25
  24.  
    ,
  25.  
     
  26.  
                                                                    width_shift_range
  27.  
    =
  28.  
    0.05
  29.  
    ,
  30.  
     
  31.  
                                                                    height_shift_range
  32.  
    =
  33.  
    0.05
  34.  
    ,
  35.  
     
  36.  
                                                                    shear_range
  37.  
    =
  38.  
    0.05
  39.  
    , horizontal_flip
  40.  
    =
  41.  
    True
  42.  
    ,
  43.  
     
  44.  
                                                                    fill_mode
  45.  
    =
  46.  
    'nearest'
  47.  
    )
  48.  
     
  49.  
     
  50.  
     
  51.  
    val_datagen
  52.  
    = tf.
  53.  
    keras .
  54.  
    preprocessing .
  55.  
    image .
  56.  
    ImageDataGenerator
  57.  
    ( rescale
  58.  
    =
  59.  
    1 ./
  60.  
    255
  61.  
    )
  62.  
     
  63.  
     
  64.  
     
  65.  
     
  66.  
    # build image augmentation generators
  67.  
     
  68.  
    train_generator
  69.  
    = train_datagen.
  70.  
    flow
  71.  
    ( train_data
  72.  
    , train_labels_enc
  73.  
    , batch_size
  74.  
    = BATCH_SIZE
  75.  
    , shuffle
  76.  
    =
  77.  
    True
  78.  
    )
  79.  
     
  80.  
    val_generator
  81.  
    = val_datagen.
  82.  
    flow
  83.  
    ( val_data
  84.  
    , val_labels_enc
  85.  
    , batch_size
  86.  
    = BATCH_SIZE
  87.  
    , shuffle
  88.  
    =
  89.  
    False
  90.  
    )
  91.  
     
  92.  
     

我们将不在验证数据集上应用任何转换(缩放图像除外,这是强制性的),因为我们将使用它来评估每个时期的模型性能。 有关转移学习中图像增强的详细说明,请随时阅读我上面引用的文章 。 让我们看一下一批图像增强转换的一些示例结果。


 
  1.  
     
  2.  
     
  3.  
     
  4.  
    img_id
  5.  
    =
  6.  
    0
  7.  
     
  8.  
    sample_generator
  9.  
    = train_datagen.
  10.  
    flow
  11.  
    ( train_data
  12.  
    [ img_id:img_id+
  13.  
    1
  14.  
    ]
  15.  
    , train_labels
  16.  
    [ img_id:img_id+
  17.  
    1
  18.  
    ]
  19.  
    ,
  20.  
     
  21.  
                                          batch_size
  22.  
    =
  23.  
    1
  24.  
    )
  25.  
     
  26.  
    sample
  27.  
    =
  28.  
    [ next
  29.  
    ( sample_generator
  30.  
    )
  31.  
    for i
  32.  
    in
  33.  
    range
  34.  
    (
  35.  
    0
  36.  
    ,
  37.  
    5
  38.  
    )
  39.  
    ]
  40.  
     
  41.  
    fig
  42.  
    , ax
  43.  
    = plt.
  44.  
    subplots
  45.  
    (
  46.  
    1
  47.  
    ,
  48.  
    5
  49.  
    , figsize
  50.  
    =
  51.  
    (
  52.  
    16
  53.  
    ,
  54.  
    6
  55.  
    )
  56.  
    )
  57.  
     
  58.  
     
  59.  
    print
  60.  
    (
  61.  
    'Labels:'
  62.  
    ,
  63.  
    [ item
  64.  
    [
  65.  
    1
  66.  
    ]
  67.  
    [
  68.  
    0
  69.  
    ]
  70.  
    for item
  71.  
    in sample
  72.  
    ]
  73.  
    )
  74.  
     
  75.  
    l
  76.  
    =
  77.  
    [ ax
  78.  
    [ i
  79.  
    ] .
  80.  
    imshow
  81.  
    ( sample
  82.  
    [ i
  83.  
    ]
  84.  
    [
  85.  
    0
  86.  
    ]
  87.  
    [
  88.  
    0
  89.  
    ]
  90.  
    )
  91.  
    for i
  92.  
    in
  93.  
    range
  94.  
    (
  95.  
    0
  96.  
    ,
  97.  
    5
  98.  
    )
  99.  
    ]
  100.  
     
  101.  
     
Sample augmented images

您可以在前面的输出中清楚地看到我们图像的细微变化。 现在,我们将构建深度学习模型,以确保VGG-19模型的最后两个模块是可训练的。


 
  1.  
     
  2.  
     
  3.  
     
  4.  
    vgg
  5.  
    = tf.
  6.  
    keras .
  7.  
    applications .
  8.  
    vgg19 .
  9.  
    VGG19
  10.  
    ( include_top
  11.  
    =
  12.  
    False
  13.  
    , weights
  14.  
    =
  15.  
    'imagenet'
  16.  
    ,
  17.  
     
  18.  
                                            input_shape
  19.  
    = INPUT_SHAPE
  20.  
    )
  21.  
     
  22.  
     
  23.  
    # Freeze the layers
  24.  
     
  25.  
    vgg.
  26.  
    trainable
  27.  
    =
  28.  
    True
  29.  
     
  30.  
     
  31.  
     
  32.  
    set_trainable
  33.  
    =
  34.  
    False
  35.  
     
  36.  
     
  37.  
    for layer
  38.  
    in vgg.
  39.  
    layers :
  40.  
     
  41.  
       
  42.  
    if layer.
  43.  
    name
  44.  
    in
  45.  
    [
  46.  
    'block5_conv1'
  47.  
    ,
  48.  
    'block4_conv1'
  49.  
    ] :
  50.  
     
  51.  
            set_trainable
  52.  
    =
  53.  
    True
  54.  
     
  55.  
       
  56.  
    if set_trainable:
  57.  
     
  58.  
            layer.
  59.  
    trainable
  60.  
    =
  61.  
    True
  62.  
     
  63.  
       
  64.  
    else :
  65.  
     
  66.  
            layer.
  67.  
    trainable
  68.  
    =
  69.  
    False
  70.  
     
  71.  
       
  72.  
     
  73.  
    base_vgg
  74.  
    = vgg
  75.  
     
  76.  
    base_out
  77.  
    = base_vgg.
  78.  
    output
  79.  
     
  80.  
    pool_out
  81.  
    = tf.
  82.  
    keras .
  83.  
    layers .
  84.  
    Flatten
  85.  
    (
  86.  
    )
  87.  
    ( base_out
  88.  
    )
  89.  
     
  90.  
    hidden1
  91.  
    = tf.
  92.  
    keras .
  93.  
    layers .
  94.  
    Dense
  95.  
    (
  96.  
    512
  97.  
    , activation
  98.  
    =
  99.  
    'relu'
  100.  
    )
  101.  
    ( pool_out
  102.  
    )
  103.  
     
  104.  
    drop1
  105.  
    = tf.
  106.  
    keras .
  107.  
    layers .
  108.  
    Dropout
  109.  
    ( rate
  110.  
    =
  111.  
    0.3
  112.  
    )
  113.  
    ( hidden1
  114.  
    )
  115.  
     
  116.  
    hidden2
  117.  
    = tf.
  118.  
    keras .
  119.  
    layers .
  120.  
    Dense
  121.  
    (
  122.  
    512
  123.  
    , activation
  124.  
    =
  125.  
    'relu'
  126.  
    )
  127.  
    ( drop1
  128.  
    )
  129.  
     
  130.  
    drop2
  131.  
    = tf.
  132.  
    keras .
  133.  
    layers .
  134.  
    Dropout
  135.  
    ( rate
  136.  
    =
  137.  
    0.3
  138.  
    )
  139.  
    ( hidden2
  140.  
    )
  141.  
     
  142.  
     
  143.  
     
  144.  
    out
  145.  
    = tf.
  146.  
    keras .
  147.  
    layers .
  148.  
    Dense
  149.  
    (
  150.  
    1
  151.  
    , activation
  152.  
    =
  153.  
    'sigmoid'
  154.  
    )
  155.  
    ( drop2
  156.  
    )
  157.  
     
  158.  
     
  159.  
     
  160.  
    model
  161.  
    = tf.
  162.  
    keras .
  163.  
    Model
  164.  
    ( inputs
  165.  
    = base_vgg.
  166.  
    input
  167.  
    , outputs
  168.  
    = out
  169.  
    )
  170.  
     
  171.  
    model.
  172.  
    compile
  173.  
    ( optimizer
  174.  
    = tf.
  175.  
    keras .
  176.  
    optimizers .
  177.  
    RMSprop
  178.  
    ( lr
  179.  
    =
  180.  
    1e-5
  181.  
    )
  182.  
    ,
  183.  
     
  184.  
                    loss
  185.  
    =
  186.  
    'binary_crossentropy'
  187.  
    ,
  188.  
     
  189.  
                    metrics
  190.  
    =
  191.  
    [
  192.  
    'accuracy'
  193.  
    ]
  194.  
    )
  195.  
     
  196.  
     
  197.  
     
  198.  
     
  199.  
    print
  200.  
    (
  201.  
    "Total Layers:"
  202.  
    ,
  203.  
    len
  204.  
    ( model.
  205.  
    layers
  206.  
    )
  207.  
    )
  208.  
     
  209.  
     
  210.  
    print
  211.  
    (
  212.  
    "Total trainable layers:"
  213.  
    ,
  214.  
    sum
  215.  
    (
  216.  
    [
  217.  
    1
  218.  
    for l
  219.  
    in model.
  220.  
    layers
  221.  
    if l.
  222.  
    trainable
  223.  
    ]
  224.  
    )
  225.  
    )
  226.  
     
  227.  
     
  228.  
     
  229.  
     
  230.  
     
  231.  
     
  232.  
    # Output
  233.  
     
  234.  
    Total Layers:
  235.  
    28
  236.  
     
  237.  
    Total trainable layers:
  238.  
    16
  239.  
     
  240.  
     

由于我们不想在微调时对预先训练的图层进行较大的权重更新,因此降低了模型的学习速度。 由于我们使用的是数据生成器,因此模型的训练过程将略有不同,因此我们将利用fit_generator(…)函数。


 
  1.  
     
  2.  
     
  3.  
     
  4.  
    tensorboard_callback
  5.  
    = tf.
  6.  
    keras .
  7.  
    callbacks .
  8.  
    TensorBoard
  9.  
    ( logdir
  10.  
    , histogram_freq
  11.  
    =
  12.  
    1
  13.  
    )
  14.  
     
  15.  
    reduce_lr
  16.  
    = tf.
  17.  
    keras .
  18.  
    callbacks .
  19.  
    ReduceLROnPlateau
  20.  
    ( monitor
  21.  
    =
  22.  
    'val_loss'
  23.  
    , factor
  24.  
    =
  25.  
    0.5
  26.  
    ,
  27.  
     
  28.  
                                  patience
  29.  
    =
  30.  
    2
  31.  
    , min_lr
  32.  
    =
  33.  
    0.000001
  34.  
    )
  35.  
     
  36.  
     
  37.  
     
  38.  
    callbacks
  39.  
    =
  40.  
    [ reduce_lr
  41.  
    , tensorboard_callback
  42.  
    ]
  43.  
     
  44.  
    train_steps_per_epoch
  45.  
    = train_generator.
  46.  
    n // train_generator.
  47.  
    batch_size
  48.  
     
  49.  
    val_steps_per_epoch
  50.  
    = val_generator.
  51.  
    n // val_generator.
  52.  
    batch_size
  53.  
     
  54.  
    history
  55.  
    = model.
  56.  
    fit_generator
  57.  
    ( train_generator
  58.  
    , steps_per_epoch
  59.  
    = train_steps_per_epoch
  60.  
    , epochs
  61.  
    = EPOCHS
  62.  
    ,
  63.  
     
  64.  
                                  validation_data
  65.  
    = val_generator
  66.  
    , validation_steps
  67.  
    = val_steps_per_epoch
  68.  
    ,
  69.  
     
  70.  
                                  verbose
  71.  
    =
  72.  
    1
  73.  
    )
  74.  
     
  75.  
     
  76.  
     
  77.  
     
  78.  
     
  79.  
     
  80.  
    # Output
  81.  
     
  82.  
    Epoch
  83.  
    1 /
  84.  
    25
  85.  
     
  86.  
     
  87.  
    271 /
  88.  
    271
  89.  
    [
  90.  
    ====
  91.  
    ] - 133s 489ms/step - loss:
  92.  
    0.2267 - accuracy:
  93.  
    0.9117 - val_loss:
  94.  
    0.1414 - val_accuracy:
  95.  
    0.9531
  96.  
     
  97.  
    Epoch
  98.  
    2 /
  99.  
    25
  100.  
     
  101.  
     
  102.  
    271 /
  103.  
    271
  104.  
    [
  105.  
    ====
  106.  
    ] - 129s 475ms/step - loss:
  107.  
    0.1399 - accuracy:
  108.  
    0.9552 - val_loss:
  109.  
    0.1292 - val_accuracy:
  110.  
    0.9589
  111.  
     
  112.  
    ...
  113.  
     
  114.  
    ...
  115.  
     
  116.  
     
  117.  
    Epoch
  118.  
    24 /
  119.  
    25
  120.  
     
  121.  
     
  122.  
    271 /
  123.  
    271
  124.  
    [
  125.  
    ====
  126.  
    ] - 128s 473ms/step - loss:
  127.  
    0.0815 - accuracy:
  128.  
    0.9727 - val_loss:
  129.  
    0.1466 - val_accuracy:
  130.  
    0.9682
  131.  
     
  132.  
    Epoch
  133.  
    25 /
  134.  
    25
  135.  
     
  136.  
     
  137.  
    271 /
  138.  
    271
  139.  
    [
  140.  
    ====
  141.  
    ] - 128s 473ms/step - loss:
  142.  
    0.0792 - accuracy:
  143.  
    0.9729 - val_loss:
  144.  
    0.1127 - val_accuracy:
  145.  
    0.9641
  146.  
     
  147.  
     

这似乎是我们最好的模型。 它提供给我们的验证准确性几乎为96.5%,并且根据训练的准确性,我们的模型看起来并没有比我们的第一个模型拟合得太多。 可以通过以下学习曲线来验证。

Learning curves for fine-tuned pre-trained CNN

微调的预训练CNN的学习曲线

让我们保存该模型,以便可以将其用于测试数据集的模型评估。

 model. save ( 'vgg_finetuned.h5' ) 

这样就完成了我们的模型训练阶段。 现在,我们准备在实际测试数据集上测试模型的性能!

深度学习模型性能评估

我们将通过对测试数据集中的数据进行预测,从而评估在训练阶段构建的三个模型-因为仅进行验证是不够的! 我们还构建了一个漂亮的实用程序模块,称为model_evaluation_utils ,可用于评估具有相关分类指标的深度学习模型的性能。 第一步是扩展我们的测试数据。


 
  1.  
     
  2.  
     
  3.  
     
  4.  
    test_imgs_scaled
  5.  
    = test_data /
  6.  
    255 .
  7.  
     
  8.  
     
  9.  
    test_imgs_scaled .
  10.  
    shape
  11.  
    , test_labels.
  12.  
    shape
  13.  
     
  14.  
     
  15.  
     
  16.  
     
  17.  
    # Output
  18.  
     
  19.  
     
  20.  
    (
  21.  
    (
  22.  
    8268
  23.  
    ,
  24.  
    125
  25.  
    ,
  26.  
    125
  27.  
    ,
  28.  
    3
  29.  
    )
  30.  
    ,
  31.  
    (
  32.  
    8268
  33.  
    ,
  34.  
    )
  35.  
    )
  36.  
     
  37.  
     

下一步涉及加载我们保存的深度学习模型,并对测试数据进行预测。


 
  1.  
     
  2.  
     
  3.  
     
  4.  
    # Load Saved Deep Learning Models
  5.  
     
  6.  
    basic_cnn
  7.  
    = tf.
  8.  
    keras .
  9.  
    models .
  10.  
    load_model
  11.  
    (
  12.  
    './basic_cnn.h5'
  13.  
    )
  14.  
     
  15.  
    vgg_frz
  16.  
    = tf.
  17.  
    keras .
  18.  
    models .
  19.  
    load_model
  20.  
    (
  21.  
    './vgg_frozen.h5'
  22.  
    )
  23.  
     
  24.  
    vgg_ft
  25.  
    = tf.
  26.  
    keras .
  27.  
    models .
  28.  
    load_model
  29.  
    (
  30.  
    './vgg_finetuned.h5'
  31.  
    )
  32.  
     
  33.  
     
  34.  
     
  35.  
     
  36.  
    # Make Predictions on Test Data
  37.  
     
  38.  
    basic_cnn_preds
  39.  
    = basic_cnn.
  40.  
    predict
  41.  
    ( test_imgs_scaled
  42.  
    , batch_size
  43.  
    =
  44.  
    512
  45.  
    )
  46.  
     
  47.  
    vgg_frz_preds
  48.  
    = vgg_frz.
  49.  
    predict
  50.  
    ( test_imgs_scaled
  51.  
    , batch_size
  52.  
    =
  53.  
    512
  54.  
    )
  55.  
     
  56.  
    vgg_ft_preds
  57.  
    = vgg_ft.
  58.  
    predict
  59.  
    ( test_imgs_scaled
  60.  
    , batch_size
  61.  
    =
  62.  
    512
  63.  
    )
  64.  
     
  65.  
     
  66.  
     
  67.  
    basic_cnn_pred_labels
  68.  
    = le.
  69.  
    inverse_transform
  70.  
    (
  71.  
    [
  72.  
    1
  73.  
    if pred
  74.  
    >
  75.  
    0.5
  76.  
    else
  77.  
    0
  78.  
     
  79.  
                                                     
  80.  
    for pred
  81.  
    in basic_cnn_preds.
  82.  
    ravel
  83.  
    (
  84.  
    )
  85.  
    ]
  86.  
    )
  87.  
     
  88.  
    vgg_frz_pred_labels
  89.  
    = le.
  90.  
    inverse_transform
  91.  
    (
  92.  
    [
  93.  
    1
  94.  
    if pred
  95.  
    >
  96.  
    0.5
  97.  
    else
  98.  
    0
  99.  
     
  100.  
                                                     
  101.  
    for pred
  102.  
    in vgg_frz_preds.
  103.  
    ravel
  104.  
    (
  105.  
    )
  106.  
    ]
  107.  
    )
  108.  
     
  109.  
    vgg_ft_pred_labels
  110.  
    = le.
  111.  
    inverse_transform
  112.  
    (
  113.  
    [
  114.  
    1
  115.  
    if pred
  116.  
    >
  117.  
    0.5
  118.  
    else
  119.  
    0
  120.  
     
  121.  
                                                     
  122.  
    for pred
  123.  
    in vgg_ft_preds.
  124.  
    ravel
  125.  
    (
  126.  
    )
  127.  
    ]
  128.  
    )
  129.  
     
  130.  
     

最后一步是利用我们的model_evaluation_utils模块,并使用相关的分类指标来检查每个模型的性能。


 
  1.  
     
  2.  
     
  3.  
     
  4.  
    import model_evaluation_utils
  5.  
    as meu
  6.  
     
  7.  
     
  8.  
    import pandas
  9.  
    as pd
  10.  
     
  11.  
     
  12.  
     
  13.  
    basic_cnn_metrics
  14.  
    = meu.
  15.  
    get_metrics
  16.  
    ( true_labels
  17.  
    = test_labels
  18.  
    , predicted_labels
  19.  
    = basic_cnn_pred_labels
  20.  
    )
  21.  
     
  22.  
    vgg_frz_metrics
  23.  
    = meu.
  24.  
    get_metrics
  25.  
    ( true_labels
  26.  
    = test_labels
  27.  
    , predicted_labels
  28.  
    = vgg_frz_pred_labels
  29.  
    )
  30.  
     
  31.  
    vgg_ft_metrics
  32.  
    = meu.
  33.  
    get_metrics
  34.  
    ( true_labels
  35.  
    = test_labels
  36.  
    , predicted_labels
  37.  
    = vgg_ft_pred_labels
  38.  
    )
  39.  
     
  40.  
     
  41.  
     
  42.  
    pd.
  43.  
    DataFrame
  44.  
    (
  45.  
    [ basic_cnn_metrics
  46.  
    , vgg_frz_metrics
  47.  
    , vgg_ft_metrics
  48.  
    ]
  49.  
    ,
  50.  
     
  51.  
                 index
  52.  
    =
  53.  
    [
  54.  
    'Basic CNN'
  55.  
    ,
  56.  
    'VGG-19 Frozen'
  57.  
    ,
  58.  
    'VGG-19 Fine-tuned'
  59.  
    ]
  60.  
    )
  61.  
     
  62.  
     
Model accuracy

看起来我们的第三个模型在测试数据集上表现最好,给出的模型准确度和F 1得分为96%,这非常好,与我们之前提到的研究论文和文章中提到的更复杂的模型相当。

结论

疟疾检测不是一件容易的事,在病例的诊断和治疗中,全球合格人员的存在是一个严重的问题。 我们看了一个有趣的现实世界中疟疾检测的医学影像案例研究。 利用AI的易于构建的开源技术可以为我们提供检测疟疾的最先进的准确性,从而使AI造福社会。

我鼓励您查看本文中提到的文章和研究论文,否则如果没有它们,我将无法对其进行概念化和编写。 如果您对运行或采用这些技术感兴趣,可以在我的GitHub存储库中找到本文中使用的所有代码。 请记住从官方网站下载数据。

让我们希望在医疗保健中更多地采用开源AI功能,以使其更便宜,更容易为世界各地的所有人使用!


Ubuntu16.04下安装多版本cuda和cudnn
基于初始残差循环神经网络的乳腺癌组织病理学图像的分类
51自学网,即我要自学网,自学EXCEL、自学PS、自学CAD、自学C语言、自学css3实例,是一个通过网络自主学习工作技能的自学平台,网友喜欢的软件自学网站。
京ICP备13026421号-1