Hlavní navigace

Segmentace MRI mozku - Transfer Learning

12. 4. 2024 0:00 Jiří Raška

V předchozích dvou článcích jsem vytvářel modely pro sémantickou segmentaci obrázků takzvaně „z čistého stolu“. Dělal jsem to proto, abych si lépe ověřil, jak modely fungují. Popravdě řečeno, takto se to dnes asi běžně nedělá. Obvykle autoři vychází z již existujících modelů, které pak přizpůsobují konkrétnímu úkolu. A to tom bude můj dnešní příspěvek.

Základem každého segmentačního modelu založeného na konvolučních vrstvách je klasický klasifikační model. Ten tvoří základ pro kontrakční fázi modelu, obvykle se označuje jako „backbone“. Nabízí se tedy možnost využít některý z již existujících modelů, a případně také váhy vrstev, které byly zjištěny při trénování na nějaké obecné datové sadě. Jinak řečeno, použít přístup tzv. „transfer learning“, a expanzní fázi doplnit vlastníma rukama. To je tedy náplň následujícího povídání.

In [1]:


import sys
import os
import shutil
import warnings
import glob
import pathlib

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import train_test_split

import tensorflow as tf
import tensorflow.keras as keras

from keras.models import Sequential
from keras import Input, Model
from keras import layers

from keras.utils import plot_model
from IPython.display import Image

import cv2

sns.set_style('darkgrid')

warnings.simplefilter(action='ignore', category=FutureWarning)

In [2]:


def seed_all():
    import random

    random.seed(42)
    np.random.seed(42)
    tf.random.set_seed(42)
    os.environ['PYTHONHASHSEED'] = str(42)
    os.environ['TF_DETERMINISTIC_OPS'] = '1'

seed_all()

if not os.path.exists('/kaggle/working/model'):
    os.makedirs('/kaggle/working/model')

In [3]:


DATA_ROOT = "/kaggle/input/lgg-mri-segmentation/kaggle_3m/"

IMAGE_SIZE = (128, 128)

Příprava dat

Dnes budu pro jednoduchost opět vycházet z již dříve použité datové sady Brain MRI segmentation. Pokud by vás zajímaly bližší informace o tom, jak jsem data načítal, odkážu vás na předchozí článek. Tam to je trochu více popsáno. Na tomto místě si pouze data načtu do dvou polí, X – zdrojové obrázky a Y – segmentační masky jako cíl.

In [4]:


image_paths = []

for path in glob.glob(DATA_ROOT + "**/*_mask.tif"):

    def strip_base(p):
        parts = pathlib.Path(p).parts
        return os.path.join(*parts[-2:])

    image = path.replace("_mask", "")
    if os.path.isfile(image):
        image_paths.append((strip_base(image), strip_base(path)))
    else:
        print("MISSING: ", image, "==>", path)

In [5]:


def get_image_data(image_paths):
    x, y = list(), list()
    for image_path, mask_path in image_paths:
        image = cv2.imread(os.path.join(DATA_ROOT, image_path), flags=cv2.IMREAD_COLOR)
        image = cv2.resize(image, IMAGE_SIZE)
        mask = cv2.imread(os.path.join(DATA_ROOT, mask_path), flags=cv2.IMREAD_GRAYSCALE)
        mask = cv2.resize(mask, IMAGE_SIZE)
        x.append(image)
        y.append(mask)
    return np.array(x) / 255, np.expand_dims(np.array(y) / 255, -1)

X, Y = get_image_data(image_paths)

print(f"X: {X.shape}")
print(f"Y: {Y.shape}")

X: (3929, 128, 128, 3)
Y: (3929, 128, 128, 1)

Z posledního výpisu je zřejmé, že jsem načetl 3929 vzorků v rozlišení 128×128 pixelů. Zdrojové obrázky jsou vytvořeny jako RGB snímky se třemi kanály, cílové masky pak jako odstíny šedi (tedy jeden kanál).

U-Net model založený na klasifikačním modelu VGG16

