Segmentace MRI mozku - knihovna TASM

19. 4. 2024 0:00 Jiří Raška

Dnes to bude taková rychlá akce. Mým cílem je vytvořit nástroj na sémantickou segmentaci obrázků (jak jsem již několikrát dělal v předchozích článcích), ale s co možná nejmenším programátorským úsilím. Doposud jsem si vždy nějakou část modelu skládal vlastníma rukama z různých typů konvolučních bloků. Ale ono to může jít i jednodušeji.

Prvním ulehčením práce je tzv. „transfer learning“, kdy jsem kontrakční fázi modelu postavil na existujícím klasifikačním modelu. To bylo v předchozím dílu.

V usnadnění práce můžu pokročit dále, a zkusit najít řešení pro celý segmentační model. Pokud budete chvíli hledat na internetu, pak zjistíte, že tohle už před vámi řešilo několik lidí. A ti své řešení poskytli volně k dispozici. Jen je potřeba si vybrat to správné pro můj aktuální záměr.

Vždy je potřeba začít otázkou, na jakém frameworku chci své dílo postavit. Já používám TensorFlow verze 2, a tomu je potřeba přizpůsobit výběr vhodné knihovny. Jednou z možností, kterou jsem použil v tomto článku, je knihovna TensorFlow Advanced Segmentation Models.

Pokud se podíváte na úvodní stránku projektu, pak zjistíte, že aktuálně podporuje 14 typů modelů pro sémantickou segmentaci a 25 klasifikačních modelů jako „backbone“. Vyjmenovávat je zde nebudu, ale určitě je z čeho vybírat. Já jsem si pro svůj rychlý pokus vybral klasickou kombinaci, a sice U-Net model s VGG16 backbone.

Abych dodržel své předsevzetí z úvodu, že to bude rychlá akce, dovolím si v tomto článku vypustit části kódu, které jsou přípravné. Opakoval bych již poněkolikáté totéž a zbytečně vás zdržoval. Pokud máte zájem, podívejte se prosím na předchozí články.

Výchozím bodem je pořízení dat z datové sady. V poli X mám 3929 vzorků dat v rozlišení 128×128 pixelů se třemi kanály. V poli Y mám odpovídající masky ve stejném rozlišení s jedním kanálem pro odstíny šedi (v tomto případě pouze hodnota 0/1).


X, Y = get_image_data(image_paths)

print(f"X: {X.shape}")
print(f"Y: {Y.shape}")
X: (3929, 128, 128, 3)
Y: (3929, 128, 128, 1)

U-Net Model s VGG16 backbone

Nejdříve si potřebuji knihovnu doinstalovat do svého prostředí. Vzhledem k tomu, že zdroj je k dispozici na GitHub, jde to docela jednoduše:

In [6]:


pip install -U git+https://github.com/JanMarcelKezmann/TensorFlow-Advanced-Segmentation-Models.git
Collecting git+https://github.com/JanMarcelKezmann/TensorFlow-Advanced-Segmentation-Models.git
  Cloning https://github.com/JanMarcelKezmann/TensorFlow-Advanced-Segmentation-Models.git to /tmp/pip-req-build-vmq5ym1x
  Running command git clone --filter=blob:none --quiet https://github.com/JanMarcelKezmann/TensorFlow-Advanced-Segmentation-Models.git /tmp/pip-req-build-vmq5ym1x
  Resolved https://github.com/JanMarcelKezmann/TensorFlow-Advanced-Segmentation-Models.git to commit 3714839ee49759b26e2b0ae3d3a0aa37b00df962
  Preparing metadata (setup.py) ... done
