Pruned RNN-T 论文阅读

RNN-T 的缺点

  1. 损失函数计算相对较慢
  2. 占用较多内存

RNN-T模型需要储存一个4维的张量,维度是 \((N, T, U, C)\), 其中\(N\)是batch_size,\(T\)是Transcription Network的输出长度,\(U\)是Prediction的输出长度,\(V\)是词表大小。

本文将\(U\)限制为 \(S\subseteq U\),从而加速训练。

基本思想

仅计算对最终损失有重要贡献的 \((t,u)\) 对应项的joiner network,从而减少计算量。

具体来说,我们执行两次递推:

  1. 使用一个trivial的joiner network,该网络计算速度非常快。利用初步计算结果:确定哪些索引对损失贡献较大。
  2. 第二次计算:仅对筛选出的 \((t,u)\)子集来计算的joiner network。

Trivial Joiner Network

Trivial Joiner Network的思路为:

直接对encoder embedding和decoder embedding进行投影,得到未归一化的对数概率:

\[ L_{\text{enc}}(t, v) \]

\[ L_{\text{dec}}(u, v) \]

那么最终的loss就是:

\[L_{\text{trivial}}(t, u, v) = L_{\text{enc}}(t, v) + L_{\text{dec}}(u, v) - L_{\text{normalizer}}(t, u)\]

其中

\[ L_{\text{normalizer}}(t, u) = \log \sum_{v} \exp \left( L_{\text{enc}}(t, v) + L_{\text{dec}}(u, v) \right) \]

这个公式本质上是log-space的矩阵乘法,容易实现。

Pruning Bounds

引入常数\(S\),表示对于任意的\(t\),我们将仅计算考虑 \(L(t, u, v)\)\(p_t \leq u < p_t + S\) 位置上的值。其中 \(p_t\) 是通过Trivial Joiner Network得到的。

剪枝就是将 \(u < p_t\)\(u \geq p_t + S\)\(P(t, u, \cdot)\) 视为 \(-\infty\)

理想情况下,我们希望找到一组整数剪枝边界 \(p = p_0, p_t, \dots, p_{T-1}\),该组边界能够最大化Trivial Joiner Network的总保留概率。换句话说,我们要最大化数据的似然性。

为了简化问题,我们通过为每个帧 \(t\) 找到局部最优的剪枝边界来解决,然后应用一些连续性约束来改进结果。

具体来说,就是求解下面这个公式:

\[p_t = \arg\max_{p=0}^{U-S+1} \left( - y'_0(t, p-1) + \sum_{u=p}^{p+S-1} \varnothing'_0(t, u) \right)\]

其中 \(y'_0(t, u)\)\(\varnothing'_0(t, u)\)\(L_{tot}\) 关于\(y(t, u)\)\(\varnothing(t, u)\).的导数。

这个公式可以形象的理解为下图中绿线减去红线的值。

Pruned-RNN-T

减去红线的原因是?

当然还需要满足一些条件(连续性等):

\[ 0 \leq p_t \leq U - S + 1 \]

\[ p_t \leq p_{t+1} \]

\[ p_{t+1} - p_t \leq S \]

Loss function

\[ L_{\text{smoothed}}(t, u, v) \]

\[ = (1 - \alpha_{\text{lm}} - \alpha_{\text{acoustic}}) \cdot L_{\text{trivial}}(t, u, v) + \alpha_{\text{lm}} \cdot L_{\text{lm}}(t, u, v) + \alpha_{\text{acoustic}} \cdot L_{\text{acoustic}}(t, u, v) \]

其中:

\[ L_{\text{trivial}}(t, u, v) = \text{LogSoftmax}_v \left( L_{\text{enc}}(t, v) + L_{\text{dec}}(u, v) \right) \]

\[ L_{\text{acoustic}}(t, u, v) = \text{LogSoftmax}_v \left( L_{\text{enc}}(t, v) + L_{\text{avg-dec}}(u, v) \right) \]

\[ L_{\text{lm}}(t, u, v) = \text{LogSoftmax}_v (L_{\text{dec}}(u, v)) \]

\[ L_{\text{avg-dec}}(u, v) = \log \left( \frac{1}{U + 1} \right) \sum_{u=0}^{X} \text{Softmax}_v L_{\text{dec}}(u, v) \]

为什么要平均?


Pruned RNN-T 论文阅读
http://example.com/2025/03/05/Pruned-RNN-T-论文阅读/
作者
Yujie Tu
发布于
2025年3月5日
许可协议