Rozpoznání zápalu plic z RTG snímků - díl třetí

15. 3. 2024 0:00 Jiří Raška

Navazuji na předchozí dva články: Rozpoznání zápalu plic z RTG snímků – díl první a Rozpoznání zápalu plic z RTG snímků – díl druhý .


ResNet50 Model

Jedná se o jeden z nejúspěšnějších a také nejoblíbenějších modelů pro klasifikaci postavených na konvolučních vrstvách. Jeho autorům se podařilo vyřešit problém s degradací gradientu u modelů s hodně vrstvami. Jejich řešením je doplnění „zkratky“ pro data postupující sítí. Existuje mnoho dokumentů popisujících princip tohoto modelu a také vysvětlujících, proč to vlastně funguje. Jedním může být například tento: Detailed Explanation of Resnet CNN Model.

Jen pro orientační představu, takto je schematicky zachycen model ReNet50:

resnet50

ResNet50 trénovaný z čistého stolu

Rovnou ukážu funkci implementující model, a pak několik poznámek:

In [19]:


def create_model_ResNet50(X_shape, classes=2, name="ResNet50"):

    def mlp(x, hidden_units, activation='relu', dropout_rate=0.3, name=""):
        for i, units in enumerate(hidden_units):
            x = layers.Dense(units, activation=activation, name=f"{name}_{i}_dense")(x)
            x = layers.Dropout(dropout_rate, name=f"{name}_{i}_dropout")(x)
        return x

    def conv_block(x, *, filters, kernel_size, strides=(1, 1), padding='same', activation='relu', name=""):
        x = layers.Conv2D(filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, name=f"{name}_conv")(x)
        x = layers.BatchNormalization(name=f"{name}_norm")(x)
        if activation:
            x = layers.Activation(activation, name=f"{name}_actn")(x)
        return x

    def identity_block(x, *, filters, name=""):
        shortcut = x
        x = conv_block(x, filters=filters, kernel_size=(1, 1), name=f"{name}_cb1")
        x = conv_block(x, filters=filters, kernel_size=(3, 3), name=f"{name}_cb2")
        x = conv_block(x, filters=filters * 4, kernel_size=(1, 1), activation='', name=f"{name}_cb3")
        x = layers.Add(name=f"{name}_add")([x, shortcut])
        x = layers.Activation('relu', name=f"{name}_actn")(x)
        return x

    def projection_block(x, *, filters, strides, name=""):
        shortcut = x
        x = conv_block(x, filters=filters, kernel_size=(1, 1), strides=strides, name=f"{name}_cb1")
        x = conv_block(x, filters=filters, kernel_size=(3, 3), name=f"{name}_cb2")
        x = conv_block(x, filters=filters * 4, kernel_size=(1, 1), activation='', name=f"{name}_cb3")
        shortcut = conv_block(shortcut, filters=filters * 4, kernel_size=(1, 1), strides=strides, activation='', name=f"{name}_cb4")
        x = layers.Add(name=f"{name}_add")([x, shortcut])
        x = layers.Activation('relu', name=f"{name}_actn")(x)
        return x

    inputs = Input(X_shape[-3:], name='inputs')

    # === Stage 1 ===
    x = conv_block(inputs, filters=64, kernel_size=(7, 7), strides=(2, 2), padding='same', name="stg1_cb1")
    x = layers.MaxPooling2D(pool_size=(3, 3), strides=(2, 2), padding='same', name="stg1_maxpool")(x)

    # === Stage 2 ===
    x = projection_block(x, filters=64, strides=(1, 1), name="stg2_pb")
    x = identity_block(x, filters=64, name="stg2_ib1")
    x = identity_block(x, filters=64, name="stg2_ib2")

    # === Stage 3 ===
    x = projection_block(x, filters=128, strides=(2, 2), name="stg3_pb")
    x = identity_block(x, filters=128, name="stg3_ib1")
    x = identity_block(x, filters=128, name="stg3_ib2")
    x = identity_block(x, filters=128, name="stg3_ib3")

    # === Stage 4 ===
    x = projection_block(x, filters=256, strides=(2, 2), name="stg4_pb")
    x = identity_block(x, filters=256, name="stg4_ib1")
    x = identity_block(x, filters=256, name="stg4_ib2")
    x = identity_block(x, filters=256, name="stg4_ib3")
    x = identity_block(x, filters=256, name="stg4_ib4")
    x = identity_block(x, filters=256, name="stg4_ib5")

    # === Stage 5 ===
    x = projection_block(x, filters=512, strides=(2, 2), name="stg5_pb")
    x = identity_block(x, filters=512, name="stg5_ib1")
    x = identity_block(x, filters=512, name="stg5_ib2")

    x = layers.GlobalAveragePooling2D(name=f"stg5_globaver")(x)

    x = mlp(x, (1024, 512), name="dense")
    outputs = layers.Dense(classes, activation='softmax', name='outputs')(x)

    return Model(inputs=inputs, outputs=outputs, name=name)

Model je postupně skládán z tzv. „projection block“ a „identity block“. Rozdíl mezi nimi je v tom, že „projection block“ provádí redukci plošných dimenzí a to s využitím změny kroku konvoluce (to je ten parametr strides). No a aby bylo možné na konci bloku sečíst tensory běžící konvolučními vrstvami se zkratkou, je potřeba doplnit ještě jednu konvoluční vrstvu do té zkratky. No a to je ten rozdíl.

Oba typy bloků se pak vrství do tzv. Etap, ale princi zůstává pro všechny klasifikační modely stejný – redukují se plošné dimenze a roste dimenze vlastností.

Za upozornění stojí ještě intenzivní využívání konvolučních vrstev s velikostí kernelu (1, 1). To vypadá na první pohled poněkud zvláštně. Jedním z důvodů jejich použití je rozšíření dimenze vlastností bez zásahu do dimenzí plošných, a také doplnění další nelinearity do modelu.

Závěr modelu je již jako obvykle zajištěn multi-layer perceptron a poslední klasifikační vrstvou.

Před samotným zkoušením si opět připravím dat. Vzhledem k tomu, že trénuji celý model sám, použiji černobílé obrázky:

In [20]:


x_train, x_valid, y_train, y_valid = train_test_split(*get_datasource(DATA_TRAIN, DATA_VALID), test_size=0.2)
x_test, y_test = get_datasource(DATA_TEST)

x_train = np.expand_dims(x_train, axis=-1)
x_valid = np.expand_dims(x_valid, axis=-1)
x_test  = np.expand_dims(x_test, axis=-1)

