InfoNCE

Noise Contrastive Estimation

πŸ”₯ Apa itu InfoNCE?

InfoNCE (Info Noise Contrastive Estimation) adalah loss function untuk contrastive learning yang memaksimalkan mutual information antara positive pairs sambil meminimalkan similarity dengan negative samples.

Core Idea:

Pull positive pairs closer together,
Push negative pairs farther apart!

πŸ“œ History

InfoNCE diperkenalkan dalam berbagai influential papers:

  • πŸ“Š CPC (2018): Contrastive Predictive Coding - prediktif representasi
  • 🎨 MoCo (2019): Momentum Contrast - self-supervised vision
  • πŸ–ΌοΈ SimCLR (2020): Simple framework for contrastive learning
  • 🌐 CLIP (2021): Vision-language alignment dengan InfoNCE

πŸ’‘ Mengapa InfoNCE Powerful?

βœ… Positive Pairs

Anchor & positive sample harus punya similarity tinggi

Example: dua augmentasi dari image yang sama

❌ Negative Pairs

Anchor & negative samples harus punya similarity rendah

Example: images dari classes berbeda

🎯 Yang Akan Dipelajari

πŸ“

Mathematical Foundation

InfoNCE derivation

βš™οΈ

Contrastive Mechanics

Pull/push forces

🌑️

Temperature Ο„

Controlling sharpness

πŸ’»

PyTorch Code

Implementation

Mathematical Foundation

InfoNCE Derivation

πŸ“ InfoNCE Formula

InfoNCE adalah loss function yang di-derive dari Noise Contrastive Estimation:

L = -log(exp(sim(q, k⁺)/Ο„) / Ξ£β±Ό exp(sim(q, kβ±Ό)/Ο„))
sim(q, k) = (qΒ·k) / (||q|| ||k||) (cosine similarity)
Ο„: temperature parameter

Minimize loss β†’ maximize similarity dengan positive, minimize dengan negatives

πŸ” Component Breakdown

  • q: Query/anchor embedding
  • k⁺: Positive key embedding (matched pair)
  • kᡒ⁻: Negative key embeddings (i = 1...N-1)
  • sim(Β·,Β·): Similarity function (biasanya cosine)
  • Ο„: Temperature (controls distribution sharpness)

πŸ’­ Intuition: Softmax Classification

InfoNCE bisa dipandang sebagai N-way classification problem:

logits = [sim(q,k₁), sim(q,kβ‚‚), ..., sim(q,k_N)] / Ο„
L = CrossEntropy(logits, target=positive_index)

Model harus "classify" mana yang positive pair di antara N candidates!

πŸ“Š Mutual Information Perspective

InfoNCE memaksimalkan lower bound dari mutual information I(q; k⁺):

Goal: Maximize I(q; k⁺)

InfoNCE loss = -log P(positive | q, {k₁, ..., k_N})
Minimizing InfoNCE β‰ˆ Maximizing MI

Contrastive Mechanics

Pull Positive, Push Negative

βš™οΈ How Contrastive Learning Works

Contrastive learning training menggunakan pull dan push forces:

➑️ Pull Force

Positive pairs ditarik lebih dekat di embedding space

Gradient mendorong:
sim(anchor, positive) β†’ 1.0
⬅️ Push Force

Negative pairs didorong lebih jauh

Gradient mendorong:
sim(anchor, negative) β†’ 0.0

πŸ“¦ Batch Construction

Untuk batch size N, kita punya:

  • βœ… N positive pairs: (anchor₁, positive₁), ..., (anchor_N, positive_N)
  • ❌ NΓ—(N-1) negative pairs: Semua kombinasi lain dalam batch!

Efficient Negatives:

Dengan batch size N=256, setiap sample punya:

  • 1 positive pair
  • 255 negative pairs (in-batch negatives)

Tidak perlu explicit negative sampling!

Example: SimCLR Augmentation

Dalam self-supervised vision learning (SimCLR):

  1. Ambil batch N images
  2. Create 2 augmented views per image β†’ 2N total
  3. Positive: dua views dari same image
  4. Negative: views dari different images
Image "dog.jpg"
↓
[Augment 1: crop+flip] & [Augment 2: color+rotate]
↓
These are positive pairs!

🎬 Contrastive Forces Animation

Visualize pull/push forces in embedding space

