Dnes to bude taková rychlá akce. Mým cílem je vytvořit nástroj na sémantickou segmentaci obrázků (jak jsem již několikrát dělal v předchozích článcích), ale s co možná nejmenším programátorským úsilím. Doposud jsem si vždy nějakou část modelu skládal vlastníma rukama z různých typů konvolučních bloků. Ale ono to může jít i jednodušeji.
Prvním ulehčením práce je tzv. „transfer learning“, kdy jsem kontrakční fázi modelu postavil na existujícím klasifikačním modelu. To bylo v předchozím dílu.
V usnadnění práce můžu pokročit dále, a zkusit najít řešení pro celý segmentační model. Pokud budete chvíli hledat na internetu, pak zjistíte, že tohle už před vámi řešilo několik lidí. A ti své řešení poskytli volně k dispozici. Jen je potřeba si vybrat to správné pro můj aktuální záměr.
Vždy je potřeba začít otázkou, na jakém frameworku chci své dílo postavit. Já používám TensorFlow verze 2, a tomu je potřeba přizpůsobit výběr vhodné knihovny. Jednou z možností, kterou jsem použil v tomto článku, je knihovna TensorFlow Advanced Segmentation Models.
Pokud se podíváte na úvodní stránku projektu, pak zjistíte, že aktuálně podporuje 14 typů modelů pro sémantickou segmentaci a 25 klasifikačních modelů jako „backbone“. Vyjmenovávat je zde nebudu, ale určitě je z čeho vybírat. Já jsem si pro svůj rychlý pokus vybral klasickou kombinaci, a sice U-Net model s VGG16 backbone.
Abych dodržel své předsevzetí z úvodu, že to bude rychlá akce, dovolím si v tomto článku vypustit části kódu, které jsou přípravné. Opakoval bych již poněkolikáté totéž a zbytečně vás zdržoval. Pokud máte zájem, podívejte se prosím na předchozí články.
Výchozím bodem je pořízení dat z datové sady. V poli X mám 3929 vzorků dat v rozlišení 128×128 pixelů se třemi kanály. V poli Y mám odpovídající masky ve stejném rozlišení s jedním kanálem pro odstíny šedi (v tomto případě pouze hodnota 0/1).
X, Y = get_image_data(image_paths) print(f"X: {X.shape}") print(f"Y: {Y.shape}") X: (3929, 128, 128, 3) Y: (3929, 128, 128, 1)
Nejdříve si potřebuji knihovnu doinstalovat do svého prostředí. Vzhledem k tomu, že zdroj je k dispozici na GitHub, jde to docela jednoduše:
In [6]:
pip install -U git+https://github.com/JanMarcelKezmann/TensorFlow-Advanced-Segmentation-Models.git Collecting git+https://github.com/JanMarcelKezmann/TensorFlow-Advanced-Segmentation-Models.git Cloning https://github.com/JanMarcelKezmann/TensorFlow-Advanced-Segmentation-Models.git to /tmp/pip-req-build-vmq5ym1x Running command git clone --filter=blob:none --quiet https://github.com/JanMarcelKezmann/TensorFlow-Advanced-Segmentation-Models.git /tmp/pip-req-build-vmq5ym1x Resolved https://github.com/JanMarcelKezmann/TensorFlow-Advanced-Segmentation-Models.git to commit 3714839ee49759b26e2b0ae3d3a0aa37b00df962 Preparing metadata (setup.py) ... done Requirement already satisfied: tensorflow>=2.0.0 in /opt/conda/lib/python3.10/site-packages (from tensorflow_advanced_segmentation_models==0.4.10) (2.15.0) Requirement already satisfied: numpy in /opt/conda/lib/python3.10/site-packages (from tensorflow_advanced_segmentation_models==0.4.10) (1.24.4) Requirement already satisfied: matplotlib in /opt/conda/lib/python3.10/site-packages (from tensorflow_advanced_segmentation_models==0.4.10) (3.7.4) Requirement already satisfied: absl-py>=1.0.0 in /opt/conda/lib/python3.10/site-packages (from tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (1.4.0) Requirement already satisfied: astunparse>=1.6.0 in /opt/conda/lib/python3.10/site-packages (from tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (1.6.3) Requirement already satisfied: flatbuffers>=23.5.26 in /opt/conda/lib/python3.10/site-packages (from tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (23.5.26) Requirement already satisfied: gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 in /opt/conda/lib/python3.10/site-packages (from tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (0.5.4) Requirement already satisfied: google-pasta>=0.1.1 in /opt/conda/lib/python3.10/site-packages (from tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (0.2.0) Requirement already satisfied: h5py>=2.9.0 in /opt/conda/lib/python3.10/site-packages (from tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (3.10.0) Requirement already satisfied: libclang>=13.0.0 in /opt/conda/lib/python3.10/site-packages (from tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (16.0.6) Requirement already satisfied: ml-dtypes~=0.2.0 in /opt/conda/lib/python3.10/site-packages (from tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (0.2.0) Requirement already satisfied: opt-einsum>=2.3.2 in /opt/conda/lib/python3.10/site-packages (from tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (3.3.0) Requirement already satisfied: packaging in /opt/conda/lib/python3.10/site-packages (from tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (21.3) Requirement already satisfied: protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev,>=3.20.3 in /opt/conda/lib/python3.10/site-packages (from tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (3.20.3) Requirement already satisfied: setuptools in /opt/conda/lib/python3.10/site-packages (from tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (69.0.3) Requirement already satisfied: six>=1.12.0 in /opt/conda/lib/python3.10/site-packages (from tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (1.16.0) Requirement already satisfied: termcolor>=1.1.0 in /opt/conda/lib/python3.10/site-packages (from tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (2.4.0) Requirement already satisfied: typing-extensions>=3.6.6 in /opt/conda/lib/python3.10/site-packages (from tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (4.9.0) Requirement already satisfied: wrapt<1.15,>=1.11.0 in /opt/conda/lib/python3.10/site-packages (from tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (1.14.1) Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /opt/conda/lib/python3.10/site-packages (from tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (0.35.0) Requirement already satisfied: grpcio<2.0,>=1.24.3 in /opt/conda/lib/python3.10/site-packages (from tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (1.51.1) Requirement already satisfied: tensorboard<2.16,>=2.15 in /opt/conda/lib/python3.10/site-packages (from tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (2.15.1) Requirement already satisfied: tensorflow-estimator<2.16,>=2.15.0 in /opt/conda/lib/python3.10/site-packages (from tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (2.15.0) Requirement already satisfied: keras<2.16,>=2.15.0 in /opt/conda/lib/python3.10/site-packages (from tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (2.15.0) Requirement already satisfied: contourpy>=1.0.1 in /opt/conda/lib/python3.10/site-packages (from matplotlib->tensorflow_advanced_segmentation_models==0.4.10) (1.2.0) Requirement already satisfied: cycler>=0.10 in /opt/conda/lib/python3.10/site-packages (from matplotlib->tensorflow_advanced_segmentation_models==0.4.10) (0.12.1) Requirement already satisfied: fonttools>=4.22.0 in /opt/conda/lib/python3.10/site-packages (from matplotlib->tensorflow_advanced_segmentation_models==0.4.10) (4.47.0) Requirement already satisfied: kiwisolver>=1.0.1 in /opt/conda/lib/python3.10/site-packages (from matplotlib->tensorflow_advanced_segmentation_models==0.4.10) (1.4.5) Requirement already satisfied: pillow>=6.2.0 in /opt/conda/lib/python3.10/site-packages (from matplotlib->tensorflow_advanced_segmentation_models==0.4.10) (9.5.0) Requirement already satisfied: pyparsing>=2.3.1 in /opt/conda/lib/python3.10/site-packages (from matplotlib->tensorflow_advanced_segmentation_models==0.4.10) (3.1.1) Requirement already satisfied: python-dateutil>=2.7 in /opt/conda/lib/python3.10/site-packages (from matplotlib->tensorflow_advanced_segmentation_models==0.4.10) (2.8.2) Requirement already satisfied: wheel<1.0,>=0.23.0 in /opt/conda/lib/python3.10/site-packages (from astunparse>=1.6.0->tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (0.42.0) Requirement already satisfied: google-auth<3,>=1.6.3 in /opt/conda/lib/python3.10/site-packages (from tensorboard<2.16,>=2.15->tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (2.26.1) Requirement already satisfied: google-auth-oauthlib<2,>=0.5 in /opt/conda/lib/python3.10/site-packages (from tensorboard<2.16,>=2.15->tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (1.2.0) Requirement already satisfied: markdown>=2.6.8 in /opt/conda/lib/python3.10/site-packages (from tensorboard<2.16,>=2.15->tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (3.5.2) Requirement already satisfied: requests<3,>=2.21.0 in /opt/conda/lib/python3.10/site-packages (from tensorboard<2.16,>=2.15->tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (2.31.0) Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /opt/conda/lib/python3.10/site-packages (from tensorboard<2.16,>=2.15->tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (0.7.2) Requirement already satisfied: werkzeug>=1.0.1 in /opt/conda/lib/python3.10/site-packages (from tensorboard<2.16,>=2.15->tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (3.0.1) Requirement already satisfied: cachetools<6.0,>=2.0.0 in /opt/conda/lib/python3.10/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.16,>=2.15->tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (4.2.4) Requirement already satisfied: pyasn1-modules>=0.2.1 in /opt/conda/lib/python3.10/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.16,>=2.15->tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (0.3.0) Requirement already satisfied: rsa<5,>=3.1.4 in /opt/conda/lib/python3.10/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.16,>=2.15->tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (4.9) Requirement already satisfied: requests-oauthlib>=0.7.0 in /opt/conda/lib/python3.10/site-packages (from google-auth-oauthlib<2,>=0.5->tensorboard<2.16,>=2.15->tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (1.3.1) Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.10/site-packages (from requests<3,>=2.21.0->tensorboard<2.16,>=2.15->tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (3.3.2) Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests<3,>=2.21.0->tensorboard<2.16,>=2.15->tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (3.6) Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests<3,>=2.21.0->tensorboard<2.16,>=2.15->tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (1.26.18) Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests<3,>=2.21.0->tensorboard<2.16,>=2.15->tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (2023.11.17) Requirement already satisfied: MarkupSafe>=2.1.1 in /opt/conda/lib/python3.10/site-packages (from werkzeug>=1.0.1->tensorboard<2.16,>=2.15->tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (2.1.3) Requirement already satisfied: pyasn1<0.6.0,>=0.4.6 in /opt/conda/lib/python3.10/site-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard<2.16,>=2.15->tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (0.5.1) Requirement already satisfied: oauthlib>=3.0.0 in /opt/conda/lib/python3.10/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<2,>=0.5->tensorboard<2.16,>=2.15->tensorflow>=2.0.0->tensorflow_advanced_segmentation_models==0.4.10) (3.2.2) Building wheels for collected packages: tensorflow_advanced_segmentation_models Building wheel for tensorflow_advanced_segmentation_models (setup.py) ... - done Created wheel for tensorflow_advanced_segmentation_models: filename=tensorflow_advanced_segmentation_models-0.4.10-py3-none-any.whl size=70079 sha256=83441477efbaed56da9309fdeb1b466aba2bceef76dd25346aba6db89088a2d0 Stored in directory: /tmp/pip-ephem-wheel-cache-d3_bh6rd/wheels/81/31/ae/80a6d6e86cebbc1f084947259d621128bf26c78448511b1461 Successfully built tensorflow_advanced_segmentation_models Installing collected packages: tensorflow_advanced_segmentation_models Successfully installed tensorflow_advanced_segmentation_models-0.4.10 Note: you may need to restart the kernel to use updated packages.
Po instalaci je možné importovat knihovnu do Python prostředí:
In [7]:
import tensorflow_advanced_segmentation_models as tasm
Vytvoření modelu zabere v podstatě dva řádky kódu. V prvém řádku jsem si vytvořil instanci klasifikačního modelu jako backbone. V mém případě jsem použil model VGG16, ale nenačítám žádné váhy (bude to tedy náhodně inicializovaný model, který budu dále celý trénovat).
Ve druhém řádku pak vytvořím samotný U-Net model. Segmentační masky budou pouze jedné třídy, takže aktivace bude funkcí Sigmoid. A dále ještě musím uvést odkaz na klasifikační model – to jsou ty proměnné base_model a layers. A to je v tom základu vše potřebné.
Ještě jsem vám nechal vykreslit schéma modelu pro lepší představu o výsledku:
In [8]:
BACKBONE_NAME = 'vgg16' WEIGHTS = None HEIGHT, WIDTH = IMAGE_SIZE base_model, layers, layer_names = tasm.create_base_model(name=BACKBONE_NAME, weights=WEIGHTS, height=HEIGHT, width=WIDTH) model = tasm.UNet(n_classes=1, base_model=base_model, output_layers=layers, height=HEIGHT, width=WIDTH, final_activation='sigmoid', backbone_trainable=True).model() plot_model(model, to_file=f'/kaggle/working/model/{model.name}.png',show_shapes=True, show_layer_names=True, expand_nested=True) Image(retina=True, filename=f'/kaggle/working/model/{model.name}.png')
Obvyklá příprava dat:
In [9]:
x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.2) print(f"x_train: {x_train.shape}, y_train: {y_train.shape}") print(f"x_test: {x_test.shape}, y_test: {y_test.shape}") x_train: (3143, 128, 128, 3), y_train: (3143, 128, 128, 1) x_test: (786, 128, 128, 3), y_test: (786, 128, 128, 1)
In [11]:
model.compile(optimizer="adam", loss=bce_dice_loss, metrics=[dice_coefficient, jaccard_index])
In [12]:
MODEL_CHECKPOINT = f"/kaggle/working/model/{model.name}.ckpt" EPOCHS = 100 callbacks_list = [ keras.callbacks.EarlyStopping(monitor='val_dice_coefficient', mode='max', patience=20), keras.callbacks.ModelCheckpoint(filepath=MODEL_CHECKPOINT, monitor='val_dice_coefficient', save_best_only=True, mode='max', verbose=1) ] history = model.fit( x=x_train, y=y_train, epochs=EPOCHS, callbacks=callbacks_list, validation_split=0.2, verbose=1) Epoch 1/100 79/79 [==============================] - ETA: 0s - loss: 0.8577 - dice_coefficient: 0.2310 - jaccard_index: 0.1394 Epoch 1: val_dice_coefficient improved from -inf to 0.48029, saving model to /kaggle/working/model/model_1.ckpt 79/79 [==============================] - 56s 579ms/step - loss: 0.8577 - dice_coefficient: 0.2310 - jaccard_index: 0.1394 - val_loss: 0.5569 - val_dice_coefficient: 0.4803 - val_jaccard_index: 0.3230 Epoch 2/100 79/79 [==============================] - ETA: 0s - loss: 0.5409 - dice_coefficient: 0.4935 - jaccard_index: 0.3354 Epoch 2: val_dice_coefficient improved from 0.48029 to 0.56807, saving model to /kaggle/working/model/model_1.ckpt 79/79 [==============================] - 44s 560ms/step - loss: 0.5409 - dice_coefficient: 0.4935 - jaccard_index: 0.3354 - val_loss: 0.4688 - val_dice_coefficient: 0.5681 - val_jaccard_index: 0.4054 Epoch 3/100 79/79 [==============================] - ETA: 0s - loss: 0.4647 - dice_coefficient: 0.5669 - jaccard_index: 0.4036 Epoch 3: val_dice_coefficient improved from 0.56807 to 0.65534, saving model to /kaggle/working/model/model_1.ckpt 79/79 [==============================] - 44s 560ms/step - loss: 0.4647 - dice_coefficient: 0.5669 - jaccard_index: 0.4036 - val_loss: 0.3744 - val_dice_coefficient: 0.6553 - val_jaccard_index: 0.4945 Epoch 4/100 79/79 [==============================] - ETA: 0s - loss: 0.4040 - dice_coefficient: 0.6275 - jaccard_index: 0.4642 Epoch 4: val_dice_coefficient did not improve from 0.65534 79/79 [==============================] - 38s 481ms/step - loss: 0.4040 - dice_coefficient: 0.6275 - jaccard_index: 0.4642 - val_loss: 0.3907 - val_dice_coefficient: 0.6449 - val_jaccard_index: 0.4825 Epoch 5/100 79/79 [==============================] - ETA: 0s - loss: 0.3982 - dice_coefficient: 0.6313 - jaccard_index: 0.4701 Epoch 5: val_dice_coefficient improved from 0.65534 to 0.67167, saving model to /kaggle/working/model/model_1.ckpt 79/79 [==============================] - 44s 558ms/step - loss: 0.3982 - dice_coefficient: 0.6313 - jaccard_index: 0.4701 - val_loss: 0.3584 - val_dice_coefficient: 0.6717 - val_jaccard_index: 0.5178 Epoch 6/100 79/79 [==============================] - ETA: 0s - loss: 0.3724 - dice_coefficient: 0.6558 - jaccard_index: 0.4969 Epoch 6: val_dice_coefficient improved from 0.67167 to 0.67987, saving model to /kaggle/working/model/model_1.ckpt 79/79 [==============================] - 44s 557ms/step - loss: 0.3724 - dice_coefficient: 0.6558 - jaccard_index: 0.4969 - val_loss: 0.3510 - val_dice_coefficient: 0.6799 - val_jaccard_index: 0.5225 Epoch 7/100 79/79 [==============================] - ETA: 0s - loss: 0.3582 - dice_coefficient: 0.6691 - jaccard_index: 0.5102 Epoch 7: val_dice_coefficient improved from 0.67987 to 0.73302, saving model to /kaggle/working/model/model_1.ckpt 79/79 [==============================] - 44s 558ms/step - loss: 0.3582 - dice_coefficient: 0.6691 - jaccard_index: 0.5102 - val_loss: 0.2931 - val_dice_coefficient: 0.7330 - val_jaccard_index: 0.5856 Epoch 8/100 79/79 [==============================] - ETA: 0s - loss: 0.3823 - dice_coefficient: 0.6455 - jaccard_index: 0.4851 Epoch 8: val_dice_coefficient did not improve from 0.73302 79/79 [==============================] - 36s 450ms/step - loss: 0.3823 - dice_coefficient: 0.6455 - jaccard_index: 0.4851 - val_loss: 0.3242 - val_dice_coefficient: 0.7038 - val_jaccard_index: 0.5485 Epoch 9/100 79/79 [==============================] - ETA: 0s - loss: 0.3468 - dice_coefficient: 0.6786 - jaccard_index: 0.5200 Epoch 9: val_dice_coefficient did not improve from 0.73302 79/79 [==============================] - 36s 450ms/step - loss: 0.3468 - dice_coefficient: 0.6786 - jaccard_index: 0.5200 - val_loss: 0.2941 - val_dice_coefficient: 0.7313 - val_jaccard_index: 0.5835 Epoch 10/100 79/79 [==============================] - ETA: 0s - loss: 0.3455 - dice_coefficient: 0.6823 - jaccard_index: 0.5247 Epoch 10: val_dice_coefficient did not improve from 0.73302 79/79 [==============================] - 35s 449ms/step - loss: 0.3455 - dice_coefficient: 0.6823 - jaccard_index: 0.5247 - val_loss: 0.3045 - val_dice_coefficient: 0.7217 - val_jaccard_index: 0.5713 ... Epoch 86/100 79/79 [==============================] - ETA: 0s - loss: 0.1400 - dice_coefficient: 0.8713 - jaccard_index: 0.7744 Epoch 86: val_dice_coefficient did not improve from 0.87112 79/79 [==============================] - 35s 449ms/step - loss: 0.1400 - dice_coefficient: 0.8713 - jaccard_index: 0.7744 - val_loss: 0.1676 - val_dice_coefficient: 0.8509 - val_jaccard_index: 0.7424 Epoch 87/100 79/79 [==============================] - ETA: 0s - loss: 0.1102 - dice_coefficient: 0.8978 - jaccard_index: 0.8163 Epoch 87: val_dice_coefficient did not improve from 0.87112 79/79 [==============================] - 35s 449ms/step - loss: 0.1102 - dice_coefficient: 0.8978 - jaccard_index: 0.8163 - val_loss: 0.1594 - val_dice_coefficient: 0.8594 - val_jaccard_index: 0.7551 Epoch 88/100 79/79 [==============================] - ETA: 0s - loss: 0.1045 - dice_coefficient: 0.9018 - jaccard_index: 0.8240 Epoch 88: val_dice_coefficient did not improve from 0.87112 79/79 [==============================] - 36s 450ms/step - loss: 0.1045 - dice_coefficient: 0.9018 - jaccard_index: 0.8240 - val_loss: 0.1506 - val_dice_coefficient: 0.8683 - val_jaccard_index: 0.7689 Epoch 89/100 79/79 [==============================] - ETA: 0s - loss: 0.1066 - dice_coefficient: 0.9007 - jaccard_index: 0.8218 Epoch 89: val_dice_coefficient did not improve from 0.87112 79/79 [==============================] - 35s 450ms/step - loss: 0.1066 - dice_coefficient: 0.9007 - jaccard_index: 0.8218 - val_loss: 0.1693 - val_dice_coefficient: 0.8492 - val_jaccard_index: 0.7399 Epoch 90/100 79/79 [==============================] - ETA: 0s - loss: 0.1097 - dice_coefficient: 0.8978 - jaccard_index: 0.8168 Epoch 90: val_dice_coefficient did not improve from 0.87112 79/79 [==============================] - 35s 449ms/step - loss: 0.1097 - dice_coefficient: 0.8978 - jaccard_index: 0.8168 - val_loss: 0.1559 - val_dice_coefficient: 0.8622 - val_jaccard_index: 0.7592 Epoch 91/100 79/79 [==============================] - ETA: 0s - loss: 0.1056 - dice_coefficient: 0.9004 - jaccard_index: 0.8232 Epoch 91: val_dice_coefficient did not improve from 0.87112 79/79 [==============================] - 35s 449ms/step - loss: 0.1056 - dice_coefficient: 0.9004 - jaccard_index: 0.8232 - val_loss: 0.1478 - val_dice_coefficient: 0.8693 - val_jaccard_index: 0.7704 Epoch 92/100 79/79 [==============================] - ETA: 0s - loss: 0.2382 - dice_coefficient: 0.7839 - jaccard_index: 0.6576 Epoch 92: val_dice_coefficient did not improve from 0.87112 79/79 [==============================] - 35s 450ms/step - loss: 0.2382 - dice_coefficient: 0.7839 - jaccard_index: 0.6576 - val_loss: 0.2808 - val_dice_coefficient: 0.7477 - val_jaccard_index: 0.5999 Epoch 93/100 79/79 [==============================] - ETA: 0s - loss: 0.2632 - dice_coefficient: 0.7616 - jaccard_index: 0.6214 Epoch 93: val_dice_coefficient did not improve from 0.87112 79/79 [==============================] - 36s 450ms/step - loss: 0.2632 - dice_coefficient: 0.7616 - jaccard_index: 0.6214 - val_loss: 0.2210 - val_dice_coefficient: 0.8013 - val_jaccard_index: 0.6716
Z průběhu trénování je zřejmé, že jsem se již začal dostávat na hranu možností modelu co se týče výsledků validace.
In [13]:
fig, ax = plt.subplots(1, 2, figsize=(16, 4)) sns.lineplot(data={k: history.history[k] for k in ('loss', 'val_loss')}, ax=ax[0]) sns.lineplot(data={k: history.history[k] for k in history.history.keys() if k not in ('loss', 'val_loss')}, ax=ax[1]) plt.show()
In [14]:
model.load_weights(f"/kaggle/working/model/{model.name}.ckpt")
Out[14]:
<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x7d33381abd90>
Zde se omezím jen na to základní, aby bylo zřejmé, že se síť něco skutečně naučila …
In [15]:
y_pred = model.predict(x_test) y_pred = (y_pred > 0.5).astype(np.float64) 25/25 [==============================] - 4s 133ms/step
In [16]:
for _ in range(20): i = np.random.randint(len(y_test)) if y_test[i].sum() > 0: plt.figure(figsize=(8, 8)) plt.subplot(1,3,1) plt.imshow(x_test[i]) plt.title('Original Image') plt.subplot(1,3,2) plt.imshow(y_test[i]) plt.title('Original Mask') plt.subplot(1,3,3) plt.imshow(y_pred[i]) plt.title('Prediction') plt.show()
pracuje na pozici IT architekta. Poslední roky se zaměřuje na integrační a komunikační projekty ve zdravotnictví. Mezi jeho koníčky patří také paragliding a jízda na horském kole.
Přečteno 25 727×
Přečteno 25 724×
Přečteno 25 392×
Přečteno 23 617×
Přečteno 19 355×