I’m trying to make a UNet based Autoencoder model to De-Blur images. I defined the model as follows:
def conv_operation(x, filters, kernel_size, strides=2): x = Conv2D(filters=filters, kernel_size=kernel_size, strides=strides, padding='same')(x) x = BatchNormalization()(x) x = ReLU()(x) return x def conv_transpose_operation(x, filters, kernel_size): x = Conv2DTranspose(filters=filters, kernel_size=kernel_size, strides=2, padding='same')(x) x = BatchNormalization()(x) x = ReLU()(x) return x def deblurring_autoencoder(): dae_inputs = Input(shape=(200,200, 3), name='dae_input') conv_block1 = conv_operation(dae_inputs, 32, 3) conv_block2 = conv_operation(conv_block1, 64, 3) conv_block3 = conv_operation(conv_block2, 128, 3) conv_block4 = conv_operation(conv_block3, 256, 3) conv_block5 = conv_operation(conv_block4, 256, 3, 1) deconv_block1 = conv_transpose_operation(conv_block5, 256,3) merge1 = Concatenate()([conv_block3,deconv_block1]) deconv_block2 = conv_transpose_operation(merge1, 128, 2) merge2 = Concatenate()([deconv_block2, conv_block2]) deconv_block3 = conv_transpose_operation(merge2, 64, 3) merge3 = Concatenate()([deconv_block3, conv_block1]) deconv_block4 = conv_transpose_operation(merge3, 32, 3) final_deconv = Conv2DTranspose(filters=3, kernel_size=3)(deconv_block4) dae_outputs = Activation('sigmoid', name='dae_output')(final_deconv) return Model(dae_inputs, dae_outputs, name='dae')
After these definitons, I try to make the model like follows:
When I run the above line, I get a long error that basically tells me that my code is breaking at the line:
merge1 = Concatenate()([conv_block3,deconv_block1])
due to a dimensionality error. The error says:
A `Concatenate` layer requires inputs with matching shapes except for the concat axis. Got inputs shapes: [(None, 25, 25, 128), (None, 26, 26, 256)]
I tried to manually check all the dimensions after each convolution and they seem to fit perfectly for me. One things I noticed is that whenever i take the input shape in the
Input() function as (32,32,3) or (64,64,3) or (128,128,3), etc, I get no errors.
How can I resolve this?
Go to Source
Author: Aryan Sethi