datagen = ImageDataGenerator(
        rotation_range = 30,
        zoom_range = 0.2,
        width_shift_range=0.1,
        height_shift_range=0.1,
        horizontal_flip = True,
        vertical_flip=False)

datagen.fit(x_train)

A nyní již samotné vyhodnocení modelu. Je vidět, že těch vrstev je tam skutečně hodně:

In [21]:


evaluate_model(create_model_ResNet50(x_train.shape, 2), forced_training=True)
=== MODEL EVALUATION =================================================

Model: "ResNet50"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to
==================================================================================================
 inputs (InputLayer)         [(None, 224, 224, 1)]        0         []
 stg1_cb1_conv (Conv2D)      (None, 112, 112, 64)         3200      ['inputs[0][0]']
 stg1_cb1_norm (BatchNormal  (None, 112, 112, 64)         256       ['stg1_cb1_conv[0][0]']
 ization)
 stg1_cb1_actn (Activation)  (None, 112, 112, 64)         0         ['stg1_cb1_norm[0][0]']
 stg1_maxpool (MaxPooling2D  (None, 56, 56, 64)           0         ['stg1_cb1_actn[0][0]']
 )
 stg2_pb_cb1_conv (Conv2D)   (None, 56, 56, 64)           4160      ['stg1_maxpool[0][0]']
 stg2_pb_cb1_norm (BatchNor  (None, 56, 56, 64)           256       ['stg2_pb_cb1_conv[0][0]']
 malization)
 stg2_pb_cb1_actn (Activati  (None, 56, 56, 64)           0         ['stg2_pb_cb1_norm[0][0]']
 on)
 stg2_pb_cb2_conv (Conv2D)   (None, 56, 56, 64)           36928     ['stg2_pb_cb1_actn[0][0]']
 stg2_pb_cb2_norm (BatchNor  (None, 56, 56, 64)           256       ['stg2_pb_cb2_conv[0][0]']
 malization)
 stg2_pb_cb2_actn (Activati  (None, 56, 56, 64)           0         ['stg2_pb_cb2_norm[0][0]']
 on)
 stg2_pb_cb3_conv (Conv2D)   (None, 56, 56, 256)          16640     ['stg2_pb_cb2_actn[0][0]']
 stg2_pb_cb4_conv (Conv2D)   (None, 56, 56, 256)          16640     ['stg1_maxpool[0][0]']
 stg2_pb_cb3_norm (BatchNor  (None, 56, 56, 256)          1024      ['stg2_pb_cb3_conv[0][0]']
 malization)
 stg2_pb_cb4_norm (BatchNor  (None, 56, 56, 256)          1024      ['stg2_pb_cb4_conv[0][0]']
 malization)
 stg2_pb_add (Add)           (None, 56, 56, 256)          0         ['stg2_pb_cb3_norm[0][0]',
                                                                     'stg2_pb_cb4_norm[0][0]']
 stg2_pb_actn (Activation)   (None, 56, 56, 256)          0         ['stg2_pb_add[0][0]']
 stg2_ib1_cb1_conv (Conv2D)  (None, 56, 56, 64)           16448     ['stg2_pb_actn[0][0]']
 stg2_ib1_cb1_norm (BatchNo  (None, 56, 56, 64)           256       ['stg2_ib1_cb1_conv[0][0]']
 rmalization)
 stg2_ib1_cb1_actn (Activat  (None, 56, 56, 64)           0         ['stg2_ib1_cb1_norm[0][0]']
 ion)
 stg2_ib1_cb2_conv (Conv2D)  (None, 56, 56, 64)           36928     ['stg2_ib1_cb1_actn[0][0]']
 stg2_ib1_cb2_norm (BatchNo  (None, 56, 56, 64)           256       ['stg2_ib1_cb2_conv[0][0]']
 rmalization)
 stg2_ib1_cb2_actn (Activat  (None, 56, 56, 64)           0         ['stg2_ib1_cb2_norm[0][0]']
 ion)
 stg2_ib1_cb3_conv (Conv2D)  (None, 56, 56, 256)          16640     ['stg2_ib1_cb2_actn[0][0]']
 stg2_ib1_cb3_norm (BatchNo  (None, 56, 56, 256)          1024      ['stg2_ib1_cb3_conv[0][0]']
 rmalization)
 stg2_ib1_add (Add)          (None, 56, 56, 256)          0         ['stg2_ib1_cb3_norm[0][0]',
                                                                     'stg2_pb_actn[0][0]']
 stg2_ib1_actn (Activation)  (None, 56, 56, 256)          0         ['stg2_ib1_add[0][0]']
 stg2_ib2_cb1_conv (Conv2D)  (None, 56, 56, 64)           16448     ['stg2_ib1_actn[0][0]']
 stg2_ib2_cb1_norm (BatchNo  (None, 56, 56, 64)           256       ['stg2_ib2_cb1_conv[0][0]']
 rmalization)
 stg2_ib2_cb1_actn (Activat  (None, 56, 56, 64)           0         ['stg2_ib2_cb1_norm[0][0]']
 ion)
 stg2_ib2_cb2_conv (Conv2D)  (None, 56, 56, 64)           36928     ['stg2_ib2_cb1_actn[0][0]']
 stg2_ib2_cb2_norm (BatchNo  (None, 56, 56, 64)           256       ['stg2_ib2_cb2_conv[0][0]']
 rmalization)
 stg2_ib2_cb2_actn (Activat  (None, 56, 56, 64)           0         ['stg2_ib2_cb2_norm[0][0]']
 ion)
 stg2_ib2_cb3_conv (Conv2D)  (None, 56, 56, 256)          16640     ['stg2_ib2_cb2_actn[0][0]']
 stg2_ib2_cb3_norm (BatchNo  (None, 56, 56, 256)          1024      ['stg2_ib2_cb3_conv[0][0]']
 rmalization)
 stg2_ib2_add (Add)          (None, 56, 56, 256)          0         ['stg2_ib2_cb3_norm[0][0]',
                                                                     'stg2_ib1_actn[0][0]']
 stg2_ib2_actn (Activation)  (None, 56, 56, 256)          0         ['stg2_ib2_add[0][0]']
 stg3_pb_cb1_conv (Conv2D)   (None, 28, 28, 128)          32896     ['stg2_ib2_actn[0][0]']
 stg3_pb_cb1_norm (BatchNor  (None, 28, 28, 128)          512       ['stg3_pb_cb1_conv[0][0]']
 malization)
 stg3_pb_cb1_actn (Activati  (None, 28, 28, 128)          0         ['stg3_pb_cb1_norm[0][0]']
 on)
 stg3_pb_cb2_conv (Conv2D)   (None, 28, 28, 128)          147584    ['stg3_pb_cb1_actn[0][0]']
 stg3_pb_cb2_norm (BatchNor  (None, 28, 28, 128)          512       ['stg3_pb_cb2_conv[0][0]']
 malization)
 stg3_pb_cb2_actn (Activati  (None, 28, 28, 128)          0         ['stg3_pb_cb2_norm[0][0]']
 on)
 stg3_pb_cb3_conv (Conv2D)   (None, 28, 28, 512)          66048     ['stg3_pb_cb2_actn[0][0]']
 stg3_pb_cb4_conv (Conv2D)   (None, 28, 28, 512)          131584    ['stg2_ib2_actn[0][0]']
 stg3_pb_cb3_norm (BatchNor  (None, 28, 28, 512)          2048      ['stg3_pb_cb3_conv[0][0]']
 malization)
 stg3_pb_cb4_norm (BatchNor  (None, 28, 28, 512)          2048      ['stg3_pb_cb4_conv[0][0]']
 malization)
 stg3_pb_add (Add)           (None, 28, 28, 512)          0         ['stg3_pb_cb3_norm[0][0]',
                                                                     'stg3_pb_cb4_norm[0][0]']
 stg3_pb_actn (Activation)   (None, 28, 28, 512)          0         ['stg3_pb_add[0][0]']
 stg3_ib1_cb1_conv (Conv2D)  (None, 28, 28, 128)          65664     ['stg3_pb_actn[0][0]']
 stg3_ib1_cb1_norm (BatchNo  (None, 28, 28, 128)          512       ['stg3_ib1_cb1_conv[0][0]']
 rmalization)
 stg3_ib1_cb1_actn (Activat  (None, 28, 28, 128)          0         ['stg3_ib1_cb1_norm[0][0]']
 ion)
 stg3_ib1_cb2_conv (Conv2D)  (None, 28, 28, 128)          147584    ['stg3_ib1_cb1_actn[0][0]']
 stg3_ib1_cb2_norm (BatchNo  (None, 28, 28, 128)          512       ['stg3_ib1_cb2_conv[0][0]']
 rmalization)
 stg3_ib1_cb2_actn (Activat  (None, 28, 28, 128)          0         ['stg3_ib1_cb2_norm[0][0]']
 ion)
 stg3_ib1_cb3_conv (Conv2D)  (None, 28, 28, 512)          66048     ['stg3_ib1_cb2_actn[0][0]']
 stg3_ib1_cb3_norm (BatchNo  (None, 28, 28, 512)          2048      ['stg3_ib1_cb3_conv[0][0]']
 rmalization)
 stg3_ib1_add (Add)          (None, 28, 28, 512)          0         ['stg3_ib1_cb3_norm[0][0]',
                                                                     'stg3_pb_actn[0][0]']
 stg3_ib1_actn (Activation)  (None, 28, 28, 512)          0         ['stg3_ib1_add[0][0]']
 stg3_ib2_cb1_conv (Conv2D)  (None, 28, 28, 128)          65664     ['stg3_ib1_actn[0][0]']
 stg3_ib2_cb1_norm (BatchNo  (None, 28, 28, 128)          512       ['stg3_ib2_cb1_conv[0][0]']
 rmalization)
 stg3_ib2_cb1_actn (Activat  (None, 28, 28, 128)          0         ['stg3_ib2_cb1_norm[0][0]']
 ion)
 stg3_ib2_cb2_conv (Conv2D)  (None, 28, 28, 128)          147584    ['stg3_ib2_cb1_actn[0][0]']
 stg3_ib2_cb2_norm (BatchNo  (None, 28, 28, 128)          512       ['stg3_ib2_cb2_conv[0][0]']
 rmalization)
 stg3_ib2_cb2_actn (Activat  (None, 28, 28, 128)          0         ['stg3_ib2_cb2_norm[0][0]']
 ion)
 stg3_ib2_cb3_conv (Conv2D)  (None, 28, 28, 512)          66048     ['stg3_ib2_cb2_actn[0][0]']
 stg3_ib2_cb3_norm (BatchNo  (None, 28, 28, 512)          2048      ['stg3_ib2_cb3_conv[0][0]']
 rmalization)
 stg3_ib2_add (Add)          (None, 28, 28, 512)          0         ['stg3_ib2_cb3_norm[0][0]',
                                                                     'stg3_ib1_actn[0][0]']
 stg3_ib2_actn (Activation)  (None, 28, 28, 512)          0         ['stg3_ib2_add[0][0]']
 stg3_ib3_cb1_conv (Conv2D)  (None, 28, 28, 128)          65664     ['stg3_ib2_actn[0][0]']
 stg3_ib3_cb1_norm (BatchNo  (None, 28, 28, 128)          512       ['stg3_ib3_cb1_conv[0][0]']
 rmalization)
 stg3_ib3_cb1_actn (Activat  (None, 28, 28, 128)          0         ['stg3_ib3_cb1_norm[0][0]']
 ion)
 stg3_ib3_cb2_conv (Conv2D)  (None, 28, 28, 128)          147584    ['stg3_ib3_cb1_actn[0][0]']
 stg3_ib3_cb2_norm (BatchNo  (None, 28, 28, 128)          512       ['stg3_ib3_cb2_conv[0][0]']
 rmalization)
 stg3_ib3_cb2_actn (Activat  (None, 28, 28, 128)          0         ['stg3_ib3_cb2_norm[0][0]']
 ion)
 stg3_ib3_cb3_conv (Conv2D)  (None, 28, 28, 512)          66048     ['stg3_ib3_cb2_actn[0][0]']
 stg3_ib3_cb3_norm (BatchNo  (None, 28, 28, 512)          2048      ['stg3_ib3_cb3_conv[0][0]']
 rmalization)
 stg3_ib3_add (Add)          (None, 28, 28, 512)          0         ['stg3_ib3_cb3_norm[0][0]',
                                                                     'stg3_ib2_actn[0][0]']
 stg3_ib3_actn (Activation)  (None, 28, 28, 512)          0         ['stg3_ib3_add[0][0]']
 stg4_pb_cb1_conv (Conv2D)   (None, 14, 14, 256)          131328    ['stg3_ib3_actn[0][0]']
 stg4_pb_cb1_norm (BatchNor  (None, 14, 14, 256)          1024      ['stg4_pb_cb1_conv[0][0]']
 malization)
 stg4_pb_cb1_actn (Activati  (None, 14, 14, 256)          0         ['stg4_pb_cb1_norm[0][0]']
 on)
 stg4_pb_cb2_conv (Conv2D)   (None, 14, 14, 256)          590080    ['stg4_pb_cb1_actn[0][0]']
 stg4_pb_cb2_norm (BatchNor  (None, 14, 14, 256)          1024      ['stg4_pb_cb2_conv[0][0]']
 malization)
 stg4_pb_cb2_actn (Activati  (None, 14, 14, 256)          0         ['stg4_pb_cb2_norm[0][0]']
 on)
 stg4_pb_cb3_conv (Conv2D)   (None, 14, 14, 1024)         263168    ['stg4_pb_cb2_actn[0][0]']
 stg4_pb_cb4_conv (Conv2D)   (None, 14, 14, 1024)         525312    ['stg3_ib3_actn[0][0]']
 stg4_pb_cb3_norm (BatchNor  (None, 14, 14, 1024)         4096      ['stg4_pb_cb3_conv[0][0]']
 malization)
 stg4_pb_cb4_norm (BatchNor  (None, 14, 14, 1024)         4096      ['stg4_pb_cb4_conv[0][0]']
 malization)
 stg4_pb_add (Add)           (None, 14, 14, 1024)         0         ['stg4_pb_cb3_norm[0][0]',
                                                                     'stg4_pb_cb4_norm[0][0]']
 stg4_pb_actn (Activation)   (None, 14, 14, 1024)         0         ['stg4_pb_add[0][0]']
 stg4_ib1_cb1_conv (Conv2D)  (None, 14, 14, 256)          262400    ['stg4_pb_actn[0][0]']
 stg4_ib1_cb1_norm (BatchNo  (None, 14, 14, 256)          1024      ['stg4_ib1_cb1_conv[0][0]']
 rmalization)
 stg4_ib1_cb1_actn (Activat  (None, 14, 14, 256)          0         ['stg4_ib1_cb1_norm[0][0]']
 ion)
 stg4_ib1_cb2_conv (Conv2D)  (None, 14, 14, 256)          590080    ['stg4_ib1_cb1_actn[0][0]']
 stg4_ib1_cb2_norm (BatchNo  (None, 14, 14, 256)          1024      ['stg4_ib1_cb2_conv[0][0]']
 rmalization)
 stg4_ib1_cb2_actn (Activat  (None, 14, 14, 256)          0         ['stg4_ib1_cb2_norm[0][0]']
 ion)
 stg4_ib1_cb3_conv (Conv2D)  (None, 14, 14, 1024)         263168    ['stg4_ib1_cb2_actn[0][0]']
 stg4_ib1_cb3_norm (BatchNo  (None, 14, 14, 1024)         4096      ['stg4_ib1_cb3_conv[0][0]']
 rmalization)
 stg4_ib1_add (Add)          (None, 14, 14, 1024)         0         ['stg4_ib1_cb3_norm[0][0]',
                                                                     'stg4_pb_actn[0][0]']
 stg4_ib1_actn (Activation)  (None, 14, 14, 1024)         0         ['stg4_ib1_add[0][0]']
 stg4_ib2_cb1_conv (Conv2D)  (None, 14, 14, 256)          262400    ['stg4_ib1_actn[0][0]']
 stg4_ib2_cb1_norm (BatchNo  (None, 14, 14, 256)          1024      ['stg4_ib2_cb1_conv[0][0]']
 rmalization)
 stg4_ib2_cb1_actn (Activat  (None, 14, 14, 256)          0         ['stg4_ib2_cb1_norm[0][0]']
 ion)
 stg4_ib2_cb2_conv (Conv2D)  (None, 14, 14, 256)          590080    ['stg4_ib2_cb1_actn[0][0]']
 stg4_ib2_cb2_norm (BatchNo  (None, 14, 14, 256)          1024      ['stg4_ib2_cb2_conv[0][0]']
 rmalization)
 stg4_ib2_cb2_actn (Activat  (None, 14, 14, 256)          0         ['stg4_ib2_cb2_norm[0][0]']
 ion)
 stg4_ib2_cb3_conv (Conv2D)  (None, 14, 14, 1024)         263168    ['stg4_ib2_cb2_actn[0][0]']
 stg4_ib2_cb3_norm (BatchNo  (None, 14, 14, 1024)         4096      ['stg4_ib2_cb3_conv[0][0]']
 rmalization)
 stg4_ib2_add (Add)          (None, 14, 14, 1024)         0         ['stg4_ib2_cb3_norm[0][0]',
                                                                     'stg4_ib1_actn[0][0]']
 stg4_ib2_actn (Activation)  (None, 14, 14, 1024)         0         ['stg4_ib2_add[0][0]']
 stg4_ib3_cb1_conv (Conv2D)  (None, 14, 14, 256)          262400    ['stg4_ib2_actn[0][0]']
 stg4_ib3_cb1_norm (BatchNo  (None, 14, 14, 256)          1024      ['stg4_ib3_cb1_conv[0][0]']
 rmalization)
 stg4_ib3_cb1_actn (Activat  (None, 14, 14, 256)          0         ['stg4_ib3_cb1_norm[0][0]']
 ion)
 stg4_ib3_cb2_conv (Conv2D)  (None, 14, 14, 256)          590080    ['stg4_ib3_cb1_actn[0][0]']
 stg4_ib3_cb2_norm (BatchNo  (None, 14, 14, 256)          1024      ['stg4_ib3_cb2_conv[0][0]']
 rmalization)
 stg4_ib3_cb2_actn (Activat  (None, 14, 14, 256)          0         ['stg4_ib3_cb2_norm[0][0]']
 ion)
 stg4_ib3_cb3_conv (Conv2D)  (None, 14, 14, 1024)         263168    ['stg4_ib3_cb2_actn[0][0]']
 stg4_ib3_cb3_norm (BatchNo  (None, 14, 14, 1024)         4096      ['stg4_ib3_cb3_conv[0][0]']
 rmalization)
 stg4_ib3_add (Add)          (None, 14, 14, 1024)         0         ['stg4_ib3_cb3_norm[0][0]',
                                                                     'stg4_ib2_actn[0][0]']
 stg4_ib3_actn (Activation)  (None, 14, 14, 1024)         0         ['stg4_ib3_add[0][0]']
 stg4_ib4_cb1_conv (Conv2D)  (None, 14, 14, 256)          262400    ['stg4_ib3_actn[0][0]']
 stg4_ib4_cb1_norm (BatchNo  (None, 14, 14, 256)          1024      ['stg4_ib4_cb1_conv[0][0]']
 rmalization)
 stg4_ib4_cb1_actn (Activat  (None, 14, 14, 256)          0         ['stg4_ib4_cb1_norm[0][0]']
 ion)
 stg4_ib4_cb2_conv (Conv2D)  (None, 14, 14, 256)          590080    ['stg4_ib4_cb1_actn[0][0]']
 stg4_ib4_cb2_norm (BatchNo  (None, 14, 14, 256)          1024      ['stg4_ib4_cb2_conv[0][0]']
 rmalization)
 stg4_ib4_cb2_actn (Activat  (None, 14, 14, 256)          0         ['stg4_ib4_cb2_norm[0][0]']
 ion)
 stg4_ib4_cb3_conv (Conv2D)  (None, 14, 14, 1024)         263168    ['stg4_ib4_cb2_actn[0][0]']
 stg4_ib4_cb3_norm (BatchNo  (None, 14, 14, 1024)         4096      ['stg4_ib4_cb3_conv[0][0]']
 rmalization)
 stg4_ib4_add (Add)          (None, 14, 14, 1024)         0         ['stg4_ib4_cb3_norm[0][0]',
                                                                     'stg4_ib3_actn[0][0]']
 stg4_ib4_actn (Activation)  (None, 14, 14, 1024)         0         ['stg4_ib4_add[0][0]']
 stg4_ib5_cb1_conv (Conv2D)  (None, 14, 14, 256)          262400    ['stg4_ib4_actn[0][0]']
 stg4_ib5_cb1_norm (BatchNo  (None, 14, 14, 256)          1024      ['stg4_ib5_cb1_conv[0][0]']
 rmalization)
 stg4_ib5_cb1_actn (Activat  (None, 14, 14, 256)          0         ['stg4_ib5_cb1_norm[0][0]']
 ion)
 stg4_ib5_cb2_conv (Conv2D)  (None, 14, 14, 256)          590080    ['stg4_ib5_cb1_actn[0][0]']
 stg4_ib5_cb2_norm (BatchNo  (None, 14, 14, 256)          1024      ['stg4_ib5_cb2_conv[0][0]']
 rmalization)
 stg4_ib5_cb2_actn (Activat  (None, 14, 14, 256)          0         ['stg4_ib5_cb2_norm[0][0]']
 ion)
 stg4_ib5_cb3_conv (Conv2D)  (None, 14, 14, 1024)         263168    ['stg4_ib5_cb2_actn[0][0]']
 stg4_ib5_cb3_norm (BatchNo  (None, 14, 14, 1024)         4096      ['stg4_ib5_cb3_conv[0][0]']
 rmalization)
 stg4_ib5_add (Add)          (None, 14, 14, 1024)         0         ['stg4_ib5_cb3_norm[0][0]',
                                                                     'stg4_ib4_actn[0][0]']
 stg4_ib5_actn (Activation)  (None, 14, 14, 1024)         0         ['stg4_ib5_add[0][0]']
 stg5_pb_cb1_conv (Conv2D)   (None, 7, 7, 512)            524800    ['stg4_ib5_actn[0][0]']
 stg5_pb_cb1_norm (BatchNor  (None, 7, 7, 512)            2048      ['stg5_pb_cb1_conv[0][0]']
 malization)
 stg5_pb_cb1_actn (Activati  (None, 7, 7, 512)            0         ['stg5_pb_cb1_norm[0][0]']
 on)
 stg5_pb_cb2_conv (Conv2D)   (None, 7, 7, 512)            2359808   ['stg5_pb_cb1_actn[0][0]']
 stg5_pb_cb2_norm (BatchNor  (None, 7, 7, 512)            2048      ['stg5_pb_cb2_conv[0][0]']
 malization)
 stg5_pb_cb2_actn (Activati  (None, 7, 7, 512)            0         ['stg5_pb_cb2_norm[0][0]']
 on)
 stg5_pb_cb3_conv (Conv2D)   (None, 7, 7, 2048)           1050624   ['stg5_pb_cb2_actn[0][0]']
 stg5_pb_cb4_conv (Conv2D)   (None, 7, 7, 2048)           2099200   ['stg4_ib5_actn[0][0]']
 stg5_pb_cb3_norm (BatchNor  (None, 7, 7, 2048)           8192      ['stg5_pb_cb3_conv[0][0]']
 malization)
 stg5_pb_cb4_norm (BatchNor  (None, 7, 7, 2048)           8192      ['stg5_pb_cb4_conv[0][0]']
 malization)
 stg5_pb_add (Add)           (None, 7, 7, 2048)           0         ['stg5_pb_cb3_norm[0][0]',
                                                                     'stg5_pb_cb4_norm[0][0]']
 stg5_pb_actn (Activation)   (None, 7, 7, 2048)           0         ['stg5_pb_add[0][0]']
 stg5_ib1_cb1_conv (Conv2D)  (None, 7, 7, 512)            1049088   ['stg5_pb_actn[0][0]']
 stg5_ib1_cb1_norm (BatchNo  (None, 7, 7, 512)            2048      ['stg5_ib1_cb1_conv[0][0]']
 rmalization)
 stg5_ib1_cb1_actn (Activat  (None, 7, 7, 512)            0         ['stg5_ib1_cb1_norm[0][0]']
 ion)
 stg5_ib1_cb2_conv (Conv2D)  (None, 7, 7, 512)            2359808   ['stg5_ib1_cb1_actn[0][0]']
 stg5_ib1_cb2_norm (BatchNo  (None, 7, 7, 512)            2048      ['stg5_ib1_cb2_conv[0][0]']
 rmalization)
 stg5_ib1_cb2_actn (Activat  (None, 7, 7, 512)            0         ['stg5_ib1_cb2_norm[0][0]']
 ion)
 stg5_ib1_cb3_conv (Conv2D)  (None, 7, 7, 2048)           1050624   ['stg5_ib1_cb2_actn[0][0]']
 stg5_ib1_cb3_norm (BatchNo  (None, 7, 7, 2048)           8192      ['stg5_ib1_cb3_conv[0][0]']
 rmalization)
 stg5_ib1_add (Add)          (None, 7, 7, 2048)           0         ['stg5_ib1_cb3_norm[0][0]',
                                                                     'stg5_pb_actn[0][0]']
 stg5_ib1_actn (Activation)  (None, 7, 7, 2048)           0         ['stg5_ib1_add[0][0]']
 stg5_ib2_cb1_conv (Conv2D)  (None, 7, 7, 512)            1049088   ['stg5_ib1_actn[0][0]']
 stg5_ib2_cb1_norm (BatchNo  (None, 7, 7, 512)            2048      ['stg5_ib2_cb1_conv[0][0]']
 rmalization)
 stg5_ib2_cb1_actn (Activat  (None, 7, 7, 512)            0         ['stg5_ib2_cb1_norm[0][0]']
 ion)
 stg5_ib2_cb2_conv (Conv2D)  (None, 7, 7, 512)            2359808   ['stg5_ib2_cb1_actn[0][0]']
 stg5_ib2_cb2_norm (BatchNo  (None, 7, 7, 512)            2048      ['stg5_ib2_cb2_conv[0][0]']
 rmalization)
 stg5_ib2_cb2_actn (Activat  (None, 7, 7, 512)            0         ['stg5_ib2_cb2_norm[0][0]']
 ion)
 stg5_ib2_cb3_conv (Conv2D)  (None, 7, 7, 2048)           1050624   ['stg5_ib2_cb2_actn[0][0]']
 stg5_ib2_cb3_norm (BatchNo  (None, 7, 7, 2048)           8192      ['stg5_ib2_cb3_conv[0][0]']
 rmalization)
 stg5_ib2_add (Add)          (None, 7, 7, 2048)           0         ['stg5_ib2_cb3_norm[0][0]',
                                                                     'stg5_ib1_actn[0][0]']
 stg5_ib2_actn (Activation)  (None, 7, 7, 2048)           0         ['stg5_ib2_add[0][0]']
 stg5_globaver (GlobalAvera  (None, 2048)                 0         ['stg5_ib2_actn[0][0]']
 gePooling2D)
 dense_0_dense (Dense)       (None, 1024)                 2098176   ['stg5_globaver[0][0]']
 dense_0_dropout (Dropout)   (None, 1024)                 0         ['dense_0_dense[0][0]']
 dense_1_dense (Dense)       (None, 512)                  524800    ['dense_0_dropout[0][0]']
 dense_1_dropout (Dropout)   (None, 512)                  0         ['dense_1_dense[0][0]']
 outputs (Dense)             (None, 2)                    1026      ['dense_1_dropout[0][0]']
