Výchozím bodem je pořízení dat z datové sady. V poli X mám 3929 vzorků dat v rozlišení 128×128 pixelů se třemi kanály. V poli Y mám odpovídající masky ve stejném rozlišení s jedním kanálem pro odstíny šedi (v tomto případě pouze hodnota 0/1).

X, Y = get_image_data(image_paths) print(f"X: {X.shape}") print(f"Y: {Y.shape}") X: (3929, 128, 128, 3) Y: (3929, 128, 128, 1)

U-Net Model s VGG16 backbone

Nejdříve si potřebuji knihovnu doinstalovat do svého prostředí. Vzhledem k tomu, že zdroj je k dispozici na GitHub, jde to docela jednoduše:

Po instalaci je možné importovat knihovnu do Python prostředí:

import tensorflow_advanced_segmentation_models as tasm

Vytvoření modelu zabere v podstatě dva řádky kódu. V prvém řádku jsem si vytvořil instanci klasifikačního modelu jako backbone. V mém případě jsem použil model VGG16, ale nenačítám žádné váhy (bude to tedy náhodně inicializovaný model, který budu dále celý trénovat).

Ve druhém řádku pak vytvořím samotný U-Net model. Segmentační masky budou pouze jedné třídy, takže aktivace bude funkcí Sigmoid. A dále ještě musím uvést odkaz na klasifikační model – to jsou ty proměnné base_model a layers. A to je v tom základu vše potřebné.

Ještě jsem vám nechal vykreslit schéma modelu pro lepší představu o výsledku:

BACKBONE_NAME = 'vgg16' WEIGHTS = None HEIGHT, WIDTH = IMAGE_SIZE base_model, layers, layer_names = tasm.create_base_model(name=BACKBONE_NAME, weights=WEIGHTS, height=HEIGHT, width=WIDTH) model = tasm.UNet(n_classes=1, base_model=base_model, output_layers=layers, height=HEIGHT, width=WIDTH, final_activation='sigmoid', backbone_trainable=True).model() plot_model(model, to_file=f'/kaggle/working/model/{model.name}.png',show_shapes=True, show_layer_names=True, expand_nested=True) Image(retina=True, filename=f'/kaggle/working/model/{model.name}.png')

Obvyklá příprava dat:

x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.2) print(f"x_train: {x_train.shape}, y_train: {y_train.shape}") print(f"x_test: {x_test.shape}, y_test: {y_test.shape}") x_train: (3143, 128, 128, 3), y_train: (3143, 128, 128, 1) x_test: (786, 128, 128, 3), y_test: (786, 128, 128, 1)

Překlad a trénování modelu

model.compile(optimizer="adam", loss=bce_dice_loss, metrics=[dice_coefficient, jaccard_index])