Requirement already satisfied: tensorflow>=2.0.0 in /opt/conda/lib/python3.10/site-packages (from tensorflow_advanced_segmentation_models==0.4.10) (2.15.0)
Requirement already satisfied: numpy in /opt/conda/lib/python3.10/site-packages (from tensorflow_advanced_segmentation_models==0.4.10) (1.24.4)
Requirement already satisfied: matplotlib in /opt/conda/lib/python3.10/site-packages (from tensorflow_advanced_segmentation_models==0.4.10) (3.7.4)
Requirement already satisfied: absl-py>=1.0.0 in /opt/conda/lib/python3.10/site-packages (from tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (1.4.0)
Requirement already satisfied: astunparse>=1.6.0 in /opt/conda/lib/python3.10/site-packages (from tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (1.6.3)
Requirement already satisfied: flatbuffers>=23.5.26 in /opt/conda/lib/python3.10/site-packages (from tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (23.5.26)
Requirement already satisfied: gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 in /opt/conda/lib/python3.10/site-packages (from tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (0.5.4)
Requirement already satisfied: google-pasta>=0.1.1 in /opt/conda/lib/python3.10/site-packages (from tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (0.2.0)
Requirement already satisfied: h5py>=2.9.0 in /opt/conda/lib/python3.10/site-packages (from tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (3.10.0)
Requirement already satisfied: libclang>=13.0.0 in /opt/conda/lib/python3.10/site-packages (from tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (16.0.6)
Requirement already satisfied: ml-dtypes~=0.2.0 in /opt/conda/lib/python3.10/site-packages (from tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (0.2.0)
Requirement already satisfied: opt-einsum>=2.3.2 in /opt/conda/lib/python3.10/site-packages (from tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (3.3.0)
Requirement already satisfied: packaging in /opt/conda/lib/python3.10/site-packages (from tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (21.3)
Requirement already satisfied: protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev,>=3.20.3 in /opt/conda/lib/python3.10/site-packages (from tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (3.20.3)
Requirement already satisfied: setuptools in /opt/conda/lib/python3.10/site-packages (from tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (69.0.3)
Requirement already satisfied: six>=1.12.0 in /opt/conda/lib/python3.10/site-packages (from tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (1.16.0)
Requirement already satisfied: termcolor>=1.1.0 in /opt/conda/lib/python3.10/site-packages (from tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (2.4.0)
Requirement already satisfied: typing-extensions>=3.6.6 in /opt/conda/lib/python3.10/site-packages (from tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (4.9.0)
Requirement already satisfied: wrapt<1.15,>=1.11.0 in /opt/conda/lib/python3.10/site-packages (from tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (1.14.1)
Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /opt/conda/lib/python3.10/site-packages (from tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (0.35.0)
Requirement already satisfied: grpcio<2.0,>=1.24.3 in /opt/conda/lib/python3.10/site-packages (from tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (1.51.1)
Requirement already satisfied: tensorboard<2.16,>=2.15 in /opt/conda/lib/python3.10/site-packages (from tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (2.15.1)
Requirement already satisfied: tensorflow-estimator<2.16,>=2.15.0 in /opt/conda/lib/python3.10/site-packages (from tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (2.15.0)
Requirement already satisfied: keras<2.16,>=2.15.0 in /opt/conda/lib/python3.10/site-packages (from tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (2.15.0)
Requirement already satisfied: contourpy>=1.0.1 in /opt/conda/lib/python3.10/site-packages (from matplotlib->tensorflow_advanced_segmentation_models==0.4.10) (1.2.0)
Requirement already satisfied: cycler>=0.10 in /opt/conda/lib/python3.10/site-packages (from matplotlib->tensorflow_advanced_segmentation_models==0.4.10) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /opt/conda/lib/python3.10/site-packages (from matplotlib->tensorflow_advanced_segmentation_models==0.4.10) (4.47.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /opt/conda/lib/python3.10/site-packages (from matplotlib->tensorflow_advanced_segmentation_models==0.4.10) (1.4.5)
Requirement already satisfied: pillow>=6.2.0 in /opt/conda/lib/python3.10/site-packages (from matplotlib->tensorflow_advanced_segmentation_models==0.4.10) (9.5.0)
Requirement already satisfied: pyparsing>=2.3.1 in /opt/conda/lib/python3.10/site-packages (from matplotlib->tensorflow_advanced_segmentation_models==0.4.10) (3.1.1)
Requirement already satisfied: python-dateutil>=2.7 in /opt/conda/lib/python3.10/site-packages (from matplotlib->tensorflow_advanced_segmentation_models==0.4.10) (2.8.2)
Requirement already satisfied: wheel<1.0,>=0.23.0 in /opt/conda/lib/python3.10/site-packages (from astunparse>=1.6.0->tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (0.42.0)
Requirement already satisfied: google-auth<3,>=1.6.3 in /opt/conda/lib/python3.10/site-packages (from tensorboard<2.16,>=2.15->tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (2.26.1)
Requirement already satisfied: google-auth-oauthlib<2,>=0.5 in /opt/conda/lib/python3.10/site-packages (from tensorboard<2.16,>=2.15->tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (1.2.0)
Requirement already satisfied: markdown>=2.6.8 in /opt/conda/lib/python3.10/site-packages (from tensorboard<2.16,>=2.15->tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (3.5.2)
Requirement already satisfied: requests<3,>=2.21.0 in /opt/conda/lib/python3.10/site-packages (from tensorboard<2.16,>=2.15->tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (2.31.0)
Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /opt/conda/lib/python3.10/site-packages (from tensorboard<2.16,>=2.15->tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (0.7.2)
Requirement already satisfied: werkzeug>=1.0.1 in /opt/conda/lib/python3.10/site-packages (from tensorboard<2.16,>=2.15->tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (3.0.1)
Requirement already satisfied: cachetools<6.0,>=2.0.0 in /opt/conda/lib/python3.10/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.16,>=2.15->tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (4.2.4)
Requirement already satisfied: pyasn1-modules>=0.2.1 in /opt/conda/lib/python3.10/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.16,>=2.15->tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (0.3.0)
Requirement already satisfied: rsa<5,>=3.1.4 in /opt/conda/lib/python3.10/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.16,>=2.15->tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (4.9)
Requirement already satisfied: requests-oauthlib>=0.7.0 in /opt/conda/lib/python3.10/site-packages (from google-auth-oauthlib<2,>=0.5->tensorboard<2.16,>=2.15->tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (1.3.1)
Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.10/site-packages (from requests<3,>=2.21.0->tensorboard<2.16,>=2.15->tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (3.3.2)
Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests<3,>=2.21.0->tensorboard<2.16,>=2.15->tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (3.6)
Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests<3,>=2.21.0->tensorboard<2.16,>=2.15->tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (1.26.18)
Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests<3,>=2.21.0->tensorboard<2.16,>=2.15->tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (2023.11.17)
Requirement already satisfied: MarkupSafe>=2.1.1 in /opt/conda/lib/python3.10/site-packages (from werkzeug>=1.0.1->tensorboard<2.16,>=2.15->tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (2.1.3)
Requirement already satisfied: pyasn1<0.6.0,>=0.4.6 in /opt/conda/lib/python3.10/site-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard<2.16,>=2.15->tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (0.5.1)
Requirement already satisfied: oauthlib>=3.0.0 in /opt/conda/lib/python3.10/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<2,>=0.5->tensorboard<2.16,>=2.15->tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (3.2.2)
Building wheels for collected packages: tensorflow_advanced_segmentation_models
  Building wheel for tensorflow_advanced_segmentation_models (setup.py) ... - done
  Created wheel for tensorflow_advanced_segmentation_models: filename=tensorflow_advanced_segmentation_models-0.4.10-py3-none-any.whl size=70079 sha256=83441477efbaed56da9309fdeb1b466aba2bceef76dd25346aba6db89088a2d0
  Stored in directory: /tmp/pip-ephem-wheel-cache-d3_bh6rd/wheels/81/31/ae/80a6d6e86cebbc1f084947259d621128bf26c78448511b1461
Successfully built tensorflow_advanced_segmentation_models
Installing collected packages: tensorflow_advanced_segmentation_models
Successfully installed tensorflow_advanced_segmentation_models-0.4.10
Note: you may need to restart the kernel to use updated packages.

Po instalaci je možné importovat knihovnu do Python prostředí:

In [7]:


import tensorflow_advanced_segmentation_models as tasm

Vytvoření modelu zabere v podstatě dva řádky kódu. V prvém řádku jsem si vytvořil instanci klasifikačního modelu jako backbone. V mém případě jsem použil model VGG16, ale nenačítám žádné váhy (bude to tedy náhodně inicializovaný model, který budu dále celý trénovat).

Ve druhém řádku pak vytvořím samotný U-Net model. Segmentační masky budou pouze jedné třídy, takže aktivace bude funkcí Sigmoid. A dále ještě musím uvést odkaz na klasifikační model – to jsou ty proměnné base_model a layers. A to je v tom základu vše potřebné.

Ještě jsem vám nechal vykreslit schéma modelu pro lepší představu o výsledku:

In [8]:


BACKBONE_NAME = 'vgg16'
WEIGHTS = None
HEIGHT, WIDTH = IMAGE_SIZE

base_model, layers, layer_names = tasm.create_base_model(name=BACKBONE_NAME, weights=WEIGHTS, height=HEIGHT, width=WIDTH)
model = tasm.UNet(n_classes=1, base_model=base_model, output_layers=layers, height=HEIGHT, width=WIDTH, final_activation='sigmoid', backbone_trainable=True).model()

plot_model(model, to_file=f'/kaggle/working/model/{model.name}.png',show_shapes=True, show_layer_names=True, expand_nested=True)
Image(retina=True, filename=f'/kaggle/working/model/{model.name}.png')

Obvyklá příprava dat:

In [9]:


x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.2)

print(f"x_train: {x_train.shape}, y_train: {y_train.shape}")
print(f"x_test:  {x_test.shape},  y_test:  {y_test.shape}")

x_train: (3143, 128, 128, 3), y_train: (3143, 128, 128, 1)
x_test:  (786, 128, 128, 3),  y_test:  (786, 128, 128, 1)

Překlad a trénování modelu

In [11]:


model.compile(optimizer="adam", loss=bce_dice_loss, metrics=[dice_coefficient, jaccard_index])

In [12]:


MODEL_CHECKPOINT = f"/kaggle/working/model/{model.name}.ckpt"
EPOCHS = 100

callbacks_list = [
    keras.callbacks.EarlyStopping(monitor='val_dice_coefficient', mode='max', patience=20),
    keras.callbacks.ModelCheckpoint(filepath=MODEL_CHECKPOINT, monitor='val_dice_coefficient', save_best_only=True, mode='max', verbose=1)
]

history = model.fit(
    x=x_train,
    y=y_train,
    epochs=EPOCHS, 
    callbacks=callbacks_list, 
    validation_split=0.2,
    verbose=1)

Epoch 1/100
79/79 [==============================] - ETA: 0s - loss: 0.8577 - dice_coefficient: 0.2310 - jaccard_index: 0.1394
Epoch 1: val_dice_coefficient improved from -inf to 0.48029, saving model to /kaggle/working/model/model_1.ckpt
79/79 [==============================] - 56s 579ms/step - loss: 0.8577 - dice_coefficient: 0.2310 - jaccard_index: 0.1394 - val_loss: 0.5569 - val_dice_coefficient: 0.4803 - val_jaccard_index: 0.3230
Epoch 2/100
79/79 [==============================] - ETA: 0s - loss: 0.5409 - dice_coefficient: 0.4935 - jaccard_index: 0.3354
Epoch 2: val_dice_coefficient improved from 0.48029 to 0.56807, saving model to /kaggle/working/model/model_1.ckpt
79/79 [==============================] - 44s 560ms/step - loss: 0.5409 - dice_coefficient: 0.4935 - jaccard_index: 0.3354 - val_loss: 0.4688 - val_dice_coefficient: 0.5681 - val_jaccard_index: 0.4054
Epoch 3/100
79/79 [==============================] - ETA: 0s - loss: 0.4647 - dice_coefficient: 0.5669 - jaccard_index: 0.4036
Epoch 3: val_dice_coefficient improved from 0.56807 to 0.65534, saving model to /kaggle/working/model/model_1.ckpt
79/79 [==============================] - 44s 560ms/step - loss: 0.4647 - dice_coefficient: 0.5669 - jaccard_index: 0.4036 - val_loss: 0.3744 - val_dice_coefficient: 0.6553 - val_jaccard_index: 0.4945
Epoch 4/100
79/79 [==============================] - ETA: 0s - loss: 0.4040 - dice_coefficient: 0.6275 - jaccard_index: 0.4642
Epoch 4: val_dice_coefficient did not improve from 0.65534
79/79 [==============================] - 38s 481ms/step - loss: 0.4040 - dice_coefficient: 0.6275 - jaccard_index: 0.4642 - val_loss: 0.3907 - val_dice_coefficient: 0.6449 - val_jaccard_index: 0.4825
Epoch 5/100
79/79 [==============================] - ETA: 0s - loss: 0.3982 - dice_coefficient: 0.6313 - jaccard_index: 0.4701
Epoch 5: val_dice_coefficient improved from 0.65534 to 0.67167, saving model to /kaggle/working/model/model_1.ckpt
79/79 [==============================] - 44s 558ms/step - loss: 0.3982 - dice_coefficient: 0.6313 - jaccard_index: 0.4701 - val_loss: 0.3584 - val_dice_coefficient: 0.6717 - val_jaccard_index: 0.5178
Epoch 6/100
79/79 [==============================] - ETA: 0s - loss: 0.3724 - dice_coefficient: 0.6558 - jaccard_index: 0.4969
Epoch 6: val_dice_coefficient improved from 0.67167 to 0.67987, saving model to /kaggle/working/model/model_1.ckpt
79/79 [==============================] - 44s 557ms/step - loss: 0.3724 - dice_coefficient: 0.6558 - jaccard_index: 0.4969 - val_loss: 0.3510 - val_dice_coefficient: 0.6799 - val_jaccard_index: 0.5225
Epoch 7/100
79/79 [==============================] - ETA: 0s - loss: 0.3582 - dice_coefficient: 0.6691 - jaccard_index: 0.5102
Epoch 7: val_dice_coefficient improved from 0.67987 to 0.73302, saving model to /kaggle/working/model/model_1.ckpt
79/79 [==============================] - 44s 558ms/step - loss: 0.3582 - dice_coefficient: 0.6691 - jaccard_index: 0.5102 - val_loss: 0.2931 - val_dice_coefficient: 0.7330 - val_jaccard_index: 0.5856
Epoch 8/100
79/79 [==============================] - ETA: 0s - loss: 0.3823 - dice_coefficient: 0.6455 - jaccard_index: 0.4851
Epoch 8: val_dice_coefficient did not improve from 0.73302
79/79 [==============================] - 36s 450ms/step - loss: 0.3823 - dice_coefficient: 0.6455 - jaccard_index: 0.4851 - val_loss: 0.3242 - val_dice_coefficient: 0.7038 - val_jaccard_index: 0.5485
Epoch 9/100
79/79 [==============================] - ETA: 0s - loss: 0.3468 - dice_coefficient: 0.6786 - jaccard_index: 0.5200
Epoch 9: val_dice_coefficient did not improve from 0.73302
79/79 [==============================] - 36s 450ms/step - loss: 0.3468 - dice_coefficient: 0.6786 - jaccard_index: 0.5200 - val_loss: 0.2941 - val_dice_coefficient: 0.7313 - val_jaccard_index: 0.5835
Epoch 10/100
79/79 [==============================] - ETA: 0s - loss: 0.3455 - dice_coefficient: 0.6823 - jaccard_index: 0.5247
Epoch 10: val_dice_coefficient did not improve from 0.73302
79/79 [==============================] - 35s 449ms/step - loss: 0.3455 - dice_coefficient: 0.6823 - jaccard_index: 0.5247 - val_loss: 0.3045 - val_dice_coefficient: 0.7217 - val_jaccard_index: 0.5713

...

Epoch 86/100
79/79 [==============================] - ETA: 0s - loss: 0.1400 - dice_coefficient: 0.8713 - jaccard_index: 0.7744
Epoch 86: val_dice_coefficient did not improve from 0.87112
79/79 [==============================] - 35s 449ms/step - loss: 0.1400 - dice_coefficient: 0.8713 - jaccard_index: 0.7744 - val_loss: 0.1676 - val_dice_coefficient: 0.8509 - val_jaccard_index: 0.7424
Epoch 87/100
79/79 [==============================] - ETA: 0s - loss: 0.1102 - dice_coefficient: 0.8978 - jaccard_index: 0.8163
Epoch 87: val_dice_coefficient did not improve from 0.87112
79/79 [==============================] - 35s 449ms/step - loss: 0.1102 - dice_coefficient: 0.8978 - jaccard_index: 0.8163 - val_loss: 0.1594 - val_dice_coefficient: 0.8594 - val_jaccard_index: 0.7551
Epoch 88/100
79/79 [==============================] - ETA: 0s - loss: 0.1045 - dice_coefficient: 0.9018 - jaccard_index: 0.8240
Epoch 88: val_dice_coefficient did not improve from 0.87112
79/79 [==============================] - 36s 450ms/step - loss: 0.1045 - dice_coefficient: 0.9018 - jaccard_index: 0.8240 - val_loss: 0.1506 - val_dice_coefficient: 0.8683 - val_jaccard_index: 0.7689
Epoch 89/100
79/79 [==============================] - ETA: 0s - loss: 0.1066 - dice_coefficient: 0.9007 - jaccard_index: 0.8218
Epoch 89: val_dice_coefficient did not improve from 0.87112
79/79 [==============================] - 35s 450ms/step - loss: 0.1066 - dice_coefficient: 0.9007 - jaccard_index: 0.8218 - val_loss: 0.1693 - val_dice_coefficient: 0.8492 - val_jaccard_index: 0.7399
Epoch 90/100
79/79 [==============================] - ETA: 0s - loss: 0.1097 - dice_coefficient: 0.8978 - jaccard_index: 0.8168
Epoch 90: val_dice_coefficient did not improve from 0.87112
79/79 [==============================] - 35s 449ms/step - loss: 0.1097 - dice_coefficient: 0.8978 - jaccard_index: 0.8168 - val_loss: 0.1559 - val_dice_coefficient: 0.8622 - val_jaccard_index: 0.7592
Epoch 91/100
79/79 [==============================] - ETA: 0s - loss: 0.1056 - dice_coefficient: 0.9004 - jaccard_index: 0.8232
Epoch 91: val_dice_coefficient did not improve from 0.87112
79/79 [==============================] - 35s 449ms/step - loss: 0.1056 - dice_coefficient: 0.9004 - jaccard_index: 0.8232 - val_loss: 0.1478 - val_dice_coefficient: 0.8693 - val_jaccard_index: 0.7704
Epoch 92/100
79/79 [==============================] - ETA: 0s - loss: 0.2382 - dice_coefficient: 0.7839 - jaccard_index: 0.6576
Epoch 92: val_dice_coefficient did not improve from 0.87112
79/79 [==============================] - 35s 450ms/step - loss: 0.2382 - dice_coefficient: 0.7839 - jaccard_index: 0.6576 - val_loss: 0.2808 - val_dice_coefficient: 0.7477 - val_jaccard_index: 0.5999
Epoch 93/100
79/79 [==============================] - ETA: 0s - loss: 0.2632 - dice_coefficient: 0.7616 - jaccard_index: 0.6214
Epoch 93: val_dice_coefficient did not improve from 0.87112
79/79 [==============================] - 36s 450ms/step - loss: 0.2632 - dice_coefficient: 0.7616 - jaccard_index: 0.6214 - val_loss: 0.2210 - val_dice_coefficient: 0.8013 - val_jaccard_index: 0.6716

Z průběhu trénování je zřejmé, že jsem se již začal dostávat na hranu možností modelu co se týče výsledků validace.

In [13]:


fig, ax = plt.subplots(1, 2, figsize=(16, 4))
sns.lineplot(data={k: history.history[k] for k in ('loss', 'val_loss')}, ax=ax[0])
sns.lineplot(data={k: history.history[k] for k in history.history.keys() if k not in ('loss', 'val_loss')}, ax=ax[1])
plt.show()

__results___17_0.png

In [14]:


model.load_weights(f"/kaggle/working/model/{model.name}.ckpt")

Out[14]:


<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x7d33381abd90>

Testování výsledků modelu

Zde se omezím jen na to základní, aby bylo zřejmé, že se síť něco skutečně naučila …

In [15]:


y_pred = model.predict(x_test)
y_pred = (y_pred > 0.5).astype(np.float64)
25/25 [==============================] - 4s 133ms/step

In [16]:


for _ in range(20):
    i = np.random.randint(len(y_test))
    if y_test[i].sum() > 0:
        plt.figure(figsize=(8, 8))
        plt.subplot(1,3,1)
        plt.imshow(x_test[i])
        plt.title('Original Image')
        plt.subplot(1,3,2)
        plt.imshow(y_test[i])
        plt.title('Original Mask')
        plt.subplot(1,3,3)
        plt.imshow(y_pred[i])
        plt.title('Prediction')
        plt.show()

__results___21_0.png

__results___21_1.png

__results___21_2.png

__results___21_3.png

__results___21_4.png

__results___21_5.png

__results___21_6.png

__results___21_7.png

__results___21_8.png

Sdílet