BL-JUST
概述
为了把无监督和有监督统一成一次训练。
定义 upper-level problem 是 unsupervised的, lower-level problem 是supervised。
问题详细定义
1. 目标优化问题
\[ \min_{\theta, \phi} \sum_{(x, y) \in D_{\text{sup}}} \ell_{\text{sup}}((x, y); \theta, \phi) \tag{1} \] 使得 \[ \theta \in \arg\min_{\theta', \eta} \sum_{x \in D_{\text{unsup}}} \ell_{\text{unsup}}(x; \theta', \eta) \tag{2} \]
2. 上层和下层目标
\[ f(\theta, \phi) \triangleq \sum_{(x, y) \in D_{\text{sup}}} \ell_{\text{sup}}((x, y); \theta, \phi) \tag{3} \] \[ g(\theta, \eta) \triangleq \sum_{x \in D_{\text{unsup}}} \ell_{\text{unsup}}(x; \theta, \eta) \tag{4} \]
3. 约束形式的优化问题
\[ \min_{\theta, \phi} f(\theta, \phi) \quad \text{s.t.} \quad \theta \in \arg\min_{\theta', \eta} g(\theta', \eta). \tag{5} \]
4. 函数值差距定义
\[ p(\theta, \eta) = g(\theta, \eta) - \min_{\theta, \eta} g(\theta, \eta) \triangleq g(\theta, \eta) - v \tag{6} \] 其中 \[ v \triangleq \min_{\theta, \eta} g(\theta, \eta). \tag{7} \]
5. 重构后的优化问题
\[ \min_{\theta, \phi} f(\theta, \phi) \quad \text{s.t.} \quad p(\theta, \eta) \leq 0. \tag{8} \]
6. 基于惩罚的单层重构
\[ \min_{\theta, \phi, \eta} F_{\gamma}(\theta, \phi, \eta) \triangleq f(\theta, \phi) + \gamma p(\theta, \eta) \tag{9} \]
7. 计算条件
\[ \| \nabla g(\theta, \eta) \|^2 \geq \mu \tag{10} \]
这就是我们的优化目标。
伪代码
输入参数:
- (x, y):标注数据,其中
x
是输入特征,y
是对应的标签。 - x:未标注数据,只有特征没有标签。
- ρ:自监督训练的学习率。
- α:双层梯度下降(Bilevel Gradient Descent)的学习率。
- τ:监督微调阶段的学习率。
- γm:最大惩罚因子。
- K:总训练轮数。
- N1:在自监督训练阶段的迭代次数。
- N2:联合自监督和监督训练的迭代次数。
- N3:监督微调的迭代次数。
- Lunsup:下层无监督经验风险函数
g(θ, η)
。 - Lsup:上层监督经验风险函数
f(θ, φ)
。
伪代码:
```python
Algorithm 1 Bilevel Joint Unsupervised and Supervised Training (BL-JUST)
Input: labeled data (x, y), unlabeled data x
ρ ← learning rate for self-supervised exploration
α ← learning rate for bilevel gradient descent
τ ← learning rate for supervised fine-tuning
γm ← maximum penalty factor
K ← number of epochs
N1 ← number of iterations in unsupervised training
N2 ← number of iterations in joint unsupervised and supervised training;
N3 ← number of iterations for supervised fine-tuning
Lunsup ← lower level unsupervised empirical risk g(θ, η)
Lsup ← upper level supervised empirical risk f(θ, φ)
for k = 1 : K do
# Self-supervised exploration
for i = 1 : N1 do
(θi_k+1, ηi_k+1) = (θi_k, ηi_k) − ρ∇θ,ηLunsup(θi_k, ηi_k)
# Joint unsupervised and supervised training
γk = (k − 1) γm / K
Use θN_k1+1 and ηN_k1+1 as the starting point
for j = 1 : N2 do
# Update θ
θj_k+1 = θj_k − α∇θLsup(θj_k, φk_j) − αγk∇θLunsup(θj_k, ηj_k)
# Update φ
φk_j+1 = φk_j − α∇φLsup(θj_k, φk_j)
# Update η
ηj_k+1 = ηj_k − αγ∇ηLunsup(θj_k, ηj_k)
# Supervised fine-tuning
Use θN_K2+1 and φK_N2+1 as the starting point
for t = 1 : N3 do
(θt+1, φt+1) = (θt, φt) − τ∇θ,φLsup(θt, φt)