13 min read

ArXiv Dives: I-JEPA

ArXiv Dives: I-JEPA

Today, we’re diving into the I-JEPA paper. JEPA stands for Joint-Embedding Predictive Architecture and if you have been following Yann LeCunn, is a technique he has been hyping up for awhile. Excited to dive into what the hype is about and look at the technique.

Teams: Meta AI

Publish Date: April 13th, 2023

Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture
This paper demonstrates an approach for learning highly semantic image representations without relying on hand-crafted data-augmentations. We introduce the Image-based Joint-Embedding Predictive Architecture (I-JEPA), a non-generative approach for self-supervised learning from images. The idea behind I-JEPA is simple: from a single context block, predict the representations of various target blocks in the same image. A core design choice to guide I-JEPA towards producing semantic representations is the masking strategy; specifically, it is crucial to (a) sample target blocks with sufficiently large scale (semantic), and to (b) use a sufficiently informative (spatially distributed) context block. Empirically, when combined with Vision Transformers, we find I-JEPA to be highly scalable. For instance, we train a ViT-Huge/14 on ImageNet using 16 A100 GPUs in under 72 hours to achieve strong downstream performance across a wide range of tasks, from linear classification to object counting and depth prediction.

ArXiv Dives

Every Friday at Oxen.ai we host a paper club called "ArXiv Dives" to make us smarter Oxen 🐂 🧠. We believe diving into the details of research papers is the best way to build fundamental knowledge, spot patterns and keep up with the bleeding edge.

These are the notes from our live session, feel free to follow along with the video for context. If you would like to join live to ask questions or join the discussion we would love to have you! Sign up below 👇

Oxen.ai · Events Calendar
View and subscribe to events from Oxen.ai on Luma. Build World-Class AI Datasets, Together. Track, iterate, collaborate on, & discover data in any format.

Semantic Image Representations

The goal of this paper is to create “highly semantic image representations” without relying on hand-crafted data-augmentations. They want to train a neural network in a self-supervised manner that does not require hand labeled training data.

If you have high quality semantic image representations, they can be used for many downstream tasks.

  1. Image Generation
  2. Semantic similarity search
  3. Image classification
  4. Semantic segmentation
  5. Object detection
  6. Depth estimation

Name your computer vision task - the process of abstracting pixel space into a latent space is very important.

Instead of representing images in pixel space, neural networks learn a function to transform them into a latent representation. The hope is that this representation contains more semantics about the image itself. For example the semantic representation would contain information like "dog ear" rather than "white RGB pixel".

Yann LeCunn is a big proponent of JEPA architectures and uses the analogy of planning or driving a car. While planning you think of things at multiple levels of abstraction. For natural language, if you were to plan a trip from LA → Paris, you would have to break it down into sub problems each with their own characteristics. How do I get from my house to the airport? How do I get from my chair to the door? Each one of these “thoughts” could be it’s own latent space that is processed hierarchically.

Another example he likes to give is when we are driving a car, we take in all the “pixels” of the image around us, but tend to drop out and only use certain information to make our next decision. We don’t pay attention to every leaf blowing on the tree when deciding to make a left hand turn. The dream is to allow a computer vision system to be able to process pixels to the level of semantic abstraction that is needed for the task at hand.

While this paper does not perform any of the planning explicitly, the argument is that higher quality semantic representations will allow systems to plan better in the future.

Previous Approaches

In the paper, they compare this approach to previous approaches of Invariance Based Pre-Training as well as Generative Modeling. Arguing JEPA creates more robust semantic representations.

Invariance Based Pre-Training

They refer to invariance based pre-training in this paper, which in this context means that you can recognize an object as an object, even when its appearance varies in some way. Common approaches to learning invariant representations are applying distortions to images such as random cropping, scaling, color shifts etc and then trying to predict the same embedding for the non-distorted image and the distorted one, but it is unclear how well this improves the abstraction of the representation.

Generative Pre-Training

Generative models of the world have shown great promise in terms of self-supervised pre-training, where you simply corrupt the whole image (with noise or other techniques) and try to reconstruct the image from the noise.

They argue that these generative techniques do not capture the semantic meaning within the images, because the latent vectors tend to perform worse on tasks like classification. You often need to do a second step of fine tuning on these representations to get competitive image classification scores.

I-JEPA

This work tries to improve the semantic representation of the latent vectors during self-supervision by taking a single “context block” from an image and trying to predict the representation of various target context blocks, where the representations are learned by a target encoder network.

The previous approaches of invariance based training and generative modeling can be seen in figure 2.