Klasifikační model VGG16 jsem si vybral záměrně, neboť se jedná o poměrně jednoduchý a přímočarý model, který bude pro mé pokusy plně vyhovovat. Navíc vzhledem k mému omezenému rozlišení zdrojových obrázků ani nebudu využívat výstupy všech konvolučních bloků. Tento model pro mne bude představovat kontrakční fázi segmentačního modelu U-Net. Dále budu potřebovat ještě expanzní fázi. Tu si napíšu vlastníma rukama s využitím Transpose Convolution a konvolučních bloků.

Nejdříve ale potřebuji instanci klasifikačního modelu VGG16. Nejjednodušší způsob je načtení rovnou z distribuce Keras včetně vah zjištěných při trénování modelu jeho autory:

In [6]:


from keras.applications.vgg16 import VGG16, preprocess_input

In [7]:


vgg16 = VGG16(include_top=False, weights="imagenet", input_shape=X.shape[1:])
vgg16.summary()

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg16/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5
58889256/58889256 [==============================] - 2s 0us/step
Model: "vgg16"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 128, 128, 3)]     0         

 block1_conv1 (Conv2D)       (None, 128, 128, 64)      1792      

 block1_conv2 (Conv2D)       (None, 128, 128, 64)      36928     

 block1_pool (MaxPooling2D)  (None, 64, 64, 64)        0         

 block2_conv1 (Conv2D)       (None, 64, 64, 128)       73856     

 block2_conv2 (Conv2D)       (None, 64, 64, 128)       147584    

 block2_pool (MaxPooling2D)  (None, 32, 32, 128)       0         

 block3_conv1 (Conv2D)       (None, 32, 32, 256)       295168    

 block3_conv2 (Conv2D)       (None, 32, 32, 256)       590080    

 block3_conv3 (Conv2D)       (None, 32, 32, 256)       590080    

 block3_pool (MaxPooling2D)  (None, 16, 16, 256)       0         

 block4_conv1 (Conv2D)       (None, 16, 16, 512)       1180160   

 block4_conv2 (Conv2D)       (None, 16, 16, 512)       2359808   

 block4_conv3 (Conv2D)       (None, 16, 16, 512)       2359808   

 block4_pool (MaxPooling2D)  (None, 8, 8, 512)         0         

 block5_conv1 (Conv2D)       (None, 8, 8, 512)         2359808   

 block5_conv2 (Conv2D)       (None, 8, 8, 512)         2359808   

 block5_conv3 (Conv2D)       (None, 8, 8, 512)         2359808   

 block5_pool (MaxPooling2D)  (None, 4, 4, 512)         0         

