Segmentace MRI mozku - DeepLabV3 model

5. 4. 2024 0:00 Jiří Raška

Dnes bych se rád podíval na další typ modelu pro sémantickou segmentaci obrázků, a tím je DeepLabV3. Důvodem, proč jsem si vybral právě tento, je použití bloku Spatial Pyramid Pooling a s ním spojené Atrous Convolution. K tomu se dostanu ale až u návrhu samotného modelu.

Pro své pokusy opět využiji datovou sadu Brain MRI segmentation stejným způsobem, jako tomu bylo v předchozím článku Segmentace MRI mozku – U-Net mode. Proto se zde nebudu do hloubky rozepisovat o obsahu datové sady, a jak jsem s daty pracoval. Jen bych se opakoval. Takže pokud vás to zajímá, podívejte se prosím na předchozí článek. Dále uvedu jen nezbytný kód pro načtení dat do pole obrázků X a pole cílových masek Y.

In [1]:


import sys
import os
import shutil
import warnings
import glob
import pathlib

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import train_test_split

import tensorflow as tf
import tensorflow.keras as keras

from keras.models import Sequential
from keras import Input, Model
from keras import layers

import cv2

sns.set_style('darkgrid')

warnings.simplefilter(action='ignore', category=FutureWarning)

In [2]:


def seed_all():
    import random

    random.seed(42)
    np.random.seed(42)
    tf.random.set_seed(42)
    os.environ['PYTHONHASHSEED'] = str(42)
    os.environ['TF_DETERMINISTIC_OPS'] = '1'

seed_all()

In [3]:


DATA_ROOT = "/kaggle/input/lgg-mri-segmentation/kaggle_3m/"
IMAGE_SIZE = (128, 128)

Obrázky budu načítat s rozlišením 128×128 pixelů, tedy s polovičním rozlišením oproti originálů. Je to opět z důvodu omezených především paměťových zdrojů při trénování modelu.

Samotné načtení dat …

In [4]:


image_paths = []

for path in glob.glob(DATA_ROOT + "**/*_mask.tif"):

    def strip_base(p):
        parts = pathlib.Path(p).parts
        return os.path.join(*parts[-2:])

    image = path.replace("_mask", "")
    if os.path.isfile(image):
        image_paths.append((strip_base(image), strip_base(path)))
    else:
        print("MISSING: ", image, "==>", path)

In [6]:


def get_image_data(image_paths):
    x, y = list(), list()
    for image_path, mask_path in image_paths:
        image = cv2.imread(os.path.join(DATA_ROOT, image_path), flags=cv2.IMREAD_COLOR)
        image = cv2.resize(image, IMAGE_SIZE)
        mask = cv2.imread(os.path.join(DATA_ROOT, mask_path), flags=cv2.IMREAD_GRAYSCALE)
        mask = cv2.resize(mask, IMAGE_SIZE)
        x.append(image)
        y.append(mask)
    return np.array(x) / 255, np.expand_dims(np.array(y) / 255, -1)

X, Y = get_image_data(image_paths)

print(f"X: {X.shape}")
print(f"Y: {Y.shape}")
X: (3929, 128, 128, 3)
Y: (3929, 128, 128, 1)

Model DeepLabV3

Jedná se o model postavený na architektuře Encoder-Decoder. Pro kontrakční fázi je autory doporučován klasifikační modem Xception, ale může být v této roli použit i jiný typ modelu (např. VGG, ResNet a podobně). Model je zajímavý především svým přístupem k udržování prostorového kontextu v podobě Spatial Pyramid Pooling bloku.

Nejdříve ukážu základní schéma, jak se obvykle model zobrazuje v literatuře:

DeepLabV3

To, co je ve schématu označeno jako DCNN, je obvyklý konvoluční klasifikační model představující kontrakční fázi. Jedná se pouze o část modelu, tedy nekončí úplným potlačením prostorových dimenzí ve prospěch vlastností.

Výstup kontrakční fáze přechází do úzkého hrdla tvořeného Atrous Spatial Pyramid Pooling bloku (také ASPP). O co se jedná?