They mention the “energy” of the system, this is just a fancy way of saying they assign a high number to incompatible inputs and a low number to compatible inputs. For example compatible inputs might be an image and a piece of text that represent the same thing.

The joint-embedding architecture tries to make the distance between the encoded vectors small given two inputs. This could be an example of invariance based pre-training where you corrupt X and try to predict Y. X and Y do not even necessarily be the same type of data as we saw in CLIP where X and Y are image-text pairs.

Generative architectures tend to compare the output of the decoder with the Y value you are trying to deconstruct through methods like MAE (Mean Average Error). This is not the best metric for evaluating how well an image is generated because drastic changes in pixel space can still be semantically the same image.

Joint-Embedding Predictive Architectures (JEPA) try to merge the best of both worlds, where you encode X and Y, but add in a predictor that learns to predict the embeddings of Y from the input X, not just the pure similarity score.

Method

The model architecture for the context-encoder, target-decoder, and predictor blocks are all Vision Transformers (ViT). If you are not familiar with ViT's or Transformers in general we have some other dives where we get into the gnitty gritty details.

Arxiv Dives - Vision Transformers (ViT) | Oxen.ai
With all of the hype around Transformers for natural language processing and text, the authors of this paper beg the question - can we apply self-attention and Transformers to images as well? This post dives into how it works and will give you an intuition on why it’s useful, and how it can be applied in your own work. TLDR ~ Transformers work just as well on images, given enough data 😎. What is Arxiv Dives? Every Friday at Oxen.ai we host a paper club called “Arxiv Dives” to make us smarte

The architecture of the ViT is similar to that of the MAE paper on Masked Auto Encoder, however the predictions are done in the latent space not at the raw pixel values. We have seen the MAE paper referenced a good bit in the literature so maybe a useful dive in the future.

Masked Autoencoders Are Scalable Vision Learners
This paper shows that masked autoencoders (MAE) are scalable self-supervised learners for computer vision. Our MAE approach is simple: we mask random patches of the input image and reconstruct the missing pixels. It is based on two core designs. First, we develop an asymmetric encoder-decoder architecture, with an encoder that operates only on the visible subset of patches (without mask tokens), along with a lightweight decoder that reconstructs the original image from the latent representation and mask tokens. Second, we find that masking a high proportion of the input image, e.g., 75%, yields a nontrivial and meaningful self-supervisory task. Coupling these two designs enables us to train large models efficiently and effectively: we accelerate training (by 3x or more) and improve accuracy. Our scalable approach allows for learning high-capacity models that generalize well: e.g., a vanilla ViT-Huge model achieves the best accuracy (87.8%) among methods that use only ImageNet-1K data. Transfer performance in downstream tasks outperforms supervised pre-training and shows promising scaling behavior.

Sampling Context and Targets

First they take image patches, just like the first step in any Vision Transformer. They sample 4 target blocks that may be overlapping. The "blocks" are not the patches themselves, but they are the set of representations (s_0 .. s_n) of each patch.

Then they sample a context block that is not overlapping with any of the targets. The context block is anywhere from 85%-100% of the image. After it is is sampled, they remove all the patches that overlap with the targets since those would be too easy to predict.

Prediction

Back to the image from above, there are M target block representations we wish to predict from our single context block. In the case of the image below, M=3.

For a given target block, you feed in the context and a mask of the target so that there is not overlapping information. Since there are M target blocks, the predictor is applied M times, each time with a separate mask given the context and the target.

The predictors job is the predict the latent space of the target blocks given what it knows about the context.

Where do the latent spaces come from?

If you have 16x16x3 patch size you would flatten each patch into 768 values.

If you had a 224x224 image with a patch size of 16, 224/16=14 means you have 14x14 patches. This creates a 768x196 matrix that can be fed into the transformer.

Source: https://yurkovak.medium.com/vision-transformer-vit-under-the-magnifying-glass-part-1-70be8d6661a7

Each image patch can call out to each other patch and say “hey, I am a fluffy ear, any noses to go along with it to help me decide if I’m a cat or dog?” Then the nose patch says, “I’m a nose! Let us combine our powers into a new representation.

The way these patches communicate with each other to create their semantic representation are through transformer blocks.

Transformer Block

At the end of the transformer block, you get h(x_i) which is the hidden space that we are concerned with. This latent vector is an updated representation that has taken in the context of all the other patches. The idea is this updated representation will have passed through context from all the other patches to update the representation from "furry ear" to "furry dog ear".

Attention Head

The attention heads are the real magic of the transformer block. I updated the diagram we went through in our Mechanistic Interpretability of Transformers series to take in image patches instead of words in a sentence.