=================================================================
Total params: 14714688 (56.13 MB)
Trainable params: 14714688 (56.13 MB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________

Výpis modelu uvádím záměrně. Vzhledem k tomu, že budu vytvářet expanzní fázi U-Net modelu, budu potřebovat napojení přes tzv. skip connections z různých úrovní klasifikačního modelu.

Obvykle se využívá výstup poslední konvoluční vrstvy před vrstvou MaxPooling2D. Budu se tohoto doporučení držet, proto budu potřebovat výstupy z vrstev block1_conv2, block2_conv2 , block3_conv3 a block4_conv3 (pro ověření je výše právě ten výpis).

A nyní již mám vše potřebné pro vytvoření U-Net modelu založeného na VGG16 backbone:

In [8]:


def create_model_UNet_VGG16Backbone(X_shape, classes=1, name="UNet_VGG16Backbone"):

    def conv_block(x, *, filters, kernel_size=(3, 3), strides=(1, 1), padding='same', activation='relu', name=""):
        x = layers.Conv2D(filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, kernel_initializer="he_normal", name=f"{name}_conv")(x)
        x = layers.BatchNormalization(name=f"{name}_norm")(x)
        if activation:
            x = layers.Activation(activation, name=f"{name}_acti")(x)
        return x

    def decoder_block(x, s, *, filters, name=""):
        x = layers.Conv2DTranspose(filters, (2, 2), strides=2, padding='same', kernel_initializer="he_normal", name=f"{name}_trans")(x) 
        x = layers.Concatenate(name=f"{name}_concat")([x, s])
        x = conv_block(x, filters=filters, name=f"{name}_conv1")
        x = conv_block(x, filters=filters, name=f"{name}_conv2")
        return x

    # Contracting Path 
    base_model = VGG16(include_top=False, input_shape=X_shape[-3:])
    base_model.trainable = False

    # Bottleneck 
    x = conv_block(base_model.get_layer("block4_conv3").output, filters=512, name="bot1")
    x = conv_block(x, filters=512, name="bot2")

    # Expansive Path 
    x = decoder_block(x, base_model.get_layer("block3_conv3").output, filters=256, name="dec2") 
    x = decoder_block(x, base_model.get_layer("block2_conv2").output, filters=128, name="dec3") 
    x = decoder_block(x, base_model.get_layer("block1_conv2").output, filters=64, name="dec4") 

    # Output 
    outputs = conv_block(x, filters=classes, kernel_size=(1, 1), activation='sigmoid', name="outputs")

    return Model(inputs=base_model.input, outputs=outputs, name=name)

Kontrakční fázi modelu mně představuje instance VGG16. Vzhledem k tomu, že jsem si model načetl včetně vah, nebudu tuto část modelu trénovat, a proto je nastaven příznak trainable=False.

Vstupem pro úzké hrdlo modelu je výstup vrstvy block4_conv3 kontrakční fáze (jen pro připomenutí, poslední konvoluce před pooling). V tomto místě mám rozlišení 16×16 s 256 vlastnostmi. Hrdlo je tvořeno dvěma konvolučními bloky se zachováním rozlišení obrázku, ale se zvětšením počtu vlastností na dvojnásobek.

Expanzní fáze je pak tvořena třemi bloky komplementárními ke kontrakční fázi. Jedná se tedy o bloky s roztažením obrázku na dvojnásobnou velikost vrstvou Conv2DTranspose s krokem konvoluce 2. Následuje zřetězení výstupu vrstvy s tenzorem skip connection. A vše je završeno dvěma konvolučními bloky s redukcí počtu sledovaných vlastností.

Výstupem celého modelu je opět konvoluční blok s rozlišením původního obrázku a aktivační funkcí Sigmoid pro zjištění pravděpodobnosti zařazení pixelu do třídy.

Vyhodnocení modelu

Opět jsem si zde doplnil implementace pro ztrátovou funkci a metriky. Jako ztrátovou funkci používám kombinaci Binary CrossEntropy a Dice Coefficient. Metriky jsem si v tomto případě vybral Dice Coefficient a Jaccard Index (i když by mně asi stačila jenom jedna, neboť jejich výsledky se dost kryjí).

In [9]:


import keras.backend as K
from keras.losses import binary_crossentropy

def dice_coefficient(y_true, y_pred):
    smooth = 1.
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    score = (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
    return score

def dice_loss(y_true, y_pred):
    loss = 1 - dice_coefficient(y_true, y_pred)
    return loss

def bce_dice_loss(y_true, y_pred):
    loss = binary_crossentropy(y_true, y_pred) + dice_loss(y_true, y_pred)
    return loss

def jaccard_index(y_true, y_pred):
    smooth = 1.
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f) + smooth
    union = K.sum((y_true_f + y_pred_f) - (y_true_f * y_pred_f)) + smooth
    return intersection / union

def jaccard_loss(y_true, y_pred):
    return 1 - jaccard_index(y_true, y_pred)

Příprava dat pro vyhodnocení

Ještě si musím rozdělit celou datovou sadu na část pro trénování a pro testování výkonu modelu (poměr 80:20):

In [10]:


x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.2)

print(f"x_train: {x_train.shape}, y_train: {y_train.shape}")
print(f"x_test:  {x_test.shape},  y_test:  {y_test.shape}")

x_train: (3143, 128, 128, 3), y_train: (3143, 128, 128, 1)
x_test:  (786, 128, 128, 3),  y_test:  (786, 128, 128, 1)

