State Space Models
Sequence Modeling dengan Linear Time Complexity
β‘ Apa that State Space Models?
State Space Models (SSMs) adalah framework matematis untuk memproses data sequence dengan efisiensi tinggi. SSMs menggabungkan keunggulan RNNs dan Transformers:
- π Linear time complexity saat inference (seperti RNN)
- β‘ Parallelizable training dengan convolution (seperti Transformer)
- π― Long-range dependencies tanpa degradasi
- π Continuous-time modeling untuk flexibility
π‘ Mengapa SSM penting?
Problem dengan existing models:
- β RNN: Sequential processing lambat, vanishing gradients
- β Transformer: O(LΒ²) complexity, memory-intensive untuk long sequences
- β SSM: O(L) inference, O(L log L) training, long-range modeling
π― Yang Akan Dipelajari
State Space Mathematics
Continuous dan discrete systems
Dual Modes
Recurrent dan convolutional
Mamba Architecture
Selective state spaces
Implementation
PyTorch S4 dan Mamba
π Comparison Table
| Model | Training | Inference | Long Range |
|---|---|---|---|
| RNN | O(L) sequential | O(L) | β Poor |
| Transformer | O(LΒ²) parallel | O(LΒ²) | β Excellent |
| SSM (S4/Mamba) | O(L log L) parallel | O(L) | β Excellent |
State Space Basics
Fondasi Matematis SSM
π Continuous-Time State Space
State space model dalam continuous time didefinisikan dengan dua persamaan differensial:
u(t): input signal
A: state matrix (NΓN)
B: input matrix (NΓ1)
C: output matrix (1ΓN)
D: feedthrough (skipconnection)
π¬ Visualisasi State Transition
π‘ Analogi: RC Circuit
Bayangkan sebuah RC circuit (resistor-capacitor):
- x(t): voltage across capacitor (state)
- u(t): input voltage
- A: decay rate (-1/RC)
- B: input coupling (1/RC)
State x(t) berubah sesuai input dan "memory" sebelumnya!
π Key Properties
Linearity
Superposisi berlaku - output untuk sum of inputs = sum of outputs
Time-Invariance
Matrices A, B, C, D konstan (tidak bergantung waktu)
Structured
Bisa gunakan structured matrices (HiPPO) untuk long-range
Discretization
Continuous β Discrete Conversion
π Mengapa Discretization?
Komputer bekerja dengan discrete time steps, sementara SSM didefinisikan di continuous time. Kita perlu convert persamaan differensial ke difference equation.
Step size Ξ: Interval sampling (e.g., 0.001s untuk audio 1kHz)
π Discrete SSM Equations
BΜ = (ΞA)β»ΒΉ(exp(ΞA) - I)B β ΞB
βοΈ Zero-Order Hold (ZOH)
Metode discretization paling umum: asumsikan input konstan dalam interval [kΞ, (k+1)Ξ].
Sample Input
u(kΞ) β u_k
Compute AΜ , BΜ
Matrix exponentials
Discrete Update
x_k = AΜ x_{k-1} + BΜ u_k
π¬ Discretization Animation
Recurrent Mode
Sequential Processing - O(L) Time
π SSM as Recurrence
Mode recurrent: process sequence element by element, seperti RNN. Berguna untuk inference/deployment (streaming).
Space: O(N) untuk state size N
π¬ Recurrent Flow Animation
π Characteristics
- β Fast inference: Constant time per step
- β Low memory: Only store current state
- β Streaming: Process input as it arrives
- β Slow training: Sequential updates tidak parallel
Convolutional Mode
Parallel Training - O(L log L)
β‘ SSM as Convolution
SSM dapat di-reformulasi sebagai global convolution! Ini memungkinkan parallel training.
K adalah SSM convolution kernel
π’ Kernel Construction
Compute Powers
AΜ β°, AΜ ΒΉ, AΜ Β², ..., AΜ ^{L-1}
Build Kernel
K[i] = CΜ AΜ ^i BΜ
FFT Convolution
y = IFFT(FFT(K) β FFT(u))
π¬ Convolution Visualization
β‘ Efficiency
Training complexity:
- Naive: O(LΒ²) untuk convolution
- FFT: O(L log L) - Much faster!
- Fully parallelizable on GPU
Training SSM
Parameter Learning & Optimization
π Learned Parameters
SSM memiliki parameter A, B, C, D yang dioptimize via backpropagation:
π§ HiPPO Initialization
High-order Polynomial Projection Operator (HiPPO): Structured initialization untuk A yang optimal untuk long-range dependencies.
HiPPO matrices "remember" history dengan polynomial approximation. Eigenvalues dirancang untuk capture different timescales!
π Training Dynamics
π§ Optimization Tips
- π― Use HiPPO initialization untuk A matrix
- β‘ Train in convolution mode (parallel)
- π Deploy in recurrent mode (efficient)
- π Normalize state untuk numerical stability
Mamba Architecture
Selective State Spaces
π¦ Apa itu Mamba?
Mamba adalah evolved SSM dengan selective mechanism: parameters Ξ, B, C menjadi input-dependent!
Key Innovation:
Standard SSM: A, B, C fixed
Mamba: B, C, Ξ = functions of input β selective focus
π― Selective SSM
Ξ: step size (controls memory vs. forget)
B, C: input/output projections
π¬ Selective Mechanism Visualization
π‘ Why Selective?
Content-Aware
Model decides what to remember/forget
Hardware-Efficient
Fused kernels, no materialization
SOTA Performance
Outperforms Transformers on long-seq
π Mamba vs Transformer
| Aspect | Transformer | Mamba |
|---|---|---|
| Inference | O(LΒ²) | O(L) |
| Memory (seq 16k) | ~10GB | ~1GB |
| Throughput | Baseline | 5x faster |
| Quality (long-seq) | Good | Better |
Implementation
PyTorch Code & Use Cases
π» PyTorch S4 Implementation
import torch
import torch.nn as nn
class S4Layer(nn.Module):
def __init__(self, d_model, d_state=64):
super().__init__()
self.d_model = d_model
self.d_state = d_state
# SSM parameters
self.A = nn.Parameter(torch.randn(d_state, d_state))
self.B = nn.Parameter(torch.randn(d_state, 1))
self.C = nn.Parameter(torch.randn(1, d_state))
self.D = nn.Parameter(torch.randn(1))
# Discretization step size
self.log_step = nn.Parameter(torch.log(torch.rand(1)))
def forward(self, u):
"""
u: (batch, length, d_model)
Returns: (batch, length, d_model)
"""
# Discretize
step = torch.exp(self.log_step)
dA = torch.matrix_exp(step * self.A)
dB = (dA - torch.eye(self.d_state)) @ torch.inverse(self.A) @ self.B
# Convolutional mode (training)
K = self._compute_kernel(dA, dB, u.size(1))
y = torch.fft.ifft(
torch.fft.fft(K) * torch.fft.fft(u)
).real
return y
π¦ Mamba Selective SSM
class Mamba(nn.Module):
def __init__(self, d_model, d_state=16):
super().__init__()
# Input-dependent parameter generators
self.x_proj = nn.Linear(d_model, d_state * 2 + 1)
# Static A matrix (HiPPO)
self.A = nn.Parameter(self._init_hippo(d_state))
def forward(self, x):
"""Selective SSM with data-dependent B, C, Ξ"""
# Generate B, C, Ξ from input
projections = self.x_proj(x) # (B, L, 2*d_state + 1)
delta = F.softplus(projections[..., 0]) # (B, L)
B = projections[..., 1:d_state+1] # (B, L, d_state)
C = projections[..., d_state+1:] # (B, L, d_state)
# Selective scan (hardware-aware kernel)
y = selective_scan(x, delta, self.A, B, C)
return y
π― Use Cases
π Time Series
Financial forecasting, weather prediction
π΅ Audio Processing
Speech recognition, music generation
𧬠DNA Sequences
Genomics, protein folding
π Long-form Text
Document understanding, books
β Selamat!
π Tutorial Selesai!
Anda telah mempelajari:
- β State space mathematics (continuous & discrete)
- β Dual modes: recurrent & convolutional
- β Training dengan HiPPO initialization
- β Mamba selective state spaces
- β PyTorch implementation