If you want to learn what all the fancy math in here means, feel free to checkout the past dives.

Arxiv Dives - A Mathematical Framework for Transformer Circuits - Part 1 | Oxen.ai
Every Friday at Oxen.ai we host a paper club called “Arxiv Dives” to make us smarter Oxen 🐂 🧠. We believe diving into the details of research papers is the best way to build fundamental knowledge and keep up with the bleeding edge. If you would like to join the discussion live, sign up here. Every week there are great minds from companies like Amazon, Doordash, Google, MIT, NVIDIA, Tesla, and many more. The following are the notes from the live session. Feel free to follow along with the vid

In the appendix they say that the embedding dimension of the predictor is 384, and the depth (number of transformer blocks) is anywhere between 6-12 in various experiments.

Loss

The loss is a simple L2 distance between the patch level representations within the blocks. So looking at our blocks again, you sum over the M target blocks, and compute the L2 value between the vectors of each patch in the context and the target, then take the average over each target block.

Evaluation on Image Classification

To study how good the image representations are, they report results on various image classification tasks with a “linear probe” and partial fine tuning. All I-JEPA models are trained at 224x224 resolution, except for ViT-H/16 which is trained at a resolution of 448x448. They compare the results to many approach’s before it including the MAE work that does similar operations in pixel space.

They evaluate the classification accuracy on ImageNet by using an average pooled representation of the output tokens.

I-JEPA demonstrates it can match performance of view invariance approaches without requiring data augmentation.

You can see it also takes less epochs to get to higher accuracy.

One very convincing ablation study was that they look at the linear probe when training I-JEPA in pixel space vs latent-space. Latent space blows it out of the water in less epochs.

It is also more computationally efficient and learns more semantic off-the-shelf representations.

Conclusion

After reading the diffusion transformers paper as well as the I-JEPA paper, it is clear that working in latent spaces rather than pixel spaces is:

  1. More efficient
  2. Gives higher quality semantics

We had a Discord community member @johnweak15 take a stab at implementing I-JEPA on his own and found some interesting take aways.

GitHub - Ugenteraan/I-JEPA: PyTorch implementation of I-JEPA
PyTorch implementation of I-JEPA. Contribute to Ugenteraan/I-JEPA development by creating an account on GitHub.

The main question we had was trying to understand how the model gets out of a random initialization state and avoids mode collapse. If the context encoder and the target encoder are randomly initialized aren't the latent representations going to be also going to be random at the start? So what is the ground truth latent space you are trying to optimize for? How does the network get there from a random state?

It feels like the network could either solve by mode collapsing all the latent spaces down to a single value or never have enough signal to give you something meaningful. Any clarity on initialization or intuition on how the underlying mechanism works would be appreciated! Feel free to join our Discord if you have any thoughts or answers.

Next Up

To continue the conversation, we would love you to join our Discord! There are a ton of smart engineers, researchers, and practitioners that love diving into the latest in AI.

Join the oxen Discord Server!
Check out the oxen community on Discord - hang out with 616 other members and enjoy free voice and text chat.

If you enjoyed this dive, please join us next week live! We always save time for questions at the end, and always enjoy the live discussion where we can clarify and dive deeper as needed.

Arxiv Dives with Oxen.ai · Luma
Hey Nerd, join the Herd!... for a little book/paper review. Make sure to also join our Discord here (https://discord.gg/s3tBEn7Ptg) to share recommendations for future reads and more…

All the past dives can be found on the blog.

Oxen.ai Blog | Oxen.ai
Manage your machine learning datasets with Oxen AI.

The live sessions are posted on YouTube if you want to watch at your own leisure.

Oxen
Each week we dive deep into a topic in machine learning or general artificial intelligence research. The sessions are live with a group of smart Oxen every Friday. Join the discussion: https://lu.ma/oxenbookclub

Best & Moo,

~ The herd at Oxen.ai

Who is Oxen.ai?

Oxen.ai is an open source project aimed at solving some of the challenges with iterating on and curating machine learning datasets. At its core Oxen is a lightning fast data version control tool optimized for large unstructured datasets. We are currently working on collaboration workflows to enable the high quality, curated public and private data repositories to advance the field of AI, while keeping all the data accessible and auditable.

If you would like to learn more, star us on GitHub or head to Oxen.ai and create an account.

GitHub - Oxen-AI/oxen-release: Lightning fast data version control system for structured and unstructured machine learning datasets. We aim to make versioning datasets as easy as versioning code.
Lightning fast data version control system for structured and unstructured machine learning datasets. We aim to make versioning datasets as easy as versioning code. - GitHub - Oxen-AI/oxen-release:…