==================================================================================================
Total params: 26205442 (99.97 MB)
Trainable params: 26152322 (99.76 MB)
Non-trainable params: 53120 (207.50 KB)
__________________________________________________________________________________________________

--- Model training ---------------------------------------------------

Epoch 1/40
131/131 [==============================] - ETA: 0s - loss: 0.5744 - accuracy: 0.8057 - auc: 0.8881
Epoch 1: val_auc improved from -inf to 0.72875, saving model to /kaggle/working/model/ResNet50.ckpt
131/131 [==============================] - 89s 400ms/step - loss: 0.5744 - accuracy: 0.8057 - auc: 0.8881 - val_loss: 2.0084 - val_accuracy: 0.7287 - val_auc: 0.7287
Epoch 2/40
131/131 [==============================] - ETA: 0s - loss: 0.3045 - accuracy: 0.8863 - auc: 0.9525
Epoch 2: val_auc did not improve from 0.72875
131/131 [==============================] - 25s 193ms/step - loss: 0.3045 - accuracy: 0.8863 - auc: 0.9525 - val_loss: 4.7596 - val_accuracy: 0.7287 - val_auc: 0.7287
Epoch 3/40
131/131 [==============================] - ETA: 0s - loss: 0.2752 - accuracy: 0.8879 - auc: 0.9552
Epoch 3: val_auc did not improve from 0.72875
131/131 [==============================] - 25s 190ms/step - loss: 0.2752 - accuracy: 0.8879 - auc: 0.9552 - val_loss: 4.9704 - val_accuracy: 0.7287 - val_auc: 0.7287
Epoch 4/40
131/131 [==============================] - ETA: 0s - loss: 0.2286 - accuracy: 0.9102 - auc: 0.9685
Epoch 4: val_auc improved from 0.72875 to 0.88451, saving model to /kaggle/working/model/ResNet50.ckpt
131/131 [==============================] - 49s 374ms/step - loss: 0.2286 - accuracy: 0.9102 - auc: 0.9685 - val_loss: 0.7117 - val_accuracy: 0.8262 - val_auc: 0.8845
Epoch 5/40
131/131 [==============================] - ETA: 0s - loss: 0.2202 - accuracy: 0.9114 - auc: 0.9707
Epoch 5: val_auc did not improve from 0.88451
131/131 [==============================] - 25s 190ms/step - loss: 0.2202 - accuracy: 0.9114 - auc: 0.9707 - val_loss: 1.8955 - val_accuracy: 0.7287 - val_auc: 0.7902
Epoch 6/40
131/131 [==============================] - ETA: 0s - loss: 0.2126 - accuracy: 0.9211 - auc: 0.9726
Epoch 6: val_auc did not improve from 0.88451
131/131 [==============================] - 25s 190ms/step - loss: 0.2126 - accuracy: 0.9211 - auc: 0.9726 - val_loss: 1.7594 - val_accuracy: 0.7287 - val_auc: 0.8051
Epoch 7/40
131/131 [==============================] - ETA: 0s - loss: 0.1946 - accuracy: 0.9271 - auc: 0.9766
Epoch 7: val_auc did not improve from 0.88451
131/131 [==============================] - 25s 190ms/step - loss: 0.1946 - accuracy: 0.9271 - auc: 0.9766 - val_loss: 3.1147 - val_accuracy: 0.7287 - val_auc: 0.748
Epoch 8/40
131/131 [==============================] - ETA: 0s - loss: 0.1843 - accuracy: 0.9376 - auc: 0.9785
Epoch 8: val_auc did not improve from 0.88451
131/131 [==============================] - 25s 191ms/step - loss: 0.1843 - accuracy: 0.9376 - auc: 0.9785 - val_loss: 0.5455 - val_accuracy: 0.7517 - val_auc: 0.8557
Epoch 9/40
131/131 [==============================] - ETA: 0s - loss: 0.1847 - accuracy: 0.9324 - auc: 0.9796
Epoch 9: val_auc did not improve from 0.88451
131/131 [==============================] - 25s 191ms/step - loss: 0.1847 - accuracy: 0.9324 - auc: 0.9796 - val_loss: 7.3896 - val_accuracy: 0.2970 - val_auc: 0.2984
Epoch 10/40
131/131 [==============================] - ETA: 0s - loss: 0.1710 - accuracy: 0.9345 - auc: 0.9817
Epoch 10: val_auc did not improve from 0.88451
131/131 [==============================] - 25s 191ms/step - loss: 0.1710 - accuracy: 0.9345 - auc: 0.9817 - val_loss: 0.5625 - val_accuracy: 0.7851 - val_auc: 0.8507
Epoch 11/40
131/131 [==============================] - ETA: 0s - loss: 0.1695 - accuracy: 0.9393 - auc: 0.9819
Epoch 11: val_auc did not improve from 0.88451
131/131 [==============================] - 25s 191ms/step - loss: 0.1695 - accuracy: 0.9393 - auc: 0.9819 - val_loss: 28.2723 - val_accuracy: 0.2837 - val_auc: 0.2830
Epoch 12/40
131/131 [==============================] - ETA: 0s - loss: 0.1653 - accuracy: 0.9376 - auc: 0.9829
Epoch 12: val_auc did not improve from 0.88451
131/131 [==============================] - 25s 190ms/step - loss: 0.1653 - accuracy: 0.9376 - auc: 0.9829 - val_loss: 1.9131 - val_accuracy: 0.4976 - val_auc: 0.5622
Epoch 13/40
131/131 [==============================] - ETA: 0s - loss: 0.1652 - accuracy: 0.9357 - auc: 0.9828
Epoch 13: val_auc did not improve from 0.88451
131/131 [==============================] - 25s 191ms/step - loss: 0.1652 - accuracy: 0.9357 - auc: 0.9828 - val_loss: 2.3911 - val_accuracy: 0.4460 - val_auc: 0.4168
Epoch 14/40
131/131 [==============================] - ETA: 0s - loss: 0.1659 - accuracy: 0.9407 - auc: 0.9825
Epoch 14: val_auc did not improve from 0.88451
131/131 [==============================] - 25s 190ms/step - loss: 0.1659 - accuracy: 0.9407 - auc: 0.9825 - val_loss: 0.9757 - val_accuracy: 0.7125 - val_auc: 0.7713

