Visual question answering with multimodal transformers
PyTorch implementation of VQA models using text and image transformers from Hugging Face
Recent years have seen significant advancements not only in the respective domains of Natural Language Processing (NLP) and Computer Vision (CV) but also in tasks involving multiple modalities (text + image features) such as image captioning, visual question answering (VQA), cross-modal retrieval, visual common-sense reasoning, and more. Among these, VQA has particularly drawn the interest of several researchers.
What is VQA?
VQA is a multimodal task wherein, given an image and a natural language question related to the image, the objective is to produce a natural language answer correctly as output.
It involves understanding the content of the image and correlating it with the context of the question asked. Because we need to compare the semantics of information present in both of the modalities — the image and natural language question related to it — VQA entails a wide range of sub-problems in both CV and NLP (such as object detection and recognition, scene classification, counting, and so on). Thus, it is considered an task.
VQA with multimodal fusion models
Multimodal models can be of various forms to capture information from the text and image modalities, along with some cross-modal interaction as well. In fusion models, the information from the text and image encoders are fused into a combined representation to perform the downstream task.
A typical fusion model for a VQA system involves the following steps:
- Featurization of image and question: We need to extract features from the image and obtain the embeddings of the question after tokenization. The question can be featurized using simple embeddings (like GLoVe), Seq2Seq models (like LSTMs), or transformers. Similarly, the image features can be extracted using simple CNNs (convolutional neural networks), early layers of object detection or image classification models, or image transformers.
- Feature fusion: Since VQA involves a comparison of the semantic information present in the image and the question, there is a need to jointly represent the features from both modalities. This is usually accomplished through a fusion layer that allows cross-modal interaction between image and text features to generate a fused multimodal representation.
- Answer generation: Depending on the modelling of the VQA task, the correct answers could either be generated purely using natural language generation (for longish or descriptive answers) or using a simple classifier model (for one-word/phrase answers present in a fixed answer space).
Following are some methods used to perform the individual feature extraction and feature fusion steps:
In this article, I explore the idea of late fusion by fine-tuning pretrained text and image transformer models, as they are simpler to train.
With this background in place, it’s time to delve into the code and implement our multimodal model for VQA. First, we process the . Because all the questions have single word/phrase-type answers, we consider the entire vocabulary of answers available (answer space) and treat them as labels. This converts visual question answering into a multiclass classification problem. We then train our multimodal transformer model and evaluate it using some established metrics for VQA. Toward the end, we compare and explain the results for various combinations of textual and image transformers used for featurization.
Tl;dr: contains all the code mentioned in this article. Although GitHub gists are used as code snippets throughout this article, if copied directly, they may not work as intended. Please refer to present in the repository for working implementations.
Preliminaries
Installing required packages
We need to create a virtual environment and install the required packages:
datasets==1.17.0
nltk==3.5
pandas==1.3.5
Pillow==9.0.0
scikit-learn==0.23.2
torch==1.8.2+cu111
transformers==4.14.0
Note: It is advisable to have some GPU access to train the multimodal models because they are large and require considerable time for training otherwise.
Setting up the environment
To set up the environment for training our multimodal VQA model, we need to import the required modules and set the appropriate device
for PyTorch.
Data preparation
For the VQA model training, we use the full dataset, which contains approximately 12,500 question-answer pairs based on images from the .
Preprocessing the dataset
The raw dataset contains the actual images separately in the images/
directory. All the question-answer pairs are present on consecutive lines in a .txt
file as shown below:
what is on the desk and behind the black cup in the image4 ?
bottle
what is in front of the monitor in the image6 ?
keyboard
...
We run the following script to pre-process these question-answer pairs. It normalizes the questions by removing the image IDs present in the question. The questions and answers, along with the corresponding image IDs extracted during normalization, are stored in a tabular (CSV) format. Moreover, because the original DAQUAR dataset provides only about 54 percent of the question-answer pairs for training (this amounts to only around 6700 samples, which is very less for training), we produce our custom split (80 percent training and 20 percent evaluation) from the overall data.
This script produces data_train.csv
and data_eval.csv
files, along with answer_space.txt
, containing a vocabulary of all the answers.
These files are already available in the
dataset/
folder of the repository for direct consumption and .
Loading the data
Now we are set to load this data using this processed dataset. For this, we use the datasets
library from Hugging Face. Since we model this task as a multiclass classification task, we should assign labels to every answer. These labels are derived from the indices of the answers in the answer space.
We can also inspect entries present in our training or evaluation dataset (specific or random) using Jupyter notebook:
Defining a multimodal collator for data
Up to this point, we have just loaded the questions, answers, and corresponding image IDs, along with the labels. To feed the information about the question and actual images batchwise into our multimodal model, we need to define a data collator.
This collator will process the question (text) and the image and return the tokenized text (with attention masks) along with the featurized image (basically, the pixel values). These will be fed into our multimodal transformer model for question answering.
We use AutoTokenizer
and AutoFeatureExtractor
from Hugging Face transformers
to convert the raw images and questions into inputs for featurization using the respective image and text transformers.
Defining the multimodal VQA model architecture
As mentioned previously, we use the idea of late fusion to define our multimodal model comprising:
- A text transformer to encode the question and generate embeddings
- An image transformer to encode the image and generate features
- A reasonably simple fusion layer that concatenates the textual and image features and passes them through a linear layer to generate an intermediate output
- A classifier, which is a fully connected network with output having the dimensions equal to that of the answer-space
We model VQA as a multiclass classification task. Thus, cross-entropy loss becomes a natural choice for the loss function to be minimized.
Besides training a particular VQA model with multimodal transformers, we intend to experiment with various pre-trained model combinations and evaluate their performance on the DAQUAR dataset.
Pretrained models for textual encoding
Pretrained models for image encoding
Creating the collator and multimodal model
Because we aim to experiment with multiple combinations of text and image transformers, it is reasonable to implement a function for creating the corresponding collators with the respective models.
For demonstration in this article, we will create the collator and model using the tokenizer, feature extractor, and models from pretrained BERT and ViT.
Evaluation metrics
We approach the VQA task as a multiclass classification problem in this article. Hence, accuracy and macro F1 score are straightforward choices as metrics for evaluating the performance of our model. However, because these metrics may often be too restrictive, penalizing almost correct answers (‘tree’ versus ‘plant’) as heavily as incorrect answers (‘tree’ versus ‘table’), we select a metric like WUPS as our primary evaluation metric. Such a metric considers the semantic similarity between the predicted answer and the ground truth.
Wu and Palmer Similarity (WUPS) Score
One option to evaluate open-ended natural language answers is to perform exact string matching. However, it is too stringent and cannot capture the semantic relatedness between the predicted answer and the ground truth. This prompts the use of other metrics that capture the semantic similarity of strings effectively. One such commonly used metric is the Wu and Palmer Similarity (WUPS) Score.
WUPS computes the semantic similarity between two words or phrases based on their longest common subsequence in the taxonomy tree. This score works well for single-word answers (hence, we use it for our task), but may not work for phrases or sentences.
Although nltk
has an implementation of the WUPS based on the WordNet taxonomy, for our experimentation, we use the through the wup_measure(...)
function.
Training the multimodal VQA model
We finally come to the part where we use the previously defined functions to initialize our multimodal model and train it using the Trainer from Hugging Face to abstract away most of the code required for setting up a PyTorch training loop. The hyperparameters such as training epochs, batch size, and so on, are passed to the Trainer
by setting the corresponding values in the TrainingArguments
.
For this article, we use the following hyperparameters:
The training of this BERT + ViT model (and all other combinations of transformers) was carried out on an GPU.
The model checkpoints are saved periodically in the indicated output directory based on the information provided in the TrainingArguments
.
Making inferences using trained model
To use any of the saved model checkpoints for inferencing, the question must be tokenized, and image features must be extracted appropriately (as done in the collator). These would serve as input to the model, with weights loaded from the trained checkpoint. The label predicted by the model is then mapped to the index of the actual answer in the answer space.
Comparing the performance of various models
A similar approach is followed to train VQA models with various combinations of text and image transformers by changing the text
and image
arguments while calling the createMultimodalVQACollatorAndModel(...)
function.
The table below summarizes the performance of these different models. These details can also be found on .
- RoBERTa + BEiT performs the best in terms of both WUPS and accuracy.
- RoBERTa-based models generally perform better than the rest. This can be attributed to the larger number of trainable parameters and the embeddings generated through more robust pre-training.
- ALBERT-based models are expected to have lower performance because ALBERT is much smaller compared to BERT and RoBERTa. Yet, the ALBERT + ViT model can achieve scores comparable to the BERT + ViT model, despite having only around half the number of parameters.
- For BERT and RoBERTa-based text transformers, the best results are achieved using BEiT as the image transformer. However, it does not perform up to the mark with ALBERT. This could indicate that higher quality textual embeddings are required to complement the image embeddings generated by BEiT.
Concluding remarks
In summary, we successfully implemented, trained, and evaluated a late fusion type of multimodal transformer model in PyTorch for visual question answering using the DAQUAR dataset. We also learned how to use the model weights from a trained checkpoint to answer questions related to an image. Last, we compared the performance of several models using different text and image transformers to featurize the question and image before performing fusion.
I hope this article has given you a good overview of some of the concepts involved in visual question answering and helped you understand the nuances of training multimodal transformer models in PyTorch for such a task. Feel free to check out the References section below for more details regarding concepts and terms that I might have breezed through in this article. Please leave any feedback or suggestions in the comments section below.
All the code mentioned in this article is available in . The performance metrics for different combinations of transformers can be found on . Feel free to clone the repository, follow the steps mentioned in the
README.md
and tweak theparams.yaml
file to experiment with different model architectures for the VQA task.