library(tensorflow)
library(keras)
library(tfdatasets)
Download the dataset
We will be using the Flickr8K dataset for this tutorial. This dataset comprises over 8,000 images, that are each paired with five different captions.
flickr_images <- get_file(
"fickr8k.zip",
"https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip"
)
flickr_text <- get_file(
"flickr9k_text.zip",
"https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip"
)
if (!fs::dir_exists(fs::path(fs::path_dir(flickr_text), "Flicker8k_Dataset"))) {
unzip(flickr_images, exdir = fs::path_dir(flickr_images))
unzip(flickr_text, exdir = fs::path_dir(flickr_text))
}
# Path to the images
IMAGES_PATH <- "Flicker8k_Dataset"
# Desired image dimensions
IMAGE_SIZE <- shape(299, 299)
# Vocabulary size
VOCAB_SIZE <- 10000
# Fixed length allowed for any sequence
SEQ_LENGTH <- 25
# Dimension for the image embeddings and token embeddings
EMBED_DIM <- 512
# Per-layer units in the feed-forward network
FF_DIM <- 512
# Other training parameters
BATCH_SIZE <- 64
EPOCHS <- 30
AUTOTUNE <- tf$data$AUTOTUNE
Preparing the dataset
captions <- fs::path(fs::path_dir(flickr_text), "Flickr8k.token.txt") %>%
readr::read_delim(
col_names = c("img", "caption"),
delim = "\t"
) %>%
tidyr::separate(img, into = c("img", "caption_id"), sep = "#") %>%
dplyr::select(img, caption) %>%
dplyr::group_by(img) %>%
dplyr::summarise(caption = list(caption)) %>%
dplyr::mutate(img = fs::path(fs::path_dir(flickr_text), "Flicker8k_Dataset", img))
train <- fs::path(fs::path_dir(flickr_text), "Flickr_8k.trainImages.txt") %>%
readr::read_lines()
valid <- fs::path(fs::path_dir(flickr_text), "Flickr_8k.devImages.txt") %>%
readr::read_lines()
test <- fs::path(fs::path_dir(flickr_text), "Flickr_8k.testImages.txt") %>%
readr::read_lines()
train_data <- captions %>%
dplyr::filter(fs::path_file(img) %in% train)
valid_data <- captions %>%
dplyr::filter(fs::path_file(img) %in% test)
dplyr::n_distinct(train_data$img)
dplyr::n_distinct(valid_data$img)
Vectorizing the text data
We’ll use the text_vectorization
layer to vectorize the text data, that is to say, to turn the original strings into integer sequences where each integer represents the index of a word in a vocabulary. We will use a custom string standardization scheme (strip punctuation characters except <
and >
) and the default splitting scheme (split on whitespace).
punctuation <- c("!", "\\", "\"", "#", "$", "%", "&", "'", "(", ")", "*",
"+", ",", "-", ".", "/", ":", ";", "=", "?", "@", "[",
"\\", "\\", "]", "^", "_", "`", "{", "|", "}", "~")
re <- reticulate::import("re")
punctuation_group <- punctuation %>%
sapply(re$escape) %>%
paste0(collapse = "") %>%
sprintf("[%s]", .)
custom_standardization <- function(input_string) {
lowercase <- tf$strings$lower(input_string)
tf$strings$regex_replace(lowercase, punctuation_group, "")
}
vectorization <- layer_text_vectorization(
max_tokens = VOCAB_SIZE,
output_mode = "int",
output_sequence_length = SEQ_LENGTH,
standardize = custom_standardization,
)
vectorization %>% adapt(unlist(train_data$caption))
# Data augmentation for image data
image_augmentation <- keras_model_sequential() %>%
layer_random_flip("horizontal") %>%
layer_random_rotation(0.2) %>%
layer_random_contrast(0.3)
Building a TensorFlow dataset pipeline for training
We will generate pairs of images and corresponding captions using a tf$data$Dataset
object. The pipeline consists of two steps:
decode_and_resize <- function(img_path) {
img_path %>%
tf$io$read_file() %>%
tf$image$decode_jpeg(channels = 3) %>%
tf$image$resize(IMAGE_SIZE) %>%
tf$image$convert_image_dtype(tf$float32)
}
process_input <- function(img_path, captions) {
reticulate::tuple(
decode_and_resize(img_path),
vectorization(captions)
)
}
make_dataset <- function(data) {
data %>% unname() %>%
tensor_slices_dataset() %>%
dataset_shuffle(nrow(data)) %>%
dataset_map(process_input, num_parallel_calls = AUTOTUNE) %>%
dataset_batch(BATCH_SIZE) %>%
dataset_prefetch(AUTOTUNE)
}
# Pass the list of images and the list of corresponding captions
train_dataset <- make_dataset(train_data)
valid_dataset <- make_dataset(valid_data)
Building the model
Our image captioning architecture consists of three models:
get_cnn_model <- function() {
base_model <- application_efficientnet_b0(
input_shape = c(IMAGE_SIZE, 3),
include_top = FALSE,
weights = "imagenet"
)
# We freeze our feature extractor
base_model$trainable <- FALSE
base_model_out <- base_model$output %>%
layer_reshape(target_shape = c(-1, tail(dim(base_model$output), 1)))
keras_model(base_model$input, base_model_out)
}
transformer_encoder_block <- new_layer_class(
"transformer_encoder_block",
initialize = function(embed_dim, dense_dim, num_heads, ...) {
super()$`__init__`(...)
self$embed_dim <- embed_dim
self$dense_dim <- dense_dim
self$num_heads <- num_heads
self$attention_1 <- layer_multi_head_attention(
num_heads = num_heads, key_dim = embed_dim, dropout = 0.0
)
self$layernorm_1 <- layer_normalization()
self$layernorm_2 <- layer_normalization()
self$dense_1 <- layer_dense(units = embed_dim, activation = "relu")
},
call = function(inputs, training, mask = NULL) {
inputs <- self$layernorm_1(inputs)
inputs <- self$dense_1(inputs)
attention_output_1 <- self$attention_1(
query = inputs,
value = inputs,
key = inputs,
attention_mask = NULL,
training = training,
)
out_1 <- self$layernorm_2(inputs + attention_output_1)
out_1
}
)
positional_embedding <- new_layer_class(
"positional_embedding",
initialize = function(sequence_length, vocab_size, embed_dim, ...) {
super()$`__init__`(...)
self$token_embeddings <- layer_embedding(
input_dim = vocab_size, output_dim = embed_dim
)
self$position_embeddings <- layer_embedding(
input_dim = sequence_length, output_dim = embed_dim
)
self$sequence_length <- sequence_length
self$vocab_size <- vocab_size
self$embed_dim <- embed_dim
self$embed_scale <- tf$math$sqrt(tf$cast(embed_dim, tf$float32))
},
call = function(inputs) {
length <- tail(dim(inputs), 1)
positions <- tf$range(start = 0L, limit = length, delta = 1L)
embedded_tokens <- self$token_embeddings(inputs)
embedded_tokens <- embedded_tokens * self$embed_scale
embedded_positions <- self$position_embeddings(positions)
embedded_tokens + embedded_positions
},
compute_mask = function(inputs, mask) {
tf$math$not_equal(inputs, 0L)
}
)
transformer_decoder_block <- new_layer_class(
"transformer_decoder_block",
initialize = function(embed_dim, ff_dim, num_heads, ...) {
super()$`__init__`(...)
self$embed_dim <- embed_dim
self$ff_dim <- ff_dim
self$num_heads <- num_heads
self$attention_1 <- layer_multi_head_attention(
num_heads = num_heads, key_dim = embed_dim, dropout = 0.1
)
self$attention_2 <- layer_multi_head_attention(
num_heads = num_heads, key_dim = embed_dim, dropout = 0.1
)
self$ffn_layer_1 <- layer_dense(units = ff_dim, activation = "relu")
self$ffn_layer_2 <- layer_dense(units = embed_dim)
self$layernorm_1 <- layer_normalization()
self$layernorm_2 <- layer_normalization()
self$layernorm_3 <- layer_normalization()
self$embedding <- positional_embedding(
embed_dim = EMBED_DIM, sequence_length = SEQ_LENGTH, vocab_size = VOCAB_SIZE
)
self$out <- layer_dense(units = VOCAB_SIZE, activation = "softmax")
self$dropout_1 <- layer_dropout(rate = 0.3)
self$dropout_2 <- layer_dropout(rate = 0.5)
self$supports_masking <- TRUE
},
call = function(inputs, encoder_outputs, training, mask = NULL) {
inputs <- self$embedding(inputs)
causal_mask <- self$get_causal_attention_mask(inputs)
if(!is.null(mask)) {
padding_mask <- tf$cast(mask[, , tf$newaxis], dtype = tf$int32)
combined_mask <- tf$cast(mask[, tf$newaxis, ], dtype = tf$int32)
combined_mask <- tf$minimum(combined_mask, causal_mask)
}
attention_output_1 <- self$attention_1(
query = inputs,
value = inputs,
key = inputs,
attention_mask = combined_mask,
training = training,
)
out_1 <- self$layernorm_1(inputs + attention_output_1)
attention_output_2 <- self$attention_2(
query = out_1,
value = encoder_outputs,
key = encoder_outputs,
attention_mask = padding_mask,
training = training,
)
out_2 <- self$layernorm_2(out_1 + attention_output_2)
ffn_out <- self$ffn_layer_1(out_2)
ffn_out <- self$dropout_1(ffn_out, training = training)
ffn_out <- self$ffn_layer_2(ffn_out)
ffn_out <- self$layernorm_3(ffn_out + out_2, training = training)
ffn_out <- self$dropout_2(ffn_out, training = training)
preds <- self$out(ffn_out)
preds
},
get_causal_attention_mask = function(inputs) {
input_shape <- tf$shape(inputs)
batch_size <- input_shape[1]
sequence_length <- input_shape[2]
i <- tf$range(sequence_length)[, tf$newaxis]
j <- tf$range(sequence_length)
mask <- tf$cast(i >= j, dtype = "int32")
mask <- tf$reshape(mask, list(1L, input_shape[2], input_shape[2]))
mult <- tf$concat(list(
tf$expand_dims(batch_size, -1L),
as_tensor(c(1L, 1L), dtype = tf$int32)
), axis = 0L)
tf$tile(mask, mult)
}
)
image_captioning_model <- new_model_class(
"image_captioning_model",
initialize = function(cnn_model, encoder, decoder, num_captions_per_image = 5,
image_aug = NULL) {
super()$`__init__`()
self$cnn_model <- cnn_model
self$encoder <- encoder
self$decoder <- decoder
self$loss_tracker <- metric_mean(name = "loss")
self$acc_tracker <- metric_mean(name = "accuracy")
self$num_captions_per_image <- num_captions_per_image
self$image_aug <- image_aug
},
calculate_loss = function(y_true, y_pred, mask) {
loss <- self$loss(y_true, y_pred)
mask <- tf$cast(mask, dtype = loss$dtype)
loss <- loss* mask
tf$reduce_sum(loss) / tf$reduce_sum(mask)
},
calculate_accuracy = function(y_true, y_pred, mask) {
accuracy <- tf$equal(y_true, tf$argmax(y_pred, axis = 2L))
accuracy <- tf$math$logical_and(mask, accuracy)
accuracy <- tf$cast(accuracy, dtype = tf$float32)
mask <- tf$cast(mask, dtype = tf$float32)
tf$reduce_sum(accuracy) / tf$reduce_sum(mask)
},
.compute_caption_loss_and_acc = function(img_embed, batch_seq, training = TRUE) {
encoder_out <- self$encoder(img_embed, training = training)
batch_seq_inp <- batch_seq[, NULL:-2]
batch_seq_true <- batch_seq[, 2:NULL]
mask <- tf$math$not_equal(batch_seq_true, 0L)
batch_seq_pred <- self$decoder(
batch_seq_inp, encoder_out, training = training, mask = mask
)
loss <- self$calculate_loss(batch_seq_true, batch_seq_pred, mask)
acc <- self$calculate_accuracy(batch_seq_true, batch_seq_pred, mask)
list(loss, acc)
},
train_step = function(batch_data) {
batch_img <- batch_data[[1]]
batch_seq <- batch_data[[2]]
batch_loss <- 0
batch_acc <- 0
if (!is.null(self$image_aug)){
batch_img <- self$image_aug(batch_img)
}
# 1. Get image embeddings
img_embed <- self$cnn_model(batch_img)
# 2. Pass each of the five captions one by one to the decoder
# along with the encoder outputs and compute the loss as well as accuracy
# for each caption.
for (i in seq_len(self$num_captions_per_image)) {
with(tf$GradientTape() %as% tape, {
c(loss, acc) %<-% self$.compute_caption_loss_and_acc(
img_embed, batch_seq[, i, ], training = TRUE
)
# 3. Update loss and accuracy
batch_loss <- batch_loss + loss
batch_acc <- batch_acc + acc
})
# 4. Get the list of all the trainable weights
train_vars <- c(self$encoder$trainable_variables,
self$decoder$trainable_variables)
# 5. Get the gradients
grads <- tape$gradient(loss, train_vars)
# 6. Update the trainable weights
self$optimizer$apply_gradients(zip_lists(grads, train_vars))
}
# 7. Update the trackers
batch_acc <- batch_acc/self$num_captions_per_image
self$loss_tracker$update_state(batch_loss)
self$acc_tracker$update_state(batch_acc)
# 8. Return the loss and accuracy values
list(
loss = self$loss_tracker$result(),
acc = self$acc_tracker$result()
)
},
test_step = function(batch_data) {
batch_img <- batch_data[[1]]
batch_seq <- batch_data[[2]]
batch_loss <- 0
batch_acc <- 0
# 1. Get image embeddings
img_embed <- self$cnn_model(batch_img)
# 2. Pass each of the five captions one by one to the decoder
# along with the encoder outputs and compute the loss as well as accuracy
# for each caption.
for (i in seq_len(self$num_captions_per_image)) {
with(tf$GradientTape() %as% tape, {
c(loss, acc) %<-% self$.compute_caption_loss_and_acc(
img_embed, batch_seq[, i, ], training = TRUE
)
# 3. Update loss and accuracy
batch_loss <- batch_loss + loss
batch_acc <- batch_acc + acc
})
}
batch_acc <- batch_acc / self$num_captions_per_image
# 4. Update the trackers
self$loss_tracker$update_state(batch_loss)
self$acc_tracker$update_state(batch_acc)
# 5. Return the loss and accuracy values
list(
"loss" = self$loss_tracker$result(),
"acc" = self$acc_tracker$result()
)
},
metrics = mark_active(function() {
# We need to list our metrics here so the `reset_states()` can be
# called automatically.
list(self$loss_tracker, self$acc_tracker)
})
)
cnn_model <- get_cnn_model()
encoder <- transformer_encoder_block(embed_dim = EMBED_DIM, dense_dim = FF_DIM, num_heads = 1)
decoder <- transformer_decoder_block(embed_dim = EMBED_DIM, ff_dim = FF_DIM, num_heads = 2)
caption_model <- image_captioning_model(
cnn_model = cnn_model,
encoder = encoder,
decoder = decoder,
image_aug = image_augmentation
)
Model training
# Define the loss function
cross_entropy <- loss_sparse_categorical_crossentropy(
from_logits = FALSE, reduction = "none"
)
# EarlyStopping criteria
early_stopping <- callback_early_stopping(patience = 3, restore_best_weights = TRUE)
# Learning Rate Scheduler for the optimizer
lr_schedule <- new_learning_rate_schedule_class(
"lr_schedule",
initialize = function(post_warmup_learning_rate, warmup_steps) {
super()$`__init__`()
self$post_warmup_learning_rate <- post_warmup_learning_rate
self$warmup_steps <- warmup_steps
},
call = function(step) {
global_step <- tf$cast(step, tf$float32)
warmup_steps <- tf$cast(self$warmup_steps, tf$float32)
warmup_progress <- global_step / warmup_steps
warmup_learning_rate <- self$post_warmup_learning_rate * warmup_progress
tf$cond(
global_step < warmup_steps,
function() warmup_learning_rate,
function() self$post_warmup_learning_rate
)
}
)
# Create a learning rate schedule
num_train_steps <- length(train_dataset) * EPOCHS
num_warmup_steps <- num_train_steps %/% 15
lr <- lr_schedule(post_warmup_learning_rate = 1e-4, warmup_steps = num_warmup_steps)
# Compile the model
caption_model %>% compile(
optimizer = optimizer_adam(learning_rate = lr),
loss = cross_entropy
)
# Fit the model
caption_model %>% fit(
train_dataset,
epochs = EPOCHS,
validation_data = valid_dataset,
callbacks = list(early_stopping)
)
Check sample predictions
vocab <- get_vocabulary(vectorization)
max_decoded_sentence_length <- SEQ_LENGTH - 1
valid_images <- valid_data$img
generate_caption <- function() {
# Select a random image from the validation dataset
sample_img <- sample(valid_images, 1)
# Read the image from the disk
sample_img <- decode_and_resize(sample_img)
img <- as.array(tf$clip_by_value(sample_img, 0, 255))
img %>% as.raster(max = 255) %>% plot()
# Pass the image to the CNN
img <- tf$expand_dims(sample_img, 0L)
img <- caption_model$cnn_model(img)
# Pass the image features to the Transformer encoder
encoded_img <- caption_model$encoder(img, training = FALSE)
# Generate the caption using the Transformer decoder
decoded_caption <- "<start> "
for (i in seq_len(max_decoded_sentence_length)) {
tokenized_caption <- vectorization(list(decoded_caption))
mask <- tf$math$not_equal(tokenized_caption, 0L)
predictions <- caption_model$decoder(
tokenized_caption, encoded_img, training = FALSE, mask = mask
)
sampled_token_index <- tf$argmax(predictions[1, i, ])
sampled_token <- vocab[as.integer(sampled_token_index) + 1]
if (sampled_token == " <end>") {
break
}
decoded_caption <- paste(decoded_caption, sampled_token, sep = " ")
}
cat("Predicted Caption: ", decoded_caption)
}
# Check predictions for a few samples
generate_caption()
generate_caption()
generate_caption()
End Notes
We saw that the model starts to generate reasonable captions after a few epochs. To keep this example easily runnable, we have trained it with a few constraints, like a minimal number of attention heads. To improve the predictions, you can try changing these training settings and find a good model for your use case.
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4