--- Training history -------------------------------------------------

__results___42_1.png


--- Test Predictions and Metrics -------------------------------------

__results___42_3.png


              precision    recall  f1-score   support

      NORMAL       0.60      0.86      0.70       234
   PNEUMONIA       0.89      0.65      0.75       390

    accuracy                           0.73       624
   macro avg       0.74      0.76      0.73       624
weighted avg       0.78      0.73      0.73       624


=== MODEL EVALUATION FINISHED ========================================

ResNet50 transfer learning

A stejně jako u předchozího modelu vyzkouším také variantu s již vytrénovaným modelem ResNet50 tak, jak je k dispozici v distribuci Keras. Jen pro doplnění, modely jsou trénovány na datové sadě CIFAR-10.

In [22]:


from keras.applications import ResNet50

def create_model_ResNet50Trans(X_shape, classes=2, name="ResNet50Trans"):

    def mlp(x, hidden_units, activation='relu', dropout_rate=0.3, name=""):
        for i, units in enumerate(hidden_units):
            x = layers.Dense(units, activation=activation, name=f"{name}_{i}_dense")(x)
            x = layers.Dropout(dropout_rate, name=f"{name}_{i}_dropout")(x)
        return x

    base_model = ResNet50(include_top=False, input_shape=tuple(X_shape)[-3:])
    base_model.trainable = False

    inputs = Input(X_shape[-3:], name='inputs')

    x = base_model(inputs, training=False)

    x = layers.GlobalAveragePooling2D(name=f"global_average")(x)

    x = mlp(x, (1024, 512), name="dense")
    outputs = layers.Dense(classes, activation='softmax', name='outputs')(x)

    return Model(inputs=inputs, outputs=outputs, name=name)

