Navazuji na předchozí dva články: Rozpoznání zápalu plic z RTG snímků – díl první a Rozpoznání zápalu plic z RTG snímků – díl druhý .
Jedná se o jeden z nejúspěšnějších a také nejoblíbenějších modelů pro klasifikaci postavených na konvolučních vrstvách. Jeho autorům se podařilo vyřešit problém s degradací gradientu u modelů s hodně vrstvami. Jejich řešením je doplnění „zkratky“ pro data postupující sítí. Existuje mnoho dokumentů popisujících princip tohoto modelu a také vysvětlujících, proč to vlastně funguje. Jedním může být například tento: Detailed Explanation of Resnet CNN Model.
Jen pro orientační představu, takto je schematicky zachycen model ReNet50:
Rovnou ukážu funkci implementující model, a pak několik poznámek:
In [19]:
def create_model_ResNet50(X_shape, classes=2, name="ResNet50"): def mlp(x, hidden_units, activation='relu', dropout_rate=0.3, name=""): for i, units in enumerate(hidden_units): x = layers.Dense(units, activation=activation, name=f"{name}_{i}_dense")(x) x = layers.Dropout(dropout_rate, name=f"{name}_{i}_dropout")(x) return x def conv_block(x, *, filters, kernel_size, strides=(1, 1), padding='same', activation='relu', name=""): x = layers.Conv2D(filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, name=f"{name}_conv")(x) x = layers.BatchNormalization(name=f"{name}_norm")(x) if activation: x = layers.Activation(activation, name=f"{name}_actn")(x) return x def identity_block(x, *, filters, name=""): shortcut = x x = conv_block(x, filters=filters, kernel_size=(1, 1), name=f"{name}_cb1") x = conv_block(x, filters=filters, kernel_size=(3, 3), name=f"{name}_cb2") x = conv_block(x, filters=filters * 4, kernel_size=(1, 1), activation='', name=f"{name}_cb3") x = layers.Add(name=f"{name}_add")([x, shortcut]) x = layers.Activation('relu', name=f"{name}_actn")(x) return x def projection_block(x, *, filters, strides, name=""): shortcut = x x = conv_block(x, filters=filters, kernel_size=(1, 1), strides=strides, name=f"{name}_cb1") x = conv_block(x, filters=filters, kernel_size=(3, 3), name=f"{name}_cb2") x = conv_block(x, filters=filters * 4, kernel_size=(1, 1), activation='', name=f"{name}_cb3") shortcut = conv_block(shortcut, filters=filters * 4, kernel_size=(1, 1), strides=strides, activation='', name=f"{name}_cb4") x = layers.Add(name=f"{name}_add")([x, shortcut]) x = layers.Activation('relu', name=f"{name}_actn")(x) return x inputs = Input(X_shape[-3:], name='inputs') # === Stage 1 === x = conv_block(inputs, filters=64, kernel_size=(7, 7), strides=(2, 2), padding='same', name="stg1_cb1") x = layers.MaxPooling2D(pool_size=(3, 3), strides=(2, 2), padding='same', name="stg1_maxpool")(x) # === Stage 2 === x = projection_block(x, filters=64, strides=(1, 1), name="stg2_pb") x = identity_block(x, filters=64, name="stg2_ib1") x = identity_block(x, filters=64, name="stg2_ib2") # === Stage 3 === x = projection_block(x, filters=128, strides=(2, 2), name="stg3_pb") x = identity_block(x, filters=128, name="stg3_ib1") x = identity_block(x, filters=128, name="stg3_ib2") x = identity_block(x, filters=128, name="stg3_ib3") # === Stage 4 === x = projection_block(x, filters=256, strides=(2, 2), name="stg4_pb") x = identity_block(x, filters=256, name="stg4_ib1") x = identity_block(x, filters=256, name="stg4_ib2") x = identity_block(x, filters=256, name="stg4_ib3") x = identity_block(x, filters=256, name="stg4_ib4") x = identity_block(x, filters=256, name="stg4_ib5") # === Stage 5 === x = projection_block(x, filters=512, strides=(2, 2), name="stg5_pb") x = identity_block(x, filters=512, name="stg5_ib1") x = identity_block(x, filters=512, name="stg5_ib2") x = layers.GlobalAveragePooling2D(name=f"stg5_globaver")(x) x = mlp(x, (1024, 512), name="dense") outputs = layers.Dense(classes, activation='softmax', name='outputs')(x) return Model(inputs=inputs, outputs=outputs, name=name)
Model je postupně skládán z tzv. „projection block“ a „identity block“. Rozdíl mezi nimi je v tom, že „projection block“ provádí redukci plošných dimenzí a to s využitím změny kroku konvoluce (to je ten parametr strides). No a aby bylo možné na konci bloku sečíst tensory běžící konvolučními vrstvami se zkratkou, je potřeba doplnit ještě jednu konvoluční vrstvu do té zkratky. No a to je ten rozdíl.
Oba typy bloků se pak vrství do tzv. Etap, ale princi zůstává pro všechny klasifikační modely stejný – redukují se plošné dimenze a roste dimenze vlastností.
Za upozornění stojí ještě intenzivní využívání konvolučních vrstev s velikostí kernelu (1, 1). To vypadá na první pohled poněkud zvláštně. Jedním z důvodů jejich použití je rozšíření dimenze vlastností bez zásahu do dimenzí plošných, a také doplnění další nelinearity do modelu.
Závěr modelu je již jako obvykle zajištěn multi-layer perceptron a poslední klasifikační vrstvou.
Před samotným zkoušením si opět připravím dat. Vzhledem k tomu, že trénuji celý model sám, použiji černobílé obrázky:
In [20]:
x_train, x_valid, y_train, y_valid = train_test_split(*get_datasource(DATA_TRAIN, DATA_VALID), test_size=0.2) x_test, y_test = get_datasource(DATA_TEST) x_train = np.expand_dims(x_train, axis=-1) x_valid = np.expand_dims(x_valid, axis=-1) x_test = np.expand_dims(x_test, axis=-1) datagen = ImageDataGenerator( rotation_range = 30, zoom_range = 0.2, width_shift_range=0.1, height_shift_range=0.1, horizontal_flip = True, vertical_flip=False) datagen.fit(x_train)
A nyní již samotné vyhodnocení modelu. Je vidět, že těch vrstev je tam skutečně hodně:
In [21]:
evaluate_model(create_model_ResNet50(x_train.shape, 2), forced_training=True) === MODEL EVALUATION ================================================= Model: "ResNet50" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== inputs (InputLayer) [(None, 224, 224, 1)] 0 [] stg1_cb1_conv (Conv2D) (None, 112, 112, 64) 3200 ['inputs[0][0]'] stg1_cb1_norm (BatchNormal (None, 112, 112, 64) 256 ['stg1_cb1_conv[0][0]'] ization) stg1_cb1_actn (Activation) (None, 112, 112, 64) 0 ['stg1_cb1_norm[0][0]'] stg1_maxpool (MaxPooling2D (None, 56, 56, 64) 0 ['stg1_cb1_actn[0][0]'] ) stg2_pb_cb1_conv (Conv2D) (None, 56, 56, 64) 4160 ['stg1_maxpool[0][0]'] stg2_pb_cb1_norm (BatchNor (None, 56, 56, 64) 256 ['stg2_pb_cb1_conv[0][0]'] malization) stg2_pb_cb1_actn (Activati (None, 56, 56, 64) 0 ['stg2_pb_cb1_norm[0][0]'] on) stg2_pb_cb2_conv (Conv2D) (None, 56, 56, 64) 36928 ['stg2_pb_cb1_actn[0][0]'] stg2_pb_cb2_norm (BatchNor (None, 56, 56, 64) 256 ['stg2_pb_cb2_conv[0][0]'] malization) stg2_pb_cb2_actn (Activati (None, 56, 56, 64) 0 ['stg2_pb_cb2_norm[0][0]'] on) stg2_pb_cb3_conv (Conv2D) (None, 56, 56, 256) 16640 ['stg2_pb_cb2_actn[0][0]'] stg2_pb_cb4_conv (Conv2D) (None, 56, 56, 256) 16640 ['stg1_maxpool[0][0]'] stg2_pb_cb3_norm (BatchNor (None, 56, 56, 256) 1024 ['stg2_pb_cb3_conv[0][0]'] malization) stg2_pb_cb4_norm (BatchNor (None, 56, 56, 256) 1024 ['stg2_pb_cb4_conv[0][0]'] malization) stg2_pb_add (Add) (None, 56, 56, 256) 0 ['stg2_pb_cb3_norm[0][0]', 'stg2_pb_cb4_norm[0][0]'] stg2_pb_actn (Activation) (None, 56, 56, 256) 0 ['stg2_pb_add[0][0]'] stg2_ib1_cb1_conv (Conv2D) (None, 56, 56, 64) 16448 ['stg2_pb_actn[0][0]'] stg2_ib1_cb1_norm (BatchNo (None, 56, 56, 64) 256 ['stg2_ib1_cb1_conv[0][0]'] rmalization) stg2_ib1_cb1_actn (Activat (None, 56, 56, 64) 0 ['stg2_ib1_cb1_norm[0][0]'] ion) stg2_ib1_cb2_conv (Conv2D) (None, 56, 56, 64) 36928 ['stg2_ib1_cb1_actn[0][0]'] stg2_ib1_cb2_norm (BatchNo (None, 56, 56, 64) 256 ['stg2_ib1_cb2_conv[0][0]'] rmalization) stg2_ib1_cb2_actn (Activat (None, 56, 56, 64) 0 ['stg2_ib1_cb2_norm[0][0]'] ion) stg2_ib1_cb3_conv (Conv2D) (None, 56, 56, 256) 16640 ['stg2_ib1_cb2_actn[0][0]'] stg2_ib1_cb3_norm (BatchNo (None, 56, 56, 256) 1024 ['stg2_ib1_cb3_conv[0][0]'] rmalization) stg2_ib1_add (Add) (None, 56, 56, 256) 0 ['stg2_ib1_cb3_norm[0][0]', 'stg2_pb_actn[0][0]'] stg2_ib1_actn (Activation) (None, 56, 56, 256) 0 ['stg2_ib1_add[0][0]'] stg2_ib2_cb1_conv (Conv2D) (None, 56, 56, 64) 16448 ['stg2_ib1_actn[0][0]'] stg2_ib2_cb1_norm (BatchNo (None, 56, 56, 64) 256 ['stg2_ib2_cb1_conv[0][0]'] rmalization) stg2_ib2_cb1_actn (Activat (None, 56, 56, 64) 0 ['stg2_ib2_cb1_norm[0][0]'] ion) stg2_ib2_cb2_conv (Conv2D) (None, 56, 56, 64) 36928 ['stg2_ib2_cb1_actn[0][0]'] stg2_ib2_cb2_norm (BatchNo (None, 56, 56, 64) 256 ['stg2_ib2_cb2_conv[0][0]'] rmalization) stg2_ib2_cb2_actn (Activat (None, 56, 56, 64) 0 ['stg2_ib2_cb2_norm[0][0]'] ion) stg2_ib2_cb3_conv (Conv2D) (None, 56, 56, 256) 16640 ['stg2_ib2_cb2_actn[0][0]'] stg2_ib2_cb3_norm (BatchNo (None, 56, 56, 256) 1024 ['stg2_ib2_cb3_conv[0][0]'] rmalization) stg2_ib2_add (Add) (None, 56, 56, 256) 0 ['stg2_ib2_cb3_norm[0][0]', 'stg2_ib1_actn[0][0]'] stg2_ib2_actn (Activation) (None, 56, 56, 256) 0 ['stg2_ib2_add[0][0]'] stg3_pb_cb1_conv (Conv2D) (None, 28, 28, 128) 32896 ['stg2_ib2_actn[0][0]'] stg3_pb_cb1_norm (BatchNor (None, 28, 28, 128) 512 ['stg3_pb_cb1_conv[0][0]'] malization) stg3_pb_cb1_actn (Activati (None, 28, 28, 128) 0 ['stg3_pb_cb1_norm[0][0]'] on) stg3_pb_cb2_conv (Conv2D) (None, 28, 28, 128) 147584 ['stg3_pb_cb1_actn[0][0]'] stg3_pb_cb2_norm (BatchNor (None, 28, 28, 128) 512 ['stg3_pb_cb2_conv[0][0]'] malization) stg3_pb_cb2_actn (Activati (None, 28, 28, 128) 0 ['stg3_pb_cb2_norm[0][0]'] on) stg3_pb_cb3_conv (Conv2D) (None, 28, 28, 512) 66048 ['stg3_pb_cb2_actn[0][0]'] stg3_pb_cb4_conv (Conv2D) (None, 28, 28, 512) 131584 ['stg2_ib2_actn[0][0]'] stg3_pb_cb3_norm (BatchNor (None, 28, 28, 512) 2048 ['stg3_pb_cb3_conv[0][0]'] malization) stg3_pb_cb4_norm (BatchNor (None, 28, 28, 512) 2048 ['stg3_pb_cb4_conv[0][0]'] malization) stg3_pb_add (Add) (None, 28, 28, 512) 0 ['stg3_pb_cb3_norm[0][0]', 'stg3_pb_cb4_norm[0][0]'] stg3_pb_actn (Activation) (None, 28, 28, 512) 0 ['stg3_pb_add[0][0]'] stg3_ib1_cb1_conv (Conv2D) (None, 28, 28, 128) 65664 ['stg3_pb_actn[0][0]'] stg3_ib1_cb1_norm (BatchNo (None, 28, 28, 128) 512 ['stg3_ib1_cb1_conv[0][0]'] rmalization) stg3_ib1_cb1_actn (Activat (None, 28, 28, 128) 0 ['stg3_ib1_cb1_norm[0][0]'] ion) stg3_ib1_cb2_conv (Conv2D) (None, 28, 28, 128) 147584 ['stg3_ib1_cb1_actn[0][0]'] stg3_ib1_cb2_norm (BatchNo (None, 28, 28, 128) 512 ['stg3_ib1_cb2_conv[0][0]'] rmalization) stg3_ib1_cb2_actn (Activat (None, 28, 28, 128) 0 ['stg3_ib1_cb2_norm[0][0]'] ion) stg3_ib1_cb3_conv (Conv2D) (None, 28, 28, 512) 66048 ['stg3_ib1_cb2_actn[0][0]'] stg3_ib1_cb3_norm (BatchNo (None, 28, 28, 512) 2048 ['stg3_ib1_cb3_conv[0][0]'] rmalization) stg3_ib1_add (Add) (None, 28, 28, 512) 0 ['stg3_ib1_cb3_norm[0][0]', 'stg3_pb_actn[0][0]'] stg3_ib1_actn (Activation) (None, 28, 28, 512) 0 ['stg3_ib1_add[0][0]'] stg3_ib2_cb1_conv (Conv2D) (None, 28, 28, 128) 65664 ['stg3_ib1_actn[0][0]'] stg3_ib2_cb1_norm (BatchNo (None, 28, 28, 128) 512 ['stg3_ib2_cb1_conv[0][0]'] rmalization) stg3_ib2_cb1_actn (Activat (None, 28, 28, 128) 0 ['stg3_ib2_cb1_norm[0][0]'] ion) stg3_ib2_cb2_conv (Conv2D) (None, 28, 28, 128) 147584 ['stg3_ib2_cb1_actn[0][0]'] stg3_ib2_cb2_norm (BatchNo (None, 28, 28, 128) 512 ['stg3_ib2_cb2_conv[0][0]'] rmalization) stg3_ib2_cb2_actn (Activat (None, 28, 28, 128) 0 ['stg3_ib2_cb2_norm[0][0]'] ion) stg3_ib2_cb3_conv (Conv2D) (None, 28, 28, 512) 66048 ['stg3_ib2_cb2_actn[0][0]'] stg3_ib2_cb3_norm (BatchNo (None, 28, 28, 512) 2048 ['stg3_ib2_cb3_conv[0][0]'] rmalization) stg3_ib2_add (Add) (None, 28, 28, 512) 0 ['stg3_ib2_cb3_norm[0][0]', 'stg3_ib1_actn[0][0]'] stg3_ib2_actn (Activation) (None, 28, 28, 512) 0 ['stg3_ib2_add[0][0]'] stg3_ib3_cb1_conv (Conv2D) (None, 28, 28, 128) 65664 ['stg3_ib2_actn[0][0]'] stg3_ib3_cb1_norm (BatchNo (None, 28, 28, 128) 512 ['stg3_ib3_cb1_conv[0][0]'] rmalization) stg3_ib3_cb1_actn (Activat (None, 28, 28, 128) 0 ['stg3_ib3_cb1_norm[0][0]'] ion) stg3_ib3_cb2_conv (Conv2D) (None, 28, 28, 128) 147584 ['stg3_ib3_cb1_actn[0][0]'] stg3_ib3_cb2_norm (BatchNo (None, 28, 28, 128) 512 ['stg3_ib3_cb2_conv[0][0]'] rmalization) stg3_ib3_cb2_actn (Activat (None, 28, 28, 128) 0 ['stg3_ib3_cb2_norm[0][0]'] ion) stg3_ib3_cb3_conv (Conv2D) (None, 28, 28, 512) 66048 ['stg3_ib3_cb2_actn[0][0]'] stg3_ib3_cb3_norm (BatchNo (None, 28, 28, 512) 2048 ['stg3_ib3_cb3_conv[0][0]'] rmalization) stg3_ib3_add (Add) (None, 28, 28, 512) 0 ['stg3_ib3_cb3_norm[0][0]', 'stg3_ib2_actn[0][0]'] stg3_ib3_actn (Activation) (None, 28, 28, 512) 0 ['stg3_ib3_add[0][0]'] stg4_pb_cb1_conv (Conv2D) (None, 14, 14, 256) 131328 ['stg3_ib3_actn[0][0]'] stg4_pb_cb1_norm (BatchNor (None, 14, 14, 256) 1024 ['stg4_pb_cb1_conv[0][0]'] malization) stg4_pb_cb1_actn (Activati (None, 14, 14, 256) 0 ['stg4_pb_cb1_norm[0][0]'] on) stg4_pb_cb2_conv (Conv2D) (None, 14, 14, 256) 590080 ['stg4_pb_cb1_actn[0][0]'] stg4_pb_cb2_norm (BatchNor (None, 14, 14, 256) 1024 ['stg4_pb_cb2_conv[0][0]'] malization) stg4_pb_cb2_actn (Activati (None, 14, 14, 256) 0 ['stg4_pb_cb2_norm[0][0]'] on) stg4_pb_cb3_conv (Conv2D) (None, 14, 14, 1024) 263168 ['stg4_pb_cb2_actn[0][0]'] stg4_pb_cb4_conv (Conv2D) (None, 14, 14, 1024) 525312 ['stg3_ib3_actn[0][0]'] stg4_pb_cb3_norm (BatchNor (None, 14, 14, 1024) 4096 ['stg4_pb_cb3_conv[0][0]'] malization) stg4_pb_cb4_norm (BatchNor (None, 14, 14, 1024) 4096 ['stg4_pb_cb4_conv[0][0]'] malization) stg4_pb_add (Add) (None, 14, 14, 1024) 0 ['stg4_pb_cb3_norm[0][0]', 'stg4_pb_cb4_norm[0][0]'] stg4_pb_actn (Activation) (None, 14, 14, 1024) 0 ['stg4_pb_add[0][0]'] stg4_ib1_cb1_conv (Conv2D) (None, 14, 14, 256) 262400 ['stg4_pb_actn[0][0]'] stg4_ib1_cb1_norm (BatchNo (None, 14, 14, 256) 1024 ['stg4_ib1_cb1_conv[0][0]'] rmalization) stg4_ib1_cb1_actn (Activat (None, 14, 14, 256) 0 ['stg4_ib1_cb1_norm[0][0]'] ion) stg4_ib1_cb2_conv (Conv2D) (None, 14, 14, 256) 590080 ['stg4_ib1_cb1_actn[0][0]'] stg4_ib1_cb2_norm (BatchNo (None, 14, 14, 256) 1024 ['stg4_ib1_cb2_conv[0][0]'] rmalization) stg4_ib1_cb2_actn (Activat (None, 14, 14, 256) 0 ['stg4_ib1_cb2_norm[0][0]'] ion) stg4_ib1_cb3_conv (Conv2D) (None, 14, 14, 1024) 263168 ['stg4_ib1_cb2_actn[0][0]'] stg4_ib1_cb3_norm (BatchNo (None, 14, 14, 1024) 4096 ['stg4_ib1_cb3_conv[0][0]'] rmalization) stg4_ib1_add (Add) (None, 14, 14, 1024) 0 ['stg4_ib1_cb3_norm[0][0]', 'stg4_pb_actn[0][0]'] stg4_ib1_actn (Activation) (None, 14, 14, 1024) 0 ['stg4_ib1_add[0][0]'] stg4_ib2_cb1_conv (Conv2D) (None, 14, 14, 256) 262400 ['stg4_ib1_actn[0][0]'] stg4_ib2_cb1_norm (BatchNo (None, 14, 14, 256) 1024 ['stg4_ib2_cb1_conv[0][0]'] rmalization) stg4_ib2_cb1_actn (Activat (None, 14, 14, 256) 0 ['stg4_ib2_cb1_norm[0][0]'] ion) stg4_ib2_cb2_conv (Conv2D) (None, 14, 14, 256) 590080 ['stg4_ib2_cb1_actn[0][0]'] stg4_ib2_cb2_norm (BatchNo (None, 14, 14, 256) 1024 ['stg4_ib2_cb2_conv[0][0]'] rmalization) stg4_ib2_cb2_actn (Activat (None, 14, 14, 256) 0 ['stg4_ib2_cb2_norm[0][0]'] ion) stg4_ib2_cb3_conv (Conv2D) (None, 14, 14, 1024) 263168 ['stg4_ib2_cb2_actn[0][0]'] stg4_ib2_cb3_norm (BatchNo (None, 14, 14, 1024) 4096 ['stg4_ib2_cb3_conv[0][0]'] rmalization) stg4_ib2_add (Add) (None, 14, 14, 1024) 0 ['stg4_ib2_cb3_norm[0][0]', 'stg4_ib1_actn[0][0]'] stg4_ib2_actn (Activation) (None, 14, 14, 1024) 0 ['stg4_ib2_add[0][0]'] stg4_ib3_cb1_conv (Conv2D) (None, 14, 14, 256) 262400 ['stg4_ib2_actn[0][0]'] stg4_ib3_cb1_norm (BatchNo (None, 14, 14, 256) 1024 ['stg4_ib3_cb1_conv[0][0]'] rmalization) stg4_ib3_cb1_actn (Activat (None, 14, 14, 256) 0 ['stg4_ib3_cb1_norm[0][0]'] ion) stg4_ib3_cb2_conv (Conv2D) (None, 14, 14, 256) 590080 ['stg4_ib3_cb1_actn[0][0]'] stg4_ib3_cb2_norm (BatchNo (None, 14, 14, 256) 1024 ['stg4_ib3_cb2_conv[0][0]'] rmalization) stg4_ib3_cb2_actn (Activat (None, 14, 14, 256) 0 ['stg4_ib3_cb2_norm[0][0]'] ion) stg4_ib3_cb3_conv (Conv2D) (None, 14, 14, 1024) 263168 ['stg4_ib3_cb2_actn[0][0]'] stg4_ib3_cb3_norm (BatchNo (None, 14, 14, 1024) 4096 ['stg4_ib3_cb3_conv[0][0]'] rmalization) stg4_ib3_add (Add) (None, 14, 14, 1024) 0 ['stg4_ib3_cb3_norm[0][0]', 'stg4_ib2_actn[0][0]'] stg4_ib3_actn (Activation) (None, 14, 14, 1024) 0 ['stg4_ib3_add[0][0]'] stg4_ib4_cb1_conv (Conv2D) (None, 14, 14, 256) 262400 ['stg4_ib3_actn[0][0]'] stg4_ib4_cb1_norm (BatchNo (None, 14, 14, 256) 1024 ['stg4_ib4_cb1_conv[0][0]'] rmalization) stg4_ib4_cb1_actn (Activat (None, 14, 14, 256) 0 ['stg4_ib4_cb1_norm[0][0]'] ion) stg4_ib4_cb2_conv (Conv2D) (None, 14, 14, 256) 590080 ['stg4_ib4_cb1_actn[0][0]'] stg4_ib4_cb2_norm (BatchNo (None, 14, 14, 256) 1024 ['stg4_ib4_cb2_conv[0][0]'] rmalization) stg4_ib4_cb2_actn (Activat (None, 14, 14, 256) 0 ['stg4_ib4_cb2_norm[0][0]'] ion) stg4_ib4_cb3_conv (Conv2D) (None, 14, 14, 1024) 263168 ['stg4_ib4_cb2_actn[0][0]'] stg4_ib4_cb3_norm (BatchNo (None, 14, 14, 1024) 4096 ['stg4_ib4_cb3_conv[0][0]'] rmalization) stg4_ib4_add (Add) (None, 14, 14, 1024) 0 ['stg4_ib4_cb3_norm[0][0]', 'stg4_ib3_actn[0][0]'] stg4_ib4_actn (Activation) (None, 14, 14, 1024) 0 ['stg4_ib4_add[0][0]'] stg4_ib5_cb1_conv (Conv2D) (None, 14, 14, 256) 262400 ['stg4_ib4_actn[0][0]'] stg4_ib5_cb1_norm (BatchNo (None, 14, 14, 256) 1024 ['stg4_ib5_cb1_conv[0][0]'] rmalization) stg4_ib5_cb1_actn (Activat (None, 14, 14, 256) 0 ['stg4_ib5_cb1_norm[0][0]'] ion) stg4_ib5_cb2_conv (Conv2D) (None, 14, 14, 256) 590080 ['stg4_ib5_cb1_actn[0][0]'] stg4_ib5_cb2_norm (BatchNo (None, 14, 14, 256) 1024 ['stg4_ib5_cb2_conv[0][0]'] rmalization) stg4_ib5_cb2_actn (Activat (None, 14, 14, 256) 0 ['stg4_ib5_cb2_norm[0][0]'] ion) stg4_ib5_cb3_conv (Conv2D) (None, 14, 14, 1024) 263168 ['stg4_ib5_cb2_actn[0][0]'] stg4_ib5_cb3_norm (BatchNo (None, 14, 14, 1024) 4096 ['stg4_ib5_cb3_conv[0][0]'] rmalization) stg4_ib5_add (Add) (None, 14, 14, 1024) 0 ['stg4_ib5_cb3_norm[0][0]', 'stg4_ib4_actn[0][0]'] stg4_ib5_actn (Activation) (None, 14, 14, 1024) 0 ['stg4_ib5_add[0][0]'] stg5_pb_cb1_conv (Conv2D) (None, 7, 7, 512) 524800 ['stg4_ib5_actn[0][0]'] stg5_pb_cb1_norm (BatchNor (None, 7, 7, 512) 2048 ['stg5_pb_cb1_conv[0][0]'] malization) stg5_pb_cb1_actn (Activati (None, 7, 7, 512) 0 ['stg5_pb_cb1_norm[0][0]'] on) stg5_pb_cb2_conv (Conv2D) (None, 7, 7, 512) 2359808 ['stg5_pb_cb1_actn[0][0]'] stg5_pb_cb2_norm (BatchNor (None, 7, 7, 512) 2048 ['stg5_pb_cb2_conv[0][0]'] malization) stg5_pb_cb2_actn (Activati (None, 7, 7, 512) 0 ['stg5_pb_cb2_norm[0][0]'] on) stg5_pb_cb3_conv (Conv2D) (None, 7, 7, 2048) 1050624 ['stg5_pb_cb2_actn[0][0]'] stg5_pb_cb4_conv (Conv2D) (None, 7, 7, 2048) 2099200 ['stg4_ib5_actn[0][0]'] stg5_pb_cb3_norm (BatchNor (None, 7, 7, 2048) 8192 ['stg5_pb_cb3_conv[0][0]'] malization) stg5_pb_cb4_norm (BatchNor (None, 7, 7, 2048) 8192 ['stg5_pb_cb4_conv[0][0]'] malization) stg5_pb_add (Add) (None, 7, 7, 2048) 0 ['stg5_pb_cb3_norm[0][0]', 'stg5_pb_cb4_norm[0][0]'] stg5_pb_actn (Activation) (None, 7, 7, 2048) 0 ['stg5_pb_add[0][0]'] stg5_ib1_cb1_conv (Conv2D) (None, 7, 7, 512) 1049088 ['stg5_pb_actn[0][0]'] stg5_ib1_cb1_norm (BatchNo (None, 7, 7, 512) 2048 ['stg5_ib1_cb1_conv[0][0]'] rmalization) stg5_ib1_cb1_actn (Activat (None, 7, 7, 512) 0 ['stg5_ib1_cb1_norm[0][0]'] ion) stg5_ib1_cb2_conv (Conv2D) (None, 7, 7, 512) 2359808 ['stg5_ib1_cb1_actn[0][0]'] stg5_ib1_cb2_norm (BatchNo (None, 7, 7, 512) 2048 ['stg5_ib1_cb2_conv[0][0]'] rmalization) stg5_ib1_cb2_actn (Activat (None, 7, 7, 512) 0 ['stg5_ib1_cb2_norm[0][0]'] ion) stg5_ib1_cb3_conv (Conv2D) (None, 7, 7, 2048) 1050624 ['stg5_ib1_cb2_actn[0][0]'] stg5_ib1_cb3_norm (BatchNo (None, 7, 7, 2048) 8192 ['stg5_ib1_cb3_conv[0][0]'] rmalization) stg5_ib1_add (Add) (None, 7, 7, 2048) 0 ['stg5_ib1_cb3_norm[0][0]', 'stg5_pb_actn[0][0]'] stg5_ib1_actn (Activation) (None, 7, 7, 2048) 0 ['stg5_ib1_add[0][0]'] stg5_ib2_cb1_conv (Conv2D) (None, 7, 7, 512) 1049088 ['stg5_ib1_actn[0][0]'] stg5_ib2_cb1_norm (BatchNo (None, 7, 7, 512) 2048 ['stg5_ib2_cb1_conv[0][0]'] rmalization) stg5_ib2_cb1_actn (Activat (None, 7, 7, 512) 0 ['stg5_ib2_cb1_norm[0][0]'] ion) stg5_ib2_cb2_conv (Conv2D) (None, 7, 7, 512) 2359808 ['stg5_ib2_cb1_actn[0][0]'] stg5_ib2_cb2_norm (BatchNo (None, 7, 7, 512) 2048 ['stg5_ib2_cb2_conv[0][0]'] rmalization) stg5_ib2_cb2_actn (Activat (None, 7, 7, 512) 0 ['stg5_ib2_cb2_norm[0][0]'] ion) stg5_ib2_cb3_conv (Conv2D) (None, 7, 7, 2048) 1050624 ['stg5_ib2_cb2_actn[0][0]'] stg5_ib2_cb3_norm (BatchNo (None, 7, 7, 2048) 8192 ['stg5_ib2_cb3_conv[0][0]'] rmalization) stg5_ib2_add (Add) (None, 7, 7, 2048) 0 ['stg5_ib2_cb3_norm[0][0]', 'stg5_ib1_actn[0][0]'] stg5_ib2_actn (Activation) (None, 7, 7, 2048) 0 ['stg5_ib2_add[0][0]'] stg5_globaver (GlobalAvera (None, 2048) 0 ['stg5_ib2_actn[0][0]'] gePooling2D) dense_0_dense (Dense) (None, 1024) 2098176 ['stg5_globaver[0][0]'] dense_0_dropout (Dropout) (None, 1024) 0 ['dense_0_dense[0][0]'] dense_1_dense (Dense) (None, 512) 524800 ['dense_0_dropout[0][0]'] dense_1_dropout (Dropout) (None, 512) 0 ['dense_1_dense[0][0]'] outputs (Dense) (None, 2) 1026 ['dense_1_dropout[0][0]'] ================================================================================================== Total params: 26205442 (99.97 MB) Trainable params: 26152322 (99.76 MB) Non-trainable params: 53120 (207.50 KB) __________________________________________________________________________________________________ --- Model training --------------------------------------------------- Epoch 1/40 131/131 [==============================] - ETA: 0s - loss: 0.5744 - accuracy: 0.8057 - auc: 0.8881 Epoch 1: val_auc improved from -inf to 0.72875, saving model to /kaggle/working/model/ResNet50.ckpt 131/131 [==============================] - 89s 400ms/step - loss: 0.5744 - accuracy: 0.8057 - auc: 0.8881 - val_loss: 2.0084 - val_accuracy: 0.7287 - val_auc: 0.7287 Epoch 2/40 131/131 [==============================] - ETA: 0s - loss: 0.3045 - accuracy: 0.8863 - auc: 0.9525 Epoch 2: val_auc did not improve from 0.72875 131/131 [==============================] - 25s 193ms/step - loss: 0.3045 - accuracy: 0.8863 - auc: 0.9525 - val_loss: 4.7596 - val_accuracy: 0.7287 - val_auc: 0.7287 Epoch 3/40 131/131 [==============================] - ETA: 0s - loss: 0.2752 - accuracy: 0.8879 - auc: 0.9552 Epoch 3: val_auc did not improve from 0.72875 131/131 [==============================] - 25s 190ms/step - loss: 0.2752 - accuracy: 0.8879 - auc: 0.9552 - val_loss: 4.9704 - val_accuracy: 0.7287 - val_auc: 0.7287 Epoch 4/40 131/131 [==============================] - ETA: 0s - loss: 0.2286 - accuracy: 0.9102 - auc: 0.9685 Epoch 4: val_auc improved from 0.72875 to 0.88451, saving model to /kaggle/working/model/ResNet50.ckpt 131/131 [==============================] - 49s 374ms/step - loss: 0.2286 - accuracy: 0.9102 - auc: 0.9685 - val_loss: 0.7117 - val_accuracy: 0.8262 - val_auc: 0.8845 Epoch 5/40 131/131 [==============================] - ETA: 0s - loss: 0.2202 - accuracy: 0.9114 - auc: 0.9707 Epoch 5: val_auc did not improve from 0.88451 131/131 [==============================] - 25s 190ms/step - loss: 0.2202 - accuracy: 0.9114 - auc: 0.9707 - val_loss: 1.8955 - val_accuracy: 0.7287 - val_auc: 0.7902 Epoch 6/40 131/131 [==============================] - ETA: 0s - loss: 0.2126 - accuracy: 0.9211 - auc: 0.9726 Epoch 6: val_auc did not improve from 0.88451 131/131 [==============================] - 25s 190ms/step - loss: 0.2126 - accuracy: 0.9211 - auc: 0.9726 - val_loss: 1.7594 - val_accuracy: 0.7287 - val_auc: 0.8051 Epoch 7/40 131/131 [==============================] - ETA: 0s - loss: 0.1946 - accuracy: 0.9271 - auc: 0.9766 Epoch 7: val_auc did not improve from 0.88451 131/131 [==============================] - 25s 190ms/step - loss: 0.1946 - accuracy: 0.9271 - auc: 0.9766 - val_loss: 3.1147 - val_accuracy: 0.7287 - val_auc: 0.748 Epoch 8/40 131/131 [==============================] - ETA: 0s - loss: 0.1843 - accuracy: 0.9376 - auc: 0.9785 Epoch 8: val_auc did not improve from 0.88451 131/131 [==============================] - 25s 191ms/step - loss: 0.1843 - accuracy: 0.9376 - auc: 0.9785 - val_loss: 0.5455 - val_accuracy: 0.7517 - val_auc: 0.8557 Epoch 9/40 131/131 [==============================] - ETA: 0s - loss: 0.1847 - accuracy: 0.9324 - auc: 0.9796 Epoch 9: val_auc did not improve from 0.88451 131/131 [==============================] - 25s 191ms/step - loss: 0.1847 - accuracy: 0.9324 - auc: 0.9796 - val_loss: 7.3896 - val_accuracy: 0.2970 - val_auc: 0.2984 Epoch 10/40 131/131 [==============================] - ETA: 0s - loss: 0.1710 - accuracy: 0.9345 - auc: 0.9817 Epoch 10: val_auc did not improve from 0.88451 131/131 [==============================] - 25s 191ms/step - loss: 0.1710 - accuracy: 0.9345 - auc: 0.9817 - val_loss: 0.5625 - val_accuracy: 0.7851 - val_auc: 0.8507 Epoch 11/40 131/131 [==============================] - ETA: 0s - loss: 0.1695 - accuracy: 0.9393 - auc: 0.9819 Epoch 11: val_auc did not improve from 0.88451 131/131 [==============================] - 25s 191ms/step - loss: 0.1695 - accuracy: 0.9393 - auc: 0.9819 - val_loss: 28.2723 - val_accuracy: 0.2837 - val_auc: 0.2830 Epoch 12/40 131/131 [==============================] - ETA: 0s - loss: 0.1653 - accuracy: 0.9376 - auc: 0.9829 Epoch 12: val_auc did not improve from 0.88451 131/131 [==============================] - 25s 190ms/step - loss: 0.1653 - accuracy: 0.9376 - auc: 0.9829 - val_loss: 1.9131 - val_accuracy: 0.4976 - val_auc: 0.5622 Epoch 13/40 131/131 [==============================] - ETA: 0s - loss: 0.1652 - accuracy: 0.9357 - auc: 0.9828 Epoch 13: val_auc did not improve from 0.88451 131/131 [==============================] - 25s 191ms/step - loss: 0.1652 - accuracy: 0.9357 - auc: 0.9828 - val_loss: 2.3911 - val_accuracy: 0.4460 - val_auc: 0.4168 Epoch 14/40 131/131 [==============================] - ETA: 0s - loss: 0.1659 - accuracy: 0.9407 - auc: 0.9825 Epoch 14: val_auc did not improve from 0.88451 131/131 [==============================] - 25s 190ms/step - loss: 0.1659 - accuracy: 0.9407 - auc: 0.9825 - val_loss: 0.9757 - val_accuracy: 0.7125 - val_auc: 0.7713 --- Training history -------------------------------------------------
--- Test Predictions and Metrics -------------------------------------
precision recall f1-score support NORMAL 0.60 0.86 0.70 234 PNEUMONIA 0.89 0.65 0.75 390 accuracy 0.73 624 macro avg 0.74 0.76 0.73 624 weighted avg 0.78 0.73 0.73 624 === MODEL EVALUATION FINISHED ========================================
A stejně jako u předchozího modelu vyzkouším také variantu s již vytrénovaným modelem ResNet50 tak, jak je k dispozici v distribuci Keras. Jen pro doplnění, modely jsou trénovány na datové sadě CIFAR-10.
In [22]:
from keras.applications import ResNet50 def create_model_ResNet50Trans(X_shape, classes=2, name="ResNet50Trans"): def mlp(x, hidden_units, activation='relu', dropout_rate=0.3, name=""): for i, units in enumerate(hidden_units): x = layers.Dense(units, activation=activation, name=f"{name}_{i}_dense")(x) x = layers.Dropout(dropout_rate, name=f"{name}_{i}_dropout")(x) return x base_model = ResNet50(include_top=False, input_shape=tuple(X_shape)[-3:]) base_model.trainable = False inputs = Input(X_shape[-3:], name='inputs') x = base_model(inputs, training=False) x = layers.GlobalAveragePooling2D(name=f"global_average")(x) x = mlp(x, (1024, 512), name="dense") outputs = layers.Dense(classes, activation='softmax', name='outputs')(x) return Model(inputs=inputs, outputs=outputs, name=name)
Jedná se tedy o model ResNet50 z balíčku Keras Applications, který jsem doplnil o klasifikační vrstvy dle vlastní potřeby.
Připravím si obrazová data jako RGB snímky (to je potřeba s ohledem na základní model):
In [23]:
x_train, x_valid, y_train, y_valid = train_test_split(*get_datasource(DATA_TRAIN, DATA_VALID, flag=cv2.IMREAD_COLOR), test_size=0.2) x_test, y_test = get_datasource(DATA_TEST, flag=cv2.IMREAD_COLOR) datagen = ImageDataGenerator( rotation_range = 30, zoom_range = 0.2, width_shift_range=0.1, height_shift_range=0.1, horizontal_flip = True, vertical_flip=False) datagen.fit(x_train)
No a nyní již vlastní vyhodnocení modelu:
In [24]:
evaluate_model(create_model_ResNet50Trans(x_train.shape, 2), forced_training=True) Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5 94765736/94765736 [==============================] - 0s 0us/step === MODEL EVALUATION ================================================= Model: "ResNet50Trans" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= inputs (InputLayer) [(None, 224, 224, 3)] 0 resnet50 (Functional) (None, 7, 7, 2048) 23587712 global_average (GlobalAver (None, 2048) 0 agePooling2D) dense_0_dense (Dense) (None, 1024) 2098176 dense_0_dropout (Dropout) (None, 1024) 0 dense_1_dense (Dense) (None, 512) 524800 dense_1_dropout (Dropout) (None, 512) 0 outputs (Dense) (None, 2) 1026 ================================================================= Total params: 26211714 (99.99 MB) Trainable params: 2624002 (10.01 MB) Non-trainable params: 23587712 (89.98 MB) _________________________________________________________________ --- Model training --------------------------------------------------- Epoch 1/40 131/131 [==============================] - ETA: 0s - loss: 0.6163 - accuracy: 0.7250 - auc: 0.7551 Epoch 1: val_auc improved from -inf to 0.86753, saving model to /kaggle/working/model/ResNet50Trans.ckpt 131/131 [==============================] - 76s 550ms/step - loss: 0.6163 - accuracy: 0.7250 - auc: 0.7551 - val_loss: 0.5485 - val_accuracy: 0.7326 - val_auc: 0.8675 Epoch 2/40 131/131 [==============================] - ETA: 0s - loss: 0.5344 - accuracy: 0.7462 - auc: 0.8092 Epoch 2: val_auc improved from 0.86753 to 0.88226, saving model to /kaggle/working/model/ResNet50Trans.ckpt 131/131 [==============================] - 69s 530ms/step - loss: 0.5344 - accuracy: 0.7462 - auc: 0.8092 - val_loss: 0.5178 - val_accuracy: 0.7326 - val_auc: 0.8823 Epoch 3/40 131/131 [==============================] - ETA: 0s - loss: 0.4697 - accuracy: 0.7730 - auc: 0.8585 Epoch 3: val_auc improved from 0.88226 to 0.88246, saving model to /kaggle/working/model/ResNet50Trans.ckpt 131/131 [==============================] - 70s 536ms/step - loss: 0.4697 - accuracy: 0.7730 - auc: 0.8585 - val_loss: 0.4488 - val_accuracy: 0.7612 - val_auc: 0.8825 Epoch 4/40 131/131 [==============================] - ETA: 0s - loss: 0.4384 - accuracy: 0.7771 - auc: 0.8749 Epoch 4: val_auc improved from 0.88246 to 0.89494, saving model to /kaggle/working/model/ResNet50Trans.ckpt 131/131 [==============================] - 69s 531ms/step - loss: 0.4384 - accuracy: 0.7771 - auc: 0.8749 - val_loss: 0.4043 - val_accuracy: 0.7947 - val_auc: 0.8949 Epoch 5/40 131/131 [==============================] - ETA: 0s - loss: 0.4407 - accuracy: 0.7795 - auc: 0.8732 Epoch 5: val_auc did not improve from 0.89494 131/131 [==============================] - 51s 390ms/step - loss: 0.4407 - accuracy: 0.7795 - auc: 0.8732 - val_loss: 0.4147 - val_accuracy: 0.7937 - val_auc: 0.8945 Epoch 6/40 131/131 [==============================] - ETA: 0s - loss: 0.4329 - accuracy: 0.7845 - auc: 0.8787 Epoch 6: val_auc improved from 0.89494 to 0.91911, saving model to /kaggle/working/model/ResNet50Trans.ckpt 131/131 [==============================] - 70s 532ms/step - loss: 0.4329 - accuracy: 0.7845 - auc: 0.8787 - val_loss: 0.4018 - val_accuracy: 0.8185 - val_auc: 0.9191 Epoch 7/40 131/131 [==============================] - ETA: 0s - loss: 0.4281 - accuracy: 0.7720 - auc: 0.8767 Epoch 7: val_auc did not improve from 0.91911 131/131 [==============================] - 52s 395ms/step - loss: 0.4281 - accuracy: 0.7720 - auc: 0.8767 - val_loss: 0.4607 - val_accuracy: 0.7326 - val_auc: 0.8788 Epoch 8/40 131/131 [==============================] - ETA: 0s - loss: 0.4918 - accuracy: 0.7388 - auc: 0.8300 Epoch 8: val_auc did not improve from 0.91911 131/131 [==============================] - 51s 388ms/step - loss: 0.4918 - accuracy: 0.7388 - auc: 0.8300 - val_loss: 0.4640 - val_accuracy: 0.7326 - val_auc: 0.8811 Epoch 9/40 131/131 [==============================] - ETA: 0s - loss: 0.4865 - accuracy: 0.7446 - auc: 0.8231 Epoch 9: val_auc did not improve from 0.91911 131/131 [==============================] - 52s 394ms/step - loss: 0.4865 - accuracy: 0.7446 - auc: 0.8231 - val_loss: 0.5245 - val_accuracy: 0.7326 - val_auc: 0.8436 Epoch 10/40 131/131 [==============================] - ETA: 0s - loss: 0.4874 - accuracy: 0.7446 - auc: 0.8295 Epoch 10: val_auc did not improve from 0.91911 131/131 [==============================] - 52s 394ms/step - loss: 0.4874 - accuracy: 0.7446 - auc: 0.8295 - val_loss: 0.4076 - val_accuracy: 0.7326 - val_auc: 0.8829 Epoch 11/40 131/131 [==============================] - ETA: 0s - loss: 0.4764 - accuracy: 0.7448 - auc: 0.8345 Epoch 11: val_auc did not improve from 0.91911 131/131 [==============================] - 52s 393ms/step - loss: 0.4764 - accuracy: 0.7448 - auc: 0.8345 - val_loss: 0.4205 - val_accuracy: 0.7326 - val_auc: 0.8807 Epoch 12/40 131/131 [==============================] - ETA: 0s - loss: 0.4721 - accuracy: 0.7446 - auc: 0.8337 Epoch 12: val_auc did not improve from 0.91911 131/131 [==============================] - 51s 389ms/step - loss: 0.4721 - accuracy: 0.7446 - auc: 0.8337 - val_loss: 0.4012 - val_accuracy: 0.7326 - val_auc: 0.8916 Epoch 13/40 131/131 [==============================] - ETA: 0s - loss: 0.5537 - accuracy: 0.7448 - auc: 0.7788 Epoch 13: val_auc did not improve from 0.91911 131/131 [==============================] - 51s 392ms/step - loss: 0.5537 - accuracy: 0.7448 - auc: 0.7788 - val_loss: 0.5454 - val_accuracy: 0.7326 - val_auc: 0.8753 Epoch 14/40 131/131 [==============================] - ETA: 0s - loss: 0.5676 - accuracy: 0.7453 - auc: 0.7499 Epoch 14: val_auc did not improve from 0.91911 131/131 [==============================] - 51s 391ms/step - loss: 0.5676 - accuracy: 0.7453 - auc: 0.7499 - val_loss: 0.5817 - val_accuracy: 0.7326 - val_auc: 0.7326 Epoch 15/40 131/131 [==============================] - ETA: 0s - loss: 0.5687 - accuracy: 0.7448 - auc: 0.7436 Epoch 15: val_auc did not improve from 0.91911 131/131 [==============================] - 51s 392ms/step - loss: 0.5687 - accuracy: 0.7448 - auc: 0.7436 - val_loss: 0.5477 - val_accuracy: 0.7335 - val_auc: 0.8862 Epoch 16/40 131/131 [==============================] - ETA: 0s - loss: 0.5077 - accuracy: 0.7740 - auc: 0.8281 Epoch 16: val_auc did not improve from 0.91911 131/131 [==============================] - 51s 388ms/step - loss: 0.5077 - accuracy: 0.7740 - auc: 0.8281 - val_loss: 0.5728 - val_accuracy: 0.7364 - val_auc: 0.8177 --- Training history -------------------------------------------------
--- Test Predictions and Metrics -------------------------------------
precision recall f1-score support NORMAL 0.70 0.78 0.74 234 PNEUMONIA 0.86 0.80 0.83 390 accuracy 0.79 624 macro avg 0.78 0.79 0.78 624 weighted avg 0.80 0.79 0.80 624 === MODEL EVALUATION FINISHED ========================================
Příště bych rád zabrousil do poněkud jiné oblasti, a sice využití transformer modelů pro klasifikaci obrázků.
pracuje na pozici IT architekta. Poslední roky se zaměřuje na integrační a komunikační projekty ve zdravotnictví. Mezi jeho koníčky patří také paragliding a jízda na horském kole.
Přečteno 25 727×
Přečteno 25 724×
Přečteno 25 391×
Přečteno 23 615×
Přečteno 19 355×