Temperature Parameter

Controlling Distribution Sharpness

🌑️ What is Temperature Ο„?

Temperature Ο„ (tau) adalah parameter yang mengontrol "sharpness" dari softmax distribution di InfoNCE loss.

Softmax = exp(sim/Ο„) / Ξ£ exp(sim/Ο„)

Ο„ kecil β†’ sharp distribution
Ο„ besar β†’ smooth distribution

πŸ“Š Temperature Effects

Low Temperature (Ο„ = 0.07)

  • βœ… Sharp distribution
  • βœ… Very confident predictions
  • βœ… Better differentiation
  • ❌ Harder optimization
P_positive β‰ˆ 0.95, P_negative β‰ˆ 0.001

High Temperature (Ο„ = 1.0)

  • βœ… Smooth distribution
  • βœ… Easier optimization
  • βœ… Gradients spread out
  • ❌ Less confident
P_positive β‰ˆ 0.60, P_negative β‰ˆ 0.08

πŸŽ›οΈ Temperature Interactive Demo

Temperature Ο„
0.07
Adjust slider to see temperature effect on distribution

πŸ”§ Learnable vs Fixed Temperature

Temperature bisa di-set sebagai:

  • πŸ”’ Fixed: Ο„ = 0.07 (umum di SimCLR, CLIP)
  • πŸ“ˆ Learnable: Ο„ sebagai nn.Parameter (CLIP approach)

CLIP's Approach:

Learnable log-scale temperature:
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1/0.07))
temperature = logit_scale.exp()

Symmetric Loss

Bidirectional Alignment

βš–οΈ Why Symmetric?

Dalam multimodal learning (e.g., CLIP), kita punya dua modalities: image & text. Symmetric loss ensure alignment dua arah!

L = (L_I→T + L_T→I) / 2
L_I→T: InfoNCE dari image ke text (row-wise)
L_T→I: InfoNCE dari text ke image (column-wise)

πŸ“Š Similarity Matrix

Untuk batch N=4, similarity matrix S (4Γ—4):

T₁
Tβ‚‚
T₃
Tβ‚„
I₁
0.91
0.15
0.08
0.12
Iβ‚‚
0.18
0.89
0.11
0.09
I₃
0.10
0.14
0.93
0.07
Iβ‚„
0.13
0.16
0.06
0.95

➑️ Row-wise Loss (Image β†’ Text)

For each image, classify which text is the match:

L_I→T = -(1/N) Σᡒ log(exp(Sᡒᡒ/τ) / Σⱼ exp(Sᡒⱼ/τ))

Softmax across each row

⬇️ Column-wise Loss (Text β†’ Image)

For each text, classify which image is the match:

L_T→I = -(1/N) Σⱼ log(exp(Sⱼⱼ/τ) / Σᡒ exp(Sᡒⱼ/τ))

Softmax across each column

🎬 Symmetric Loss Visualization

Visualize row-wise and column-wise softmax

Implementation

PyTorch Code

πŸ’» InfoNCE PyTorch Implementation

import torch
import torch.nn as nn
import torch.nn.functional as F