Jedná se tedy o model ResNet50 z balíčku Keras Applications, který jsem doplnil o klasifikační vrstvy dle vlastní potřeby.

Připravím si obrazová data jako RGB snímky (to je potřeba s ohledem na základní model):

In [23]:


x_train, x_valid, y_train, y_valid = train_test_split(*get_datasource(DATA_TRAIN, DATA_VALID, flag=cv2.IMREAD_COLOR), test_size=0.2)
x_test, y_test = get_datasource(DATA_TEST, flag=cv2.IMREAD_COLOR)

datagen = ImageDataGenerator(
        rotation_range = 30,
        zoom_range = 0.2,
        width_shift_range=0.1,
        height_shift_range=0.1,
        horizontal_flip = True,
        vertical_flip=False)

datagen.fit(x_train)

No a nyní již vlastní vyhodnocení modelu:

In [24]:


evaluate_model(create_model_ResNet50Trans(x_train.shape, 2), forced_training=True)
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5
94765736/94765736 [==============================] - 0s 0us/step
=== MODEL EVALUATION =================================================

Model: "ResNet50Trans"
_________________________________________________________________
 Layer (type)                Output Shape              Param #
=================================================================
 inputs (InputLayer)         [(None, 224, 224, 3)]     0
 resnet50 (Functional)       (None, 7, 7, 2048)        23587712
 global_average (GlobalAver  (None, 2048)              0
 agePooling2D)
 dense_0_dense (Dense)       (None, 1024)              2098176
 dense_0_dropout (Dropout)   (None, 1024)              0
 dense_1_dense (Dense)       (None, 512)               524800
 dense_1_dropout (Dropout)   (None, 512)               0
 outputs (Dense)             (None, 2)                 1026
