State Space Models

Sequence Modeling dengan Linear Time Complexity

⚑ Apa that State Space Models?

State Space Models (SSMs) adalah framework matematis untuk memproses data sequence dengan efisiensi tinggi. SSMs menggabungkan keunggulan RNNs dan Transformers:

  • πŸ“Š Linear time complexity saat inference (seperti RNN)
  • ⚑ Parallelizable training dengan convolution (seperti Transformer)
  • 🎯 Long-range dependencies tanpa degradasi
  • πŸ”„ Continuous-time modeling untuk flexibility

πŸ’‘ Mengapa SSM penting?

Problem dengan existing models:

  • ❌ RNN: Sequential processing lambat, vanishing gradients
  • ❌ Transformer: O(LΒ²) complexity, memory-intensive untuk long sequences
  • βœ… SSM: O(L) inference, O(L log L) training, long-range modeling

🎯 Yang Akan Dipelajari

πŸ“

State Space Mathematics

Continuous dan discrete systems

πŸ”„

Dual Modes

Recurrent dan convolutional

🧠

Mamba Architecture

Selective state spaces

πŸ’»

Implementation

PyTorch S4 dan Mamba

πŸ“Š Comparison Table

Model Training Inference Long Range
RNN O(L) sequential O(L) ❌ Poor
Transformer O(LΒ²) parallel O(LΒ²) βœ… Excellent
SSM (S4/Mamba) O(L log L) parallel O(L) βœ… Excellent

State Space Basics

Fondasi Matematis SSM

πŸ“ Continuous-Time State Space

State space model dalam continuous time didefinisikan dengan dua persamaan differensial:

State Equation
dx/dt = Ax(t) + Bu(t)
x(t): state vector (dimensi N)
u(t): input signal
A: state matrix (NΓ—N)
B: input matrix (NΓ—1)
Output Equation
y(t) = Cx(t) + Du(t)
y(t): output signal
C: output matrix (1Γ—N)
D: feedthrough (skipconnection)

🎬 Visualisasi State Transition

Press Play to see state evolution over time

πŸ’‘ Analogi: RC Circuit

Bayangkan sebuah RC circuit (resistor-capacitor):

  • x(t): voltage across capacitor (state)
  • u(t): input voltage
  • A: decay rate (-1/RC)
  • B: input coupling (1/RC)

State x(t) berubah sesuai input dan "memory" sebelumnya!

πŸ”‘ Key Properties

1

Linearity

Superposisi berlaku - output untuk sum of inputs = sum of outputs

2

Time-Invariance

Matrices A, B, C, D konstan (tidak bergantung waktu)

3

Structured

Bisa gunakan structured matrices (HiPPO) untuk long-range

Discretization

Continuous β†’ Discrete Conversion

πŸ”„ Mengapa Discretization?

Komputer bekerja dengan discrete time steps, sementara SSM didefinisikan di continuous time. Kita perlu convert persamaan differensial ke difference equation.

Step size Ξ”: Interval sampling (e.g., 0.001s untuk audio 1kHz)

πŸ“ Discrete SSM Equations

x_k = AΜ… x_{k-1} + BΜ… u_k
y_k = CΜ… x_k + DΜ… u_k
AΜ… = exp(Ξ”A) β‰ˆ I + Ξ”A + (Ξ”A)Β²/2! + ...
BΜ… = (Ξ”A)⁻¹(exp(Ξ”A) - I)B β‰ˆ Ξ”B

βš™οΈ Zero-Order Hold (ZOH)

Metode discretization paling umum: asumsikan input konstan dalam interval [kΞ”, (k+1)Ξ”].

1

Sample Input

u(kΞ”) β†’ u_k

2

Compute AΜ…, BΜ…

Matrix exponentials

3

Discrete Update

x_k = AΜ…x_{k-1} + BΜ…u_k

🎬 Discretization Animation

Visualize continuous→discrete conversion

Recurrent Mode

Sequential Processing - O(L) Time

πŸ” SSM as Recurrence

Mode recurrent: process sequence element by element, seperti RNN. Berguna untuk inference/deployment (streaming).

x_k = AΜ… x_{k-1} + BΜ… u_k
y_k = C x_k
Time complexity: O(L) untuk sequence length L
Space: O(N) untuk state size N

🎬 Recurrent Flow Animation

Step-by-step sequential processing

πŸ“Š Characteristics

  • βœ… Fast inference: Constant time per step
  • βœ… Low memory: Only store current state
  • βœ… Streaming: Process input as it arrives
  • ❌ Slow training: Sequential updates tidak parallel

Convolutional Mode

Parallel Training - O(L log L)

⚑ SSM as Convolution

SSM dapat di-reformulasi sebagai global convolution! Ini memungkinkan parallel training.

y = K βˆ— u
K = (CΜ…BΜ…, CΜ…AΜ…BΜ…, CΜ…AΜ…Β²BΜ…, ..., CΜ…AΜ…^{L-1}BΜ…)

K adalah SSM convolution kernel

πŸ”’ Kernel Construction

1

Compute Powers

A̅⁰, AΜ…ΒΉ, AΜ…Β², ..., AΜ…^{L-1}

2

Build Kernel

K[i] = CΜ… AΜ…^i BΜ…

3

FFT Convolution

y = IFFT(FFT(K) βŠ™ FFT(u))

