Deep Learning básico con Keras (Parte 4): ResNet
En este artículo vamos a mostrar la arquitectura ResNet. Ésta fue introducida por Microsoft, ganando la competición ILSVRC (ImageNet Large Scale Visual Recognition Challenge) en el año 2015. En el siguiente enlace se puede acceder al paper: https://arxiv.org/abs/1512.03385.
Diagrama de arquitectura ResNet
La idea, muy resumida, se basa en aumentar el número de capas introduciendo una conexión residual (con una capa identidad). Esta capa pasa a la siguiente directamente, mejorando el proceso de aprendizaje.
CNN tradicional VS CNN con conexión residual
Realizaremos el mismo experimento que en las partes anteriores. Obviaremos los puntos en los que importamos el dataset de CIFAR-100, la configuración básica del entorno del experimento y la importación de las librerías de python, pues son exactamente igual.
Entrenando la arquitectura ResNet
Keras tiene a nuestra disposición ésta arquitectura, pero tiene el problema que, por defecto, el tamaño de las imágenes debe ser mayor a 187 píxeles, por lo que definiremos una arquitectura más pequeña.
def CustomResNet50(include_top=True, input_tensor=None, input_shape=(32,32,3), pooling=None, classes=100):
if input_tensor is None:
img_input = Input(shape=input_shape)
else:
if not K.is_keras_tensor(input_tensor):
img_input = Input(tensor=input_tensor, shape=input_shape)
else:
img_input = input_tensor
if K.image_data_format() == 'channels_last':
bn_axis = 3
else:
bn_axis = 1
x = ZeroPadding2D(padding=(2, 2), name='conv1_pad')(img_input)
x = resnet50.conv_block(x, 3, [32, 32, 64], stage=2, block='a')
x = resnet50.identity_block(x, 3, [32, 32, 64], stage=2, block='b')
x = resnet50.identity_block(x, 3, [32, 32, 64], stage=2, block='c')
x = resnet50.conv_block(x, 3, [64, 64, 256], stage=3, block='a', strides=(1, 1))
x = resnet50.identity_block(x, 3, [64, 64, 256], stage=3, block='b')
x = resnet50.identity_block(x, 3, [64, 64, 256], stage=3, block='c')
x = resnet50.conv_block(x, 3, [128, 128, 512], stage=4, block='a')
x = resnet50.identity_block(x, 3, [128, 128, 512], stage=4, block='b')
x = resnet50.identity_block(x, 3, [128, 128, 512], stage=4, block='c')
x = resnet50.identity_block(x, 3, [128, 128, 512], stage=4, block='d')
x = resnet50.conv_block(x, 3, [256, 256, 1024], stage=5, block='a')
x = resnet50.identity_block(x, 3, [256, 256, 1024], stage=5, block='b')
x = resnet50.identity_block(x, 3, [256, 256, 1024], stage=5, block='c')
x = resnet50.identity_block(x, 3, [256, 256, 1024], stage=5, block='d')
x = resnet50.identity_block(x, 3, [256, 256, 1024], stage=5, block='e')
x = resnet50.identity_block(x, 3, [256, 256, 1024], stage=5, block='f')
x = resnet50.conv_block(x, 3, [512, 512, 2048], stage=6, block='a')
x = resnet50.identity_block(x, 3, [512, 512, 2048], stage=6, block='b')
x = resnet50.identity_block(x, 3, [512, 512, 2048], stage=6, block='c')
x = AveragePooling2D((1, 1), name='avg_pool')(x)
if include_top:
x = Flatten()(x)
x = Dense(classes, activation='softmax', name='fc1000')(x)
else:
if pooling == 'avg':
x = GlobalAveragePooling2D()(x)
elif pooling == 'max':
x = GlobalMaxPooling2D()(x)
# Ensure that the model takes into account
# any potential predecessors of `input_tensor`.
if input_tensor is not None:
inputs = get_source_inputs(input_tensor)
else:
inputs = img_input
# Create model.
model = Model(inputs, x, name='resnet50')
return model
Compilamos como hasta ahora...
def create_custom_resnet50():
model = CustomResNet50(include_top=True, input_tensor=None, input_shape=(32,32,3), pooling=None, classes=100)
return model
custom_resnet50_model = create_custom_resnet50()
custom_resnet50_model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['acc', 'mse'])
Una vez hecho esto, vamos a ver un resumen del modelo creado.
custom_resnet50_model.summary()
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) (None, 32, 32, 3) 0
__________________________________________________________________________________________________
conv1_pad (ZeroPadding2D) (None, 36, 36, 3) 0 input_1[0][0]
__________________________________________________________________________________________________
res2a_branch2a (Conv2D) (None, 18, 18, 32) 128 conv1_pad[0][0]
__________________________________________________________________________________________________
bn2a_branch2a (BatchNormalizati (None, 18, 18, 32) 128 res2a_branch2a[0][0]
__________________________________________________________________________________________________
activation_1 (Activation) (None, 18, 18, 32) 0 bn2a_branch2a[0][0]
__________________________________________________________________________________________________
res2a_branch2b (Conv2D) (None, 18, 18, 32) 9248 activation_1[0][0]
__________________________________________________________________________________________________
bn2a_branch2b (BatchNormalizati (None, 18, 18, 32) 128 res2a_branch2b[0][0]
__________________________________________________________________________________________________
activation_2 (Activation) (None, 18, 18, 32) 0 bn2a_branch2b[0][0]
__________________________________________________________________________________________________
res2a_branch2c (Conv2D) (None, 18, 18, 64) 2112 activation_2[0][0]
__________________________________________________________________________________________________
res2a_branch1 (Conv2D) (None, 18, 18, 64) 256 conv1_pad[0][0]
__________________________________________________________________________________________________
bn2a_branch2c (BatchNormalizati (None, 18, 18, 64) 256 res2a_branch2c[0][0]
__________________________________________________________________________________________________
bn2a_branch1 (BatchNormalizatio (None, 18, 18, 64) 256 res2a_branch1[0][0]
__________________________________________________________________________________________________
add_1 (Add) (None, 18, 18, 64) 0 bn2a_branch2c[0][0]
bn2a_branch1[0][0]
__________________________________________________________________________________________________
activation_3 (Activation) (None, 18, 18, 64) 0 add_1[0][0]
__________________________________________________________________________________________________
res2b_branch2a (Conv2D) (None, 18, 18, 32) 2080 activation_3[0][0]
__________________________________________________________________________________________________
bn2b_branch2a (BatchNormalizati (None, 18, 18, 32) 128 res2b_branch2a[0][0]
__________________________________________________________________________________________________
activation_4 (Activation) (None, 18, 18, 32) 0 bn2b_branch2a[0][0]
__________________________________________________________________________________________________
res2b_branch2b (Conv2D) (None, 18, 18, 32) 9248 activation_4[0][0]
__________________________________________________________________________________________________
bn2b_branch2b (BatchNormalizati (None, 18, 18, 32) 128 res2b_branch2b[0][0]
__________________________________________________________________________________________________
activation_5 (Activation) (None, 18, 18, 32) 0 bn2b_branch2b[0][0]
__________________________________________________________________________________________________
res2b_branch2c (Conv2D) (None, 18, 18, 64) 2112 activation_5[0][0]
__________________________________________________________________________________________________
bn2b_branch2c (BatchNormalizati (None, 18, 18, 64) 256 res2b_branch2c[0][0]
__________________________________________________________________________________________________
add_2 (Add) (None, 18, 18, 64) 0 bn2b_branch2c[0][0]
activation_3[0][0]
__________________________________________________________________________________________________
activation_6 (Activation) (None, 18, 18, 64) 0 add_2[0][0]
__________________________________________________________________________________________________
res2c_branch2a (Conv2D) (None, 18, 18, 32) 2080 activation_6[0][0]
__________________________________________________________________________________________________
bn2c_branch2a (BatchNormalizati (None, 18, 18, 32) 128 res2c_branch2a[0][0]
__________________________________________________________________________________________________
activation_7 (Activation) (None, 18, 18, 32) 0 bn2c_branch2a[0][0]
__________________________________________________________________________________________________
res2c_branch2b (Conv2D) (None, 18, 18, 32) 9248 activation_7[0][0]
__________________________________________________________________________________________________
bn2c_branch2b (BatchNormalizati (None, 18, 18, 32) 128 res2c_branch2b[0][0]
__________________________________________________________________________________________________
activation_8 (Activation) (None, 18, 18, 32) 0 bn2c_branch2b[0][0]
__________________________________________________________________________________________________
res2c_branch2c (Conv2D) (None, 18, 18, 64) 2112 activation_8[0][0]
__________________________________________________________________________________________________
bn2c_branch2c (BatchNormalizati (None, 18, 18, 64) 256 res2c_branch2c[0][0]
__________________________________________________________________________________________________
add_3 (Add) (None, 18, 18, 64) 0 bn2c_branch2c[0][0]
activation_6[0][0]
__________________________________________________________________________________________________
activation_9 (Activation) (None, 18, 18, 64) 0 add_3[0][0]
__________________________________________________________________________________________________
res3a_branch2a (Conv2D) (None, 18, 18, 64) 4160 activation_9[0][0]
__________________________________________________________________________________________________
bn3a_branch2a (BatchNormalizati (None, 18, 18, 64) 256 res3a_branch2a[0][0]
__________________________________________________________________________________________________
activation_10 (Activation) (None, 18, 18, 64) 0 bn3a_branch2a[0][0]
__________________________________________________________________________________________________
res3a_branch2b (Conv2D) (None, 18, 18, 64) 36928 activation_10[0][0]
__________________________________________________________________________________________________
bn3a_branch2b (BatchNormalizati (None, 18, 18, 64) 256 res3a_branch2b[0][0]
__________________________________________________________________________________________________
activation_11 (Activation) (None, 18, 18, 64) 0 bn3a_branch2b[0][0]
__________________________________________________________________________________________________
res3a_branch2c (Conv2D) (None, 18, 18, 256) 16640 activation_11[0][0]
__________________________________________________________________________________________________
res3a_branch1 (Conv2D) (None, 18, 18, 256) 16640 activation_9[0][0]
__________________________________________________________________________________________________
bn3a_branch2c (BatchNormalizati (None, 18, 18, 256) 1024 res3a_branch2c[0][0]
__________________________________________________________________________________________________
bn3a_branch1 (BatchNormalizatio (None, 18, 18, 256) 1024 res3a_branch1[0][0]
__________________________________________________________________________________________________
add_4 (Add) (None, 18, 18, 256) 0 bn3a_branch2c[0][0]
bn3a_branch1[0][0]
__________________________________________________________________________________________________
activation_12 (Activation) (None, 18, 18, 256) 0 add_4[0][0]
__________________________________________________________________________________________________
res3b_branch2a (Conv2D) (None, 18, 18, 64) 16448 activation_12[0][0]
__________________________________________________________________________________________________
bn3b_branch2a (BatchNormalizati (None, 18, 18, 64) 256 res3b_branch2a[0][0]
__________________________________________________________________________________________________
activation_13 (Activation) (None, 18, 18, 64) 0 bn3b_branch2a[0][0]
__________________________________________________________________________________________________
res3b_branch2b (Conv2D) (None, 18, 18, 64) 36928 activation_13[0][0]
__________________________________________________________________________________________________
bn3b_branch2b (BatchNormalizati (None, 18, 18, 64) 256 res3b_branch2b[0][0]
__________________________________________________________________________________________________
activation_14 (Activation) (None, 18, 18, 64) 0 bn3b_branch2b[0][0]
__________________________________________________________________________________________________
res3b_branch2c (Conv2D) (None, 18, 18, 256) 16640 activation_14[0][0]
__________________________________________________________________________________________________
bn3b_branch2c (BatchNormalizati (None, 18, 18, 256) 1024 res3b_branch2c[0][0]
__________________________________________________________________________________________________
add_5 (Add) (None, 18, 18, 256) 0 bn3b_branch2c[0][0]
activation_12[0][0]
__________________________________________________________________________________________________
activation_15 (Activation) (None, 18, 18, 256) 0 add_5[0][0]
__________________________________________________________________________________________________
res3c_branch2a (Conv2D) (None, 18, 18, 64) 16448 activation_15[0][0]
__________________________________________________________________________________________________
bn3c_branch2a (BatchNormalizati (None, 18, 18, 64) 256 res3c_branch2a[0][0]
__________________________________________________________________________________________________
activation_16 (Activation) (None, 18, 18, 64) 0 bn3c_branch2a[0][0]
__________________________________________________________________________________________________
res3c_branch2b (Conv2D) (None, 18, 18, 64) 36928 activation_16[0][0]
__________________________________________________________________________________________________
bn3c_branch2b (BatchNormalizati (None, 18, 18, 64) 256 res3c_branch2b[0][0]
__________________________________________________________________________________________________
activation_17 (Activation) (None, 18, 18, 64) 0 bn3c_branch2b[0][0]
__________________________________________________________________________________________________
res3c_branch2c (Conv2D) (None, 18, 18, 256) 16640 activation_17[0][0]
__________________________________________________________________________________________________
bn3c_branch2c (BatchNormalizati (None, 18, 18, 256) 1024 res3c_branch2c[0][0]
__________________________________________________________________________________________________
add_6 (Add) (None, 18, 18, 256) 0 bn3c_branch2c[0][0]
activation_15[0][0]
__________________________________________________________________________________________________
activation_18 (Activation) (None, 18, 18, 256) 0 add_6[0][0]
__________________________________________________________________________________________________
res4a_branch2a (Conv2D) (None, 9, 9, 128) 32896 activation_18[0][0]
__________________________________________________________________________________________________
bn4a_branch2a (BatchNormalizati (None, 9, 9, 128) 512 res4a_branch2a[0][0]
__________________________________________________________________________________________________
activation_19 (Activation) (None, 9, 9, 128) 0 bn4a_branch2a[0][0]
__________________________________________________________________________________________________
res4a_branch2b (Conv2D) (None, 9, 9, 128) 147584 activation_19[0][0]
__________________________________________________________________________________________________
bn4a_branch2b (BatchNormalizati (None, 9, 9, 128) 512 res4a_branch2b[0][0]
__________________________________________________________________________________________________
activation_20 (Activation) (None, 9, 9, 128) 0 bn4a_branch2b[0][0]
__________________________________________________________________________________________________
res4a_branch2c (Conv2D) (None, 9, 9, 512) 66048 activation_20[0][0]
__________________________________________________________________________________________________
res4a_branch1 (Conv2D) (None, 9, 9, 512) 131584 activation_18[0][0]
__________________________________________________________________________________________________
bn4a_branch2c (BatchNormalizati (None, 9, 9, 512) 2048 res4a_branch2c[0][0]
__________________________________________________________________________________________________
bn4a_branch1 (BatchNormalizatio (None, 9, 9, 512) 2048 res4a_branch1[0][0]
__________________________________________________________________________________________________
add_7 (Add) (None, 9, 9, 512) 0 bn4a_branch2c[0][0]
bn4a_branch1[0][0]
__________________________________________________________________________________________________
activation_21 (Activation) (None, 9, 9, 512) 0 add_7[0][0]
__________________________________________________________________________________________________
res4b_branch2a (Conv2D) (None, 9, 9, 128) 65664 activation_21[0][0]
__________________________________________________________________________________________________
bn4b_branch2a (BatchNormalizati (None, 9, 9, 128) 512 res4b_branch2a[0][0]
__________________________________________________________________________________________________
activation_22 (Activation) (None, 9, 9, 128) 0 bn4b_branch2a[0][0]
__________________________________________________________________________________________________
res4b_branch2b (Conv2D) (None, 9, 9, 128) 147584 activation_22[0][0]
__________________________________________________________________________________________________
bn4b_branch2b (BatchNormalizati (None, 9, 9, 128) 512 res4b_branch2b[0][0]
__________________________________________________________________________________________________
activation_23 (Activation) (None, 9, 9, 128) 0 bn4b_branch2b[0][0]
__________________________________________________________________________________________________
res4b_branch2c (Conv2D) (None, 9, 9, 512) 66048 activation_23[0][0]
__________________________________________________________________________________________________
bn4b_branch2c (BatchNormalizati (None, 9, 9, 512) 2048 res4b_branch2c[0][0]
__________________________________________________________________________________________________
add_8 (Add) (None, 9, 9, 512) 0 bn4b_branch2c[0][0]
activation_21[0][0]
__________________________________________________________________________________________________
activation_24 (Activation) (None, 9, 9, 512) 0 add_8[0][0]
__________________________________________________________________________________________________
res4c_branch2a (Conv2D) (None, 9, 9, 128) 65664 activation_24[0][0]
__________________________________________________________________________________________________
bn4c_branch2a (BatchNormalizati (None, 9, 9, 128) 512 res4c_branch2a[0][0]
__________________________________________________________________________________________________
activation_25 (Activation) (None, 9, 9, 128) 0 bn4c_branch2a[0][0]
__________________________________________________________________________________________________
res4c_branch2b (Conv2D) (None, 9, 9, 128) 147584 activation_25[0][0]
__________________________________________________________________________________________________
bn4c_branch2b (BatchNormalizati (None, 9, 9, 128) 512 res4c_branch2b[0][0]
__________________________________________________________________________________________________
activation_26 (Activation) (None, 9, 9, 128) 0 bn4c_branch2b[0][0]
__________________________________________________________________________________________________
res4c_branch2c (Conv2D) (None, 9, 9, 512) 66048 activation_26[0][0]
__________________________________________________________________________________________________
bn4c_branch2c (BatchNormalizati (None, 9, 9, 512) 2048 res4c_branch2c[0][0]
__________________________________________________________________________________________________
add_9 (Add) (None, 9, 9, 512) 0 bn4c_branch2c[0][0]
activation_24[0][0]
__________________________________________________________________________________________________
activation_27 (Activation) (None, 9, 9, 512) 0 add_9[0][0]
__________________________________________________________________________________________________
res4d_branch2a (Conv2D) (None, 9, 9, 128) 65664 activation_27[0][0]
__________________________________________________________________________________________________
bn4d_branch2a (BatchNormalizati (None, 9, 9, 128) 512 res4d_branch2a[0][0]
__________________________________________________________________________________________________
activation_28 (Activation) (None, 9, 9, 128) 0 bn4d_branch2a[0][0]
__________________________________________________________________________________________________
res4d_branch2b (Conv2D) (None, 9, 9, 128) 147584 activation_28[0][0]
__________________________________________________________________________________________________
bn4d_branch2b (BatchNormalizati (None, 9, 9, 128) 512 res4d_branch2b[0][0]
__________________________________________________________________________________________________
activation_29 (Activation) (None, 9, 9, 128) 0 bn4d_branch2b[0][0]
__________________________________________________________________________________________________
res4d_branch2c (Conv2D) (None, 9, 9, 512) 66048 activation_29[0][0]
__________________________________________________________________________________________________
bn4d_branch2c (BatchNormalizati (None, 9, 9, 512) 2048 res4d_branch2c[0][0]
__________________________________________________________________________________________________
add_10 (Add) (None, 9, 9, 512) 0 bn4d_branch2c[0][0]
activation_27[0][0]
__________________________________________________________________________________________________
activation_30 (Activation) (None, 9, 9, 512) 0 add_10[0][0]
__________________________________________________________________________________________________
res5a_branch2a (Conv2D) (None, 5, 5, 256) 131328 activation_30[0][0]
__________________________________________________________________________________________________
bn5a_branch2a (BatchNormalizati (None, 5, 5, 256) 1024 res5a_branch2a[0][0]
__________________________________________________________________________________________________
activation_31 (Activation) (None, 5, 5, 256) 0 bn5a_branch2a[0][0]
__________________________________________________________________________________________________
res5a_branch2b (Conv2D) (None, 5, 5, 256) 590080 activation_31[0][0]
__________________________________________________________________________________________________
bn5a_branch2b (BatchNormalizati (None, 5, 5, 256) 1024 res5a_branch2b[0][0]
__________________________________________________________________________________________________
activation_32 (Activation) (None, 5, 5, 256) 0 bn5a_branch2b[0][0]
__________________________________________________________________________________________________
res5a_branch2c (Conv2D) (None, 5, 5, 1024) 263168 activation_32[0][0]
__________________________________________________________________________________________________
res5a_branch1 (Conv2D) (None, 5, 5, 1024) 525312 activation_30[0][0]
__________________________________________________________________________________________________
bn5a_branch2c (BatchNormalizati (None, 5, 5, 1024) 4096 res5a_branch2c[0][0]
__________________________________________________________________________________________________
bn5a_branch1 (BatchNormalizatio (None, 5, 5, 1024) 4096 res5a_branch1[0][0]
__________________________________________________________________________________________________
add_11 (Add) (None, 5, 5, 1024) 0 bn5a_branch2c[0][0]
bn5a_branch1[0][0]
__________________________________________________________________________________________________
activation_33 (Activation) (None, 5, 5, 1024) 0 add_11[0][0]
__________________________________________________________________________________________________
res5b_branch2a (Conv2D) (None, 5, 5, 256) 262400 activation_33[0][0]
__________________________________________________________________________________________________
bn5b_branch2a (BatchNormalizati (None, 5, 5, 256) 1024 res5b_branch2a[0][0]
__________________________________________________________________________________________________
activation_34 (Activation) (None, 5, 5, 256) 0 bn5b_branch2a[0][0]
__________________________________________________________________________________________________
res5b_branch2b (Conv2D) (None, 5, 5, 256) 590080 activation_34[0][0]
__________________________________________________________________________________________________
bn5b_branch2b (BatchNormalizati (None, 5, 5, 256) 1024 res5b_branch2b[0][0]
__________________________________________________________________________________________________
activation_35 (Activation) (None, 5, 5, 256) 0 bn5b_branch2b[0][0]
__________________________________________________________________________________________________
res5b_branch2c (Conv2D) (None, 5, 5, 1024) 263168 activation_35[0][0]
__________________________________________________________________________________________________
bn5b_branch2c (BatchNormalizati (None, 5, 5, 1024) 4096 res5b_branch2c[0][0]
__________________________________________________________________________________________________
add_12 (Add) (None, 5, 5, 1024) 0 bn5b_branch2c[0][0]
activation_33[0][0]
__________________________________________________________________________________________________
activation_36 (Activation) (None, 5, 5, 1024) 0 add_12[0][0]
__________________________________________________________________________________________________
res5c_branch2a (Conv2D) (None, 5, 5, 256) 262400 activation_36[0][0]
__________________________________________________________________________________________________
bn5c_branch2a (BatchNormalizati (None, 5, 5, 256) 1024 res5c_branch2a[0][0]
__________________________________________________________________________________________________
activation_37 (Activation) (None, 5, 5, 256) 0 bn5c_branch2a[0][0]
__________________________________________________________________________________________________
res5c_branch2b (Conv2D) (None, 5, 5, 256) 590080 activation_37[0][0]
__________________________________________________________________________________________________
bn5c_branch2b (BatchNormalizati (None, 5, 5, 256) 1024 res5c_branch2b[0][0]
__________________________________________________________________________________________________
activation_38 (Activation) (None, 5, 5, 256) 0 bn5c_branch2b[0][0]
__________________________________________________________________________________________________
res5c_branch2c (Conv2D) (None, 5, 5, 1024) 263168 activation_38[0][0]
__________________________________________________________________________________________________
bn5c_branch2c (BatchNormalizati (None, 5, 5, 1024) 4096 res5c_branch2c[0][0]
__________________________________________________________________________________________________
add_13 (Add) (None, 5, 5, 1024) 0 bn5c_branch2c[0][0]
activation_36[0][0]
__________________________________________________________________________________________________
activation_39 (Activation) (None, 5, 5, 1024) 0 add_13[0][0]
__________________________________________________________________________________________________
res5d_branch2a (Conv2D) (None, 5, 5, 256) 262400 activation_39[0][0]
__________________________________________________________________________________________________
bn5d_branch2a (BatchNormalizati (None, 5, 5, 256) 1024 res5d_branch2a[0][0]
__________________________________________________________________________________________________
activation_40 (Activation) (None, 5, 5, 256) 0 bn5d_branch2a[0][0]
__________________________________________________________________________________________________
res5d_branch2b (Conv2D) (None, 5, 5, 256) 590080 activation_40[0][0]
__________________________________________________________________________________________________
bn5d_branch2b (BatchNormalizati (None, 5, 5, 256) 1024 res5d_branch2b[0][0]
__________________________________________________________________________________________________
activation_41 (Activation) (None, 5, 5, 256) 0 bn5d_branch2b[0][0]
__________________________________________________________________________________________________
res5d_branch2c (Conv2D) (None, 5, 5, 1024) 263168 activation_41[0][0]
__________________________________________________________________________________________________
bn5d_branch2c (BatchNormalizati (None, 5, 5, 1024) 4096 res5d_branch2c[0][0]
__________________________________________________________________________________________________
add_14 (Add) (None, 5, 5, 1024) 0 bn5d_branch2c[0][0]
activation_39[0][0]
__________________________________________________________________________________________________
activation_42 (Activation) (None, 5, 5, 1024) 0 add_14[0][0]
__________________________________________________________________________________________________
res5e_branch2a (Conv2D) (None, 5, 5, 256) 262400 activation_42[0][0]
__________________________________________________________________________________________________
bn5e_branch2a (BatchNormalizati (None, 5, 5, 256) 1024 res5e_branch2a[0][0]
__________________________________________________________________________________________________
activation_43 (Activation) (None, 5, 5, 256) 0 bn5e_branch2a[0][0]
__________________________________________________________________________________________________
res5e_branch2b (Conv2D) (None, 5, 5, 256) 590080 activation_43[0][0]
__________________________________________________________________________________________________
bn5e_branch2b (BatchNormalizati (None, 5, 5, 256) 1024 res5e_branch2b[0][0]
__________________________________________________________________________________________________
activation_44 (Activation) (None, 5, 5, 256) 0 bn5e_branch2b[0][0]
__________________________________________________________________________________________________
res5e_branch2c (Conv2D) (None, 5, 5, 1024) 263168 activation_44[0][0]
__________________________________________________________________________________________________
bn5e_branch2c (BatchNormalizati (None, 5, 5, 1024) 4096 res5e_branch2c[0][0]
__________________________________________________________________________________________________
add_15 (Add) (None, 5, 5, 1024) 0 bn5e_branch2c[0][0]
activation_42[0][0]
__________________________________________________________________________________________________
activation_45 (Activation) (None, 5, 5, 1024) 0 add_15[0][0]
__________________________________________________________________________________________________
res5f_branch2a (Conv2D) (None, 5, 5, 256) 262400 activation_45[0][0]
__________________________________________________________________________________________________
bn5f_branch2a (BatchNormalizati (None, 5, 5, 256) 1024 res5f_branch2a[0][0]
__________________________________________________________________________________________________
activation_46 (Activation) (None, 5, 5, 256) 0 bn5f_branch2a[0][0]
__________________________________________________________________________________________________
res5f_branch2b (Conv2D) (None, 5, 5, 256) 590080 activation_46[0][0]
__________________________________________________________________________________________________
bn5f_branch2b (BatchNormalizati (None, 5, 5, 256) 1024 res5f_branch2b[0][0]
__________________________________________________________________________________________________
activation_47 (Activation) (None, 5, 5, 256) 0 bn5f_branch2b[0][0]
__________________________________________________________________________________________________
res5f_branch2c (Conv2D) (None, 5, 5, 1024) 263168 activation_47[0][0]
__________________________________________________________________________________________________
bn5f_branch2c (BatchNormalizati (None, 5, 5, 1024) 4096 res5f_branch2c[0][0]
__________________________________________________________________________________________________
add_16 (Add) (None, 5, 5, 1024) 0 bn5f_branch2c[0][0]
activation_45[0][0]
__________________________________________________________________________________________________
activation_48 (Activation) (None, 5, 5, 1024) 0 add_16[0][0]
__________________________________________________________________________________________________
res6a_branch2a (Conv2D) (None, 3, 3, 512) 524800 activation_48[0][0]
__________________________________________________________________________________________________
bn6a_branch2a (BatchNormalizati (None, 3, 3, 512) 2048 res6a_branch2a[0][0]
__________________________________________________________________________________________________
activation_49 (Activation) (None, 3, 3, 512) 0 bn6a_branch2a[0][0]
__________________________________________________________________________________________________
res6a_branch2b (Conv2D) (None, 3, 3, 512) 2359808 activation_49[0][0]
__________________________________________________________________________________________________
bn6a_branch2b (BatchNormalizati (None, 3, 3, 512) 2048 res6a_branch2b[0][0]
__________________________________________________________________________________________________
activation_50 (Activation) (None, 3, 3, 512) 0 bn6a_branch2b[0][0]
__________________________________________________________________________________________________
res6a_branch2c (Conv2D) (None, 3, 3, 2048) 1050624 activation_50[0][0]
__________________________________________________________________________________________________
res6a_branch1 (Conv2D) (None, 3, 3, 2048) 2099200 activation_48[0][0]
__________________________________________________________________________________________________
bn6a_branch2c (BatchNormalizati (None, 3, 3, 2048) 8192 res6a_branch2c[0][0]
__________________________________________________________________________________________________
bn6a_branch1 (BatchNormalizatio (None, 3, 3, 2048) 8192 res6a_branch1[0][0]
__________________________________________________________________________________________________
add_17 (Add) (None, 3, 3, 2048) 0 bn6a_branch2c[0][0]
bn6a_branch1[0][0]
__________________________________________________________________________________________________
activation_51 (Activation) (None, 3, 3, 2048) 0 add_17[0][0]
__________________________________________________________________________________________________
res6b_branch2a (Conv2D) (None, 3, 3, 512) 1049088 activation_51[0][0]
__________________________________________________________________________________________________
bn6b_branch2a (BatchNormalizati (None, 3, 3, 512) 2048 res6b_branch2a[0][0]
__________________________________________________________________________________________________
activation_52 (Activation) (None, 3, 3, 512) 0 bn6b_branch2a[0][0]
__________________________________________________________________________________________________
res6b_branch2b (Conv2D) (None, 3, 3, 512) 2359808 activation_52[0][0]
__________________________________________________________________________________________________
bn6b_branch2b (BatchNormalizati (None, 3, 3, 512) 2048 res6b_branch2b[0][0]
__________________________________________________________________________________________________
activation_53 (Activation) (None, 3, 3, 512) 0 bn6b_branch2b[0][0]
__________________________________________________________________________________________________
res6b_branch2c (Conv2D) (None, 3, 3, 2048) 1050624 activation_53[0][0]
__________________________________________________________________________________________________
bn6b_branch2c (BatchNormalizati (None, 3, 3, 2048) 8192 res6b_branch2c[0][0]
__________________________________________________________________________________________________
add_18 (Add) (None, 3, 3, 2048) 0 bn6b_branch2c[0][0]
activation_51[0][0]
__________________________________________________________________________________________________
activation_54 (Activation) (None, 3, 3, 2048) 0 add_18[0][0]
__________________________________________________________________________________________________
res6c_branch2a (Conv2D) (None, 3, 3, 512) 1049088 activation_54[0][0]
__________________________________________________________________________________________________
bn6c_branch2a (BatchNormalizati (None, 3, 3, 512) 2048 res6c_branch2a[0][0]
__________________________________________________________________________________________________
activation_55 (Activation) (None, 3, 3, 512) 0 bn6c_branch2a[0][0]
__________________________________________________________________________________________________
res6c_branch2b (Conv2D) (None, 3, 3, 512) 2359808 activation_55[0][0]
__________________________________________________________________________________________________
bn6c_branch2b (BatchNormalizati (None, 3, 3, 512) 2048 res6c_branch2b[0][0]
__________________________________________________________________________________________________
activation_56 (Activation) (None, 3, 3, 512) 0 bn6c_branch2b[0][0]
__________________________________________________________________________________________________
res6c_branch2c (Conv2D) (None, 3, 3, 2048) 1050624 activation_56[0][0]
__________________________________________________________________________________________________
bn6c_branch2c (BatchNormalizati (None, 3, 3, 2048) 8192 res6c_branch2c[0][0]
__________________________________________________________________________________________________
add_19 (Add) (None, 3, 3, 2048) 0 bn6c_branch2c[0][0]
activation_54[0][0]
__________________________________________________________________________________________________
activation_57 (Activation) (None, 3, 3, 2048) 0 add_19[0][0]
__________________________________________________________________________________________________
avg_pool (AveragePooling2D) (None, 3, 3, 2048) 0 activation_57[0][0]
__________________________________________________________________________________________________
flatten_1 (Flatten) (None, 18432) 0 avg_pool[0][0]
__________________________________________________________________________________________________
fc1000 (Dense) (None, 100) 1843300 flatten_1[0][0]
==================================================================================================
Total params: 25,461,700
Trainable params: 25,407,812
Non-trainable params: 53,888
__________________________________________________________________________________________________
Recordemos que la arquitectura VGG-16 tenía aproximadamente 34 millones de parámetros a entrenar. Esto quiere decir que hemos aumentado la profundidad pero hemos reducido el número de parámetros a entrenar.
Bien, dicho esto, pasamos a entrenar el modelo.
crn50 = custom_resnet50_model.fit(x=x_train, y=y_train, batch_size=32, epochs=10, verbose=1, validation_data=(x_test, y_test), shuffle=True)
Train on 50000 samples, validate on 10000 samples
Epoch 1/10
50000/50000 [==============================] - 441s 9ms/step - loss: 4.5655 - acc: 0.0817 - mean_squared_error: 0.0101 - val_loss: 4.2085 - val_acc: 0.1228 - val_mean_squared_error: 0.0099
Epoch 2/10
50000/50000 [==============================] - 434s 9ms/step - loss: 4.1448 - acc: 0.1348 - mean_squared_error: 0.0098 - val_loss: 4.2032 - val_acc: 0.1236 - val_mean_squared_error: 0.0099
Epoch 3/10
50000/50000 [==============================] - 433s 9ms/step - loss: 4.2682 - acc: 0.1146 - mean_squared_error: 0.0099 - val_loss: 4.3306 - val_acc: 0.1066 - val_mean_squared_error: 0.0100
Epoch 4/10
50000/50000 [==============================] - 434s 9ms/step - loss: 4.1581 - acc: 0.1340 - mean_squared_error: 0.0098 - val_loss: 4.1405 - val_acc: 0.1384 - val_mean_squared_error: 0.0098
Epoch 5/10
50000/50000 [==============================] - 431s 9ms/step - loss: 3.9395 - acc: 0.1653 - mean_squared_error: 0.0096 - val_loss: 3.8838 - val_acc: 0.1718 - val_mean_squared_error: 0.0095
Epoch 6/10
50000/50000 [==============================] - 432s 9ms/step - loss: 3.9598 - acc: 0.1698 - mean_squared_error: 0.0096 - val_loss: 4.0047 - val_acc: 0.1608 - val_mean_squared_error: 0.0096
Epoch 7/10
50000/50000 [==============================] - 433s 9ms/step - loss: 3.8715 - acc: 0.1797 - mean_squared_error: 0.0095 - val_loss: 4.2620 - val_acc: 0.1184 - val_mean_squared_error: 0.0099
Epoch 8/10
50000/50000 [==============================] - 434s 9ms/step - loss: 3.9661 - acc: 0.1666 - mean_squared_error: 0.0096 - val_loss: 3.8181 - val_acc: 0.1898 - val_mean_squared_error: 0.0095
Epoch 9/10
50000/50000 [==============================] - 434s 9ms/step - loss: 3.8110 - acc: 0.1901 - mean_squared_error: 0.0095 - val_loss: 3.7521 - val_acc: 0.1966 - val_mean_squared_error: 0.0094
Epoch 10/10
50000/50000 [==============================] - 432s 9ms/step - loss: 3.7247 - acc: 0.2048 - mean_squared_error: 0.0094 - val_loss: 3.8206 - val_acc: 0.1929 - val_mean_squared_error: 0.0095
Veamos las métricas obtenidas para el entrenamiento y validación gráficamente.
plt.figure(0)
plt.plot(crn50.history['acc'],'r')
plt.plot(crn50.history['val_acc'],'g')
plt.xticks(np.arange(0, 11, 2.0))
plt.rcParams['figure.figsize'] = (8, 6)
plt.xlabel("Num of Epochs")
plt.ylabel("Accuracy")
plt.title("Training Accuracy vs Validation Accuracy")
plt.legend(['train','validation'])
plt.figure(1)
plt.plot(crn50.history['loss'],'r')
plt.plot(crn50.history['val_loss'],'g')
plt.xticks(np.arange(0, 11, 2.0))
plt.rcParams['figure.figsize'] = (8, 6)
plt.xlabel("Num of Epochs")
plt.ylabel("Loss")
plt.title("Training Loss vs Validation Loss")
plt.legend(['train','validation'])
plt.show()
Accuracy
Loss
El entrenamiento ha dado muy buenos resultados y ha generalizado bien (0,0119).
Matriz de confusión
Pasemos ahora a ver la matriz de confusión y las métricas de Accuracy, Recall y F1-score.
Vamos a hacer una predicción sobre el dataset de validación y, a partir de ésta, generamos la matriz de confusión y mostramos las métricas mencionadas anteriormente.
crn50_pred = custom_resnet50_model.predict(x_test, batch_size=32, verbose=1)
crn50_predicted = np.argmax(crn50_pred, axis=1)
crn50_cm = confusion_matrix(np.argmax(y_test, axis=1), crn50_predicted)
# Visualizing of confusion matrix
crn50_df_cm = pd.DataFrame(crn50_cm, range(100), range(100))
plt.figure(figsize = (20,14))
sn.set(font_scale=1.4) #for label size
sn.heatmap(crn50_df_cm, annot=True, annot_kws={"size": 12}) # font size
plt.show()
Matriz de confusión
Y por último, mostramos las métricas
crn50_report = classification_report(np.argmax(y_test, axis=1), crn50_predicted)
print(crn50_report)
precision recall f1-score support
0 0.46 0.32 0.38 100
1 0.25 0.17 0.20 100
2 0.17 0.09 0.12 100
3 0.05 0.62 0.09 100
4 0.18 0.06 0.09 100
5 0.25 0.05 0.08 100
6 0.11 0.14 0.12 100
7 0.15 0.12 0.13 100
8 0.21 0.20 0.20 100
9 0.49 0.21 0.29 100
10 0.11 0.03 0.05 100
11 0.08 0.05 0.06 100
12 0.38 0.13 0.19 100
13 0.23 0.10 0.14 100
14 0.18 0.05 0.08 100
15 0.14 0.06 0.08 100
16 0.19 0.24 0.21 100
17 0.40 0.19 0.26 100
18 0.19 0.24 0.21 100
19 0.20 0.22 0.21 100
20 0.42 0.31 0.36 100
21 0.31 0.23 0.26 100
22 0.35 0.09 0.14 100
23 0.36 0.37 0.37 100
24 0.31 0.49 0.38 100
25 0.17 0.03 0.05 100
26 0.43 0.06 0.11 100
27 0.11 0.03 0.05 100
28 0.31 0.35 0.33 100
29 0.12 0.10 0.11 100
30 0.27 0.33 0.30 100
31 0.11 0.09 0.10 100
32 0.22 0.20 0.21 100
33 0.23 0.30 0.26 100
34 0.17 0.05 0.08 100
35 0.09 0.02 0.03 100
36 0.10 0.23 0.14 100
37 0.15 0.16 0.16 100
38 0.08 0.24 0.12 100
39 0.23 0.18 0.20 100
40 0.26 0.20 0.22 100
41 0.45 0.49 0.47 100
42 0.12 0.17 0.14 100
43 0.11 0.02 0.03 100
44 0.14 0.09 0.11 100
45 0.08 0.01 0.02 100
46 0.07 0.29 0.12 100
47 0.55 0.18 0.27 100
48 0.23 0.31 0.26 100
49 0.27 0.23 0.25 100
50 0.12 0.05 0.07 100
51 0.28 0.09 0.14 100
52 0.47 0.62 0.54 100
53 0.25 0.13 0.17 100
54 0.18 0.25 0.21 100
55 0.00 0.00 0.00 100
56 0.27 0.27 0.27 100
57 0.27 0.11 0.16 100
58 0.15 0.41 0.22 100
59 0.18 0.10 0.13 100
60 0.41 0.63 0.50 100
61 0.33 0.32 0.32 100
62 0.15 0.07 0.09 100
63 0.31 0.26 0.28 100
64 0.11 0.11 0.11 100
65 0.15 0.11 0.13 100
66 0.10 0.06 0.08 100
67 0.15 0.15 0.15 100
68 0.37 0.66 0.47 100
69 0.38 0.25 0.30 100
70 0.21 0.04 0.07 100
71 0.27 0.54 0.36 100
72 0.20 0.01 0.02 100
73 0.30 0.21 0.25 100
74 0.14 0.15 0.14 100
75 0.30 0.29 0.29 100
76 0.40 0.40 0.40 100
77 0.13 0.14 0.13 100
78 0.15 0.08 0.10 100
79 0.14 0.05 0.07 100
80 0.08 0.05 0.06 100
81 0.14 0.11 0.12 100
82 0.37 0.24 0.29 100
83 0.08 0.02 0.03 100
84 0.10 0.11 0.10 100
85 0.23 0.39 0.29 100
86 0.36 0.21 0.26 100
87 0.21 0.19 0.20 100
88 0.05 0.06 0.05 100
89 0.24 0.18 0.20 100
90 0.21 0.24 0.22 100
91 0.33 0.31 0.32 100
92 0.11 0.11 0.11 100
93 0.16 0.10 0.12 100
94 0.38 0.26 0.31 100
95 0.21 0.50 0.30 100
96 0.22 0.23 0.22 100
97 0.10 0.18 0.13 100
98 0.12 0.02 0.03 100
99 0.24 0.08 0.12 100
avg / total 0.22 0.19 0.19 10000
Curva ROC (tasas de verdaderos positivos y falsos positivos)
Vamos a codificar la curva ROC.
from sklearn.datasets import make_classification
from sklearn.preprocessing import label_binarize
from scipy import interp
from itertools import cycle
n_classes = 100
from sklearn.metrics import roc_curve, auc
# Plot linewidth.
lw = 2
# Compute ROC curve and ROC area for each class
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(n_classes):
fpr[i], tpr[i], _ = roc_curve(y_test[:, i], crn50_pred[:, i])
roc_auc[i] = auc(fpr[i], tpr[i])
# Compute micro-average ROC curve and ROC area
fpr["micro"], tpr["micro"], _ = roc_curve(y_test.ravel(), crn50_pred.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
# Compute macro-average ROC curve and ROC area
# First aggregate all false positive rates
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))
# Then interpolate all ROC curves at this points
mean_tpr = np.zeros_like(all_fpr)
for i in range(n_classes):
mean_tpr += interp(all_fpr, fpr[i], tpr[i])
# Finally average it and compute AUC
mean_tpr /= n_classes
fpr["macro"] = all_fpr
tpr["macro"] = mean_tpr
roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])
# Plot all ROC curves
plt.figure(1)
plt.plot(fpr["micro"], tpr["micro"],
label='micro-average ROC curve (area = {0:0.2f})'
''.format(roc_auc["micro"]),
color='deeppink', linestyle=':', linewidth=4)
plt.plot(fpr["macro"], tpr["macro"],
label='macro-average ROC curve (area = {0:0.2f})'
''.format(roc_auc["macro"]),
color='navy', linestyle=':', linewidth=4)
colors = cycle(['aqua', 'darkorange', 'cornflowerblue'])
for i, color in zip(range(n_classes-97), colors):
plt.plot(fpr[i], tpr[i], color=color, lw=lw,
label='ROC curve of class {0} (area = {1:0.2f})'
''.format(i, roc_auc[i]))
plt.plot([0, 1], [0, 1], 'k--', lw=lw)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Some extension of Receiver operating characteristic to multi-class')
plt.legend(loc="lower right")
plt.show()
# Zoom in view of the upper left corner.
plt.figure(2)
plt.xlim(0, 0.2)
plt.ylim(0.8, 1)
plt.plot(fpr["micro"], tpr["micro"],
label='micro-average ROC curve (area = {0:0.2f})'
''.format(roc_auc["micro"]),
color='deeppink', linestyle=':', linewidth=4)
plt.plot(fpr["macro"], tpr["macro"],
label='macro-average ROC curve (area = {0:0.2f})'
''.format(roc_auc["macro"]),
color='navy', linestyle=':', linewidth=4)
colors = cycle(['aqua', 'darkorange', 'cornflowerblue'])
for i, color in zip(range(10), colors):
plt.plot(fpr[i], tpr[i], color=color, lw=lw,
label='ROC curve of class {0} (area = {1:0.2f})'
''.format(i, roc_auc[i]))
plt.plot([0, 1], [0, 1], 'k--', lw=lw)
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Some extension of Receiver operating characteristic to multi-class')
plt.legend(loc="lower right")
plt.show()
El resultado para tres clases se muestra en los siguientes gráficos.
Curva ROC para 3 clases
Zoom de la Curva ROC para 10 clases
Salvaremos los datos del histórico de entrenamiento para compararlos con otros modelos. Además, vamos a salvar el modelo con los pesos entrenados para usarlos en el futuro.
#Modelo
custom_resnet50_model.save(path_base + '/crn50.h5')
#Histórico
with open(path_base + '/crn50_history.txt', 'wb') as file_pi:
pickle.dump(crn50.history, file_pi)
A continuación, vamos a comparar las métricas con los modelos anteriores.
plt.figure(0)
plt.plot(snn.history['val_acc'],'r')
plt.plot(scnn.history['val_acc'],'g')
plt.plot(vgg16.history['val_acc'],'b')
plt.plot(vgg19.history['val_acc'],'y')
plt.plot(vgg16Bis.history['val_acc'],'m')
plt.plot(crn50.history['val_acc'],'gold')
plt.xticks(np.arange(0, 11, 2.0))
plt.rcParams['figure.figsize'] = (8, 6)
plt.xlabel("Num of Epochs")
plt.ylabel("Accuracy")
plt.title("Simple NN Accuracy vs simple CNN Accuracy")
plt.legend(['simple NN','CNN','VGG 16','VGG 19','Custom VGG','Custom ResNet'])
Simple NN Vs CNN accuracy
plt.figure(0)
plt.plot(snn.history['val_loss'],'r')
plt.plot(scnn.history['val_loss'],'g')
plt.plot(vgg16.history['val_loss'],'b')
plt.plot(vgg19.history['val_loss'],'y')
plt.plot(vgg16Bis.history['val_loss'],'m')
plt.plot(crn50.history['val_loss'],'gold')
plt.xticks(np.arange(0, 11, 2.0))
plt.rcParams['figure.figsize'] = (8, 6)
plt.xlabel("Num of Epochs")
plt.ylabel("Loss")
plt.title("Simple NN Loss vs simple CNN Loss")
plt.legend(['simple NN','CNN','VGG 16','VGG 19','Custom VGG','Custom ResNet'])
Simple NN Vs CNN loss
plt.figure(0)
plt.plot(snn.history['val_mean_squared_error'],'r')
plt.plot(scnn.history['val_mean_squared_error'],'g')
plt.plot(vgg16.history['val_mean_squared_error'],'b')
plt.plot(vgg19.history['val_mean_squared_error'],'y')
plt.plot(vgg16Bis.history['val_mean_squared_error'],'m')
plt.plot(crn50.history['val_mean_squared_error'],'gold')
plt.xticks(np.arange(0, 11, 2.0))
plt.rcParams['figure.figsize'] = (8, 6)
plt.xlabel("Num of Epochs")
plt.ylabel("Mean Squared Error")
plt.title("Simple NN MSE vs simple CNN MSE")
plt.legend(['simple NN','CNN','VGG 16','VGG 19','Custom VGG','Custom ResNet'])
Simple NN Vs CNN MSE
Conclusión sobre el experimento
Como se puede ver, la arquitectura marca un punto de inflexión. No sólo porque sea de los mejores resultados que las anteriores arquitecturas, sino también en los tiempos de entrenamiento, ya que permite aumentar las capas con un tiempo aceptable; y también en el número de parámetros, que se ha reducido considerablemente respecto a la arquitectura VGG.
En el siguiente artículo, presentaremos la arquitectura: DenseNet.