=================================================================
Total params: 26211714 (99.99 MB)
Trainable params: 2624002 (10.01 MB)
Non-trainable params: 23587712 (89.98 MB)
_________________________________________________________________

--- Model training ---------------------------------------------------

Epoch 1/40
131/131 [==============================] - ETA: 0s - loss: 0.6163 - accuracy: 0.7250 - auc: 0.7551
Epoch 1: val_auc improved from -inf to 0.86753, saving model to /kaggle/working/model/ResNet50Trans.ckpt
131/131 [==============================] - 76s 550ms/step - loss: 0.6163 - accuracy: 0.7250 - auc: 0.7551 - val_loss: 0.5485 - val_accuracy: 0.7326 - val_auc: 0.8675
Epoch 2/40
131/131 [==============================] - ETA: 0s - loss: 0.5344 - accuracy: 0.7462 - auc: 0.8092
Epoch 2: val_auc improved from 0.86753 to 0.88226, saving model to /kaggle/working/model/ResNet50Trans.ckpt
131/131 [==============================] - 69s 530ms/step - loss: 0.5344 - accuracy: 0.7462 - auc: 0.8092 - val_loss: 0.5178 - val_accuracy: 0.7326 - val_auc: 0.8823
Epoch 3/40
131/131 [==============================] - ETA: 0s - loss: 0.4697 - accuracy: 0.7730 - auc: 0.8585
Epoch 3: val_auc improved from 0.88226 to 0.88246, saving model to /kaggle/working/model/ResNet50Trans.ckpt
131/131 [==============================] - 70s 536ms/step - loss: 0.4697 - accuracy: 0.7730 - auc: 0.8585 - val_loss: 0.4488 - val_accuracy: 0.7612 - val_auc: 0.8825
Epoch 4/40
131/131 [==============================] - ETA: 0s - loss: 0.4384 - accuracy: 0.7771 - auc: 0.8749
Epoch 4: val_auc improved from 0.88246 to 0.89494, saving model to /kaggle/working/model/ResNet50Trans.ckpt
131/131 [==============================] - 69s 531ms/step - loss: 0.4384 - accuracy: 0.7771 - auc: 0.8749 - val_loss: 0.4043 - val_accuracy: 0.7947 - val_auc: 0.8949
Epoch 5/40
131/131 [==============================] - ETA: 0s - loss: 0.4407 - accuracy: 0.7795 - auc: 0.8732
Epoch 5: val_auc did not improve from 0.89494
131/131 [==============================] - 51s 390ms/step - loss: 0.4407 - accuracy: 0.7795 - auc: 0.8732 - val_loss: 0.4147 - val_accuracy: 0.7937 - val_auc: 0.8945
Epoch 6/40
131/131 [==============================] - ETA: 0s - loss: 0.4329 - accuracy: 0.7845 - auc: 0.8787
Epoch 6: val_auc improved from 0.89494 to 0.91911, saving model to /kaggle/working/model/ResNet50Trans.ckpt
131/131 [==============================] - 70s 532ms/step - loss: 0.4329 - accuracy: 0.7845 - auc: 0.8787 - val_loss: 0.4018 - val_accuracy: 0.8185 - val_auc: 0.9191
Epoch 7/40
131/131 [==============================] - ETA: 0s - loss: 0.4281 - accuracy: 0.7720 - auc: 0.8767
Epoch 7: val_auc did not improve from 0.91911
131/131 [==============================] - 52s 395ms/step - loss: 0.4281 - accuracy: 0.7720 - auc: 0.8767 - val_loss: 0.4607 - val_accuracy: 0.7326 - val_auc: 0.8788
Epoch 8/40
131/131 [==============================] - ETA: 0s - loss: 0.4918 - accuracy: 0.7388 - auc: 0.8300
Epoch 8: val_auc did not improve from 0.91911
131/131 [==============================] - 51s 388ms/step - loss: 0.4918 - accuracy: 0.7388 - auc: 0.8300 - val_loss: 0.4640 - val_accuracy: 0.7326 - val_auc: 0.8811
Epoch 9/40
131/131 [==============================] - ETA: 0s - loss: 0.4865 - accuracy: 0.7446 - auc: 0.8231
Epoch 9: val_auc did not improve from 0.91911
131/131 [==============================] - 52s 394ms/step - loss: 0.4865 - accuracy: 0.7446 - auc: 0.8231 - val_loss: 0.5245 - val_accuracy: 0.7326 - val_auc: 0.8436
Epoch 10/40
131/131 [==============================] - ETA: 0s - loss: 0.4874 - accuracy: 0.7446 - auc: 0.8295
Epoch 10: val_auc did not improve from 0.91911
131/131 [==============================] - 52s 394ms/step - loss: 0.4874 - accuracy: 0.7446 - auc: 0.8295 - val_loss: 0.4076 - val_accuracy: 0.7326 - val_auc: 0.8829
Epoch 11/40
131/131 [==============================] - ETA: 0s - loss: 0.4764 - accuracy: 0.7448 - auc: 0.8345
Epoch 11: val_auc did not improve from 0.91911
131/131 [==============================] - 52s 393ms/step - loss: 0.4764 - accuracy: 0.7448 - auc: 0.8345 - val_loss: 0.4205 - val_accuracy: 0.7326 - val_auc: 0.8807
Epoch 12/40
131/131 [==============================] - ETA: 0s - loss: 0.4721 - accuracy: 0.7446 - auc: 0.8337
Epoch 12: val_auc did not improve from 0.91911
131/131 [==============================] - 51s 389ms/step - loss: 0.4721 - accuracy: 0.7446 - auc: 0.8337 - val_loss: 0.4012 - val_accuracy: 0.7326 - val_auc: 0.8916
Epoch 13/40
131/131 [==============================] - ETA: 0s - loss: 0.5537 - accuracy: 0.7448 - auc: 0.7788
Epoch 13: val_auc did not improve from 0.91911
131/131 [==============================] - 51s 392ms/step - loss: 0.5537 - accuracy: 0.7448 - auc: 0.7788 - val_loss: 0.5454 - val_accuracy: 0.7326 - val_auc: 0.8753
Epoch 14/40
131/131 [==============================] - ETA: 0s - loss: 0.5676 - accuracy: 0.7453 - auc: 0.7499
Epoch 14: val_auc did not improve from 0.91911
131/131 [==============================] - 51s 391ms/step - loss: 0.5676 - accuracy: 0.7453 - auc: 0.7499 - val_loss: 0.5817 - val_accuracy: 0.7326 - val_auc: 0.7326
Epoch 15/40
131/131 [==============================] - ETA: 0s - loss: 0.5687 - accuracy: 0.7448 - auc: 0.7436
Epoch 15: val_auc did not improve from 0.91911
131/131 [==============================] - 51s 392ms/step - loss: 0.5687 - accuracy: 0.7448 - auc: 0.7436 - val_loss: 0.5477 - val_accuracy: 0.7335 - val_auc: 0.8862
Epoch 16/40
131/131 [==============================] - ETA: 0s - loss: 0.5077 - accuracy: 0.7740 - auc: 0.8281
Epoch 16: val_auc did not improve from 0.91911
131/131 [==============================] - 51s 388ms/step - loss: 0.5077 - accuracy: 0.7740 - auc: 0.8281 - val_loss: 0.5728 - val_accuracy: 0.7364 - val_auc: 0.8177

--- Training history -------------------------------------------------

__results___48_1.png


--- Test Predictions and Metrics -------------------------------------

__results___48_3.png


              precision    recall  f1-score   support

      NORMAL       0.70      0.78      0.74       234
   PNEUMONIA       0.86      0.80      0.83       390

    accuracy                           0.79       624
   macro avg       0.78      0.79      0.78       624
weighted avg       0.80      0.79      0.80       624


=== MODEL EVALUATION FINISHED ========================================

Příště bych rád zabrousil do poněkud jiné oblasti, a sice využití transformer modelů pro klasifikaci obrázků.

Sdílet