Full Blog TOC

Full Blog Table Of Content with Keywords Available HERE

Tuesday, August 27, 2024

Update Image Features Using Variational AutoEncoder



 

In the Variational AutoEncoder post, we have created an encoder that can generate new images.

In this post, we use the encoding and labels to change the images to a specific behavior. Hence, we take a specific digit, and slowly morph it to another digit.

The implementation is the following:

  1. Target_Encoding =  average encoding for all images of digit X
  2. Non_Target_Encoding = average encoding for all images of other digits
  3. Convert_Vector = Target_Encoding - Non_Target_Encoding

Next, we can take any source image and update it to be similar to digit X by factor alpha:

  1. Encode the image
  2. Add alpha*Convert_Vector to the encoding
  3. Decode


The image at the top display gradual conversion of each digit to the target digit '1'.


The implementation is done by adding the following method to the trainer class in the Variational AutoEncoder post.



def variation_examples(self, variation_label):
matching_encoding_samples = 0
non_matching_encoding_samples = 0

matching_encoding_sum = torch.zeros(self.model.latent_space_internal_features, dtype=torch.float)
non_matching_encoding_sum = torch.zeros(self.model.latent_space_internal_features, dtype=torch.float)

start_time = time.time()

source_image_by_label = {}

for batch_index, (batch_x, batch_y) in enumerate(self.loader_train):
batch_x = batch_x.to(device=self.device)
encoded_converted, _, _, _ = self.model(batch_x)

for batch_sample in range(batch_x.shape[0]):
label = batch_y[batch_sample].item()
source_image_by_label[label] = batch_x[batch_sample]

matching_label_ids = batch_y == variation_label
non_matching_label_ids = ~matching_label_ids

matching_encoding = encoded_converted[matching_label_ids]
non_matching_encoding = encoded_converted[non_matching_label_ids]

matching_encoding_sum += torch.sum(matching_encoding, axis=0)
non_matching_encoding_sum += torch.sum(non_matching_encoding, axis=0)

matching_encoding_samples += matching_encoding.shape[0]
non_matching_encoding_samples += non_matching_label_ids.shape[0]
passed_seconds = time.time() - start_time
if passed_seconds > 5:
start_time = time.time()
print('batch %05d/%05d' % (batch_index, len(self.loader_train)))

matching_encoding_average = matching_encoding_sum / matching_encoding_samples
non_matching_encoding_average = non_matching_encoding_sum / non_matching_encoding_samples

matching_encodings = torch.unsqueeze(matching_encoding_average, dim=0)
non_matching_encodings = torch.unsqueeze(non_matching_encoding_average, dim=0)

matching_image = self.model.decoder(matching_encodings)
non_matching_image = self.model.decoder(non_matching_encodings)

image_row = [matching_image, non_matching_image]
images_rows = [image_row]
self.plot_images("variation_base_{}.pdf".format(variation_label), images_rows)

images_rows = []
convert_vector = matching_encoding_average - non_matching_encoding_average
for label in range(10):
source_image = source_image_by_label[label]
source_images = torch.unsqueeze(source_image, dim=0)
label_row = [source_image]
encoded_source, _, _, _ = self.model(source_images)
for convert_factor in range(10):
encoded_converted = encoded_source + convert_vector * convert_factor*0.5
decoded_items = self.model.decoder(encoded_converted)
label_row.append(decoded_items[0])
images_rows.append(label_row)
self.plot_images("variation_convert_{}.pdf".format(variation_label), images_rows)


No comments:

Post a Comment