Skip to content Skip to footer

Low light image enhancement using Deep Retinex-Net model

Low light image enhancement using Deep Retinex-Net model

Images have a wide range of applications in the field of engineering like in the medical field, remote sensing, transmission and encoding, machine vision, robotics, pattern recognition, etc.
In some cases, the images captured by cameras contain blur, noise, and low lightness in it. This causes difficulty to the viewer in extracting the information from it. There can be many causes of this problem like low light environment, the poor performance of the equipment, inappropriate configurations of the equipment.
The objective of this case study is to build a model(Deep Retinex Decomposition model) using deep learning techniques like CNN, Encoder-decoder that will learn from the existing dataset of low and high quality images and will be able to convert any poor quality image given to it into the high quality image.

Contents

  1. Problem statement
  2. Source of data
  3. Deep learning problem
  4. Data cleaning
  5. Existing approaches
  6. Exploratory data analysis
  7. Model explanation-Decomposition
  8. Model explanation-Adjustment
  9. Conclusion
  10. Future work
  11. References

1. Problem statement

The problem is to train a Deep Retinex-Net model on the dataset containing the pair of low and high light images and make the model be able to convert any given low light image into a high light image.

2. Source of data

https:/ /github.com/weichen582/RetinexNet
The dataset named LOL (LOw Light paired) dataset used for the problem is taken from the above link. It contains 5000 low/normal light images pairs of different kinds such as household appliances, toys, books, garden, food items, playground, clubs, streets, etc.
These raw images are resized to 128*128 and converted to Portable Network Graphics format.
The below figure shows the subset of these images.

3. Deep learning problem

The problem is to convert a low light image into a high light image using a Deep Retinex-Net model. The enhancement process is divided into three steps: decomposition, adjustment, and reconstruction. In the decomposition step, a subnetwork Decom-Net decomposes the input image into reflectance and illumination. In the following adjustment step, an encoder-decoder based Enhance-Net brightens up the illumination. Multi-scale concatenation is introduced to adjust the illumination from multi-scale perspectives. Noise on the reflectance is also removed at this step. Finally, we reconstruct the adjusted illumination and reflectance to get the enhanced result.

4. Data cleaning

The images given in the dataset are of size 400*600. These are converted into a size of 128*128 for the training of the model.

def load_images(file):
    im = Image.open(file)
    newsize = (128, 128)
    im = im.resize(newsize)
    return np.array(im, dtype="float32") / 255.0
  
train_low_data_names = glob('/data_2/low/*.png')
train_low_data_names.sort()
train_high_data_names = glob('/data_2/high/*.png') 
train_high_data_names.sort()
assert len(train_low_data_names) == len(train_high_data_names)

train_low_data = []
train_high_data = []

for idx in tqdm(range(len(train_low_data_names))):
    low_im = load_images(train_low_data_names[idx])
    train_low_data.append(low_im)
    high_im = load_images(train_high_data_names[idx])
    train_high_data.append(high_im)

5. Existing approaches

Many techniques have been developed to improve the subjective and objective quality of low-light images like histogram equalization, De-hazing based method, Simultaneous reflectance and illumination estimation (SRIE), Low light image enhancement via illumination map estimation (LLIE), simple CNN network like techniques have been widely used for performing this task.

6. Exploratory data analysis

The dataset contains paired low/normal-light images captured in real scenes.

Sample low light image

low_img = train_low_data[200]
print("Shape of the image", low_img.shape)
plt.imshow(low_img[ : , : , :])
Sample low light image

Sample high light image


high_img = train_high_data[200]
print("Shape of the image", high_img.shape)

plt.imshow(high_img[ : , : , :])
Sample high light image

7. Model explanation-Decomposition

First, a subnetwork, Decom-Net is used to split the observed image into lighting-independent reflectance and structure-aware smooth illumination. The Decom-Net is learned with two constraints. First, low/normal-light images share the same reflectance. Second, the illumination map should be smooth but retain main structures, which is obtained by a structure-aware total variation loss.

Decom-Net takes the low-light image S_low and the normal-light one S_normal as input then estimates the reflectance R_low and the illumination I_low for S_low, as well as R_normal and I_normal for S_normal, respectively. It first uses a 3×3 convolutional layer to extract features from the input image. Then, several 3×3 convolutional layers with Rectified Linear Unit (ReLU) as the activation function are followed to map the RGB image into reflectance and illumination. A 3×3 convolutional layer projects R and I from feature space and the sigmoid function is used to constrain both R and I in the range of [0, 1].

Reflectance and Illuminance

Loss functions

The loss L consists of three terms: reconstruction loss L_recon, invariable reflectance loss L_ir, and illumination smoothness loss L_is:

where lambda_ir and lambda_is denote the coefficients to balance the consistency of reflectance and the smoothness of illumination.