Překlad modelu

Vytvořím si instanci modelu a přeložím jej:

In [11]:


model = create_model_UNet_VGG16Backbone(x_test.shape, 1)

model.compile(optimizer="adam", loss=bce_dice_loss, metrics=[dice_coefficient, jaccard_index])

plot_model(model, to_file=f'/kaggle/working/model/{model.name}.png',show_shapes=True, show_layer_names=True)
Image(retina=True, filename=f'/kaggle/working/model/{model.name}.png')

Trénování modelu

Použiji pro mne obvyklý postup s callback funkcemi pro úschovu nejlepší verze modelu a dřívější zastavení trénování v případě, že se výsledky začnou zhoršovat při validaci (což se ale nestalo a využil jsem celou sadu 100 epoch).

In [12]:


MODEL_CHECKPOINT = f"/kaggle/working/model/{model.name}.ckpt"
EPOCHS = 100

callbacks_list = [
    keras.callbacks.EarlyStopping(monitor='val_dice_coefficient', mode='max', patience=20),
    keras.callbacks.ModelCheckpoint(filepath=MODEL_CHECKPOINT, monitor='val_dice_coefficient', save_best_only=True, mode='max', verbose=1)
]

history = model.fit(
    x=x_train,
    y=y_train,
    epochs=EPOCHS, 
    callbacks=callbacks_list, 
    validation_split=0.2,
    verbose=1)

