Tento článek navazuje na předchozí: Klasifikace Elektrokardiogramu (EKG) – díl první.
Co ode mně můžete v této kapitole očekávat?
Nejdříve se budu zabývat modely pro každou vstupní sadu (X, Y a S) samostatně. V případě datové sady X to budou plně propojené neuronové vrstvy. U datové sady Y, jak již vyplývá z podstaty EKG křivek, to budou 1D konvoluční sítě. No a v případě datové sady S pak oddělené 2D konvoluční vrstvy pro každý svod zvláště.
Dále se pak pokusím propojit tyto modely dohromady s cílem zlepšení celkových výsledků klasifikačního modelu.
Dříve, než se pustím do samostatných modelů, potřebuji si připravit nějaké funkce pro vyhodnocování výsledků:
In [19]:
def model_quality_reporting(test, pred, *, model_name='model', labels=("NORM", "MI", "STTC", "CD", "HYP")): def print_confusion_matrix(confusion_matrix, axes, class_label, class_names, fontsize=14): df_cm = pd.DataFrame(confusion_matrix, index=class_names, columns=class_names) try: heatmap = sns.heatmap(df_cm, annot=True, fmt="d", cbar=False, ax=axes) except ValueError: raise ValueError("Confusion matrix values must be integers.") heatmap.yaxis.set_ticklabels(heatmap.yaxis.get_ticklabels(), rotation=0, ha='right', fontsize=fontsize) heatmap.xaxis.set_ticklabels(heatmap.xaxis.get_ticklabels(), rotation=45, ha='right', fontsize=fontsize) axes.set_ylabel('True label') axes.set_xlabel('Predicted label') axes.set_title("Class - " + class_label) fig, ax = plt.subplots(1, 5, figsize=(16, 3)) for axes, cfs_matrix, label in zip(ax.flatten(), sklearn.metrics.multilabel_confusion_matrix(test, pred), labels): print_confusion_matrix(cfs_matrix, axes, label, ["N", "Y"]) fig.tight_layout() plt.show() print(sklearn.metrics.classification_report(test, pred, target_names=labels, zero_division=0)) m = sklearn.metrics.classification_report(test, pred, target_names=labels, output_dict=True, zero_division=False) m = {(model_name, k): m[k] for k in labels} return pd.DataFrame.from_dict(m, orient='index') metrics = pd.DataFrame(columns=("precision", "recall", "f1-score", "support"), index=pd.MultiIndex.from_tuples([], names=('model', 'group')))
Funkcemodel_quality_report
mně zajisti vykreslení tzv. confusion matrix (do překladu se pouštět nebudu). Dále bude vypsána také detailní klasifikační zpráva, kde mne budou zajímat především metriky Precision, Recall a F1-score pro všechny klasifikační třídy. No a klasifikační zpráva je také návratovou hodnotou funkce, což budu používat pro jejich uschování a finální porovnání výkonu jednotlivých modelů. Celkové výsledky se budou agregovat v proměnné metrics
.
Ten jsem v úvodu nezmiňoval, ale vždy je dobré si nějaký udělat. Budu mít alespoň odrazový můstek pro porovnání, zda se mé modely skutečně něco naučily.
Za referenční model použiji náhodně generované třídy s pravděpodobností odpovídající zastoupení tříd v trénovací sadě Z_train
.
In [20]:
Z_prob = Z_train.sum(axis=0) / Z_train.shape[0] Z_pred = np.random.uniform(size=Z_test.shape) for i in range(Z_pred.shape[-1]): Z_pred[:, i] = (Z_pred[:, i] < Z_prob[i]).astype('float64')
A nyní vyhodnocení takového modelu:
In [21]:
metrics = pd.concat([metrics.drop('REF_model', level=0, errors='ignore'), model_quality_reporting(Z_test, Z_pred, model_name='REF_model')])
precision recall f1-score support NORM 0.45 0.45 0.45 964 MI 0.26 0.27 0.27 553 STTC 0.23 0.23 0.23 523 CD 0.20 0.20 0.20 498 HYP 0.15 0.13 0.14 263 micro avg 0.30 0.30 0.30 2801 macro avg 0.26 0.26 0.26 2801 weighted avg 0.30 0.30 0.30 2801 samples avg 0.24 0.32 0.25 2801
Z váše uvedeného je zřejmé, že model moc úspěšný v rozpoznávání není.
Připravím si funkci pro generování modelu datové sady X:
In [22]:
def create_X_model(X, *, units=32, dropouts=0.3): X = keras.layers.Normalization(axis=-1, name='X_norm')(X) X = keras.layers.Dense(units, activation='relu', name='X_dense_1')(X) X = keras.layers.Dropout(dropouts, name='X_drop_1')(X) X = keras.layers.Dense(units, activation='relu', name='X_dense_2')(X) X = keras.layers.Dropout(dropouts, name='X_drop_2')(X) return X
Funkce vytvoří dvě plně propojené vrstvy s aktivační funkcí RELU. Doplnil jsem ještě dvě vrstvyDropout
pro potlačení problému se zmenšováním gradientu (gradient descent). Toto vše ale předchází ještě jedna vrstva Normalization
, která zajistí normalizaci vstupních dat podle poslední osy v datovém setu.
In [23]:
def create_model01(X_shape, Z_shape): X_inputs = keras.Input(X_shape[1:], name='X_inputs') X = create_X_model(X_inputs, units=64) X = keras.layers.Dense(64, activation='relu', name='Z_dense_1')(X) X = keras.layers.Dense(64, activation='relu', name='Z_dense_2')(X) X = keras.layers.Dropout(0.5, name='Z_drop_1')(X) outputs = keras.layers.Dense(Z_shape[-1], activation='sigmoid', name='Z_outputs')(X) model = keras.Model(inputs=X_inputs, outputs=outputs, name='X_model') return model
Celý model pak vytvořím tak, že zavolám funkcicreate_X_model
a doplním výsledek o další dvě plně propojené vrstvy včetně regulace. Výstupem je pak poslední vrstva, která má počet neuronů shodný s počtem klasifikačních tříd (v mém případě je to 5). Aktivační funkcí je v tomto případě Sigmoid
, takže výsledkem bude pravděpodobnost zařazení vzorku do dané skupiny (jen připomínám, že vzorek může být zařazen do více jak jedné skupiny).
Nyní již můžu model přeložit.
Za zmínku ještě stojí použití binární entropie jako loss function, tedy to, co budu optimalizovat. Jako sledované metriky budu pro trénování používat binary
accuracy
,precision
a recall
.
In [24]:
model01 = create_model01(X_train.shape, Z_train.shape) model01.compile(optimizer='adam', loss='binary_crossentropy', metrics=['binary_accuracy', 'Precision', 'Recall']) model01.summary() Model: "X_model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= X_inputs (InputLayer) [(None, 7)] 0 X_norm (Normalization) (None, 7) 15 X_dense_1 (Dense) (None, 64) 512 X_drop_1 (Dropout) (None, 64) 0 X_dense_2 (Dense) (None, 64) 4160 X_drop_2 (Dropout) (None, 64) 0 Z_dense_1 (Dense) (None, 64) 4160 Z_dense_2 (Dense) (None, 64) 4160 Z_drop_1 (Dropout) (None, 64) 0 Z_outputs (Dense) (None, 5) 325 ================================================================= Total params: 13332 (52.08 KB) Trainable params: 13317 (52.02 KB) Non-trainable params: 15 (64.00 Byte) _________________________________________________________________
Pro trénování budu používat early stopping, což mně zajistí zastavení trénování v okamžiku, kdy se mně začnou zhoršovat výsledky loss function na validační sadě. Váhy modelu s nejlepšími výsledky si budu průběžně schovávat do souboru tak, bych si je mohl po trénování zpětně načíst.
In [25]:
MODEL_CHECKPOINT = '/kaggle/working/model/model01.ckpt' callbacks_list = [ keras.callbacks.EarlyStopping(monitor='val_binary_accuracy', patience=10), keras.callbacks.ModelCheckpoint(filepath=MODEL_CHECKPOINT, monitor='val_binary_accuracy', save_best_only=True) ] history = model01.fit(X_train, Z_train, epochs=40, batch_size=32, callbacks=callbacks_list, validation_data=(X_valid, Z_valid)) model01 = keras.models.load_model(MODEL_CHECKPOINT) model01.save('/kaggle/working/model/model01.keras') Epoch 1/40 546/546 [==============================] - 10s 8ms/step - loss: 0.6957 - binary_accuracy: 0.7139 - precision: 0.3742 - recall: 0.1823 - val_loss: 0.5378 - val_binary_accuracy: 0.7691 - val_precision: 0.6978 - val_recall: 0.1705 Epoch 2/40 546/546 [==============================] - 3s 5ms/step - loss: 0.5456 - binary_accuracy: 0.7556 - precision: 0.6127 - recall: 0.1110 - val_loss: 0.5235 - val_binary_accuracy: 0.7668 - val_precision: 0.7261 - val_recall: 0.1409 Epoch 3/40 546/546 [==============================] - 3s 5ms/step - loss: 0.5307 - binary_accuracy: 0.7616 - precision: 0.6862 - recall: 0.1188 - val_loss: 0.5134 - val_binary_accuracy: 0.7689 - val_precision: 0.6845 - val_recall: 0.1780 Epoch 4/40 546/546 [==============================] - 4s 7ms/step - loss: 0.5200 - binary_accuracy: 0.7673 - precision: 0.7002 - recall: 0.1521 - val_loss: 0.5083 - val_binary_accuracy: 0.7788 - val_precision: 0.7194 - val_recall: 0.2205 Epoch 5/40 546/546 [==============================] - 4s 7ms/step - loss: 0.5092 - binary_accuracy: 0.7729 - precision: 0.7198 - recall: 0.1779 - val_loss: 0.4921 - val_binary_accuracy: 0.7808 - val_precision: 0.7412 - val_recall: 0.2187 Epoch 6/40 546/546 [==============================] - 4s 7ms/step - loss: 0.4979 - binary_accuracy: 0.7786 - precision: 0.7283 - recall: 0.2092 - val_loss: 0.4868 - val_binary_accuracy: 0.7874 - val_precision: 0.7511 - val_recall: 0.2519 Epoch 7/40 546/546 [==============================] - 4s 8ms/step - loss: 0.4883 - binary_accuracy: 0.7837 - precision: 0.7276 - recall: 0.2418 - val_loss: 0.4808 - val_binary_accuracy: 0.7919 - val_precision: 0.7225 - val_recall: 0.3018 Epoch 8/40 546/546 [==============================] - 3s 5ms/step - loss: 0.4826 - binary_accuracy: 0.7880 - precision: 0.7391 - recall: 0.2600 - val_loss: 0.4768 - val_binary_accuracy: 0.7915 - val_precision: 0.7258 - val_recall: 0.2965 Epoch 9/40 546/546 [==============================] - 4s 7ms/step - loss: 0.4808 - binary_accuracy: 0.7886 - precision: 0.7424 - recall: 0.2610 - val_loss: 0.4764 - val_binary_accuracy: 0.7921 - val_precision: 0.7280 - val_recall: 0.2979 Epoch 10/40 546/546 [==============================] - 4s 7ms/step - loss: 0.4798 - binary_accuracy: 0.7888 - precision: 0.7503 - recall: 0.2565 - val_loss: 0.4772 - val_binary_accuracy: 0.7929 - val_precision: 0.7445 - val_recall: 0.2890 Epoch 11/40 546/546 [==============================] - 4s 7ms/step - loss: 0.4782 - binary_accuracy: 0.7902 - precision: 0.7456 - recall: 0.2683 - val_loss: 0.4735 - val_binary_accuracy: 0.7937 - val_precision: 0.7422 - val_recall: 0.2958 Epoch 12/40 546/546 [==============================] - 3s 5ms/step - loss: 0.4769 - binary_accuracy: 0.7905 - precision: 0.7545 - recall: 0.2638 - val_loss: 0.4784 - val_binary_accuracy: 0.7882 - val_precision: 0.6823 - val_recall: 0.3211 Epoch 13/40 546/546 [==============================] - 3s 5ms/step - loss: 0.4762 - binary_accuracy: 0.7909 - precision: 0.7542 - recall: 0.2664 - val_loss: 0.4739 - val_binary_accuracy: 0.7920 - val_precision: 0.7186 - val_recall: 0.3061 Epoch 14/40 546/546 [==============================] - 4s 7ms/step - loss: 0.4747 - binary_accuracy: 0.7910 - precision: 0.7520 - recall: 0.2680 - val_loss: 0.4722 - val_binary_accuracy: 0.7939 - val_precision: 0.7512 - val_recall: 0.2897 Epoch 15/40 546/546 [==============================] - 4s 7ms/step - loss: 0.4735 - binary_accuracy: 0.7911 - precision: 0.7537 - recall: 0.2679 - val_loss: 0.4720 - val_binary_accuracy: 0.7944 - val_precision: 0.7577 - val_recall: 0.2879 Epoch 16/40 546/546 [==============================] - 3s 5ms/step - loss: 0.4731 - binary_accuracy: 0.7917 - precision: 0.7621 - recall: 0.2655 - val_loss: 0.4733 - val_binary_accuracy: 0.7907 - val_precision: 0.7089 - val_recall: 0.3075 Epoch 17/40 546/546 [==============================] - 3s 5ms/step - loss: 0.4726 - binary_accuracy: 0.7923 - precision: 0.7671 - recall: 0.2654 - val_loss: 0.4723 - val_binary_accuracy: 0.7940 - val_precision: 0.7491 - val_recall: 0.2918 Epoch 18/40 546/546 [==============================] - 3s 5ms/step - loss: 0.4729 - binary_accuracy: 0.7917 - precision: 0.7614 - recall: 0.2661 - val_loss: 0.4747 - val_binary_accuracy: 0.7940 - val_precision: 0.7641 - val_recall: 0.2808 Epoch 19/40 546/546 [==============================] - 3s 5ms/step - loss: 0.4722 - binary_accuracy: 0.7923 - precision: 0.7646 - recall: 0.2673 - val_loss: 0.4708 - val_binary_accuracy: 0.7918 - val_precision: 0.7230 - val_recall: 0.3007 Epoch 20/40 546/546 [==============================] - 3s 5ms/step - loss: 0.4709 - binary_accuracy: 0.7922 - precision: 0.7685 - recall: 0.2640 - val_loss: 0.4731 - val_binary_accuracy: 0.7920 - val_precision: 0.7208 - val_recall: 0.3040 Epoch 21/40 546/546 [==============================] - 3s 5ms/step - loss: 0.4719 - binary_accuracy: 0.7920 - precision: 0.7659 - recall: 0.2648 - val_loss: 0.4694 - val_binary_accuracy: 0.7934 - val_precision: 0.7441 - val_recall: 0.2925 Epoch 22/40 546/546 [==============================] - 3s 5ms/step - loss: 0.4727 - binary_accuracy: 0.7918 - precision: 0.7676 - recall: 0.2626 - val_loss: 0.4697 - val_binary_accuracy: 0.7927 - val_precision: 0.7413 - val_recall: 0.2904 Epoch 23/40 546/546 [==============================] - 3s 5ms/step - loss: 0.4719 - binary_accuracy: 0.7917 - precision: 0.7592 - recall: 0.2677 - val_loss: 0.4730 - val_binary_accuracy: 0.7937 - val_precision: 0.7379 - val_recall: 0.2993 Epoch 24/40 546/546 [==============================] - 3s 5ms/step - loss: 0.4705 - binary_accuracy: 0.7920 - precision: 0.7622 - recall: 0.2668 - val_loss: 0.4693 - val_binary_accuracy: 0.7936 - val_precision: 0.7533 - val_recall: 0.2865 Epoch 25/40 546/546 [==============================] - 3s 5ms/step - loss: 0.4703 - binary_accuracy: 0.7924 - precision: 0.7683 - recall: 0.2655 - val_loss: 0.4697 - val_binary_accuracy: 0.7938 - val_precision: 0.7576 - val_recall: 0.2843
A nyní si můžu zobrazit průběh trénování:
In [26]:
sns.relplot(data=pd.DataFrame(history.history), kind='line', height=4, aspect=4)
plt.show()
Posledním krokem v této části bude vytvoření predikce z testovací sady a její detailní vyhodnocení:
In [27]:
Z_pred = np.around(model01.predict(X_test, verbose=0))
metrics = pd.concat([metrics.drop(model01.name, level=0, errors='ignore'), model_quality_reporting(Z_test, Z_pred, model_name=model01.name)])
precision recall f1-score support NORM 0.71 0.57 0.64 964 MI 0.97 0.40 0.56 553 STTC 0.00 0.00 0.00 523 CD 0.00 0.00 0.00 498 HYP 0.00 0.00 0.00 263 micro avg 0.77 0.28 0.41 2801 macro avg 0.34 0.19 0.24 2801 weighted avg 0.44 0.28 0.33 2801 samples avg 0.35 0.32 0.33 2801
Z výsledku je zřejmé, že model je schopen podchytit pouze nejčastěji zastoupené třídy v datech. Ostatní pak není schopen zachytit vůbec.
Opět připravím funkci, která mně vytvoří model specifický pro datovou sadu Y:
In [28]:
def create_Y_model(X, *, filters=(32, 64, 128), kernel_size=(5, 3, 3), strides=(1, 1, 1)): f1, f2, f3 = filters k1, k2, k3 = kernel_size s1, s2, s3 = strides X = keras.layers.Normalization(axis=-1, name='Y_norm')(X) X = keras.layers.Conv1D(f1, k1, strides=s1, padding='same', name='Y_conv_1')(X) X = keras.layers.BatchNormalization(name='Y_norm_1')(X) X = keras.layers.ReLU(name='Y_relu_1')(X) X = keras.layers.MaxPool1D(2, name='Y_pool_1')(X) X = keras.layers.Conv1D(f2, k2, strides=s2, padding='same', name='Y_conv_2')(X) X = keras.layers.BatchNormalization(name='Y_norm_2')(X) X = keras.layers.ReLU(name='Y_relu_2')(X) X = keras.layers.MaxPool1D(2, name='Y_pool_2')(X) X = keras.layers.Conv1D(f3, k3, strides=s3, padding='same', name='Y_conv_3')(X) X = keras.layers.BatchNormalization(name='Y_norm_3')(X) X = keras.layers.ReLU(name='Y_relu_3')(X) X = keras.layers.GlobalAveragePooling1D(name='Y_aver')(X) X = keras.layers.Dropout(0.5, name='Y_drop')(X) return X
Funkce obsahuje tři samostatné bloky pro 1D konvoluční vrstvy. V parametrech funkce se zadávají hodnoty pro počet filtrů, velikost kernelu a krok každé vrstvy. V každé vrstvě se následně provede normalizace a doplní aktivační funkce RELU.
Konvoluční vrstvy jsou proloženy agregačními vrstvami MaxPool1D
, které zajišťují redukci v časové ose na polovinu.
Poslední vrstvou je GlobalAveragePooling1D
, která provede úplné potlačení časové dimenze, a zůstane pouze dimenze vlastností (features).
Všechny výše uvedené vrstvy předchází ještě jedna vrstva určená pro normalizaci vstupních dat, tedy vrstva Normalization
.
To, jak se dimenze na jednotlivých vrstvách transformují, bude asi lépe vidět na výpisu modelu po překladu. Což bude o dva kroky dále.
A nyní ještě funkce pro vytvoření celého modelu:
In [29]:
def create_model02(Y_shape, Z_shape): Y_inputs = keras.Input(Y_shape[1:], name='Y_inputs') X = create_Y_model(Y_inputs, filters=(64, 128, 256), kernel_size=(7, 3, 3)) X = keras.layers.Dense(128, activation='relu', name='Z_dense_1')(X) X = keras.layers.Dense(128, activation='relu', name='Z_dense_2')(X) X = keras.layers.Dropout(0.5, name='Z_drop_1')(X) outputs = keras.layers.Dense(Z_shape[-1], activation='sigmoid', name='Z_outputs')(X) model = keras.Model(inputs=Y_inputs, outputs=outputs, name='Y_model') return model
Princip je stejný jak tomu bylo u předchozího modelu, proto jen nějaké poznámky:
počty filtrů pro konvoluční vrstvy jsem použil 64, 128 a 256
velikosti kernelů pak 7 pro první vrstvu a 3 pro další dvě
krok zůstal u všech vrstev stejný, tedy 1
No a nyní tedy překlad modelu:
In [30]:
model02 = create_model02(Y_train.shape, Z_train.shape) model02.compile(optimizer='adam', loss='binary_crossentropy', metrics=['binary_accuracy', 'Precision', 'Recall']) model02.summary() Model: "Y_model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= Y_inputs (InputLayer) [(None, 1000, 12)] 0 Y_norm (Normalization) (None, 1000, 12) 25 Y_conv_1 (Conv1D) (None, 1000, 64) 5440 Y_norm_1 (BatchNormalizati (None, 1000, 64) 256 on) Y_relu_1 (ReLU) (None, 1000, 64) 0 Y_pool_1 (MaxPooling1D) (None, 500, 64) 0 Y_conv_2 (Conv1D) (None, 500, 128) 24704 Y_norm_2 (BatchNormalizati (None, 500, 128) 512 on) Y_relu_2 (ReLU) (None, 500, 128) 0 Y_pool_2 (MaxPooling1D) (None, 250, 128) 0 Y_conv_3 (Conv1D) (None, 250, 256) 98560 Y_norm_3 (BatchNormalizati (None, 250, 256) 1024 on) Y_relu_3 (ReLU) (None, 250, 256) 0 Y_aver (GlobalAveragePooli (None, 256) 0 ng1D) Y_drop (Dropout) (None, 256) 0 Z_dense_1 (Dense) (None, 128) 32896 Z_dense_2 (Dense) (None, 128) 16512 Z_drop_1 (Dropout) (None, 128) 0 Z_outputs (Dense) (None, 5) 645 ================================================================= Total params: 180574 (705.37 KB) Trainable params: 179653 (701.77 KB) Non-trainable params: 921 (3.60 KB) _________________________________________________________________
Ve výpisu je vidět, jak se poslední dimenze ze vstupních dat (12 svodů) postupně mění na 64, 128 a 256 features. Současně je také vidět redukce časové dimenze (to je ta druhá) z původních 1000 na 500 a 250 prostřednictvím vrstvy MaxPooling1D. Až nakonec úplné potlačení časové dimenze s využitím GlobalAveragePooling1D.
Výstupem modelu je opět 5 pravděpodobností pro každou klasifikační třídu.
Můžu tedy přistoupit ke trénování:
In [31]:
MODEL_CHECKPOINT = '/kaggle/working/model/model02.ckpt' callbacks_list = [ keras.callbacks.EarlyStopping(monitor='val_binary_accuracy', patience=10), keras.callbacks.ModelCheckpoint(filepath=MODEL_CHECKPOINT, monitor='val_binary_accuracy', save_best_only=True) ] history = model02.fit(Y_train, Z_train, epochs=100, batch_size=32, callbacks=callbacks_list, validation_data=(Y_valid, Z_valid)) model02 = keras.models.load_model(MODEL_CHECKPOINT) model02.save('/kaggle/working/model/model02.keras') Epoch 1/100 546/546 [==============================] - 17s 16ms/step - loss: 0.4157 - binary_accuracy: 0.8152 - precision: 0.6886 - recall: 0.5013 - val_loss: 0.3688 - val_binary_accuracy: 0.8437 - val_precision: 0.7408 - val_recall: 0.5976 Epoch 2/100 546/546 [==============================] - 8s 14ms/step - loss: 0.3537 - binary_accuracy: 0.8499 - precision: 0.7489 - recall: 0.6186 - val_loss: 0.3546 - val_binary_accuracy: 0.8563 - val_precision: 0.7614 - val_recall: 0.6375 Epoch 3/100 546/546 [==============================] - 6s 10ms/step - loss: 0.3339 - binary_accuracy: 0.8613 - precision: 0.7690 - recall: 0.6515 - val_loss: 0.3528 - val_binary_accuracy: 0.8453 - val_precision: 0.7444 - val_recall: 0.6015 Epoch 4/100 546/546 [==============================] - 8s 15ms/step - loss: 0.3215 - binary_accuracy: 0.8685 - precision: 0.7824 - recall: 0.6704 - val_loss: 0.3236 - val_binary_accuracy: 0.8633 - val_precision: 0.7832 - val_recall: 0.6432 Epoch 5/100 546/546 [==============================] - 8s 14ms/step - loss: 0.3106 - binary_accuracy: 0.8721 - precision: 0.7903 - recall: 0.6782 - val_loss: 0.3331 - val_binary_accuracy: 0.8689 - val_precision: 0.7725 - val_recall: 0.6907 Epoch 6/100 546/546 [==============================] - 6s 10ms/step - loss: 0.3061 - binary_accuracy: 0.8750 - precision: 0.7927 - recall: 0.6899 - val_loss: 0.3175 - val_binary_accuracy: 0.8620 - val_precision: 0.7597 - val_recall: 0.6732 Epoch 7/100 546/546 [==============================] - 8s 15ms/step - loss: 0.2968 - binary_accuracy: 0.8795 - precision: 0.8016 - recall: 0.7005 - val_loss: 0.3174 - val_binary_accuracy: 0.8718 - val_precision: 0.7786 - val_recall: 0.6964 Epoch 8/100 546/546 [==============================] - 8s 16ms/step - loss: 0.2917 - binary_accuracy: 0.8821 - precision: 0.8062 - recall: 0.7075 - val_loss: 0.2889 - val_binary_accuracy: 0.8794 - val_precision: 0.8079 - val_recall: 0.6932 Epoch 9/100 546/546 [==============================] - 6s 10ms/step - loss: 0.2880 - binary_accuracy: 0.8827 - precision: 0.8079 - recall: 0.7078 - val_loss: 0.3355 - val_binary_accuracy: 0.8731 - val_precision: 0.7742 - val_recall: 0.7107 Epoch 10/100 546/546 [==============================] - 8s 14ms/step - loss: 0.2834 - binary_accuracy: 0.8851 - precision: 0.8103 - recall: 0.7170 - val_loss: 0.2798 - val_binary_accuracy: 0.8830 - val_precision: 0.8069 - val_recall: 0.7128 Epoch 11/100 546/546 [==============================] - 8s 15ms/step - loss: 0.2783 - binary_accuracy: 0.8876 - precision: 0.8161 - recall: 0.7214 - val_loss: 0.2862 - val_binary_accuracy: 0.8835 - val_precision: 0.8037 - val_recall: 0.7203 Epoch 12/100 546/546 [==============================] - 6s 11ms/step - loss: 0.2773 - binary_accuracy: 0.8882 - precision: 0.8158 - recall: 0.7249 - val_loss: 0.2972 - val_binary_accuracy: 0.8752 - val_precision: 0.7935 - val_recall: 0.6921 Epoch 13/100 546/546 [==============================] - 6s 10ms/step - loss: 0.2752 - binary_accuracy: 0.8896 - precision: 0.8193 - recall: 0.7272 - val_loss: 0.2860 - val_binary_accuracy: 0.8771 - val_precision: 0.7963 - val_recall: 0.6975 Epoch 14/100 546/546 [==============================] - 6s 10ms/step - loss: 0.2698 - binary_accuracy: 0.8911 - precision: 0.8235 - recall: 0.7290 - val_loss: 0.3070 - val_binary_accuracy: 0.8731 - val_precision: 0.7946 - val_recall: 0.6789 Epoch 15/100 546/546 [==============================] - 6s 10ms/step - loss: 0.2693 - binary_accuracy: 0.8918 - precision: 0.8259 - recall: 0.7292 - val_loss: 0.2906 - val_binary_accuracy: 0.8819 - val_precision: 0.8016 - val_recall: 0.7149 Epoch 16/100 546/546 [==============================] - 8s 14ms/step - loss: 0.2658 - binary_accuracy: 0.8926 - precision: 0.8278 - recall: 0.7303 - val_loss: 0.2797 - val_binary_accuracy: 0.8865 - val_precision: 0.8129 - val_recall: 0.7221 Epoch 17/100 546/546 [==============================] - 6s 11ms/step - loss: 0.2632 - binary_accuracy: 0.8933 - precision: 0.8286 - recall: 0.7329 - val_loss: 0.2784 - val_binary_accuracy: 0.8855 - val_precision: 0.8089 - val_recall: 0.7232 Epoch 18/100 546/546 [==============================] - 6s 10ms/step - loss: 0.2631 - binary_accuracy: 0.8946 - precision: 0.8314 - recall: 0.7358 - val_loss: 0.2824 - val_binary_accuracy: 0.8822 - val_precision: 0.8018 - val_recall: 0.7160 Epoch 19/100 546/546 [==============================] - 6s 10ms/step - loss: 0.2608 - binary_accuracy: 0.8958 - precision: 0.8330 - recall: 0.7391 - val_loss: 0.2848 - val_binary_accuracy: 0.8831 - val_precision: 0.7922 - val_recall: 0.7356 Epoch 20/100 546/546 [==============================] - 6s 10ms/step - loss: 0.2571 - binary_accuracy: 0.8959 - precision: 0.8349 - recall: 0.7371 - val_loss: 0.2756 - val_binary_accuracy: 0.8853 - val_precision: 0.8045 - val_recall: 0.7281 Epoch 21/100 546/546 [==============================] - 6s 10ms/step - loss: 0.2572 - binary_accuracy: 0.8969 - precision: 0.8318 - recall: 0.7464 - val_loss: 0.2905 - val_binary_accuracy: 0.8821 - val_precision: 0.8052 - val_recall: 0.7107 Epoch 22/100 546/546 [==============================] - 8s 15ms/step - loss: 0.2544 - binary_accuracy: 0.8984 - precision: 0.8368 - recall: 0.7470 - val_loss: 0.2740 - val_binary_accuracy: 0.8872 - val_precision: 0.8083 - val_recall: 0.7324 Epoch 23/100 546/546 [==============================] - 5s 9ms/step - loss: 0.2523 - binary_accuracy: 0.8983 - precision: 0.8371 - recall: 0.7461 - val_loss: 0.2771 - val_binary_accuracy: 0.8861 - val_precision: 0.7966 - val_recall: 0.7446 Epoch 24/100 546/546 [==============================] - 5s 9ms/step - loss: 0.2516 - binary_accuracy: 0.8992 - precision: 0.8389 - recall: 0.7480 - val_loss: 0.2770 - val_binary_accuracy: 0.8859 - val_precision: 0.8099 - val_recall: 0.7235 Epoch 25/100 546/546 [==============================] - 5s 9ms/step - loss: 0.2488 - binary_accuracy: 0.9002 - precision: 0.8404 - recall: 0.7509 - val_loss: 0.2826 - val_binary_accuracy: 0.8804 - val_precision: 0.8030 - val_recall: 0.7053 Epoch 26/100 546/546 [==============================] - 5s 10ms/step - loss: 0.2475 - binary_accuracy: 0.8999 - precision: 0.8381 - recall: 0.7525 - val_loss: 0.2775 - val_binary_accuracy: 0.8842 - val_precision: 0.7967 - val_recall: 0.7342 Epoch 27/100 546/546 [==============================] - 5s 10ms/step - loss: 0.2467 - binary_accuracy: 0.9014 - precision: 0.8441 - recall: 0.7522 - val_loss: 0.2806 - val_binary_accuracy: 0.8840 - val_precision: 0.8017 - val_recall: 0.7257 Epoch 28/100 546/546 [==============================] - 6s 10ms/step - loss: 0.2447 - binary_accuracy: 0.9019 - precision: 0.8431 - recall: 0.7556 - val_loss: 0.2843 - val_binary_accuracy: 0.8847 - val_precision: 0.8102 - val_recall: 0.7171 Epoch 29/100 546/546 [==============================] - 5s 10ms/step - loss: 0.2421 - binary_accuracy: 0.9028 - precision: 0.8460 - recall: 0.7561 - val_loss: 0.2698 - val_binary_accuracy: 0.8860 - val_precision: 0.8105 - val_recall: 0.7232 Epoch 30/100 546/546 [==============================] - 5s 10ms/step - loss: 0.2420 - binary_accuracy: 0.9019 - precision: 0.8435 - recall: 0.7553 - val_loss: 0.2918 - val_binary_accuracy: 0.8839 - val_precision: 0.8065 - val_recall: 0.7182 Epoch 31/100 546/546 [==============================] - 5s 10ms/step - loss: 0.2408 - binary_accuracy: 0.9037 - precision: 0.8470 - recall: 0.7593 - val_loss: 0.2764 - val_binary_accuracy: 0.8849 - val_precision: 0.7944 - val_recall: 0.7417 Epoch 32/100 546/546 [==============================] - 5s 10ms/step - loss: 0.2413 - binary_accuracy: 0.9039 - precision: 0.8490 - recall: 0.7574 - val_loss: 0.2796 - val_binary_accuracy: 0.8853 - val_precision: 0.7929 - val_recall: 0.7460
Takto vypadal průběh trénování z pohledu optimalizované funkce a metrik, a to jak pro trénovací tak validační sadu:
In [32]:
sns.relplot(data=pd.DataFrame(history.history), kind='line', height=4, aspect=4)
plt.show()
Nyní již můžu vyhodnotit výsledky modelu proti testovací sadě dat:
In [33]:
Z_pred = np.around(model02.predict(Y_test, verbose=0))
metrics = pd.concat([metrics.drop(model02.name, level=0, errors='ignore'), model_quality_reporting(Z_test, Z_pred, model_name=model02.name)])
precision recall f1-score support NORM 0.81 0.92 0.86 964 MI 0.81 0.62 0.70 553 STTC 0.78 0.71 0.74 523 CD 0.80 0.65 0.72 498 HYP 0.72 0.47 0.57 263 micro avg 0.80 0.73 0.76 2801 macro avg 0.78 0.67 0.72 2801 weighted avg 0.79 0.73 0.75 2801 samples avg 0.77 0.74 0.74 2801
Z Výše uvedené zprávy je zřejmé, že v tomto případě jsou výsledky již výrazně lepší. Model pokulhává především u klasifikační třídy HYP, kde metrikarecall
a tedy takéf1-score
není nic moc. Přisuzuji to především malému zastoupení vzorků s touto klasifikační třídou v celé datové sadě (viz mé poznámky při přípravě dat).
2D konvoluce potřebuje datové sady rozšířené o jednu dimenzi, proto:
In [34]:
S_train = np.expand_dims(S_train, -1) S_valid = np.expand_dims(S_valid, -1) S_test = np.expand_dims(S_test, -1) print(f"S_train={S_train.shape} S_valid={S_valid.shape} S_test={S_test.shape}") S_train=(17441, 12, 31, 34, 1) S_valid=(2193, 12, 31, 34, 1) S_test=(2203, 12, 31, 34, 1)
Opět přichází na scénu na funkce, která vytvoří konvoluční vrstvy pro přepracování spektrogramů na features.
Na tomto místě musím zmínit ještě jednu věc, a sice to, že signál z každého svodu jsem přepočítal na samostatný spektrogram. Pokud z nich budu chtít vytvořit oddělené vlastnosti, pak to budu muset udělat v oddělených větvích. Proto jsem vytvořil samostatné 2D konvoluční vrstvy pro každý svod (ve funkci cyklím podle druhé dimenze):
In [35]:
def create_S_model(X, *, filters=(32, 64, 128), kernel_size=((3, 3), (3, 3), (3, 3)), strides=((1, 1), (1, 1), (1, 1))): f1, f2, f3 = filters k1, k2, k3 = kernel_size s1, s2, s3 = strides X = keras.layers.Normalization(axis=None, name='S_norm')(X) lead_branches = [] for lead, x in enumerate(tf.split(X, X.shape[1], axis=1, name="S_split")): x = tf.squeeze(x, axis=1, name=f'S_squeeze_{lead}') x = keras.layers.Conv2D(f1, k1, strides=s1, padding='same', name=f'S_conv_{lead}_1')(x) x = keras.layers.BatchNormalization(name=f'S_norm_{lead}_1')(x) x = keras.layers.ReLU(name=f'S_relu_{lead}_1')(x) x = keras.layers.MaxPool2D((2, 2), name=f'S_pool_{lead}_1')(x) x = keras.layers.Conv2D(f2, k2, strides=s2, padding='same', name=f'S_conv_{lead}_2')(x) x = keras.layers.BatchNormalization(name=f'S_norm_{lead}_2')(x) x = keras.layers.ReLU(name=f'S_relu_{lead}_2')(x) x = keras.layers.MaxPool2D((2, 2), name=f'S_pool_{lead}_2')(x) x = keras.layers.Conv2D(f3, k3, strides=s3, padding='same', name=f'S_conv_{lead}_3')(x) x = keras.layers.BatchNormalization(name=f'S_norm_{lead}_3')(x) x = keras.layers.ReLU(name=f'S_relu_{lead}_3')(x) x = keras.layers.MaxPool2D((2, 2), name=f'S_pool_{lead}_3')(x) x = keras.layers.GlobalAveragePooling2D(name=f'S_aver_{lead}')(x) x = tf.expand_dims(x, 1, name=f'S_expand_{lead}') lead_branches.append(x) X = keras.layers.Concatenate(axis=1, name="S_concat")(lead_branches) X = keras.layers.Flatten(name="S_flatten")(X) return X
Vytvářím tři bloky konvolučních vrstev, v tomto případě ale odděleně pro každý svod. Parametrem funkce je počet vlastností v každé vrstvě, velikost kernelu a krok konvoluce. Naleznete zde také vrstvy pro redukcí dimenzí, v tomto případě jak ve frekvenční tak v časové rovině. No a na závěr zpracování každé větve pak úplné potlačení těchto dimenzí s využitímGlobalAveragePooling2D
vrstvy. Výsledek všech větví je pak vrstvouConcatenate
spojen do jedné dimenze vlastností. Celý model je uveden vrstvou Normalization
, která provede normalizaci všech spektrogramů podle jednoho společného průměru a směrodatné odchylky.
Doporučuji prohlédnou si výpis po překladu modelu. Transformace dimenzí tam bude asi přehlednější.
In [36]:
def create_model03(S_shape, Z_shape): S_inputs = keras.Input(S_shape[1:], name='S_inputs') X = create_S_model(S_inputs, filters=(32, 64, 128), kernel_size=((3, 3), (3, 3), (3, 3))) X = keras.layers.Dense(128, activation='relu', name='Z_dense_1')(X) X = keras.layers.Dense(128, activation='relu', name='Z_dense_2')(X) X = keras.layers.Dropout(0.5, name='Z_drop_1')(X) outputs = keras.layers.Dense(Z_shape[-1], activation='sigmoid', name='Z_outputs')(X) model = keras.Model(inputs=S_inputs, outputs=outputs, name='S_model') return model
Zvolil jsem narůstající počet vlastností postupně 32, 62 a 128. Kernely ve všech vrstvách jsou velikosti 3×3 hodnoty.
A nyní již může následovat překlad:
In [37]:
model03 = create_model03(S_train.shape, Z_train.shape) model03.compile(optimizer='adam', loss='binary_crossentropy', metrics=['binary_accuracy', 'Precision', 'Recall']) model03.summary() Model: "S_model" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== S_inputs (InputLayer) [(None, 12, 31, 34, 1)] 0 [] S_norm (Normalization) (None, 12, 31, 34, 1) 3 ['S_inputs[0][0]'] tf.split (TFOpLambda) [(None, 1, 31, 34, 1), 0 ['S_norm[0][0]'] (None, 1, 31, 34, 1), (None, 1, 31, 34, 1), (None, 1, 31, 34, 1), (None, 1, 31, 34, 1), (None, 1, 31, 34, 1), (None, 1, 31, 34, 1), (None, 1, 31, 34, 1), (None, 1, 31, 34, 1), (None, 1, 31, 34, 1), (None, 1, 31, 34, 1), (None, 1, 31, 34, 1)] tf.compat.v1.squeeze (TFOp (None, 31, 34, 1) 0 ['tf.split[0][0]'] Lambda) tf.compat.v1.squeeze_1 (TF (None, 31, 34, 1) 0 ['tf.split[0][1]'] OpLambda) tf.compat.v1.squeeze_2 (TF (None, 31, 34, 1) 0 ['tf.split[0][2]'] OpLambda) tf.compat.v1.squeeze_3 (TF (None, 31, 34, 1) 0 ['tf.split[0][3]'] OpLambda) tf.compat.v1.squeeze_4 (TF (None, 31, 34, 1) 0 ['tf.split[0][4]'] OpLambda) tf.compat.v1.squeeze_5 (TF (None, 31, 34, 1) 0 ['tf.split[0][5]'] OpLambda) tf.compat.v1.squeeze_6 (TF (None, 31, 34, 1) 0 ['tf.split[0][6]'] OpLambda) tf.compat.v1.squeeze_7 (TF (None, 31, 34, 1) 0 ['tf.split[0][7]'] OpLambda) tf.compat.v1.squeeze_8 (TF (None, 31, 34, 1) 0 ['tf.split[0][8]'] OpLambda) tf.compat.v1.squeeze_9 (TF (None, 31, 34, 1) 0 ['tf.split[0][9]'] OpLambda) tf.compat.v1.squeeze_10 (T (None, 31, 34, 1) 0 ['tf.split[0][10]'] FOpLambda) tf.compat.v1.squeeze_11 (T (None, 31, 34, 1) 0 ['tf.split[0][11]'] FOpLambda) S_conv_0_1 (Conv2D) (None, 31, 34, 32) 320 ['tf.compat.v1.squeeze[0][0]'] S_conv_1_1 (Conv2D) (None, 31, 34, 32) 320 ['tf.compat.v1.squeeze_1[0][0] '] S_conv_2_1 (Conv2D) (None, 31, 34, 32) 320 ['tf.compat.v1.squeeze_2[0][0] '] S_conv_3_1 (Conv2D) (None, 31, 34, 32) 320 ['tf.compat.v1.squeeze_3[0][0] '] S_conv_4_1 (Conv2D) (None, 31, 34, 32) 320 ['tf.compat.v1.squeeze_4[0][0] '] S_conv_5_1 (Conv2D) (None, 31, 34, 32) 320 ['tf.compat.v1.squeeze_5[0][0] '] S_conv_6_1 (Conv2D) (None, 31, 34, 32) 320 ['tf.compat.v1.squeeze_6[0][0] '] S_conv_7_1 (Conv2D) (None, 31, 34, 32) 320 ['tf.compat.v1.squeeze_7[0][0] '] S_conv_8_1 (Conv2D) (None, 31, 34, 32) 320 ['tf.compat.v1.squeeze_8[0][0] '] S_conv_9_1 (Conv2D) (None, 31, 34, 32) 320 ['tf.compat.v1.squeeze_9[0][0] '] S_conv_10_1 (Conv2D) (None, 31, 34, 32) 320 ['tf.compat.v1.squeeze_10[0][0 ]'] S_conv_11_1 (Conv2D) (None, 31, 34, 32) 320 ['tf.compat.v1.squeeze_11[0][0 ]'] S_norm_0_1 (BatchNormaliza (None, 31, 34, 32) 128 ['S_conv_0_1[0][0]'] tion) S_norm_1_1 (BatchNormaliza (None, 31, 34, 32) 128 ['S_conv_1_1[0][0]'] tion) S_norm_2_1 (BatchNormaliza (None, 31, 34, 32) 128 ['S_conv_2_1[0][0]'] tion) S_norm_3_1 (BatchNormaliza (None, 31, 34, 32) 128 ['S_conv_3_1[0][0]'] tion) S_norm_4_1 (BatchNormaliza (None, 31, 34, 32) 128 ['S_conv_4_1[0][0]'] tion) S_norm_5_1 (BatchNormaliza (None, 31, 34, 32) 128 ['S_conv_5_1[0][0]'] tion) S_norm_6_1 (BatchNormaliza (None, 31, 34, 32) 128 ['S_conv_6_1[0][0]'] tion) S_norm_7_1 (BatchNormaliza (None, 31, 34, 32) 128 ['S_conv_7_1[0][0]'] tion) S_norm_8_1 (BatchNormaliza (None, 31, 34, 32) 128 ['S_conv_8_1[0][0]'] tion) S_norm_9_1 (BatchNormaliza (None, 31, 34, 32) 128 ['S_conv_9_1[0][0]'] tion) S_norm_10_1 (BatchNormaliz (None, 31, 34, 32) 128 ['S_conv_10_1[0][0]'] ation) S_norm_11_1 (BatchNormaliz (None, 31, 34, 32) 128 ['S_conv_11_1[0][0]'] ation) S_relu_0_1 (ReLU) (None, 31, 34, 32) 0 ['S_norm_0_1[0][0]'] S_relu_1_1 (ReLU) (None, 31, 34, 32) 0 ['S_norm_1_1[0][0]'] S_relu_2_1 (ReLU) (None, 31, 34, 32) 0 ['S_norm_2_1[0][0]'] S_relu_3_1 (ReLU) (None, 31, 34, 32) 0 ['S_norm_3_1[0][0]'] S_relu_4_1 (ReLU) (None, 31, 34, 32) 0 ['S_norm_4_1[0][0]'] S_relu_5_1 (ReLU) (None, 31, 34, 32) 0 ['S_norm_5_1[0][0]'] S_relu_6_1 (ReLU) (None, 31, 34, 32) 0 ['S_norm_6_1[0][0]'] S_relu_7_1 (ReLU) (None, 31, 34, 32) 0 ['S_norm_7_1[0][0]'] S_relu_8_1 (ReLU) (None, 31, 34, 32) 0 ['S_norm_8_1[0][0]'] S_relu_9_1 (ReLU) (None, 31, 34, 32) 0 ['S_norm_9_1[0][0]'] S_relu_10_1 (ReLU) (None, 31, 34, 32) 0 ['S_norm_10_1[0][0]'] S_relu_11_1 (ReLU) (None, 31, 34, 32) 0 ['S_norm_11_1[0][0]'] S_pool_0_1 (MaxPooling2D) (None, 15, 17, 32) 0 ['S_relu_0_1[0][0]'] S_pool_1_1 (MaxPooling2D) (None, 15, 17, 32) 0 ['S_relu_1_1[0][0]'] S_pool_2_1 (MaxPooling2D) (None, 15, 17, 32) 0 ['S_relu_2_1[0][0]'] S_pool_3_1 (MaxPooling2D) (None, 15, 17, 32) 0 ['S_relu_3_1[0][0]'] S_pool_4_1 (MaxPooling2D) (None, 15, 17, 32) 0 ['S_relu_4_1[0][0]'] S_pool_5_1 (MaxPooling2D) (None, 15, 17, 32) 0 ['S_relu_5_1[0][0]'] S_pool_6_1 (MaxPooling2D) (None, 15, 17, 32) 0 ['S_relu_6_1[0][0]'] S_pool_7_1 (MaxPooling2D) (None, 15, 17, 32) 0 ['S_relu_7_1[0][0]'] S_pool_8_1 (MaxPooling2D) (None, 15, 17, 32) 0 ['S_relu_8_1[0][0]'] S_pool_9_1 (MaxPooling2D) (None, 15, 17, 32) 0 ['S_relu_9_1[0][0]'] S_pool_10_1 (MaxPooling2D) (None, 15, 17, 32) 0 ['S_relu_10_1[0][0]'] S_pool_11_1 (MaxPooling2D) (None, 15, 17, 32) 0 ['S_relu_11_1[0][0]'] S_conv_0_2 (Conv2D) (None, 15, 17, 64) 18496 ['S_pool_0_1[0][0]'] S_conv_1_2 (Conv2D) (None, 15, 17, 64) 18496 ['S_pool_1_1[0][0]'] S_conv_2_2 (Conv2D) (None, 15, 17, 64) 18496 ['S_pool_2_1[0][0]'] S_conv_3_2 (Conv2D) (None, 15, 17, 64) 18496 ['S_pool_3_1[0][0]'] S_conv_4_2 (Conv2D) (None, 15, 17, 64) 18496 ['S_pool_4_1[0][0]'] S_conv_5_2 (Conv2D) (None, 15, 17, 64) 18496 ['S_pool_5_1[0][0]'] S_conv_6_2 (Conv2D) (None, 15, 17, 64) 18496 ['S_pool_6_1[0][0]'] S_conv_7_2 (Conv2D) (None, 15, 17, 64) 18496 ['S_pool_7_1[0][0]'] S_conv_8_2 (Conv2D) (None, 15, 17, 64) 18496 ['S_pool_8_1[0][0]'] S_conv_9_2 (Conv2D) (None, 15, 17, 64) 18496 ['S_pool_9_1[0][0]'] S_conv_10_2 (Conv2D) (None, 15, 17, 64) 18496 ['S_pool_10_1[0][0]'] S_conv_11_2 (Conv2D) (None, 15, 17, 64) 18496 ['S_pool_11_1[0][0]'] S_norm_0_2 (BatchNormaliza (None, 15, 17, 64) 256 ['S_conv_0_2[0][0]'] tion) S_norm_1_2 (BatchNormaliza (None, 15, 17, 64) 256 ['S_conv_1_2[0][0]'] tion) S_norm_2_2 (BatchNormaliza (None, 15, 17, 64) 256 ['S_conv_2_2[0][0]'] tion) S_norm_3_2 (BatchNormaliza (None, 15, 17, 64) 256 ['S_conv_3_2[0][0]'] tion) S_norm_4_2 (BatchNormaliza (None, 15, 17, 64) 256 ['S_conv_4_2[0][0]'] tion) S_norm_5_2 (BatchNormaliza (None, 15, 17, 64) 256 ['S_conv_5_2[0][0]'] tion) S_norm_6_2 (BatchNormaliza (None, 15, 17, 64) 256 ['S_conv_6_2[0][0]'] tion) S_norm_7_2 (BatchNormaliza (None, 15, 17, 64) 256 ['S_conv_7_2[0][0]'] tion) S_norm_8_2 (BatchNormaliza (None, 15, 17, 64) 256 ['S_conv_8_2[0][0]'] tion) S_norm_9_2 (BatchNormaliza (None, 15, 17, 64) 256 ['S_conv_9_2[0][0]'] tion) S_norm_10_2 (BatchNormaliz (None, 15, 17, 64) 256 ['S_conv_10_2[0][0]'] ation) S_norm_11_2 (BatchNormaliz (None, 15, 17, 64) 256 ['S_conv_11_2[0][0]'] ation) S_relu_0_2 (ReLU) (None, 15, 17, 64) 0 ['S_norm_0_2[0][0]'] S_relu_1_2 (ReLU) (None, 15, 17, 64) 0 ['S_norm_1_2[0][0]'] S_relu_2_2 (ReLU) (None, 15, 17, 64) 0 ['S_norm_2_2[0][0]'] S_relu_3_2 (ReLU) (None, 15, 17, 64) 0 ['S_norm_3_2[0][0]'] S_relu_4_2 (ReLU) (None, 15, 17, 64) 0 ['S_norm_4_2[0][0]'] S_relu_5_2 (ReLU) (None, 15, 17, 64) 0 ['S_norm_5_2[0][0]'] S_relu_6_2 (ReLU) (None, 15, 17, 64) 0 ['S_norm_6_2[0][0]'] S_relu_7_2 (ReLU) (None, 15, 17, 64) 0 ['S_norm_7_2[0][0]'] S_relu_8_2 (ReLU) (None, 15, 17, 64) 0 ['S_norm_8_2[0][0]'] S_relu_9_2 (ReLU) (None, 15, 17, 64) 0 ['S_norm_9_2[0][0]'] S_relu_10_2 (ReLU) (None, 15, 17, 64) 0 ['S_norm_10_2[0][0]'] S_relu_11_2 (ReLU) (None, 15, 17, 64) 0 ['S_norm_11_2[0][0]'] S_pool_0_2 (MaxPooling2D) (None, 7, 8, 64) 0 ['S_relu_0_2[0][0]'] S_pool_1_2 (MaxPooling2D) (None, 7, 8, 64) 0 ['S_relu_1_2[0][0]'] S_pool_2_2 (MaxPooling2D) (None, 7, 8, 64) 0 ['S_relu_2_2[0][0]'] S_pool_3_2 (MaxPooling2D) (None, 7, 8, 64) 0 ['S_relu_3_2[0][0]'] S_pool_4_2 (MaxPooling2D) (None, 7, 8, 64) 0 ['S_relu_4_2[0][0]'] S_pool_5_2 (MaxPooling2D) (None, 7, 8, 64) 0 ['S_relu_5_2[0][0]'] S_pool_6_2 (MaxPooling2D) (None, 7, 8, 64) 0 ['S_relu_6_2[0][0]'] S_pool_7_2 (MaxPooling2D) (None, 7, 8, 64) 0 ['S_relu_7_2[0][0]'] S_pool_8_2 (MaxPooling2D) (None, 7, 8, 64) 0 ['S_relu_8_2[0][0]'] S_pool_9_2 (MaxPooling2D) (None, 7, 8, 64) 0 ['S_relu_9_2[0][0]'] S_pool_10_2 (MaxPooling2D) (None, 7, 8, 64) 0 ['S_relu_10_2[0][0]'] S_pool_11_2 (MaxPooling2D) (None, 7, 8, 64) 0 ['S_relu_11_2[0][0]'] S_conv_0_3 (Conv2D) (None, 7, 8, 128) 73856 ['S_pool_0_2[0][0]'] S_conv_1_3 (Conv2D) (None, 7, 8, 128) 73856 ['S_pool_1_2[0][0]'] S_conv_2_3 (Conv2D) (None, 7, 8, 128) 73856 ['S_pool_2_2[0][0]'] S_conv_3_3 (Conv2D) (None, 7, 8, 128) 73856 ['S_pool_3_2[0][0]'] S_conv_4_3 (Conv2D) (None, 7, 8, 128) 73856 ['S_pool_4_2[0][0]'] S_conv_5_3 (Conv2D) (None, 7, 8, 128) 73856 ['S_pool_5_2[0][0]'] S_conv_6_3 (Conv2D) (None, 7, 8, 128) 73856 ['S_pool_6_2[0][0]'] S_conv_7_3 (Conv2D) (None, 7, 8, 128) 73856 ['S_pool_7_2[0][0]'] S_conv_8_3 (Conv2D) (None, 7, 8, 128) 73856 ['S_pool_8_2[0][0]'] S_conv_9_3 (Conv2D) (None, 7, 8, 128) 73856 ['S_pool_9_2[0][0]'] S_conv_10_3 (Conv2D) (None, 7, 8, 128) 73856 ['S_pool_10_2[0][0]'] S_conv_11_3 (Conv2D) (None, 7, 8, 128) 73856 ['S_pool_11_2[0][0]'] S_norm_0_3 (BatchNormaliza (None, 7, 8, 128) 512 ['S_conv_0_3[0][0]'] tion) S_norm_1_3 (BatchNormaliza (None, 7, 8, 128) 512 ['S_conv_1_3[0][0]'] tion) S_norm_2_3 (BatchNormaliza (None, 7, 8, 128) 512 ['S_conv_2_3[0][0]'] tion) S_norm_3_3 (BatchNormaliza (None, 7, 8, 128) 512 ['S_conv_3_3[0][0]'] tion) S_norm_4_3 (BatchNormaliza (None, 7, 8, 128) 512 ['S_conv_4_3[0][0]'] tion) S_norm_5_3 (BatchNormaliza (None, 7, 8, 128) 512 ['S_conv_5_3[0][0]'] tion) S_norm_6_3 (BatchNormaliza (None, 7, 8, 128) 512 ['S_conv_6_3[0][0]'] tion) S_norm_7_3 (BatchNormaliza (None, 7, 8, 128) 512 ['S_conv_7_3[0][0]'] tion) S_norm_8_3 (BatchNormaliza (None, 7, 8, 128) 512 ['S_conv_8_3[0][0]'] tion) S_norm_9_3 (BatchNormaliza (None, 7, 8, 128) 512 ['S_conv_9_3[0][0]'] tion) S_norm_10_3 (BatchNormaliz (None, 7, 8, 128) 512 ['S_conv_10_3[0][0]'] ation) S_norm_11_3 (BatchNormaliz (None, 7, 8, 128) 512 ['S_conv_11_3[0][0]'] ation) S_relu_0_3 (ReLU) (None, 7, 8, 128) 0 ['S_norm_0_3[0][0]'] S_relu_1_3 (ReLU) (None, 7, 8, 128) 0 ['S_norm_1_3[0][0]'] S_relu_2_3 (ReLU) (None, 7, 8, 128) 0 ['S_norm_2_3[0][0]'] S_relu_3_3 (ReLU) (None, 7, 8, 128) 0 ['S_norm_3_3[0][0]'] S_relu_4_3 (ReLU) (None, 7, 8, 128) 0 ['S_norm_4_3[0][0]'] S_relu_5_3 (ReLU) (None, 7, 8, 128) 0 ['S_norm_5_3[0][0]'] S_relu_6_3 (ReLU) (None, 7, 8, 128) 0 ['S_norm_6_3[0][0]'] S_relu_7_3 (ReLU) (None, 7, 8, 128) 0 ['S_norm_7_3[0][0]'] S_relu_8_3 (ReLU) (None, 7, 8, 128) 0 ['S_norm_8_3[0][0]'] S_relu_9_3 (ReLU) (None, 7, 8, 128) 0 ['S_norm_9_3[0][0]'] S_relu_10_3 (ReLU) (None, 7, 8, 128) 0 ['S_norm_10_3[0][0]'] S_relu_11_3 (ReLU) (None, 7, 8, 128) 0 ['S_norm_11_3[0][0]'] S_pool_0_3 (MaxPooling2D) (None, 3, 4, 128) 0 ['S_relu_0_3[0][0]'] S_pool_1_3 (MaxPooling2D) (None, 3, 4, 128) 0 ['S_relu_1_3[0][0]'] S_pool_2_3 (MaxPooling2D) (None, 3, 4, 128) 0 ['S_relu_2_3[0][0]'] S_pool_3_3 (MaxPooling2D) (None, 3, 4, 128) 0 ['S_relu_3_3[0][0]'] S_pool_4_3 (MaxPooling2D) (None, 3, 4, 128) 0 ['S_relu_4_3[0][0]'] S_pool_5_3 (MaxPooling2D) (None, 3, 4, 128) 0 ['S_relu_5_3[0][0]'] S_pool_6_3 (MaxPooling2D) (None, 3, 4, 128) 0 ['S_relu_6_3[0][0]'] S_pool_7_3 (MaxPooling2D) (None, 3, 4, 128) 0 ['S_relu_7_3[0][0]'] S_pool_8_3 (MaxPooling2D) (None, 3, 4, 128) 0 ['S_relu_8_3[0][0]'] S_pool_9_3 (MaxPooling2D) (None, 3, 4, 128) 0 ['S_relu_9_3[0][0]'] S_pool_10_3 (MaxPooling2D) (None, 3, 4, 128) 0 ['S_relu_10_3[0][0]'] S_pool_11_3 (MaxPooling2D) (None, 3, 4, 128) 0 ['S_relu_11_3[0][0]'] S_aver_0 (GlobalAveragePoo (None, 128) 0 ['S_pool_0_3[0][0]'] ling2D) S_aver_1 (GlobalAveragePoo (None, 128) 0 ['S_pool_1_3[0][0]'] ling2D) S_aver_2 (GlobalAveragePoo (None, 128) 0 ['S_pool_2_3[0][0]'] ling2D) S_aver_3 (GlobalAveragePoo (None, 128) 0 ['S_pool_3_3[0][0]'] ling2D) S_aver_4 (GlobalAveragePoo (None, 128) 0 ['S_pool_4_3[0][0]'] ling2D) S_aver_5 (GlobalAveragePoo (None, 128) 0 ['S_pool_5_3[0][0]'] ling2D) S_aver_6 (GlobalAveragePoo (None, 128) 0 ['S_pool_6_3[0][0]'] ling2D) S_aver_7 (GlobalAveragePoo (None, 128) 0 ['S_pool_7_3[0][0]'] ling2D) S_aver_8 (GlobalAveragePoo (None, 128) 0 ['S_pool_8_3[0][0]'] ling2D) S_aver_9 (GlobalAveragePoo (None, 128) 0 ['S_pool_9_3[0][0]'] ling2D) S_aver_10 (GlobalAveragePo (None, 128) 0 ['S_pool_10_3[0][0]'] oling2D) S_aver_11 (GlobalAveragePo (None, 128) 0 ['S_pool_11_3[0][0]'] oling2D) tf.expand_dims (TFOpLambda (None, 1, 128) 0 ['S_aver_0[0][0]'] ) tf.expand_dims_1 (TFOpLamb (None, 1, 128) 0 ['S_aver_1[0][0]'] da) tf.expand_dims_2 (TFOpLamb (None, 1, 128) 0 ['S_aver_2[0][0]'] da) tf.expand_dims_3 (TFOpLamb (None, 1, 128) 0 ['S_aver_3[0][0]'] da) tf.expand_dims_4 (TFOpLamb (None, 1, 128) 0 ['S_aver_4[0][0]'] da) tf.expand_dims_5 (TFOpLamb (None, 1, 128) 0 ['S_aver_5[0][0]'] da) tf.expand_dims_6 (TFOpLamb (None, 1, 128) 0 ['S_aver_6[0][0]'] da) tf.expand_dims_7 (TFOpLamb (None, 1, 128) 0 ['S_aver_7[0][0]'] da) tf.expand_dims_8 (TFOpLamb (None, 1, 128) 0 ['S_aver_8[0][0]'] da) tf.expand_dims_9 (TFOpLamb (None, 1, 128) 0 ['S_aver_9[0][0]'] da) tf.expand_dims_10 (TFOpLam (None, 1, 128) 0 ['S_aver_10[0][0]'] bda) tf.expand_dims_11 (TFOpLam (None, 1, 128) 0 ['S_aver_11[0][0]'] bda) S_concat (Concatenate) (None, 12, 128) 0 ['tf.expand_dims[0][0]', 'tf.expand_dims_1[0][0]', 'tf.expand_dims_2[0][0]', 'tf.expand_dims_3[0][0]', 'tf.expand_dims_4[0][0]', 'tf.expand_dims_5[0][0]', 'tf.expand_dims_6[0][0]', 'tf.expand_dims_7[0][0]', 'tf.expand_dims_8[0][0]', 'tf.expand_dims_9[0][0]', 'tf.expand_dims_10[0][0]', 'tf.expand_dims_11[0][0]'] S_flatten (Flatten) (None, 1536) 0 ['S_concat[0][0]'] Z_dense_1 (Dense) (None, 128) 196736 ['S_flatten[0][0]'] Z_dense_2 (Dense) (None, 128) 16512 ['Z_dense_1[0][0]'] Z_drop_1 (Dropout) (None, 128) 0 ['Z_dense_2[0][0]'] Z_outputs (Dense) (None, 5) 645 ['Z_drop_1[0][0]'] ================================================================================================== Total params: 1336712 (5.10 MB) Trainable params: 1331333 (5.08 MB) Non-trainable params: 5379 (21.02 KB) __________________________________________________________________________________________________
Následuje trénování modelu:
In [38]:
MODEL_CHECKPOINT = '/kaggle/working/model/model03.ckpt' callbacks_list = [ keras.callbacks.EarlyStopping(monitor='val_binary_accuracy', patience=10), keras.callbacks.ModelCheckpoint(filepath=MODEL_CHECKPOINT, monitor='val_binary_accuracy', save_best_only=True) ] history = model03.fit(S_train, Z_train, epochs=100, batch_size=32, callbacks=callbacks_list, validation_data=(S_valid, Z_valid)) model03 = keras.models.load_model(MODEL_CHECKPOINT) model03.save('/kaggle/working/model/model03.keras') Epoch 1/100 546/546 [==============================] - 62s 66ms/step - loss: 0.4174 - binary_accuracy: 0.8142 - precision: 0.6817 - recall: 0.5084 - val_loss: 0.3741 - val_binary_accuracy: 0.8337 - val_precision: 0.7169 - val_recall: 0.5772 Epoch 2/100 546/546 [==============================] - 35s 63ms/step - loss: 0.3655 - binary_accuracy: 0.8436 - precision: 0.7355 - recall: 0.6030 - val_loss: 0.3575 - val_binary_accuracy: 0.8409 - val_precision: 0.7449 - val_recall: 0.5740 Epoch 3/100 546/546 [==============================] - 16s 30ms/step - loss: 0.3469 - binary_accuracy: 0.8526 - precision: 0.7523 - recall: 0.6286 - val_loss: 0.3786 - val_binary_accuracy: 0.8320 - val_precision: 0.7163 - val_recall: 0.5676 Epoch 4/100 546/546 [==============================] - 34s 63ms/step - loss: 0.3342 - binary_accuracy: 0.8575 - precision: 0.7612 - recall: 0.6424 - val_loss: 0.3459 - val_binary_accuracy: 0.8522 - val_precision: 0.7225 - val_recall: 0.6846 Epoch 5/100 546/546 [==============================] - 16s 30ms/step - loss: 0.3315 - binary_accuracy: 0.8591 - precision: 0.7598 - recall: 0.6539 - val_loss: 0.3540 - val_binary_accuracy: 0.8516 - val_precision: 0.7537 - val_recall: 0.6233 Epoch 6/100 546/546 [==============================] - 16s 30ms/step - loss: 0.3186 - binary_accuracy: 0.8659 - precision: 0.7724 - recall: 0.6716 - val_loss: 0.3518 - val_binary_accuracy: 0.8468 - val_precision: 0.7528 - val_recall: 0.5965 Epoch 7/100 546/546 [==============================] - 34s 63ms/step - loss: 0.3097 - binary_accuracy: 0.8691 - precision: 0.7794 - recall: 0.6781 - val_loss: 0.3599 - val_binary_accuracy: 0.8532 - val_precision: 0.7529 - val_recall: 0.6336 Epoch 8/100 546/546 [==============================] - 34s 63ms/step - loss: 0.3048 - binary_accuracy: 0.8712 - precision: 0.7829 - recall: 0.6845 - val_loss: 0.3285 - val_binary_accuracy: 0.8589 - val_precision: 0.7526 - val_recall: 0.6675 Epoch 9/100 546/546 [==============================] - 16s 30ms/step - loss: 0.2985 - binary_accuracy: 0.8754 - precision: 0.7915 - recall: 0.6935 - val_loss: 0.3583 - val_binary_accuracy: 0.8420 - val_precision: 0.7284 - val_recall: 0.6086 Epoch 10/100 546/546 [==============================] - 17s 30ms/step - loss: 0.2919 - binary_accuracy: 0.8780 - precision: 0.7965 - recall: 0.7002 - val_loss: 0.4031 - val_binary_accuracy: 0.8202 - val_precision: 0.6635 - val_recall: 0.6022 Epoch 11/100 546/546 [==============================] - 34s 63ms/step - loss: 0.2826 - binary_accuracy: 0.8817 - precision: 0.8026 - recall: 0.7104 - val_loss: 0.3278 - val_binary_accuracy: 0.8605 - val_precision: 0.7835 - val_recall: 0.6275 Epoch 12/100 546/546 [==============================] - 35s 64ms/step - loss: 0.2760 - binary_accuracy: 0.8852 - precision: 0.8100 - recall: 0.7180 - val_loss: 0.3267 - val_binary_accuracy: 0.8613 - val_precision: 0.7483 - val_recall: 0.6893 Epoch 13/100 546/546 [==============================] - 34s 63ms/step - loss: 0.2705 - binary_accuracy: 0.8882 - precision: 0.8149 - recall: 0.7263 - val_loss: 0.3371 - val_binary_accuracy: 0.8621 - val_precision: 0.7534 - val_recall: 0.6846 Epoch 14/100 546/546 [==============================] - 17s 31ms/step - loss: 0.2627 - binary_accuracy: 0.8910 - precision: 0.8193 - recall: 0.7341 - val_loss: 0.4129 - val_binary_accuracy: 0.8244 - val_precision: 0.6660 - val_recall: 0.6283 Epoch 15/100 546/546 [==============================] - 16s 30ms/step - loss: 0.2544 - binary_accuracy: 0.8953 - precision: 0.8261 - recall: 0.7462 - val_loss: 0.3595 - val_binary_accuracy: 0.8416 - val_precision: 0.6992 - val_recall: 0.6675 Epoch 16/100 546/546 [==============================] - 17s 30ms/step - loss: 0.2460 - binary_accuracy: 0.8988 - precision: 0.8361 - recall: 0.7501 - val_loss: 0.3900 - val_binary_accuracy: 0.8522 - val_precision: 0.7667 - val_recall: 0.6061 Epoch 17/100 546/546 [==============================] - 17s 30ms/step - loss: 0.2364 - binary_accuracy: 0.9046 - precision: 0.8438 - recall: 0.7678 - val_loss: 0.3842 - val_binary_accuracy: 0.8421 - val_precision: 0.6780 - val_recall: 0.7285 Epoch 18/100 546/546 [==============================] - 17s 30ms/step - loss: 0.2307 - binary_accuracy: 0.9062 - precision: 0.8462 - recall: 0.7724 - val_loss: 0.3902 - val_binary_accuracy: 0.8523 - val_precision: 0.7487 - val_recall: 0.6357 Epoch 19/100 546/546 [==============================] - 16s 30ms/step - loss: 0.2180 - binary_accuracy: 0.9126 - precision: 0.8597 - recall: 0.7850 - val_loss: 0.4497 - val_binary_accuracy: 0.8565 - val_precision: 0.7513 - val_recall: 0.6554 Epoch 20/100 546/546 [==============================] - 16s 30ms/step - loss: 0.2098 - binary_accuracy: 0.9157 - precision: 0.8629 - recall: 0.7954 - val_loss: 0.4039 - val_binary_accuracy: 0.8544 - val_precision: 0.7434 - val_recall: 0.6575 Epoch 21/100 546/546 [==============================] - 16s 30ms/step - loss: 0.1974 - binary_accuracy: 0.9212 - precision: 0.8743 - recall: 0.8067 - val_loss: 0.5562 - val_binary_accuracy: 0.8031 - val_precision: 0.6326 - val_recall: 0.5480 Epoch 22/100 546/546 [==============================] - 17s 30ms/step - loss: 0.1954 - binary_accuracy: 0.9210 - precision: 0.8732 - recall: 0.8072 - val_loss: 0.5557 - val_binary_accuracy: 0.7650 - val_precision: 0.5376 - val_recall: 0.5762 Epoch 23/100 546/546 [==============================] - 16s 30ms/step - loss: 0.1851 - binary_accuracy: 0.9265 - precision: 0.8832 - recall: 0.8201 - val_loss: 0.5077 - val_binary_accuracy: 0.8196 - val_precision: 0.6461 - val_recall: 0.6507
A takto vypadal průběh trénování z pohledu optimalizované funkce a metrik:
In [39]:
sns.relplot(data=pd.DataFrame(history.history), kind='line', height=4, aspect=4)
plt.show()
A nyní si můžu model opět vyhodnotit proti testovací sadě dat:
In [40]:
Z_pred = np.around(model03.predict(S_test, verbose=0))
metrics = pd.concat([metrics.drop(model03.name, level=0, errors='ignore'), model_quality_reporting(Z_test, Z_pred, model_name=model03.name)])
precision recall f1-score support NORM 0.80 0.85 0.82 964 MI 0.65 0.62 0.64 553 STTC 0.81 0.54 0.64 523 CD 0.74 0.64 0.69 498 HYP 0.57 0.58 0.57 263 micro avg 0.74 0.68 0.71 2801 macro avg 0.71 0.65 0.67 2801 weighted avg 0.74 0.68 0.71 2801 samples avg 0.70 0.69 0.68 2801
Výsledky nejsou výrazně lepší, než tomu bylo u předchozího modelu, ale jsou poněkud rovnoměrnější pro všechny klasifikační třídy.
V příštím dílu budu pokračovat kombinovanými modely.
pracuje na pozici IT architekta. Poslední roky se zaměřuje na integrační a komunikační projekty ve zdravotnictví. Mezi jeho koníčky patří také paragliding a jízda na horském kole.
Přečteno 27 671×
Přečteno 27 338×
Přečteno 25 868×
Přečteno 23 739×
Přečteno 19 478×