class InfoNCE(nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature
    
    def forward(self, query, keys):
        """
        Compute InfoNCE loss.
        
        Args:
            query: (batch, dim) - anchor embeddings
            keys: (batch, dim) - key embeddings
        
        Returns:
            loss: scalar InfoNCE loss
        """
        # Normalize embeddings
        query = F.normalize(query, dim=-1)
        keys = F.normalize(keys, dim=-1)
        
        # Compute cosine similarity matrix
        logits = query @ keys.T / self.temperature  # (batch, batch)
        
        # Diagonal are positive pairs
        labels = torch.arange(len(query), device=query.device)
        
        # Cross-entropy loss
        loss = F.cross_entropy(logits, labels)
        
        return loss

πŸ“Š Symmetric InfoNCE (CLIP-style)

class SymmetricInfoNCE(nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature
    
    def forward(self, image_embeds, text_embeds):
        """
        Symmetric contrastive loss.
        
        Args:
            image_embeds: (N, dim)
            text_embeds: (N, dim)
        
        Returns:
            loss: symmetric InfoNCE loss
        """
        # Normalize
        image_embeds = F.normalize(image_embeds, dim=-1)
        text_embeds = F.normalize(text_embeds, dim=-1)
        
        # Similarity matrix
        logits = image_embeds @ text_embeds.T / self.temperature  # (N, N)
        
        # Labels: diagonal indices
        labels = torch.arange(len(image_embeds), device=image_embeds.device)
        
        # Row-wise (image β†’ text)
        loss_i2t = F.cross_entropy(logits, labels)
        
        # Column-wise (text β†’ image)
        loss_t2i = F.cross_entropy(logits.T, labels)
        
        # Symmetric average
        loss = (loss_i2t + loss_t2i) / 2
        
        return loss

πŸŽ“ Training Loop Example

def train_contrastive(model, dataloader, optimizer, device):
    """Training loop dengan InfoNCE."""
    model.train()
    criterion = SymmetricInfoNCE(temperature=0.07).to(device)
    
    for images, texts in dataloader:
        images = images.to(device)
        texts = texts.to(device)
        
        # Forward: get embeddings
        image_embeds = model.encode_image(images)  # (N, dim)
        text_embeds = model.encode_text(texts)      # (N, dim)
        
        # Compute symmetric InfoNCE loss
        loss = criterion(image_embeds, text_embeds)
        
        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        print(f"Loss: {loss.item():.4f}")
    
    return loss.item()

πŸ”§ Learnable Temperature

import numpy as np

class LearnableTemperature(nn.Module):
    def __init__(self, init_temp=0.07):
        super().__init__()
        # Log-scale learnable parameter (CLIP approach)
        self.logit_scale = nn.Parameter(
            torch.ones([]) * np.log(1 / init_temp)
        )
    
    def forward(self, query, keys):
        # Normalize
        query = F.normalize(query, dim=-1)
        keys = F.normalize(keys, dim=-1)
        
        # Similarity with learnable scale
        temperature = self.logit_scale.exp()
        logits = query @ keys.T * temperature
        
        # InfoNCE
        labels = torch.arange(len(query), device=query.device)
        loss = F.cross_entropy(logits, labels)
        
        return loss, temperature.item()

Applications

InfoNCE in the Wild

πŸš€ InfoNCE Applications

InfoNCE telah menjadi foundational loss untuk banyak breakthrough models:

🌐 CLIP

Vision-Language Alignment

400M (image, text) pairs training dengan symmetric InfoNCE

Zero-shot classification, text-to-image retrieval

πŸ–ΌοΈ SimCLR

Self-Supervised Vision

Learn representations dari augmented views dengan InfoNCE

Pre-training untuk downstream vision tasks

🎯 MoCo

Momentum Contrast

Large negative queue + momentum encoder

Efficient contrastive learning untuk vision

πŸ“Š CPC

Contrastive Predictive Coding

Predict future representations dengan InfoNCE

Time-series, audio, video understanding

πŸ”Š Audio SSL

wav2vec 2.0

Learn speech representations without transcripts

Speech recognition, speaker identification

🎬 Video Understanding

VideoMoCo

Temporal contrastive learning untuk video

Action recognition, video retrieval

πŸ’‘ Why InfoNCE Works So Well

  • βœ… Simple: Easy to implement (just softmax + cross-entropy)
  • βœ… Scalable: Efficient with large batches
  • βœ… Effective: Strong performance across domains
  • βœ… Flexible: Works for unimodal & multimodal
  • βœ… No labels needed: Self-supervised learning

🎯 Key Takeaways

InfoNCE Core Principles:

  • πŸ”₯ Maximize similarity untuk positive pairs
  • ❄️ Minimize similarity untuk negative pairs
  • 🌑️ Temperature controls sharpness
  • βš–οΈ Symmetric loss untuk bidirectional alignment
  • πŸ“Š Softmax framework makes it easy

βœ… Selamat!

πŸŽ‰ Tutorial Selesai!

Anda telah mempelajari:

  • βœ… InfoNCE mathematical foundation
  • βœ… Contrastive mechanics (pull/push)
  • βœ… Temperature parameter effects
  • βœ… Symmetric loss untuk multimodal
  • βœ… PyTorch implementation
  • βœ… Real-world applications

πŸš€ Next Steps

β€’ Implement InfoNCE untuk your dataset

β€’ Experiment dengan temperature values

β€’ Explore CLIP, SimCLR source code

β€’ Read papers: CPC, MoCo, SimCLR, CLIP