Epoch 1/100
I0000 00:00:1709056488.748993      67 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
79/79 [==============================] - ETA: 0s - loss: 1.6200 - dice_coefficient: 0.0358 - jaccard_index: 0.0183
Epoch 1: val_dice_coefficient improved from -inf to 0.03375, saving model to /kaggle/working/model/UNet_VGG16Backbone.ckpt
79/79 [==============================] - 31s 267ms/step - loss: 1.6200 - dice_coefficient: 0.0358 - jaccard_index: 0.0183 - val_loss: 2.6663 - val_dice_coefficient: 0.0337 - val_jaccard_index: 0.0172
Epoch 2/100
79/79 [==============================] - ETA: 0s - loss: 1.5630 - dice_coefficient: 0.0393 - jaccard_index: 0.0201
Epoch 2: val_dice_coefficient improved from 0.03375 to 0.04521, saving model to /kaggle/working/model/UNet_VGG16Backbone.ckpt
79/79 [==============================] - 19s 244ms/step - loss: 1.5630 - dice_coefficient: 0.0393 - jaccard_index: 0.0201 - val_loss: 1.6512 - val_dice_coefficient: 0.0452 - val_jaccard_index: 0.0232
Epoch 3/100
79/79 [==============================] - ETA: 0s - loss: 1.5218 - dice_coefficient: 0.0421 - jaccard_index: 0.0216
Epoch 3: val_dice_coefficient improved from 0.04521 to 0.04943, saving model to /kaggle/working/model/UNet_VGG16Backbone.ckpt
79/79 [==============================] - 20s 249ms/step - loss: 1.5218 - dice_coefficient: 0.0421 - jaccard_index: 0.0216 - val_loss: 1.4917 - val_dice_coefficient: 0.0494 - val_jaccard_index: 0.0254
Epoch 4/100
79/79 [==============================] - ETA: 0s - loss: 1.4834 - dice_coefficient: 0.0448 - jaccard_index: 0.0229
Epoch 4: val_dice_coefficient improved from 0.04943 to 0.05024, saving model to /kaggle/working/model/UNet_VGG16Backbone.ckpt
79/79 [==============================] - 20s 248ms/step - loss: 1.4834 - dice_coefficient: 0.0448 - jaccard_index: 0.0229 - val_loss: 1.4353 - val_dice_coefficient: 0.0502 - val_jaccard_index: 0.0259
Epoch 5/100
79/79 [==============================] - ETA: 0s - loss: 1.4490 - dice_coefficient: 0.0469 - jaccard_index: 0.0241
Epoch 5: val_dice_coefficient improved from 0.05024 to 0.05573, saving model to /kaggle/working/model/UNet_VGG16Backbone.ckpt
79/79 [==============================] - 20s 248ms/step - loss: 1.4490 - dice_coefficient: 0.0469 - jaccard_index: 0.0241 - val_loss: 1.4291 - val_dice_coefficient: 0.0557 - val_jaccard_index: 0.0288
Epoch 6/100
79/79 [==============================] - ETA: 0s - loss: 1.4185 - dice_coefficient: 0.0492 - jaccard_index: 0.0253
Epoch 6: val_dice_coefficient did not improve from 0.05573
79/79 [==============================] - 14s 176ms/step - loss: 1.4185 - dice_coefficient: 0.0492 - jaccard_index: 0.0253 - val_loss: 1.7762 - val_dice_coefficient: 0.0477 - val_jaccard_index: 0.0246
Epoch 7/100
79/79 [==============================] - ETA: 0s - loss: 1.3861 - dice_coefficient: 0.0520 - jaccard_index: 0.0268
Epoch 7: val_dice_coefficient improved from 0.05573 to 0.06047, saving model to /kaggle/working/model/UNet_VGG16Backbone.ckpt
79/79 [==============================] - 19s 243ms/step - loss: 1.3861 - dice_coefficient: 0.0520 - jaccard_index: 0.0268 - val_loss: 1.3667 - val_dice_coefficient: 0.0605 - val_jaccard_index: 0.0313
Epoch 8/100
79/79 [==============================] - ETA: 0s - loss: 1.3565 - dice_coefficient: 0.0547 - jaccard_index: 0.0282
Epoch 8: val_dice_coefficient improved from 0.06047 to 0.06562, saving model to /kaggle/working/model/UNet_VGG16Backbone.ckpt
79/79 [==============================] - 20s 250ms/step - loss: 1.3565 - dice_coefficient: 0.0547 - jaccard_index: 0.0282 - val_loss: 1.3166 - val_dice_coefficient: 0.0656 - val_jaccard_index: 0.0341
Epoch 9/100
79/79 [==============================] - ETA: 0s - loss: 1.3301 - dice_coefficient: 0.0572 - jaccard_index: 0.0296
Epoch 9: val_dice_coefficient did not improve from 0.06562
79/79 [==============================] - 14s 176ms/step - loss: 1.3301 - dice_coefficient: 0.0572 - jaccard_index: 0.0296 - val_loss: 1.3041 - val_dice_coefficient: 0.0646 - val_jaccard_index: 0.0335
Epoch 10/100
79/79 [==============================] - ETA: 0s - loss: 1.3042 - dice_coefficient: 0.0605 - jaccard_index: 0.0313
Epoch 10: val_dice_coefficient improved from 0.06562 to 0.06736, saving model to /kaggle/working/model/UNet_VGG16Backbone.ckpt
79/79 [==============================] - 19s 248ms/step - loss: 1.3042 - dice_coefficient: 0.0605 - jaccard_index: 0.0313 - val_loss: 1.2948 - val_dice_coefficient: 0.0674 - val_jaccard_index: 0.0350

...