🎬 Convolution Visualization

Visualize parallel convolution operation

⚑ Efficiency

Training complexity:

  • Naive: O(LΒ²) untuk convolution
  • FFT: O(L log L) - Much faster!
  • Fully parallelizable on GPU

Training SSM

Parameter Learning & Optimization

πŸŽ“ Learned Parameters

SSM memiliki parameter A, B, C, D yang dioptimize via backpropagation:

A (NΓ—N)
State dynamics
B (NΓ—1)
Input projection
C (1Γ—N)
Output projection
D (scalar)
Skip connection

🧠 HiPPO Initialization

High-order Polynomial Projection Operator (HiPPO): Structured initialization untuk A yang optimal untuk long-range dependencies.

HiPPO matrices "remember" history dengan polynomial approximation. Eigenvalues dirancang untuk capture different timescales!

πŸ“ˆ Training Dynamics

Simulate SSM training progress

πŸ”§ Optimization Tips

  • 🎯 Use HiPPO initialization untuk A matrix
  • ⚑ Train in convolution mode (parallel)
  • πŸ”„ Deploy in recurrent mode (efficient)
  • πŸ“ Normalize state untuk numerical stability

Mamba Architecture

Selective State Spaces

🦎 Apa itu Mamba?

Mamba adalah evolved SSM dengan selective mechanism: parameters Ξ”, B, C menjadi input-dependent!

Key Innovation:

Standard SSM: A, B, C fixed
Mamba: B, C, Ξ” = functions of input β†’ selective focus

🎯 Selective SSM

B = Linear_B(x)
C = Linear_C(x)
Ξ” = Softplus(Linear_Ξ”(x))
x: input token
Ξ”: step size (controls memory vs. forget)
B, C: input/output projections

🎬 Selective Mechanism Visualization

See how Ξ”, B, C adapt per token

πŸ’‘ Why Selective?

1

Content-Aware

Model decides what to remember/forget

2

Hardware-Efficient

Fused kernels, no materialization

3

SOTA Performance

Outperforms Transformers on long-seq

πŸ“Š Mamba vs Transformer

Aspect Transformer Mamba
Inference O(LΒ²) O(L)
Memory (seq 16k) ~10GB ~1GB
Throughput Baseline 5x faster
Quality (long-seq) Good Better

Implementation

PyTorch Code & Use Cases

πŸ’» PyTorch S4 Implementation

import torch
import torch.nn as nn

class S4Layer(nn.Module):
    def __init__(self, d_model, d_state=64):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        
        # SSM parameters
        self.A = nn.Parameter(torch.randn(d_state, d_state))
        self.B = nn.Parameter(torch.randn(d_state, 1))
        self.C = nn.Parameter(torch.randn(1, d_state))
        self.D = nn.Parameter(torch.randn(1))
        
        # Discretization step size
        self.log_step = nn.Parameter(torch.log(torch.rand(1)))
    
    def forward(self, u):
        """
        u: (batch, length, d_model)
        Returns: (batch, length, d_model)
        """
        # Discretize
        step = torch.exp(self.log_step)
        dA = torch.matrix_exp(step * self.A)
        dB = (dA - torch.eye(self.d_state)) @ torch.inverse(self.A) @ self.B
        
        # Convolutional mode (training)
        K = self._compute_kernel(dA, dB, u.size(1))
        y = torch.fft.ifft(
            torch.fft.fft(K) * torch.fft.fft(u)
        ).real
        
        return y

🦎 Mamba Selective SSM

class Mamba(nn.Module):
    def __init__(self, d_model, d_state=16):
        super().__init__()
        
        # Input-dependent parameter generators
        self.x_proj = nn.Linear(d_model, d_state * 2 + 1)
        
        # Static A matrix (HiPPO)
        self.A = nn.Parameter(self._init_hippo(d_state))
    
    def forward(self, x):
        """Selective SSM with data-dependent B, C, Ξ”"""
        # Generate B, C, Ξ” from input
        projections = self.x_proj(x)  # (B, L, 2*d_state + 1)
        
        delta = F.softplus(projections[..., 0])      # (B, L)
        B = projections[..., 1:d_state+1]            # (B, L, d_state)
        C = projections[..., d_state+1:]             # (B, L, d_state)
        
        # Selective scan (hardware-aware kernel)
        y = selective_scan(x, delta, self.A, B, C)
        return y

🎯 Use Cases

πŸ“ˆ Time Series

Financial forecasting, weather prediction

Example: Stock price prediction with 10k+ history

🎡 Audio Processing

Speech recognition, music generation

Example: 16kHz audio (long sequences)

🧬 DNA Sequences

Genomics, protein folding

Example: 100k+ nucleotide sequences

πŸ“ Long-form Text

Document understanding, books

Example: 32k token documents

βœ… Selamat!

πŸŽ‰ Tutorial Selesai!

Anda telah mempelajari:

  • βœ… State space mathematics (continuous & discrete)
  • βœ… Dual modes: recurrent & convolutional
  • βœ… Training dengan HiPPO initialization
  • βœ… Mamba selective state spaces
  • βœ… PyTorch implementation

πŸš€ Next Steps

β€’ Implement S4/Mamba untuk your dataset

β€’ Explore S4D, S5, Mamba-2 variants

β€’ Read papers: S4, Mamba (Gu & Dao, 2023)