Reconstruction loss is formulated as

where R, I, and S are Reflectance, Illumination, and actual image respectively.

Invariable reflectance loss is introduced to constrain the consistency of reflectance

Illumination smoothness loss L_is is formulated as

where delta denotes the gradient including delta_h (horizontal) and delta_v (vertical), and lambda_g denotes the coefficient balancing the strength of structure-awareness.

def concat(layers):
    return tf.concat(layers, axis=3)
def DecomNet(input_im, layer_num, channel=64, kernel_size=3):
    input_max = tf.reduce_max(input_im, axis=3, keepdims=True)
    input_im = concat([input_max, input_im])
    with tf.variable_scope('DecomNet', reuse=tf.AUTO_REUSE):
        conv = tf.layers.conv2d(input_im, channel, kernel_size * 3, padding='same', activation=None, name="shallow_feature_extraction")
        for idx in range(layer_num):
            conv = tf.layers.conv2d(conv, channel, kernel_size, padding='same', activation=tf.nn.relu, name='activated_layer_%d' % idx)
        conv = tf.layers.conv2d(conv, 4, kernel_size, padding='same', activation=None, name='recon_layer')

    R = tf.sigmoid(conv[:,:,:,0:3])
    L = tf.sigmoid(conv[:,:,:,3:4])

    return R, L