Epoch 90/100
79/79 [==============================] - ETA: 0s - loss: 0.1222 - dice_coefficient: 0.8841 - jaccard_index: 0.7946
Epoch 90: val_dice_coefficient improved from 0.82025 to 0.82255, saving model to /kaggle/working/model/UNet_VGG16Backbone.ckpt
79/79 [==============================] - 20s 251ms/step - loss: 0.1222 - dice_coefficient: 0.8841 - jaccard_index: 0.7946 - val_loss: 0.2062 - val_dice_coefficient: 0.8226 - val_jaccard_index: 0.7001
Epoch 91/100
79/79 [==============================] - ETA: 0s - loss: 0.1175 - dice_coefficient: 0.8885 - jaccard_index: 0.8013
Epoch 91: val_dice_coefficient improved from 0.82255 to 0.82593, saving model to /kaggle/working/model/UNet_VGG16Backbone.ckpt
79/79 [==============================] - 20s 249ms/step - loss: 0.1175 - dice_coefficient: 0.8885 - jaccard_index: 0.8013 - val_loss: 0.2047 - val_dice_coefficient: 0.8259 - val_jaccard_index: 0.7049
Epoch 92/100
79/79 [==============================] - ETA: 0s - loss: 0.1202 - dice_coefficient: 0.8863 - jaccard_index: 0.7982
Epoch 92: val_dice_coefficient improved from 0.82593 to 0.83059, saving model to /kaggle/working/model/UNet_VGG16Backbone.ckpt
79/79 [==============================] - 20s 249ms/step - loss: 0.1202 - dice_coefficient: 0.8863 - jaccard_index: 0.7982 - val_loss: 0.1978 - val_dice_coefficient: 0.8306 - val_jaccard_index: 0.7115
Epoch 93/100
79/79 [==============================] - ETA: 0s - loss: 0.1188 - dice_coefficient: 0.8877 - jaccard_index: 0.8009
Epoch 93: val_dice_coefficient did not improve from 0.83059
79/79 [==============================] - 14s 176ms/step - loss: 0.1188 - dice_coefficient: 0.8877 - jaccard_index: 0.8009 - val_loss: 0.2024 - val_dice_coefficient: 0.8245 - val_jaccard_index: 0.7026
Epoch 94/100
79/79 [==============================] - ETA: 0s - loss: 0.1095 - dice_coefficient: 0.8965 - jaccard_index: 0.8139
Epoch 94: val_dice_coefficient did not improve from 0.83059
79/79 [==============================] - 14s 176ms/step - loss: 0.1095 - dice_coefficient: 0.8965 - jaccard_index: 0.8139 - val_loss: 0.2007 - val_dice_coefficient: 0.8264 - val_jaccard_index: 0.7057
Epoch 95/100
79/79 [==============================] - ETA: 0s - loss: 0.1102 - dice_coefficient: 0.8958 - jaccard_index: 0.8135
Epoch 95: val_dice_coefficient improved from 0.83059 to 0.83182, saving model to /kaggle/working/model/UNet_VGG16Backbone.ckpt
79/79 [==============================] - 20s 250ms/step - loss: 0.1102 - dice_coefficient: 0.8958 - jaccard_index: 0.8135 - val_loss: 0.1987 - val_dice_coefficient: 0.8318 - val_jaccard_index: 0.7133
Epoch 96/100
79/79 [==============================] - ETA: 0s - loss: 0.1051 - dice_coefficient: 0.9003 - jaccard_index: 0.8206
Epoch 96: val_dice_coefficient improved from 0.83182 to 0.83433, saving model to /kaggle/working/model/UNet_VGG16Backbone.ckpt
79/79 [==============================] - 19s 244ms/step - loss: 0.1051 - dice_coefficient: 0.9003 - jaccard_index: 0.8206 - val_loss: 0.1952 - val_dice_coefficient: 0.8343 - val_jaccard_index: 0.7170
Epoch 97/100
79/79 [==============================] - ETA: 0s - loss: 0.1092 - dice_coefficient: 0.8957 - jaccard_index: 0.8136
Epoch 97: val_dice_coefficient did not improve from 0.83433
79/79 [==============================] - 14s 176ms/step - loss: 0.1092 - dice_coefficient: 0.8957 - jaccard_index: 0.8136 - val_loss: 0.1974 - val_dice_coefficient: 0.8312 - val_jaccard_index: 0.7124
Epoch 98/100
79/79 [==============================] - ETA: 0s - loss: 0.1062 - dice_coefficient: 0.8998 - jaccard_index: 0.8203
Epoch 98: val_dice_coefficient did not improve from 0.83433
79/79 [==============================] - 14s 176ms/step - loss: 0.1062 - dice_coefficient: 0.8998 - jaccard_index: 0.8203 - val_loss: 0.1964 - val_dice_coefficient: 0.8307 - val_jaccard_index: 0.7119
Epoch 99/100
79/79 [==============================] - ETA: 0s - loss: 0.0996 - dice_coefficient: 0.9061 - jaccard_index: 0.8303
Epoch 99: val_dice_coefficient improved from 0.83433 to 0.83926, saving model to /kaggle/working/model/UNet_VGG16Backbone.ckpt
79/79 [==============================] - 20s 251ms/step - loss: 0.0996 - dice_coefficient: 0.9061 - jaccard_index: 0.8303 - val_loss: 0.1908 - val_dice_coefficient: 0.8393 - val_jaccard_index: 0.7243
Epoch 100/100
79/79 [==============================] - ETA: 0s - loss: 0.0943 - dice_coefficient: 0.9111 - jaccard_index: 0.8381
Epoch 100: val_dice_coefficient improved from 0.83926 to 0.84058, saving model to /kaggle/working/model/UNet_VGG16Backbone.ckpt
79/79 [==============================] - 20s 250ms/step - loss: 0.0943 - dice_coefficient: 0.9111 - jaccard_index: 0.8381 - val_loss: 0.1873 - val_dice_coefficient: 0.8406 - val_jaccard_index: 0.7261