MODEL_CHECKPOINT = f"/kaggle/working/model/{model.name}.ckpt" EPOCHS = 100 callbacks_list = [ keras.callbacks.EarlyStopping(monitor='val_dice_coefficient', mode='max', patience=20), keras.callbacks.ModelCheckpoint(filepath=MODEL_CHECKPOINT, monitor='val_dice_coefficient', save_best_only=True, mode='max', verbose=1) ] history = model.fit( x=x_train, y=y_train, epochs=EPOCHS, callbacks=callbacks_list, validation_split=0.2, verbose=1) Epoch 1/100 79/79 [==============================] - ETA: 0s - loss: 0.8577 - dice_coefficient: 0.2310 - jaccard_index: 0.1394 Epoch 1: val_dice_coefficient improved from -inf to 0.48029, saving model to /kaggle/working/model/model_1.ckpt 79/79 [==============================] - 56s 579ms/step - loss: 0.8577 - dice_coefficient: 0.2310 - jaccard_index: 0.1394 - val_loss: 0.5569 - val_dice_coefficient: 0.4803 - val_jaccard_index: 0.3230 Epoch 2/100 79/79 [==============================] - ETA: 0s - loss: 0.5409 - dice_coefficient: 0.4935 - jaccard_index: 0.3354 Epoch 2: val_dice_coefficient improved from 0.48029 to 0.56807, saving model to /kaggle/working/model/model_1.ckpt 79/79 [==============================] - 44s 560ms/step - loss: 0.5409 - dice_coefficient: 0.4935 - jaccard_index: 0.3354 - val_loss: 0.4688 - val_dice_coefficient: 0.5681 - val_jaccard_index: 0.4054 Epoch 3/100 79/79 [==============================] - ETA: 0s - loss: 0.4647 - dice_coefficient: 0.5669 - jaccard_index: 0.4036 Epoch 3: val_dice_coefficient improved from 0.56807 to 0.65534, saving model to /kaggle/working/model/model_1.ckpt 79/79 [==============================] - 44s 560ms/step - loss: 0.4647 - dice_coefficient: 0.5669 - jaccard_index: 0.4036 - val_loss: 0.3744 - val_dice_coefficient: 0.6553 - val_jaccard_index: 0.4945 Epoch 4/100 79/79 [==============================] - ETA: 0s - loss: 0.4040 - dice_coefficient: 0.6275 - jaccard_index: 0.4642 Epoch 4: val_dice_coefficient did not improve from 0.65534 79/79 [==============================] - 38s 481ms/step - loss: 0.4040 - dice_coefficient: 0.6275 - jaccard_index: 0.4642 - val_loss: 0.3907 - val_dice_coefficient: 0.6449 - val_jaccard_index: 0.4825 Epoch 5/100 79/79 [==============================] - ETA: 0s - loss: 0.3982 - dice_coefficient: 0.6313 - jaccard_index: 0.4701 Epoch 5: val_dice_coefficient improved from 0.65534 to 0.67167, saving model to /kaggle/working/model/model_1.ckpt 79/79 [==============================] - 44s 558ms/step - loss: 0.3982 - dice_coefficient: 0.6313 - jaccard_index: 0.4701 - val_loss: 0.3584 - val_dice_coefficient: 0.6717 - val_jaccard_index: 0.5178 Epoch 6/100 79/79 [==============================] - ETA: 0s - loss: 0.3724 - dice_coefficient: 0.6558 - jaccard_index: 0.4969 Epoch 6: val_dice_coefficient improved from 0.67167 to 0.67987, saving model to /kaggle/working/model/model_1.ckpt 79/79 [==============================] - 44s 557ms/step - loss: 0.3724 - dice_coefficient: 0.6558 - jaccard_index: 0.4969 - val_loss: 0.3510 - val_dice_coefficient: 0.6799 - val_jaccard_index: 0.5225 Epoch 7/100 79/79 [==============================] - ETA: 0s - loss: 0.3582 - dice_coefficient: 0.6691 - jaccard_index: 0.5102 Epoch 7: val_dice_coefficient improved from 0.67987 to 0.73302, saving model to /kaggle/working/model/model_1.ckpt 79/79 [==============================] - 44s 558ms/step - loss: 0.3582 - dice_coefficient: 0.6691 - jaccard_index: 0.5102 - val_loss: 0.2931 - val_dice_coefficient: 0.7330 - val_jaccard_index: 0.5856 Epoch 8/100 79/79 [==============================] - ETA: 0s - loss: 0.3823 - dice_coefficient: 0.6455 - jaccard_index: 0.4851 Epoch 8: val_dice_coefficient did not improve from 0.73302 79/79 [==============================] - 36s 450ms/step - loss: 0.3823 - dice_coefficient: 0.6455 - jaccard_index: 0.4851 - val_loss: 0.3242 - val_dice_coefficient: 0.7038 - val_jaccard_index: 0.5485 Epoch 9/100 79/79 [==============================] - ETA: 0s - loss: 0.3468 - dice_coefficient: 0.6786 - jaccard_index: 0.5200 Epoch 9: val_dice_coefficient did not improve from 0.73302 79/79 [==============================] - 36s 450ms/step - loss: 0.3468 - dice_coefficient: 0.6786 - jaccard_index: 0.5200 - val_loss: 0.2941 - val_dice_coefficient: 0.7313 - val_jaccard_index: 0.5835 Epoch 10/100 79/79 [==============================] - ETA: 0s - loss: 0.3455 - dice_coefficient: 0.6823 - jaccard_index: 0.5247 Epoch 10: val_dice_coefficient did not improve from 0.73302 79/79 [==============================] - 35s 449ms/step - loss: 0.3455 - dice_coefficient: 0.6823 - jaccard_index: 0.5247 - val_loss: 0.3045 - val_dice_coefficient: 0.7217 - val_jaccard_index: 0.5713 ... Epoch 86/100 79/79 [==============================] - ETA: 0s - loss: 0.1400 - dice_coefficient: 0.8713 - jaccard_index: 0.7744 Epoch 86: val_dice_coefficient did not improve from 0.87112 79/79 [==============================] - 35s 449ms/step - loss: 0.1400 - dice_coefficient: 0.8713 - jaccard_index: 0.7744 - val_loss: 0.1676 - val_dice_coefficient: 0.8509 - val_jaccard_index: 0.7424 Epoch 87/100 79/79 [==============================] - ETA: 0s - loss: 0.1102 - dice_coefficient: 0.8978 - jaccard_index: 0.8163 Epoch 87: val_dice_coefficient did not improve from 0.87112 79/79 [==============================] - 35s 449ms/step - loss: 0.1102 - dice_coefficient: 0.8978 - jaccard_index: 0.8163 - val_loss: 0.1594 - val_dice_coefficient: 0.8594 - val_jaccard_index: 0.7551 Epoch 88/100 79/79 [==============================] - ETA: 0s - loss: 0.1045 - dice_coefficient: 0.9018 - jaccard_index: 0.8240 Epoch 88: val_dice_coefficient did not improve from 0.87112 79/79 [==============================] - 36s 450ms/step - loss: 0.1045 - dice_coefficient: 0.9018 - jaccard_index: 0.8240 - val_loss: 0.1506 - val_dice_coefficient: 0.8683 - val_jaccard_index: 0.7689 Epoch 89/100 79/79 [==============================] - ETA: 0s - loss: 0.1066 - dice_coefficient: 0.9007 - jaccard_index: 0.8218 Epoch 89: val_dice_coefficient did not improve from 0.87112 79/79 [==============================] - 35s 450ms/step - loss: 0.1066 - dice_coefficient: 0.9007 - jaccard_index: 0.8218 - val_loss: 0.1693 - val_dice_coefficient: 0.8492 - val_jaccard_index: 0.7399 Epoch 90/100 79/79 [==============================] - ETA: 0s - loss: 0.1097 - dice_coefficient: 0.8978 - jaccard_index: 0.8168 Epoch 90: val_dice_coefficient did not improve from 0.87112 79/79 [==============================] - 35s 449ms/step - loss: 0.1097 - dice_coefficient: 0.8978 - jaccard_index: 0.8168 - val_loss: 0.1559 - val_dice_coefficient: 0.8622 - val_jaccard_index: 0.7592 Epoch 91/100 79/79 [==============================] - ETA: 0s - loss: 0.1056 - dice_coefficient: 0.9004 - jaccard_index: 0.8232 Epoch 91: val_dice_coefficient did not improve from 0.87112 79/79 [==============================] - 35s 449ms/step - loss: 0.1056 - dice_coefficient: 0.9004 - jaccard_index: 0.8232 - val_loss: 0.1478 - val_dice_coefficient: 0.8693 - val_jaccard_index: 0.7704 Epoch 92/100 79/79 [==============================] - ETA: 0s - loss: 0.2382 - dice_coefficient: 0.7839 - jaccard_index: 0.6576 Epoch 92: val_dice_coefficient did not improve from 0.87112 79/79 [==============================] - 35s 450ms/step - loss: 0.2382 - dice_coefficient: 0.7839 - jaccard_index: 0.6576 - val_loss: 0.2808 - val_dice_coefficient: 0.7477 - val_jaccard_index: 0.5999 Epoch 93/100 79/79 [==============================] - ETA: 0s - loss: 0.2632 - dice_coefficient: 0.7616 - jaccard_index: 0.6214 Epoch 93: val_dice_coefficient did not improve from 0.87112 79/79 [==============================] - 36s 450ms/step - loss: 0.2632 - dice_coefficient: 0.7616 - jaccard_index: 0.6214 - val_loss: 0.2210 - val_dice_coefficient: 0.8013 - val_jaccard_index: 0.6716

Z průběhu trénování je zřejmé, že jsem se již začal dostávat na hranu možností modelu co se týče výsledků validace.

fig, ax = plt.subplots(1, 2, figsize=(16, 4)) sns.lineplot(data={k: history.history[k] for k in ('loss', 'val_loss')}, ax=ax[0]) sns.lineplot(data={k: history.history[k] for k in history.history.keys() if k not in ('loss', 'val_loss')}, ax=ax[1]) plt.show()

model.load_weights(f"/kaggle/working/model/{model.name}.ckpt")

Testování výsledků modelu

Zde se omezím jen na to základní, aby bylo zřejmé, že se síť něco skutečně naučila …

y_pred = model.predict(x_test) y_pred = (y_pred > 0.5).astype(np.float64) 25/25 [==============================] - 4s 133ms/step

