最近在做 4bit PQ(Accelerated Nearest Neighbor Search with Quick ADC) , 直觉上的话 4bit PQ 能比 8bit PQ 快很多,不过到完整实现这个东西后才发现坑还挺大的……

Product Quantization

从 PQ 说起,PQ 就是把 DD 维的向量分为等长的 MM 个块,然后把所有向量的相同块放一起做 kmeans,跑出 2B2^{B} 个中心,然后再把每个向量的每一块表示为离它的最近的中心的编号。这里 BB 就是 bit 数,通常是8,取8的原因很简单,这样可以用一个 u8 表示一个块,能取得很好的压缩效果。当大家讨论 PQ 时,基本都默认 B=8B=8,等价于把 DD 个 f32 压缩成 MM 个 u8。

查询的时候,为了计算查询向量 qq 到数据集中所有点的距离,因为量化后每个向量的每一维本质都是这一维度的某个中心点,所以可以先构造一个 distance table,预计算 qq 到每个中心的距离。后面计算 qq 到数据点的距离时,就只需要查表了。具体来说:

distance_table[i][j] = distance to i-th block j-th centorid

对于一个压缩后的向量 codecode,计算到它的距离也就是

j=0M1distance_table[j][code[j]]\sum_{j=0}^{M-1} distance\_table[j][code[j]]

Transposing

这个 idea 是我从 Accelerated Nearest Neighbor Search with Quick ADC 里看到的,但是其实这个优化同样适用于 8bit,而且效果非常好。

考虑我们计算到一个压缩后的向量的距离的方法:

let mut dists = vec![0.0f32; n]
for i in 0..n {
    for j in 0..M {
        dists[i] += distance_table[j][code[i][j]]
    }
}

这个计算的问题是它对 distance_table 的访问是完全随机的,性能会被内存访问限制。如果我们把枚举顺序调换一下,好不少:

let mut dists = vec![0.0f32; n]
for j in 0..M {
    for i in 0..n {
        dists[i] += distance_table[j][code[i][j]]
    }
}

distance table 的第二维是 28=2562^8=256 个 f32,也就是 1KiB,locality 还是很好的。
但是这个时候对 code 的访问又变成随机的了,不过这个比较容易解决,对 code 做矩阵转置就可以解决:

let mut dists = vec![0.0f32; n]
for j in 0..M {
    for i in 0..n {
        dists[i] += distance_table[j][code[j][i]]
    }
}

这样不仅是 locality 会好很多,也可以让编译器生成更好的代码,对性能的提升是相当可观的。

4bit PQ

回到正题。
基于前面的优化,其实还可以更进一步,比 cache 更快的是寄存器。让我们先不考虑寄存器大小,上面的代码其实等价于下面的伪代码:

let mut dists = vec![0.0f32; n]
for j in 0..M {
    let num_centroids = 2.pow(B);
    for i in (0..n).step_by(num_centroids) {
        let shuffled = shuffle(distance_table[j], code[j][i..i+num_centroids]);
        dists[i..i+num_centroids] += shuffled;
    }
}

也就是说计算可以通过 shuffle + add 两条指令完成。前面说过,distance_table[i] 是 256 个 f32,也就是 1KiB,目前的寄存器最长的是 512bit,而 ARM 上最长的是 128bit。要把 distance_table[i] 压到这个大小,paper 用了两个方法:

  • 用 4bit 代替 8bit,这样长度也就是 24=162^4=16。不过 16 个 f32 就已经是 512bit 了,还需要再压一下
  • 把 distance 量化成 u8(scalar quantization)。这样也就是 16 个 u8,刚好是 128bit

实际实现

到目前为止的思路都是很直接的,不过我上手实现之后量化距离这个点把我狠狠的 confuse 了,paper 里 dists[i..i+num_centroids] 也是一个 u8x16,这意味着整个累加过程中,u8都没有溢出。怎么才能做到这一点呢,最简单方法的就是做 scalar quantization 的时候,把 max 设为整个 distance table 的和,不出意外的 recall 直接拉闸了。

我又回去看了一下 paper,看看作者怎么做的,而作者的做法狠狠的让我感叹学术和实现之间的巨大 gap……作者的方法是先做个爆搜,算一下比如 top200,然后拿最远的那个距离做 max。后面做 add 的时候做 saturating add,反正如果你搜 top100,会成为最终结果的那部份不会溢出。这个方法过于不实际了,反正我没用……

我最后放弃了把结果一直存 u8x16 里,还是写内存了。因为 4bit 是把两个块压到一个 byte 里,我需要在 u8x16 中存储两个块加起来的结果,所以其实我只需要求 distance table 中相邻两个块的和的最大值,作为 max 去做 scalar quantization 就可以了。 写内存的次数也可以减少一半,所以其实性能依然不错,最主要的是这样 recall loss 很小。