A takto vypadal průběh ztrátové funkce a metrik při trénování modelu:

In [13]:


fig, ax = plt.subplots(1, 2, figsize=(16, 4))
sns.lineplot(data={k: history.history[k] for k in ('loss', 'val_loss')}, ax=ax[0])
sns.lineplot(data={k: history.history[k] for k in history.history.keys() if k not in ('loss', 'val_loss')}, ax=ax[1])
plt.show()

__results___18_0.png

Načtu si váhy modelu, který vyšel při validaci nejlépe. Budu doufat, že bude mít také nejlepší výsledky při testování:

In [14]:


model.load_weights(f"/kaggle/working/model/{model.name}.ckpt")

Out[14]:


<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x79015c4e3a90>

Testování výsledků modelu

Udělám tedy predikci pro celou testovací sadu. Výsledek pak převedu z pravděpodobnosti na hodnoty ano/ne (dělící hladinou v mém případě je pravděpodobnost 50%):

In [15]:


y_pred = model.predict(x_test)
y_pred = (y_pred > 0.5).astype(np.float64)
25/25 [==============================] - 2s 79ms/step

Následuje ukázka několika náhodně vybraných vzorků, neboť obrázkem nikdy neurazíš:

In [16]:


for _ in range(20):
    i = np.random.randint(len(y_test))
    if y_test[i].sum() > 0:
        plt.figure(figsize=(8, 8))
        plt.subplot(1,3,1)
        plt.imshow(x_test[i])
        plt.title('Original Image')
        plt.subplot(1,3,2)
        plt.imshow(y_test[i])
        plt.title('Original Mask')
        plt.subplot(1,3,3)
        plt.imshow(y_pred[i])
        plt.title('Prediction')
        plt.show()

__results___22_0.png

__results___22_1.png

__results___22_2.png

__results___22_3.png

__results___22_4.png

:__results___22_5.png

__results___22_6.png

__results___22_7.png

__results___22_8.png

Pro dokreslení výsledků modelu ještě doplňuji dva histogramy pro metriky Dice Coefficient a Jaccard Index stejně, jak tomu bylo v předchozích dvou článcích:

In [17]:


pred_dice_metric = np.array([dice_coefficient(y_test[i], y_pred[i]).numpy() for i in range(len(y_test))])

In [18]:


fig=plt.figure(figsize=(8, 3))
sns.histplot(pred_dice_metric, stat="probability", bins=50)
plt.xlabel("Dice metric")
plt.show()

__results___24_0.png

In [19]:


pred_jaccard_metric = np.array([jaccard_index(y_test[i], y_pred[i]).numpy() for i in range(len(y_test))])

In [20]:


fig=plt.figure(figsize=(8, 3))
sns.histplot(pred_jaccard_metric, stat="probability", bins=50)
plt.xlabel("Jaccard (IoU) metric")
plt.show()

__results___26_0.png

Sdílet