class lowlight_enhance:

  def __init__(self, train_low_data, train_high_data, eval_low_data, batch_size, patch_size, epoch, learning_rate, train_phase, ckpt_dir):

    self.DecomNet_layer_num = 5
    self.sess = tf.Session()

    self.train_low_data = train_low_data
    self.train_high_data = train_high_data
    self.eval_low_data = eval_low_data
    self.batch_size = batch_size
    self.patch_size = patch_size
    self.epoch = epoch
    self.learning_rate = learning_rate
    self.train_phase = train_phase

    self.train_low_data_ph = tf.placeholder(tf.float32, [None, None, None, 3], name='train_low_data_ph')
    self.train_high_data_ph = tf.placeholder(tf.float32, [None, None, None, 3], name='train_high_data_ph')
    self.lr_ph = tf.placeholder(tf.float32, name='lr_ph')

    self.ckpt_dir = ckpt_dir

    [R_low, I_low] = DecomNet(self.train_low_data_ph, layer_num=self.DecomNet_layer_num)
    [R_high, I_high] = DecomNet(self.train_high_data_ph, layer_num=self.DecomNet_layer_num)

    I_low_3 = concat([I_low, I_low, I_low])
    I_high_3 = concat([I_high, I_high, I_high])

    self.output_R_low = R_low
    self.output_I_low = I_low_3

    # loss
    self.recon_loss_low = tf.reduce_mean(tf.abs(R_low * I_low_3 -  self.train_low_data_ph))
    self.recon_loss_high = tf.reduce_mean(tf.abs(R_high * I_high_3 - self.train_high_data_ph))
    self.recon_loss_mutal_low = tf.reduce_mean(tf.abs(R_high * I_low_3 - self.train_low_data_ph))
    self.recon_loss_mutal_high = tf.reduce_mean(tf.abs(R_low * I_high_3 - self.train_high_data_ph))
    self.equal_R_loss = tf.reduce_mean(tf.abs(R_low - R_high))

    self.Ismooth_loss_low = self.smooth(I_low, R_low)
    self.Ismooth_loss_high = self.smooth(I_high, R_high)

    self.loss_Decom = self.recon_loss_low + self.recon_loss_high + 0.001 * self.recon_loss_mutal_low + 0.001 * self.recon_loss_mutal_high + 0.1 * self.Ismooth_loss_low + 0.1 * self.Ismooth_loss_high + 0.01 * self.equal_R_loss

    self.lr_ph = tf.placeholder(tf.float32, name='learning_rate')
    optimizer = tf.train.AdamOptimizer(self.lr_ph, name='AdamOptimizer')

    self.var_Decom = [var for var in tf.trainable_variables() if 'DecomNet' in var.name] 

    self.train_op_Decom = optimizer.minimize(self.loss_Decom, var_list = self.var_Decom)

    self.sess.run(tf.global_variables_initializer())

    self.saver_Decom = tf.train.Saver(var_list = self.var_Decom)

    print("[*] Initialize model successfully...")

  def gradient(self, input_tensor, direction):
    self.smooth_kernel_x = tf.reshape(tf.constant([[0, 0], [-1, 1]], tf.float32), [2, 2, 1, 1])
    self.smooth_kernel_y = tf.transpose(self.smooth_kernel_x, [1, 0, 2, 3])

    if direction == "x":
        kernel = self.smooth_kernel_x
    elif direction == "y":
        kernel = self.smooth_kernel_y
    return tf.abs(tf.nn.conv2d(input_tensor, kernel, strides=[1, 1, 1, 1], padding='SAME'))

  def ave_gradient(self, input_tensor, direction):
    return tf.layers.average_pooling2d(self.gradient(input_tensor, direction), pool_size=3, strides=1, padding='SAME')

  def smooth(self, input_I, input_R):
    input_R = tf.image.rgb_to_grayscale(input_R)
    return tf.reduce_mean(self.gradient(input_I, "x") * tf.exp(-10 * self.ave_gradient(input_R, "x")) + self.gradient(input_I, "y") * tf.exp(-10 * self.ave_gradient(input_R, "y")))

  def evaluate_test(self):

    print("Evaluating for test data")
    Reflectance = []
    Illuminance = []

    for idx in range(len(self.eval_low_data)):
      input_low_eval = np.expand_dims(self.eval_low_data[idx], axis=0)

      if train_phase == "Decom":
        result_1, result_2 = self.sess.run([self.output_R_low, self.output_I_low], feed_dict={self.train_low_data_ph: input_low_eval})

      Reflectance.append(result_1)
      Illuminance.append(result_2)

    return Reflectance, Illuminance

  def evaluate_train(self):

    print("Evaluating for train data")
    Reflectance = []
    Illuminance = []

    for idx in tqdm(range(len(self.train_low_data))):
    #for idx in tqdm(range(start_id, end_id)):
      input_low_train = np.expand_dims(self.train_low_data[idx], axis=0)

      if train_phase == "Decom":
        result_1, result_2 = self.sess.run([self.output_R_low, self.output_I_low], feed_dict={self.train_low_data_ph: input_low_train})

      Reflectance.append(result_1)
      Illuminance.append(result_2)

    return Reflectance, Illuminance

  def train(self):
 
    numBatch = 30
    
    # load pretrained model
    
    train_op = self.train_op_Decom
    train_loss = self.loss_Decom
    saver = self.saver_Decom

    iter_num = 0
    start_epoch = 0
    start_step = 0
    lr1 = self.learning_rate

    start_time = time.time()
    image_id = 0

    for epoch in range(start_epoch, self.epoch):

      for batch_id in range(start_step, numBatch):

        # generate data for a batch
        batch_input_low = np.zeros((self.batch_size, self.patch_size, self.patch_size, 3), dtype="float32")
        batch_input_high = np.zeros((self.batch_size, self.patch_size, self.patch_size, 3), dtype="float32")
        for patch_id in range(self.batch_size):

          h, w, _ = self.train_low_data[image_id].shape
          x = random.randint(0, h - self.patch_size)
          y = random.randint(0, w - self.patch_size)

          rand_mode = random.randint(0, 7)
          batch_input_low[patch_id, :, :, :] = data_augmentation(self.train_low_data[image_id][x : x+self.patch_size, y : y+self.patch_size, :], rand_mode)
          batch_input_high[patch_id, :, :, :] = data_augmentation(self.train_high_data[image_id][x : x+self.patch_size, y : y+self.patch_size, :], rand_mode)
          
          image_id = (image_id + 1) % len(self.train_low_data)
          if image_id == 0:
            tmp = list(zip(self.train_low_data, self.train_high_data))
            random.shuffle(list(tmp))
            train_low_data, train_high_data  = zip(*tmp)

        # train
        _, loss = self.sess.run([train_op, train_loss], feed_dict={self.train_low_data_ph: batch_input_low, \
                                                                    self.train_high_data_ph: batch_input_high, \
                                                                    self.lr_ph: lr1[epoch]})
        
        print("%s Epoch: [%2d] [%4d/%4d] time: %4.4f, loss: %.6f" \
              % (train_phase, epoch + 1, batch_id + 1, numBatch, time.time() - start_time, loss))
        iter_num += 1   

  def save_pretrained(self):
    self.save(self.saver_Decom, self.ckpt_dir, "RetinexNet-%s" % self.train_phase)

  def load_pretrained(self):
    load_model_status = self.load(self.saver_Decom, self.ckpt_dir)
    print("[*] Model restore success!")

  def save(self, saver, ckpt_dir, model_name):
    if not os.path.exists(ckpt_dir):
      os.makedirs(ckpt_dir)
    print("[*] Saving model %s" % model_name)
    saver.save(self.sess, \
                os.path.join(ckpt_dir, model_name), \
               )

  def load(self, saver, ckpt_dir):
    ckpt = tf.train.get_checkpoint_state(ckpt_dir)
    full_path = ckpt_dir + '/RetinexNet-Decom'

    saver.restore(self.sess, full_path)
    return True
view rawLLE_9.py hosted with ❤ by GitHub

Plotting Reflectance and Illumination of a sample image

Read More