BL-JUST

概述

为了把无监督和有监督统一成一次训练。

定义 upper-level problem 是 unsupervised的, lower-level problem 是supervised。

问题详细定义

1. 目标优化问题

\[ \min_{\theta, \phi} \sum_{(x, y) \in D_{\text{sup}}} \ell_{\text{sup}}((x, y); \theta, \phi) \tag{1} \] 使得 \[ \theta \in \arg\min_{\theta', \eta} \sum_{x \in D_{\text{unsup}}} \ell_{\text{unsup}}(x; \theta', \eta) \tag{2} \]

2. 上层和下层目标

\[ f(\theta, \phi) \triangleq \sum_{(x, y) \in D_{\text{sup}}} \ell_{\text{sup}}((x, y); \theta, \phi) \tag{3} \] \[ g(\theta, \eta) \triangleq \sum_{x \in D_{\text{unsup}}} \ell_{\text{unsup}}(x; \theta, \eta) \tag{4} \]

3. 约束形式的优化问题

\[ \min_{\theta, \phi} f(\theta, \phi) \quad \text{s.t.} \quad \theta \in \arg\min_{\theta', \eta} g(\theta', \eta). \tag{5} \]

4. 函数值差距定义

\[ p(\theta, \eta) = g(\theta, \eta) - \min_{\theta, \eta} g(\theta, \eta) \triangleq g(\theta, \eta) - v \tag{6} \] 其中 \[ v \triangleq \min_{\theta, \eta} g(\theta, \eta). \tag{7} \]

5. 重构后的优化问题

\[ \min_{\theta, \phi} f(\theta, \phi) \quad \text{s.t.} \quad p(\theta, \eta) \leq 0. \tag{8} \]

6. 基于惩罚的单层重构

\[ \min_{\theta, \phi, \eta} F_{\gamma}(\theta, \phi, \eta) \triangleq f(\theta, \phi) + \gamma p(\theta, \eta) \tag{9} \]

7. 计算条件

\[ \| \nabla g(\theta, \eta) \|^2 \geq \mu \tag{10} \]

这就是我们的优化目标。

伪代码

输入参数:

  • (x, y):标注数据,其中 x 是输入特征,y 是对应的标签。
  • x:未标注数据,只有特征没有标签。
  • ρ:自监督训练的学习率。
  • α:双层梯度下降(Bilevel Gradient Descent)的学习率。
  • τ:监督微调阶段的学习率。
  • γm:最大惩罚因子。
  • K:总训练轮数。
  • N1:在自监督训练阶段的迭代次数。
  • N2:联合自监督和监督训练的迭代次数。
  • N3:监督微调的迭代次数。
  • Lunsup:下层无监督经验风险函数 g(θ, η)
  • Lsup:上层监督经验风险函数 f(θ, φ)

伪代码:

```python

Algorithm 1 Bilevel Joint Unsupervised and Supervised Training (BL-JUST)

Input: labeled data (x, y), unlabeled data x
ρ ← learning rate for self-supervised exploration
α ← learning rate for bilevel gradient descent
τ ← learning rate for supervised fine-tuning
γm ← maximum penalty factor
K ← number of epochs
N1 ← number of iterations in unsupervised training
N2 ← number of iterations in joint unsupervised and supervised training;
N3 ← number of iterations for supervised fine-tuning
Lunsup ← lower level unsupervised empirical risk g(θ, η)
Lsup ← upper level supervised empirical risk f(θ, φ)

for k = 1 : K do
    # Self-supervised exploration
    for i = 1 : N1 do
        (θi_k+1, ηi_k+1) = (θi_k, ηi_k) − ρ∇θ,ηLunsup(θi_k, ηi_k)

    # Joint unsupervised and supervised training
    γk = (k − 1) γm / K
    Use θN_k1+1 and ηN_k1+1 as the starting point
    for j = 1 : N2 do
        # Update θ
        θj_k+1 = θj_k − α∇θLsup(θj_k, φk_j) − αγk∇θLunsup(θj_k, ηj_k)
        # Update φ
        φk_j+1 = φk_j − α∇φLsup(θj_k, φk_j)
        # Update η
        ηj_k+1 = ηj_k − αγ∇ηLunsup(θj_k, ηj_k)

    # Supervised fine-tuning
    Use θN_K2+1 and φK_N2+1 as the starting point
    for t = 1 : N3 do
        (θt+1, φt+1) = (θt, φt) − τ∇θ,φLsup(θt, φt)

BL-JUST
http://example.com/2025/03/17/BL-JUST论文阅读/
作者
Yujie Tu
发布于
2025年3月17日
许可协议