Pro každý bod ve zmenšeném obrázku je postupně spočítáno několik konvolučních map, které se neliší velikostí kernelu, ale rozestupem (dilatací) mezi body zdrojové plochy. Pokud byste se chtěli dozvědět více o Atrous Convolution, pak můžete např. A Comprehensive Guide on Atrous Convolution in CNNs .

Dle autorů modelu je doporučeno udělat tyto konvoluční mapy, vždy se stejným počtem filtrů:

  • kernel = 1×1, dilatace = 1

  • kernel = 3×3, dilatace = 6

  • kernel = 3×3, dilatace = 12

  • kernel = 3×3, dilatace = 18

  • image pool, což je postupně AveragePooling2D celé plochy, konvoluce 1×1 a následný UpSampling2D na celou plochu

Výsledné mapy jsou řazeny postupně za sebou do jednoho tenzoru, který je zakončen konvolučním blokem s velikostí kernelu 1.

Výstup z ASPP bloku je v expanzní fází rozšířen na čtvrtinovou velikost rozměru původního obrázku s použitím UpSampling2D vrstvy. K výsledku se dále připojí tenzor z kontrakční fáze s odpovídajícím rozměrem. Jedná se obvykle o výstup z druhého bloku (nízkoúrovňové vlastnosti), kdy jsou rozměry zmenšeny na čtvrtinu původního rozměru. Následuje již klasický konvoluční blok zakončeny dalším rozšířením obrázku vrstvou UpSampling2D, tentokrát již na původní velikost.

Další informace o modelu můžete načerpat například zde:

Jako inspiraci pro implementaci modelu jsem vycházel především z: https://keras.io/examples/vi­sion/deeplabv3_plus/ 

Nic dalšího k implementaci již nepotřebuji. Takže si rozdělím data na množiny pro trénování a testování:

In [7]:


x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.2)

print(f"x_train: {x_train.shape}, y_train: {y_train.shape}")
print(f"x_test:  {x_test.shape},  y_test:  {y_test.shape}")

x_train: (3143, 128, 128, 3), y_train: (3143, 128, 128, 1) x_test: (786, 128, 128, 3), y_test: (786, 128, 128, 1)

Následuje funkce, která mně vytvoří model. Nejdříve ukážu zdroj, a pak bude následovat k němu komentář …

In [8]:


def create_model_DeepLabV3(X_shape, classes=1, name="DeepLabV3"):

    def conv_block(x, *, filters, kernel_size=3, strides=1, dilation_rate=1, use_bias=False, padding='same', activation='relu', name=""):
        x = layers.Conv2D(filters=filters, kernel_size=kernel_size, strides=strides, dilation_rate=dilation_rate, use_bias=use_bias, padding=padding, kernel_initializer="he_normal", name=f"{name}_conv")(x)
        x = layers.BatchNormalization(name=f"{name}_norm")(x)
        if activation:
            x = layers.Activation(activation, name=f"{name}_acti")(x)
        return x

    def encoder_block(x, *, filters, name="", pooling=True):
        for i, f in enumerate(filters):
            x = conv_block(x, filters=f, name=f'{name}_block{i}')
        if pooling:
            x = layers.MaxPooling2D((2, 2), name=f'{name}_maxpool')(x)
        return x

    def aspp_block(x, *, filters, name=""):
        dims = x.shape

        out_pool = layers.AveragePooling2D(pool_size=dims[-3:-1], name=f"{name}_avrg_pool")(x)
        out_pool = conv_block(out_pool, filters=filters, kernel_size=1, use_bias=True, name=f"{name}_conv1")
        out_pool = layers.UpSampling2D(size=dims[-3:-1], interpolation="bilinear", name=f"{name}_upsampl")(out_pool)

        out_1 = conv_block(x, filters=filters, kernel_size=1, dilation_rate=1, name=f"{name}_conv2")
        out_4 = conv_block(x, filters=filters, kernel_size=3, dilation_rate=4, name=f"{name}_conv3")
        out_8 = conv_block(x, filters=filters, kernel_size=3, dilation_rate=8, name=f"{name}_conv4")

        x = layers.Concatenate(axis=-1, name=f"{name}_concat")([out_pool, out_1, out_4, out_8])
        output = conv_block(x, filters=filters, kernel_size=1, name=f"{name}_conv5")
        return output    

    inputs = Input(X_shape[-3:], name='inputs')

    x1 = encoder_block(inputs, filters=(32, 32), name="enc_1")
    x2 = encoder_block(x1, filters=(64, 64), name="enc_2")
    x3 = encoder_block(x2, filters=(128, 128), name="enc_3")

    aspp = aspp_block(x3, filters=256, name="aspp")
    dec_input_a = layers.UpSampling2D(size=(IMAGE_SIZE[0] // aspp.shape[-3] // 2, IMAGE_SIZE[1] // aspp.shape[-2] // 2), interpolation="bilinear", name="dec_input_a")(aspp)

    dec_input_b = conv_block(x1, filters=64, kernel_size=1, name="dec_input_b")

    x = layers.Concatenate(axis=-1, name="dec_concat")([dec_input_a, dec_input_b])
    x = conv_block(x, filters=128, kernel_size=3, name=f"dec_conv")
    x = layers.UpSampling2D(size=(IMAGE_SIZE[0] // x.shape[-3], IMAGE_SIZE[1] // x.shape[-2]), interpolation="bilinear", name="dec_output")(x)

    # Output 
    outputs = conv_block(x, filters=classes, kernel_size=(1, 1), activation='sigmoid', name="outputs")

    return Model(inputs=inputs, outputs=outputs, name=name)

První podstatná informace je, že předchozí funkce není plnou implementací modelu se všemi jeho částmi.

Především, pro kontrakční fázi jsem použil tři konvoluční bloky jako v modelu VGG. Důvodem je malé rozlišení zdrojových dat (omezení zdrojů pro trénování),takže nemělo smysl budovat hlubokou klasifikační část. Postupnou kontrakcí jsem se dostal z rozměrů 128×128×3 na 16×16×128, co je vstupem do ASPP bloku. Současně jsem také redukoval počty sledovaných vlastností.

V ASPP bloku je upravena dilatace konvolucí tak, aby lépe odpovídaly rozměrům tenzorů, které do ní vstupují. A také jsem vypustil konvoluci s největší pokrytou plochou, která by výrazně přesahovala zdrojová data.

Expanzní fáze odpovídá docela věrně původnímu návrhu až na to, že jako nízkoúrovňové vlastnosti pro spojení z kontrakční fáze jsem použil výstup prvního konvolučního bloku.

Celý model je zakončen posledním konvolučním blokem s jedním filtrem a aktivační funkcí sigmoid, tedy klasické řešení pro binární segmentaci.

Vyhodnocení modelu

Specifické ztrátové funkce a metricky …

Jako optimalizovanou ztrátovou funkci opět použiji kombinaci Binary CrossEntropy a Dice Coefficient. Pro posouzení kvality modelu jsem v tomto případě použil metriky Dice Coefficient a Jaccard Index.

Implementace funkcí jsem si vypůjčil a částečně upravil z: https://github.com/shruti-jadon/Semantic-Segmentation-Loss-Functions/tree/master

In [9]:


import keras.backend as K
from keras.losses import binary_crossentropy


def dsc(y_true, y_pred):
    smooth = 1.
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    score = (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
    return score

def dice_loss(y_true, y_pred):
    loss = 1 - dsc(y_true, y_pred)
    return loss

def bce_dice_loss(y_true, y_pred):
    loss = binary_crossentropy(y_true, y_pred) + dice_loss(y_true, y_pred)
    return loss

def jaccard_similarity(y_true, y_pred):
    smooth = 1.
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f) + smooth
    union = K.sum((y_true_f + y_pred_f) - (y_true_f * y_pred_f)) + smooth
    return intersection / union

def jaccard_loss(y_true, y_pred):
    return 1 - jaccard_similarity(y_true, y_pred)

Překlad modelu …

In [10]:


model = create_model_DeepLabV3(x_test.shape, 1)

model.compile(optimizer="adam", loss=bce_dice_loss, metrics=[dsc, jaccard_similarity])
model.summary()

Model: "DeepLabV3" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== inputs (InputLayer) [(None, 128, 128, 3)] 0 [] enc_1_block0_conv (Conv2D) (None, 128, 128, 32) 864 ['inputs[0][0]'] enc_1_block0_norm (BatchNo (None, 128, 128, 32) 128 ['enc_1_block0_conv[0][0]'] rmalization) enc_1_block0_acti (Activat (None, 128, 128, 32) 0 ['enc_1_block0_norm[0][0]'] ion) enc_1_block1_conv (Conv2D) (None, 128, 128, 32) 9216 ['enc_1_block0_acti[0][0]'] enc_1_block1_norm (BatchNo (None, 128, 128, 32) 128 ['enc_1_block1_conv[0][0]'] rmalization) enc_1_block1_acti (Activat (None, 128, 128, 32) 0 ['enc_1_block1_norm[0][0]'] ion) enc_1_maxpool (MaxPooling2 (None, 64, 64, 32) 0 ['enc_1_block1_acti[0][0]'] D) enc_2_block0_conv (Conv2D) (None, 64, 64, 64) 18432 ['enc_1_maxpool[0][0]'] enc_2_block0_norm (BatchNo (None, 64, 64, 64) 256 ['enc_2_block0_conv[0][0]'] rmalization) enc_2_block0_acti (Activat (None, 64, 64, 64) 0 ['enc_2_block0_norm[0][0]'] ion) enc_2_block1_conv (Conv2D) (None, 64, 64, 64) 36864 ['enc_2_block0_acti[0][0]'] enc_2_block1_norm (BatchNo (None, 64, 64, 64) 256 ['enc_2_block1_conv[0][0]'] rmalization) enc_2_block1_acti (Activat (None, 64, 64, 64) 0 ['enc_2_block1_norm[0][0]'] ion) enc_2_maxpool (MaxPooling2 (None, 32, 32, 64) 0 ['enc_2_block1_acti[0][0]'] D) enc_3_block0_conv (Conv2D) (None, 32, 32, 128) 73728 ['enc_2_maxpool[0][0]'] enc_3_block0_norm (BatchNo (None, 32, 32, 128) 512 ['enc_3_block0_conv[0][0]'] rmalization) enc_3_block0_acti (Activat (None, 32, 32, 128) 0 ['enc_3_block0_norm[0][0]'] ion) enc_3_block1_conv (Conv2D) (None, 32, 32, 128) 147456 ['enc_3_block0_acti[0][0]'] enc_3_block1_norm (BatchNo (None, 32, 32, 128) 512 ['enc_3_block1_conv[0][0]'] rmalization) enc_3_block1_acti (Activat (None, 32, 32, 128) 0 ['enc_3_block1_norm[0][0]'] ion) enc_3_maxpool (MaxPooling2 (None, 16, 16, 128) 0 ['enc_3_block1_acti[0][0]'] D) aspp_avrg_pool (AveragePoo (None, 1, 1, 128) 0 ['enc_3_maxpool[0][0]'] ling2D) aspp_conv1_conv (Conv2D) (None, 1, 1, 256) 33024 ['aspp_avrg_pool[0][0]'] aspp_conv1_norm (BatchNorm (None, 1, 1, 256) 1024 ['aspp_conv1_conv[0][0]'] alization) aspp_conv2_conv (Conv2D) (None, 16, 16, 256) 32768 ['enc_3_maxpool[0][0]'] aspp_conv3_conv (Conv2D) (None, 16, 16, 256) 294912 ['enc_3_maxpool[0][0]'] aspp_conv4_conv (Conv2D) (None, 16, 16, 256) 294912 ['enc_3_maxpool[0][0]'] aspp_conv1_acti (Activatio (None, 1, 1, 256) 0 ['aspp_conv1_norm[0][0]'] n) aspp_conv2_norm (BatchNorm (None, 16, 16, 256) 1024 ['aspp_conv2_conv[0][0]'] alization) aspp_conv3_norm (BatchNorm (None, 16, 16, 256) 1024 ['aspp_conv3_conv[0][0]'] alization) aspp_conv4_norm (BatchNorm (None, 16, 16, 256) 1024 ['aspp_conv4_conv[0][0]'] alization) aspp_upsampl (UpSampling2D (None, 16, 16, 256) 0 ['aspp_conv1_acti[0][0]'] ) aspp_conv2_acti (Activatio (None, 16, 16, 256) 0 ['aspp_conv2_norm[0][0]'] n) aspp_conv3_acti (Activatio (None, 16, 16, 256) 0 ['aspp_conv3_norm[0][0]'] n) aspp_conv4_acti (Activatio (None, 16, 16, 256) 0 ['aspp_conv4_norm[0][0]'] n) aspp_concat (Concatenate) (None, 16, 16, 1024) 0 ['aspp_upsampl[0][0]', 'aspp_conv2_acti[0][0]', 'aspp_conv3_acti[0][0]', 'aspp_conv4_acti[0][0]'] aspp_conv5_conv (Conv2D) (None, 16, 16, 256) 262144 ['aspp_concat[0][0]'] aspp_conv5_norm (BatchNorm (None, 16, 16, 256) 1024 ['aspp_conv5_conv[0][0]'] alization) dec_input_b_conv (Conv2D) (None, 64, 64, 64) 2048 ['enc_1_maxpool[0][0]'] aspp_conv5_acti (Activatio (None, 16, 16, 256) 0 ['aspp_conv5_norm[0][0]'] n) dec_input_b_norm (BatchNor (None, 64, 64, 64) 256 ['dec_input_b_conv[0][0]'] malization) dec_input_a (UpSampling2D) (None, 64, 64, 256) 0 ['aspp_conv5_acti[0][0]'] dec_input_b_acti (Activati (None, 64, 64, 64) 0 ['dec_input_b_norm[0][0]'] on) dec_concat (Concatenate) (None, 64, 64, 320) 0 ['dec_input_a[0][0]', 'dec_input_b_acti[0][0]'] dec_conv_conv (Conv2D) (None, 64, 64, 128) 368640 ['dec_concat[0][0]'] dec_conv_norm (BatchNormal (None, 64, 64, 128) 512 ['dec_conv_conv[0][0]'] ization) dec_conv_acti (Activation) (None, 64, 64, 128) 0 ['dec_conv_norm[0][0]'] dec_output (UpSampling2D) (None, 128, 128, 128) 0 ['dec_conv_acti[0][0]'] outputs_conv (Conv2D) (None, 128, 128, 1) 128 ['dec_output[0][0]'] outputs_norm (BatchNormali (None, 128, 128, 1) 4 ['outputs_conv[0][0]'] zation) outputs_acti (Activation) (None, 128, 128, 1) 0 ['outputs_norm[0][0]'] ================================================================================================== Total params: 1582820 (6.04 MB) Trainable params: 1578978 (6.02 MB) Non-trainable params: 3842 (15.01 KB) __________________________________________________________________________________________________

Trénování modelu …

V průběhu trénování si budu ukládat model s nejlepšími výsledky na validační sadě. Dále mám v rámci callback funkcí nastaveno dřívější zastavení trénování, pokud se výsledky začnou zhoršovat. Z průběhu testování ale uvidíte, že to potřeba nebylo, a musel jsem využít celého rozsahu 100 epoch.

In [11]:


MODEL_CHECKPOINT = f"/kaggle/working/model/{model.name}.ckpt"
EPOCHS = 100

callbacks_list = [
    keras.callbacks.EarlyStopping(monitor='val_dsc', mode='max', patience=20),
    keras.callbacks.ModelCheckpoint(filepath=MODEL_CHECKPOINT, monitor='val_dsc', save_best_only=True, mode='max', verbose=1)
]

history = model.fit(
    x=x_train,
    y=y_train,
    epochs=EPOCHS, 
    callbacks=callbacks_list, 
    validation_split=0.2,
    verbose=1)

Epoch 1/100 79/79 [==============================] - ETA: 0s - loss: 1.6219 - dsc: 0.0354 - jaccard_similarity: 0.0181 Epoch 1: val_dsc improved from -inf to 0.03709, saving model to /kaggle/working/model/DeepLabV3.ckpt 79/79 [==============================] - 24s 189ms/step - loss: 1.6219 - dsc: 0.0354 - jaccard_similarity: 0.0181 - val_loss: 1.7341 - val_dsc: 0.0371 - val_jaccard_similarity: 0.0190 Epoch 2/100 79/79 [==============================] - ETA: 0s - loss: 1.5701 - dsc: 0.0380 - jaccard_similarity: 0.0194 Epoch 2: val_dsc did not improve from 0.03709 79/79 [==============================] - 8s 100ms/step - loss: 1.5701 - dsc: 0.0380 - jaccard_similarity: 0.0194 - val_loss: 1.5330 - val_dsc: 0.0364 - val_jaccard_similarity: 0.0186 Epoch 3/100 79/79 [==============================] - ETA: 0s - loss: 1.5299 - dsc: 0.0407 - jaccard_similarity: 0.0208 Epoch 3: val_dsc improved from 0.03709 to 0.04513, saving model to /kaggle/working/model/DeepLabV3.ckpt 79/79 [==============================] - 13s 170ms/step - loss: 1.5299 - dsc: 0.0407 - jaccard_similarity: 0.0208 - val_loss: 1.4492 - val_dsc: 0.0451 - val_jaccard_similarity: 0.0232 Epoch 4/100 79/79 [==============================] - ETA: 0s - loss: 1.4917 - dsc: 0.0430 - jaccard_similarity: 0.0220 Epoch 4: val_dsc improved from 0.04513 to 0.04689, saving model to /kaggle/working/model/DeepLabV3.ckpt 79/79 [==============================] - 13s 165ms/step - loss: 1.4917 - dsc: 0.0430 - jaccard_similarity: 0.0220 - val_loss: 1.4577 - val_dsc: 0.0469 - val_jaccard_similarity: 0.0241 Epoch 5/100 79/79 [==============================] - ETA: 0s - loss: 1.4589 - dsc: 0.0450 - jaccard_similarity: 0.0231 Epoch 5: val_dsc improved from 0.04689 to 0.05238, saving model to /kaggle/working/model/DeepLabV3.ckpt 79/79 [==============================] - 13s 171ms/step - loss: 1.4589 - dsc: 0.0450 - jaccard_similarity: 0.0231 - val_loss: 1.4151 - val_dsc: 0.0524 - val_jaccard_similarity: 0.0270 Epoch 6/100 79/79 [==============================] - ETA: 0s - loss: 1.4245 - dsc: 0.0481 - jaccard_similarity: 0.0247 Epoch 6: val_dsc improved from 0.05238 to 0.05462, saving model to /kaggle/working/model/DeepLabV3.ckpt 79/79 [==============================] - 13s 170ms/step - loss: 1.4245 - dsc: 0.0481 - jaccard_similarity: 0.0247 - val_loss: 1.4098 - val_dsc: 0.0546 - val_jaccard_similarity: 0.0282 Epoch 7/100 79/79 [==============================] - ETA: 0s - loss: 1.3946 - dsc: 0.0505 - jaccard_similarity: 0.0260 Epoch 7: val_dsc improved from 0.05462 to 0.06367, saving model to /kaggle/working/model/DeepLabV3.ckpt 79/79 [==============================] - 13s 170ms/step - loss: 1.3946 - dsc: 0.0505 - jaccard_similarity: 0.0260 - val_loss: 1.3397 - val_dsc: 0.0637 - val_jaccard_similarity: 0.0331 Epoch 8/100 79/79 [==============================] - ETA: 0s - loss: 1.3631 - dsc: 0.0537 - jaccard_similarity: 0.0277 Epoch 8: val_dsc improved from 0.06367 to 0.06505, saving model to /kaggle/working/model/DeepLabV3.ckpt 79/79 [==============================] - 13s 165ms/step - loss: 1.3631 - dsc: 0.0537 - jaccard_similarity: 0.0277 - val_loss: 1.3226 - val_dsc: 0.0650 - val_jaccard_similarity: 0.0338 Epoch 9/100 79/79 [==============================] - ETA: 0s - loss: 1.3360 - dsc: 0.0563 - jaccard_similarity: 0.0291 Epoch 9: val_dsc did not improve from 0.06505 79/79 [==============================] - 8s 100ms/step - loss: 1.3360 - dsc: 0.0563 - jaccard_similarity: 0.0291 - val_loss: 1.3331 - val_dsc: 0.0594 - val_jaccard_similarity: 0.0307 Epoch 10/100 79/79 [==============================] - ETA: 0s - loss: 1.3098 - dsc: 0.0597 - jaccard_similarity: 0.0309 Epoch 10: val_dsc improved from 0.06505 to 0.07013, saving model to /kaggle/working/model/DeepLabV3.ckpt ... Epoch 90/100 79/79 [==============================] - ETA: 0s - loss: 0.1356 - dsc: 0.8712 - jaccard_similarity: 0.7743 Epoch 90: val_dsc improved from 0.82367 to 0.82965, saving model to /kaggle/working/model/DeepLabV3.ckpt 79/79 [==============================] - 14s 174ms/step - loss: 0.1356 - dsc: 0.8712 - jaccard_similarity: 0.7743 - val_loss: 0.1887 - val_dsc: 0.8297 - val_jaccard_similarity: 0.7105 Epoch 91/100 79/79 [==============================] - ETA: 0s - loss: 0.1311 - dsc: 0.8757 - jaccard_similarity: 0.7808 Epoch 91: val_dsc did not improve from 0.82965 79/79 [==============================] - 8s 101ms/step - loss: 0.1311 - dsc: 0.8757 - jaccard_similarity: 0.7808 - val_loss: 0.1939 - val_dsc: 0.8257 - val_jaccard_similarity: 0.7049 Epoch 92/100 79/79 [==============================] - ETA: 0s - loss: 0.1307 - dsc: 0.8761 - jaccard_similarity: 0.7821 Epoch 92: val_dsc did not improve from 0.82965 79/79 [==============================] - 8s 100ms/step - loss: 0.1307 - dsc: 0.8761 - jaccard_similarity: 0.7821 - val_loss: 0.1918 - val_dsc: 0.8282 - val_jaccard_similarity: 0.7084 Epoch 93/100 79/79 [==============================] - ETA: 0s - loss: 0.1292 - dsc: 0.8775 - jaccard_similarity: 0.7847 Epoch 93: val_dsc did not improve from 0.82965 79/79 [==============================] - 8s 100ms/step - loss: 0.1292 - dsc: 0.8775 - jaccard_similarity: 0.7847 - val_loss: 0.1986 - val_dsc: 0.8216 - val_jaccard_similarity: 0.6991 Epoch 94/100 79/79 [==============================] - ETA: 0s - loss: 0.1211 - dsc: 0.8854 - jaccard_similarity: 0.7960 Epoch 94: val_dsc did not improve from 0.82965 79/79 [==============================] - 8s 100ms/step - loss: 0.1211 - dsc: 0.8854 - jaccard_similarity: 0.7960 - val_loss: 0.1971 - val_dsc: 0.8220 - val_jaccard_similarity: 0.7000 Epoch 95/100 79/79 [==============================] - ETA: 0s - loss: 0.1232 - dsc: 0.8833 - jaccard_similarity: 0.7934 Epoch 95: val_dsc did not improve from 0.82965 79/79 [==============================] - 8s 100ms/step - loss: 0.1232 - dsc: 0.8833 - jaccard_similarity: 0.7934 - val_loss: 0.2003 - val_dsc: 0.8191 - val_jaccard_similarity: 0.6957 Epoch 96/100 79/79 [==============================] - ETA: 0s - loss: 0.1174 - dsc: 0.8886 - jaccard_similarity: 0.8015 Epoch 96: val_dsc improved from 0.82965 to 0.83262, saving model to /kaggle/working/model/DeepLabV3.ckpt 79/79 [==============================] - 13s 170ms/step - loss: 0.1174 - dsc: 0.8886 - jaccard_similarity: 0.8015 - val_loss: 0.1864 - val_dsc: 0.8326 - val_jaccard_similarity: 0.7149 Epoch 97/100 79/79 [==============================] - ETA: 0s - loss: 0.1136 - dsc: 0.8910 - jaccard_similarity: 0.8057 Epoch 97: val_dsc improved from 0.83262 to 0.84074, saving model to /kaggle/working/model/DeepLabV3.ckpt 79/79 [==============================] - 13s 172ms/step - loss: 0.1136 - dsc: 0.8910 - jaccard_similarity: 0.8057 - val_loss: 0.1788 - val_dsc: 0.8407 - val_jaccard_similarity: 0.7266 Epoch 98/100 79/79 [==============================] - ETA: 0s - loss: 0.1581 - dsc: 0.8527 - jaccard_similarity: 0.7484 Epoch 98: val_dsc did not improve from 0.84074 79/79 [==============================] - 8s 101ms/step - loss: 0.1581 - dsc: 0.8527 - jaccard_similarity: 0.7484 - val_loss: 0.2297 - val_dsc: 0.8082 - val_jaccard_similarity: 0.6801 Epoch 99/100 79/79 [==============================] - ETA: 0s - loss: 0.1327 - dsc: 0.8757 - jaccard_similarity: 0.7814 Epoch 99: val_dsc did not improve from 0.84074 79/79 [==============================] - 8s 100ms/step - loss: 0.1327 - dsc: 0.8757 - jaccard_similarity: 0.7814 - val_loss: 0.1816 - val_dsc: 0.8369 - val_jaccard_similarity: 0.7211 Epoch 100/100 79/79 [==============================] - ETA: 0s - loss: 0.1134 - dsc: 0.8933 - jaccard_similarity: 0.8088 Epoch 100: val_dsc did not improve from 0.84074 79/79 [==============================] - 8s 100ms/step - loss: 0.1134 - dsc: 0.8933 - jaccard_similarity: 0.8088 - val_loss: 0.1869 - val_dsc: 0.8293 - val_jaccard_similarity: 0.7102

Výpisy z průběhu trénování jsem výrazně zkrátil. Stejně by vám asi příliš informací neposkytly. Zajímavější je jistě grafické zobrazení vývoje ztrátové funkce a metrik v průběhu trénování (obě metriky vypadají velice podobně, takže v reálu by mně stačila jenom jedna):

In [12]:


fig, ax = plt.subplots(1, 2, figsize=(16, 4))
sns.lineplot(data={k: history.history[k] for k in ('loss', 'val_loss')}, ax=ax[0])
sns.lineplot(data={k: history.history[k] for k in history.history.keys() if k not in ('loss', 'val_loss')}, ax=ax[1])
plt.show()

Image3

Načtu si zpět do modelu váhy modelu s nejlepšími výsledky na validační sadě, a to proto, že nyní nastupuje fáze ověření výsledků modelu na testovací sadě dat.

In [13]:


model.load_weights(f"/kaggle/working/model/{model.name}.ckpt")

Out[13]:


<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x7ca4ff782410>

Ověření na testovací sadě dat …

Udělám predikci pro celou testovací sadu. Vzhledem k tomu, že výstupem modelu jsou pravděpodobnosti zařazení každého pixelu do třídy nádoru, potřebuji ještě udělat rozdělení na 0/1 – není/je nádor. Nejjednodušší je aplikace prahu o hodnotě 0.5.

In [14]:


y_pred = model.predict(x_test)
y_pred = (y_pred > 0.5).astype(np.float64)
25/25 [==============================] - 1s 35ms/step

Obrázek nikdy neuškodí, takže zde je několik vybraných obrázků z testovací sady, kdy se vedle sebe zobrazují originální snímek, skutečná maska, a na závěr predikce masky provedená modelem:

In [15]:


for _ in range(20):
    i = np.random.randint(len(y_test))
    if y_test[i].sum() > 0:
        plt.figure(figsize=(8, 8))
        plt.subplot(1,3,1)
        plt.imshow(x_test[i])
        plt.title('Original Image')
        plt.subplot(1,3,2)
        plt.imshow(y_test[i])
        plt.title('Original Mask')
        plt.subplot(1,3,3)
        plt.imshow(y_pred[i])
        plt.title('Prediction')
        plt.show()

Image4

Image5

Image6

Image7

Image8

Image9

Image10

Image11

Image12

Pro dokreslení celkového pohledu ještě doplním histogram hodnot metriky Dice Coefficient pro celou testovací sadu:

In [16]:


pred_dice_metric = np.array([dsc(y_test[i], y_pred[i]).numpy() for i in range(len(y_test))])

In [17]:


fig=plt.figure(figsize=(8, 3))
sns.histplot(pred_dice_metric, stat="probability", bins=50)
plt.xlabel("Dice metric")
plt.show()

Image13

A totéž ještě pro metriku Jaccard Index:

In [18]:


pred_jaccard_metric = np.array([jaccard_similarity(y_test[i], y_pred[i]).numpy() for i in range(len(y_test))])

In [19]:


fig=plt.figure(figsize=(8, 3))
sns.histplot(pred_jaccard_metric, stat="probability", bins=50)
plt.xlabel("Jaccard (IoU) metric")
plt.show()

Image14

Sdílet