pipapo_set学习笔记


前言

分析

pipapo的全称是PIle PAcket POlicies: set for arbitrary concatenations of ranges

其设计的主要目的是为了实现多字段映射

PIPAPO 是 Linux 内核中用于高效匹配复杂数据包字段的算法,其设计核心受 Grouper 启发,采用了“分组 + 查表 + 位图交集”的方式,实现了 按位掩码/范围规则的高速匹配,支持任意字段组合。

下面是来自gpt的场景介绍:

为了匹配一个数据包的源端口(16bit),可以将这16bit分为4个组,每一个组有4bit,考虑入口如下:

0000 0001 0101 1001

然后一个数据包的源端口如下:

0000 0001 1010 1001

用AND进行匹配,第一个和第二个组匹配成功,但是第三个组没有匹配成功,因此得到结论并没有匹配对应的入口;

将set翻译成一些列lookup table,每一个对应一个field(一个set要处理多个field,每一个field就是前面说的ip地址或者port)。每一个table有两个维度,要匹配的单个数据包字段的比特组,以及这些比特组的所有可能值(桶)。

假设某一个field有b个bit,就按照t个bit一组拆分,如b=32,t=4,就分成8组;每一个组4bit则可能有16种取值,也就是源码中说的16个“桶”,展开示意图如下:(这也就是一个table的两个维度,一个维度是group的掩码,另一个维度,是组内的所有可能取值)

假设现在有两个entry:10.0.0.5:1024, 192.168.1.0~192.168.2.1:2048,则拆分成三个rule:

  • rule#0:10.0.0.5
  • rule#1:192.168.1.0/24
  • rule#2:192.168.2.0/31

在lookup table中加入对rule的参考,

如果这不是一个set的最后一个field,就要填充一个映射数组,这个数组能够实现从lookup table到在下一个lookup table中属于同一个entry的rule;

table0 for field0:

table1 for field1:

注意这里,field#0拆成了table#0中的三个rule,但是filed#1只拆成了table#1中的两个rule,因此两个table中存在如下的映射关系:

0 -> 0

1 -> 1

2 -> 1

如果这是set的最后一个field,用元素指针填充映射数组;

在匹配过程中,使用一个初始化为全1的掩码作为result,(也就是pipapo_lookup函数中的res_map)

在每一个table中将当前的result掩码和lookup的“桶”进行AND运算:

例如,192.168.1.5 < 12 0 10 8 0 1 0 5 >只找有1的“桶”的值进行AND,

最后得到了2;

如果不是最后一个field,就用前面的结果值在下一个table中做同样的工作:

最后得到的结果还是2,写成二进制就是010,也就说明是#1;

然后这就是最后一个field了,就要用#1来找对应的elem:

总结

要匹配的东西可能包含多个field,让每一个field对应一个table,table都是groups*buckets的二维数组,(假设二维数组中的每一个元素的bit位足够长),然后开始insert,每一个rule也是按组来进行,在自己对应的bucket的bit位上set1;

然后通过一个映射数组来映射不同field之间rule的连接关系;

匹配过程中就用0xff一通AND,最后就能得到一个对应的elem;(感觉应该每一次映射,都要将掩码先转成位,比如0x2->bit#1,0x4->bit#2,然后利用映射数组得到对应的映射值,例如1->1,然后再转成set1,1->2,然后继续AND;只不过最后一次是得到了一个elem;

数据结构

pipapo_field

https://elixir.bootlin.com/linux/v5.15/source/net/netfilter/nft_set_pipapo.h#L121

pipapo_match

https://elixir.bootlin.com/linux/v5.15/source/net/netfilter/nft_set_pipapo.h#L142

nft_pipapo

可以看到pipapo_set的data成员会是一个nft_pipapo结构体;

源码分析

pipapo_init

static int nft_pipapo_init(const struct nft_set *set,
const struct nft_set_desc *desc,
const struct nlattr * const nla[])
{
struct nft_pipapo *priv = nft_set_priv(set);
struct nft_pipapo_match *m;
struct nft_pipapo_field *f;
int err, i, field_count;

field_count = desc->field_count ? : 1;

if (field_count > NFT_PIPAPO_MAX_FIELDS)
return -EINVAL;

m = kmalloc(sizeof(*priv->match) + sizeof(*f) * field_count,
GFP_KERNEL);
if (!m)
return -ENOMEM;

m->field_count = field_count;
m->bsize_max = 0;

m->scratch = alloc_percpu(unsigned long *);
if (!m->scratch) {
err = -ENOMEM;
goto out_scratch;
}
for_each_possible_cpu(i)
*per_cpu_ptr(m->scratch, i) = NULL;

#ifdef NFT_PIPAPO_ALIGN
m->scratch_aligned = alloc_percpu(unsigned long *);
if (!m->scratch_aligned) {
err = -ENOMEM;
goto out_free;
}
for_each_possible_cpu(i)
*per_cpu_ptr(m->scratch_aligned, i) = NULL;
#endif

rcu_head_init(&m->rcu);

nft_pipapo_for_each_field(f, i, m) {
int len = desc->field_len[i] ? : set->klen;

f->bb = NFT_PIPAPO_GROUP_BITS_INIT;
f->groups = len * NFT_PIPAPO_GROUPS_PER_BYTE(f);

priv->width += round_up(len, sizeof(u32));

f->bsize = 0;
f->rules = 0;
NFT_PIPAPO_LT_ASSIGN(f, NULL);
f->mt = NULL;
}

/* Create an initial clone of matching data for next insertion */
priv->clone = pipapo_clone(m);
if (IS_ERR(priv->clone)) {
err = PTR_ERR(priv->clone);
goto out_free;
}

priv->dirty = false;

rcu_assign_pointer(priv->match, m);

return 0;

out_free:
#ifdef NFT_PIPAPO_ALIGN
free_percpu(m->scratch_aligned);
#endif
free_percpu(m->scratch);
out_scratch:
kfree(m);

return err;
}

nft_pipapo_insert

在set的函数表中就会调用到这里:

elem是栈上的临时变量,其priv才是真的从堆区分配的:

https://elixir.bootlin.com/linux/v5.15/source/net/netfilter/nft_set_pipapo.c#L1153

static int nft_pipapo_insert(const struct net *net, const struct nft_set *set,
const struct nft_set_elem *elem,
struct nft_set_ext **ext2)
{
const struct nft_set_ext *ext = nft_set_elem_ext(set, elem->priv);
union nft_pipapo_map_bucket rulemap[NFT_PIPAPO_MAX_FIELDS];
const u8 *start = (const u8 *)elem->key.val.data, *end;
struct nft_pipapo_elem *e = elem->priv, *dup;//elem->priv 是从堆区分配的elem
struct nft_pipapo *priv = nft_set_priv(set); //priv = set->data
struct nft_pipapo_match *m = priv->clone; //
u8 genmask = nft_genmask_next(net);
struct nft_pipapo_field *f;
int i, bsize_max, err = 0;

if (nft_set_ext_exists(ext, NFT_SET_EXT_KEY_END))
end = (const u8 *)nft_set_ext_key_end(ext)->data;
else
end = start;

dup = pipapo_get(net, set, start, genmask);
if (!IS_ERR(dup)) {
/* Check if we already have the same exact entry */
const struct nft_data *dup_key, *dup_end;

dup_key = nft_set_ext_key(&dup->ext);
if (nft_set_ext_exists(&dup->ext, NFT_SET_EXT_KEY_END))
dup_end = nft_set_ext_key_end(&dup->ext);
else
dup_end = dup_key;

if (!memcmp(start, dup_key->data, sizeof(*dup_key->data)) &&
!memcmp(end, dup_end->data, sizeof(*dup_end->data))) {
*ext2 = &dup->ext;
return -EEXIST;
}

return -ENOTEMPTY;
}

if (PTR_ERR(dup) == -ENOENT) {
/* Look for partially overlapping entries */
dup = pipapo_get(net, set, end, nft_genmask_next(net));
}

if (PTR_ERR(dup) != -ENOENT) {
if (IS_ERR(dup))
return PTR_ERR(dup);
*ext2 = &dup->ext;
return -ENOTEMPTY;
}

/* Validate */
/*
本循环检查用户通过netlink传递进来的数据
逐个field进行检查,每个field都有start和end,这是因为可能是一个范围
如果不是范围,那么没有end,就会使start == end
主要检查 start <= end
*/
nft_pipapo_for_each_field(f, i, m) {
const u8 *start_p = start, *end_p = end;

if (f->rules >= (unsigned long)NFT_PIPAPO_RULE0_MAX)
return -ENOSPC;

/*
#define NFT_PIPAPO_GROUPS_PER_BYTE(f) (BITS_PER_BYTE / (f)->bb)
8/f->bb f->bb就是t值,表示一个group中的bit数量;
一个字节中包含多少个group,1 or 2
然后memcmp就是比较所有groups占据的所有字节,因为长度总还是要用字节作为单位的
其实对于一个field,其groups占据的总字节数,说白了也就是这个field的总长度,
比如192.168.0.1,这个field值是用户通过netlink传递进来的,
目的是为了保证field的start <= end
*/
if (memcmp(start_p, end_p,
f->groups / NFT_PIPAPO_GROUPS_PER_BYTE(f)) > 0)
return -EINVAL;
/*
#define NFT_PIPAPO_GROUPS_PADDED_SIZE(f) \
(round_up((f)->groups / NFT_PIPAPO_GROUPS_PER_BYTE(f), sizeof(u32)))
想让两个指针各自递增groups的长度,但是要满足4字节对齐,向上取整;
也就是去处理下一个field;
*/
start_p += NFT_PIPAPO_GROUPS_PADDED_SIZE(f);
end_p += NFT_PIPAPO_GROUPS_PADDED_SIZE(f);
}

/* Insert */
priv->dirty = true;

bsize_max = m->bsize_max;

/*
下面这个循环的作用是:
1. 逐个域 realloc mt和lt这两个内存空间,更新其中内容
2. rulemap记录新增的元素在每一个域中的映射;
*/
nft_pipapo_for_each_field(f, i, m) {
int ret;

rulemap[i].to = f->rules;//相当于给f新增一个或者多个rule,然后映射到这里

//检查这个域的start和end是否相同,即检查是否为一个范围
ret = memcmp(start, end,
f->groups / NFT_PIPAPO_GROUPS_PER_BYTE(f));
if (!ret) //如果是一个定值就insert
ret = pipapo_insert(f, start, f->groups * f->bb);//在这里f->rules会++
else //如果是一个范围就调用expand
ret = pipapo_expand(f, start, end, f->groups * f->bb);

if (f->bsize > bsize_max)
bsize_max = f->bsize;

rulemap[i].n = ret;//这里的i是field的idx,n就是映射的数量

//处理下一个域
start += NFT_PIPAPO_GROUPS_PADDED_SIZE(f);
end += NFT_PIPAPO_GROUPS_PADDED_SIZE(f);
}

if (!*get_cpu_ptr(m->scratch) || bsize_max > m->bsize_max) {
put_cpu_ptr(m->scratch);

err = pipapo_realloc_scratch(m, bsize_max);
if (err)
return err;

m->bsize_max = bsize_max;
} else {
put_cpu_ptr(m->scratch);
}

*ext2 = &e->ext;

//在所有域的mt中更新新增elem的映射
pipapo_map(m, rulemap, e);

return 0;
}

pipapo_insert的代码如下:

https://elixir.bootlin.com/linux/v5.15/source/net/netfilter/nft_set_pipapo.c#L901

static int pipapo_insert(struct nft_pipapo_field *f, const uint8_t *k,
int mask_bits)
{
int rule = f->rules++, group, ret, bit_offset = 0;

ret = pipapo_resize(f, f->rules - 1, f->rules);
if (ret)
return ret;

for (group = 0; group < f->groups; group++) {
int i, v;
u8 mask;

v = k[group / (BITS_PER_BYTE / f->bb)];
v &= GENMASK(BITS_PER_BYTE - bit_offset - 1, 0);
v >>= (BITS_PER_BYTE - bit_offset) - f->bb;

bit_offset += f->bb;
bit_offset %= BITS_PER_BYTE;

if (mask_bits >= (group + 1) * f->bb) {
/* Not masked */
pipapo_bucket_set(f, rule, group, v);
} else if (mask_bits <= group * f->bb) {
/* Completely masked */
for (i = 0; i < NFT_PIPAPO_BUCKETS(f->bb); i++)
pipapo_bucket_set(f, rule, group, i);
} else {
/* The mask limit falls on this group */
mask = GENMASK(f->bb - 1, 0);
mask >>= mask_bits - group * f->bb;
for (i = 0; i < NFT_PIPAPO_BUCKETS(f->bb); i++) {
if ((i & ~mask) == (v & ~mask))
pipapo_bucket_set(f, rule, group, i);
}
}
}

pipapo_lt_bits_adjust(f);

return 1;
}

pipapo_map:

/**
* pipapo_map() - Insert rules in mapping tables, mapping them between fields
* @m: Matching data, including mapping table
* @map: Table of rule maps: array of first rule and amount of rules
* in next field a given rule maps to, for each field
* @e: For last field, nft_set_ext pointer matching rules map to
*/
static void pipapo_map(struct nft_pipapo_match *m,
union nft_pipapo_map_bucket map[NFT_PIPAPO_MAX_FIELDS],
struct nft_pipapo_elem *e)
{
struct nft_pipapo_field *f;
int i, j;

for (i = 0, f = m->f; i < m->field_count - 1; i++, f++) { //遍历每一个域,除了最后一个
for (j = 0; j < map[i].n; j++) { //在域i中映射了map[i].n个rule
f->mt[map[i].to + j].to = map[i + 1].to;//map[i].to是起始rule的idx
f->mt[map[i].to + j].n = map[i + 1].n;
}
}

/* Last field: map to ext instead of mapping to next field */
for (j = 0; j < map[i].n; j++)
f->mt[map[i].to + j].e = e; //最后一个直接写入元素的地址,这也就是element挂入的位置
}

总结:

给pipapo_set添加一个元素大致是这样的一个过程:

  1. 通过KEY和KEY_END传递进来多个field的起始+结束对;
  2. 然后检查每一个域的边界值是否合法,要求start <= end;
  3. 然后将每一个field插入到match中,具体做法是
    1. 更新每一个field的mt和lt的大小,并copy原来的值;
    2. 加入新的值,这里会返回得到插入到这个field中的rule的数量ret;
    3. rulemap记录新增的rule的起始idx和数量ret;
  4. 在match的每一个field中更新mt数组:
    1. 其实就是要把整个过程中在每一个域中新产生的rule都串联起来;
    2. 最后一个指向这个新的elem的地址;

在每一个域中,同一批次添加的rule们,它们的to和n都是相同的,也就是说映射范围都是相同的,内部没有区分具体的映射;

pipapo_resize

/**
* pipapo_resize() - Resize lookup or mapping table, or both
* @f: Field containing lookup and mapping tables
* @old_rules: Previous amount of rules in field
* @rules: New amount of rules
*
* Increase, decrease or maintain tables size depending on new amount of rules,
* and copy data over. In case the new size is smaller, throw away data for
* highest-numbered rules.
*
* Return: 0 on success, -ENOMEM on allocation failure.
*/
static int pipapo_resize(struct nft_pipapo_field *f, int old_rules, int rules)
{
long *new_lt = NULL, *new_p, *old_lt = f->lt, *old_p;
union nft_pipapo_map_bucket *new_mt, *old_mt = f->mt;
size_t new_bucket_size, copy;
int group, bucket;

//计算每个bucket需要的long的数量
new_bucket_size = DIV_ROUND_UP(rules, BITS_PER_LONG);
#ifdef NFT_PIPAPO_ALIGN
new_bucket_size = roundup(new_bucket_size,
NFT_PIPAPO_ALIGN / sizeof(*new_lt));
#endif

if (new_bucket_size == f->bsize)
goto mt;

if (new_bucket_size > f->bsize)
copy = f->bsize;
else
copy = new_bucket_size;

//分配新的bucket空间
new_lt = kvzalloc(f->groups * NFT_PIPAPO_BUCKETS(f->bb) *
new_bucket_size * sizeof(*new_lt) +
NFT_PIPAPO_ALIGN_HEADROOM,
GFP_KERNEL);
if (!new_lt)
return -ENOMEM;


new_p = NFT_PIPAPO_LT_ALIGN(new_lt);
old_p = NFT_PIPAPO_LT_ALIGN(old_lt);

for (group = 0; group < f->groups; group++) {
for (bucket = 0; bucket < NFT_PIPAPO_BUCKETS(f->bb); bucket++) {
memcpy(new_p, old_p, copy * sizeof(*new_p));
new_p += copy;
old_p += copy;

if (new_bucket_size > f->bsize)
new_p += new_bucket_size - f->bsize;
else
old_p += f->bsize - new_bucket_size;
}
}

mt:
new_mt = kvmalloc(rules * sizeof(*new_mt), GFP_KERNEL);
if (!new_mt) {
kvfree(new_lt);
return -ENOMEM;
}

memcpy(new_mt, f->mt, min(old_rules, rules) * sizeof(*new_mt));
if (rules > old_rules) {
memset(new_mt + old_rules, 0,
(rules - old_rules) * sizeof(*new_mt));
}

if (new_lt) {
f->bsize = new_bucket_size;
NFT_PIPAPO_LT_ASSIGN(f, new_lt);
kvfree(old_lt);
}

f->mt = new_mt;
kvfree(old_mt);

return 0;
}

pipapo_match_field

r在迭代遍历一个field中的所有rule;

这里的e记录的是上一个rule的范围(这里e用的是union特性,不是为了比较e这个指针,而是为了一次性比较to和n两个成员);

第一个不要比较,之后再后一个和前一个进行比较,一旦不一样就结束,也就是说寻找第一个映射范围和first不一样的rule;

如果结束了还没有找到(全都一样),就返回f->rules - first;

或者循环没进去first>=f->rules,那么就返回0;

就是遍历所有field;

pipapo_match_field:

static bool pipapo_match_field(struct nft_pipapo_field *f,
int first_rule, int rule_count,
const u8 *start, const u8 *end)
{
u8 right[NFT_PIPAPO_MAX_BYTES] = { 0 };
u8 left[NFT_PIPAPO_MAX_BYTES] = { 0 };

pipapo_get_boundaries(f, first_rule, rule_count, left, right);

return !memcmp(start, left,
f->groups / NFT_PIPAPO_GROUPS_PER_BYTE(f)) &&
!memcmp(end, right, f->groups / NFT_PIPAPO_GROUPS_PER_BYTE(f));
}

先更新边界值,然后要求所有上限、所有下限都分别对应相等,才认为这一组rule是等价的;

pipapo_get_boundaries:

/**
* pipapo_get_boundaries() - Get byte interval for associated rules
* @f: Field including lookup table
* @first_rule: First rule (lowest index)
* @rule_count: Number of associated rules
* @left: Byte expression for left boundary (start of range)
* @right: Byte expression for right boundary (end of range)
*
* Given the first rule and amount of rules that originated from the same entry,
* build the original range associated with the entry, and calculate the length
* of the originating netmask.
*
* In pictures:
*
* bucket
* group 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
* 0 1,2
* 1 1,2
* 2 1,2
* 3 1,2
* 4 1,2
* 5 1 2
* 6 1,2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
* 7 1,2 1,2 1 1 1 1 1 1 1 1 1 1 1 1 1 1
*
* this is the lookup table corresponding to the IPv4 range
* 192.168.1.0-192.168.2.1, which was expanded to the two composing netmasks,
* rule #1: 192.168.1.0/24, and rule #2: 192.168.2.0/31.
*
* This function fills @left and @right with the byte values of the leftmost
* and rightmost bucket indices for the lowest and highest rule indices,
* respectively. If @first_rule is 1 and @rule_count is 2, we obtain, in
* nibbles:
* left: < 12, 0, 10, 8, 0, 1, 0, 0 >
* right: < 12, 0, 10, 8, 0, 2, 2, 1 >
* corresponding to bytes:
* left: < 192, 168, 1, 0 >
* right: < 192, 168, 2, 1 >
* with mask length irrelevant here, unused on return, as the range is already
* defined by its start and end points. The mask length is relevant for a single
* ranged entry instead: if @first_rule is 1 and @rule_count is 1, we ignore
* rule 2 above: @left becomes < 192, 168, 1, 0 >, @right becomes
* < 192, 168, 1, 255 >, and the mask length, calculated from the distances
* between leftmost and rightmost bucket indices for each group, would be 24.
*
* Return: mask length, in bits.
*/
static int pipapo_get_boundaries(struct nft_pipapo_field *f, int first_rule,
int rule_count, u8 *left, u8 *right)
{
int g, mask_len = 0, bit_offset = 0;
u8 *l = left, *r = right;

for (g = 0; g < f->groups; g++) { //遍历一个field中的所有group
int b, x0, x1;

x0 = -1;
x1 = -1;
for (b = 0; b < NFT_PIPAPO_BUCKETS(f->bb); b++) { //f->bb就是t值,表示一个group中的bit数量;NFT_PIPAPO_BUCKETS(f->bb) == 1 << (f->bb),也就是bucket的数量;
//因此这里是在遍历每一个bucket的取值;
unsigned long *pos;

//f->bsize表示ROUNDUP(rules/8),表示一个bucket要占用多少个long;
//NFT_PIPAPO_BUCKETS(f->bb) == 2^f->bb表示一个group的bucket数目
//g * NFT_PIPAPO_BUCKETS(f->bb) + b 表示这是第几个bucket
//(g * NFT_PIPAPO_BUCKETS(f->bb) + b) * f-> 表示当前bucket的long偏移
//pos指向lt为起始位置的、当前bucket的long偏移处
pos = NFT_PIPAPO_LT_ALIGN(f->lt) +
(g * NFT_PIPAPO_BUCKETS(f->bb) + b) * f->bsize;

//first_rule表示第一个rule
//这里检测pos指向的long的first_rule处的bit值是否为1
//注意只有x0 == -1 的时候才更新x0,说明x0只更新一次
//x0的含义是本group中第一个能匹配first_rule的bucket的idx
if (test_bit(first_rule, pos) && x0 == -1)
x0 = b;
//first_rule + rule_count - 1 表示最后一个rule
//检测最后一个rule对应的bit位是否为1
//x1每次都可以更新,x1表示本groups中最后一个匹配last_rule的bucket的idx
if (test_bit(first_rule + rule_count - 1, pos))
x1 = b;

}

//f->bb就是t值,表示一个group中的bit数量
//bit_offset用于表示在每一个字节中的bit偏移量;
//先用低字节后用高字节,先用高位后用低位
*l |= x0 << (BITS_PER_BYTE - f->bb - bit_offset);
*r |= x1 << (BITS_PER_BYTE - f->bb - bit_offset);


bit_offset += f->bb; //准备处理下一个group
if (bit_offset >= BITS_PER_BYTE) {
bit_offset %= BITS_PER_BYTE;
l++;
r++;
}
//left、right的每一个bit代表bucket值

if (x1 - x0 == 0)
mask_len += 4;
else if (x1 - x0 == 1)
mask_len += 3;
else if (x1 - x0 == 3)
mask_len += 2;
else if (x1 - x0 == 7)
mask_len += 1;
}

return mask_len;
}

针对一个field中的所有group进行遍历;

bit_offset

#define NFT_PIPAPO_BUCKETS(bb)		(1 << (bb))

举个例子来理解这个边界函数,比如我们要处理表示192.168.0.1~192.168.2.1的rule们,这个范围首先会被拆分成多个rule,(注意这是一个field),假设在这个field中一共可以分出来8个group,每个group中有16个bucket;

那么在遍历group的时候,就相当于将这些IP地址拆成了8部分,每部分得出来一个最大值和最小值,将其等价对应的bucket下表写入到left和right数组中,8对边界值都有了其实也就表示了整个IP地址的范围了,所以这个函数其实就是以另一种形式来计算一组rule的范围,竟边界值更新到两个数组中;

pipapo_remove

nf_tables_commit函数中:

static void nft_pipapo_remove(const struct net *net, const struct nft_set *set,
const struct nft_set_elem *elem)
{
struct nft_pipapo *priv = nft_set_priv(set);
struct nft_pipapo_match *m = priv->clone;
struct nft_pipapo_elem *e = elem->priv;
int rules_f0, first_rule = 0;
const u8 *data;

data = (const u8 *)nft_set_ext_key(&e->ext);

e = pipapo_get(net, set, data, 0);
if (IS_ERR(e))
return;

while ((rules_f0 = pipapo_rules_same_key(m->f, first_rule))) {//找出本field中和first_rule向下映射完全相同的rule的数量 rules_f0
//以第一个field中的映射进行分组
union nft_pipapo_map_bucket rulemap[NFT_PIPAPO_MAX_FIELDS];
const u8 *match_start, *match_end;
struct nft_pipapo_field *f;
int i, start, rules_fx;

match_start = data;
match_end = (const u8 *)nft_set_ext_key_end(&e->ext)->data;

start = first_rule;
rules_fx = rules_f0;

//用rulemap记录这一组rule的所有向下映射
nft_pipapo_for_each_field(f, i, m) {
if (!pipapo_match_field(f, start, rules_fx, //找相同映射的一组rule,这属于同一个elem
match_start, match_end)) //用户传进去一组范围,直接用上下限
break; //如果发现有对不上的,就不drop

//
rulemap[i].to = start;
rulemap[i].n = rules_fx;

rules_fx = f->mt[start].n;
start = f->mt[start].to;

//迭代每一个field
match_start += NFT_PIPAPO_GROUPS_PADDED_SIZE(f);
match_end += NFT_PIPAPO_GROUPS_PADDED_SIZE(f);
}

/*

*/
if (i == m->field_count) { //要求所有的都match
priv->dirty = true;
pipapo_drop(m, rulemap);
pipapo_commit(set);
return;
}

first_rule += rules_f0;
}
}

接着看下面这张图:

黑色是elem0在每个field添加的rule,蓝色是elem1添加的;

所以pipapo_rules_same_key函数主要是区分不同的elem;

然后遍历所有域逐层检查,要求每一层的范围都完全一致;

一个elem相当于一个[ip][port][ip]这种复合东西的形式化表示,我们要想删除这个elem就要保证每一个部分(域)的范围都对上了,nft_pipapo_for_each_field(f, i, m)这个循环就是要保证所有的部分(域)都匹配成功,而具体到pipapo_match_field函数中就是检查一个域,一个域又拆成好多group查各自的范围,完全匹配了才能说这个elem匹配成功了,就到下边调用pipapo_drop函数来删除elem;

pipapo_drop

pipapo_drop:

static void pipapo_drop(struct nft_pipapo_match *m,
union nft_pipapo_map_bucket rulemap[])
{
struct nft_pipapo_field *f;
int i;

nft_pipapo_for_each_field(f, i, m) {
int g;

for (g = 0; g < f->groups; g++) {
unsigned long *pos;
int b;

//bsize表示一个bucket中的long的数
//NFT_PIPAPO_BUCKETS(f->bb) * f->bsize 表示一个group中所有bucket占用的long的数
//pos指向当前group在lt中的起始位置
pos = NFT_PIPAPO_LT_ALIGN(f->lt) + g *
NFT_PIPAPO_BUCKETS(f->bb) * f->bsize;

//
for (b = 0; b < NFT_PIPAPO_BUCKETS(f->bb); b++) { //迭代2^t次,每一个bucket一次
//bitmap_cut的作用大致就是处理一个bucket对应的lt空间
//从to位置开始n个bit给删除掉
bitmap_cut(pos, pos, rulemap[i].to,
rulemap[i].n,
f->bsize * BITS_PER_LONG); //f->bsize * BITS_PER_LONG 表示一个bucket占用的bit数,一般一个bit就是一个rule,但是要向上取整

pos += f->bsize; //下一个group的lt
}
}

pipapo_unmap(f->mt, f->rules, rulemap[i].to, rulemap[i].n,
rulemap[i + 1].n, i == m->field_count - 1);
if (pipapo_resize(f, f->rules, f->rules - rulemap[i].n)) {
/* We can ignore this, a failure to shrink tables down
* doesn't make tables invalid.
*/
;
}
f->rules -= rulemap[i].n;

pipapo_lt_bits_adjust(f);
}
}

根据前面的分析,pipapo_drop 函数是用来drop掉一个elem的,rulemap就是这个elem的映射关系;

总结:

bitmap_cut函数如下:

void bitmap_cut(unsigned long *dst, const unsigned long *src,
unsigned int first, unsigned int cut, unsigned int nbits) //nbits是指一个bucket为了表示该field中的所有rule所需要的bit数,最后是要向上取整的
//first表示的是rule的下标,
{
unsigned int len = BITS_TO_LONGS(nbits);
unsigned long keep = 0, carry;
int i;

if (first % BITS_PER_LONG) {
keep = src[first / BITS_PER_LONG] &
(~0UL >> (BITS_PER_LONG - first % BITS_PER_LONG));
}

memmove(dst, src, len * sizeof(*dst));

while (cut--) {
for (i = first / BITS_PER_LONG; i < len; i++) {
if (i < len - 1)
carry = dst[i + 1] & 1UL;
else
carry = 0;
//carry是下一个字节的第一个bit
//这里的目的就是将后边的bit挪上来
dst[i] = (dst[i] >> 1) | (carry << (BITS_PER_LONG - 1));
}
}

dst[first / BITS_PER_LONG] &= ~0UL << (first % BITS_PER_LONG);
dst[first / BITS_PER_LONG] |= keep;
}

pipapo_unmap函数如下:

static void pipapo_unmap(union nft_pipapo_map_bucket *mt, int rules,
int start, int n, int to_offset, bool is_last)
{
//mt is f->mt, rules is f->rules
//start: rule_start, n is count
//to_offset : 下一个field中的映射个数
//is_last: 标识本field是不是最后一个field
int i;

//将start ~ start+n-1 的空间覆盖掉,后边腾出来的空间清零
memmove(mt + start, mt + start + n, (rules - start - n) * sizeof(*mt));
memset(mt + rules - n, 0, n * sizeof(*mt));

//
if (is_last)
return;

//从start开始都要改变映射关系
for (i = start; i < rules - n; i++)
mt[i].to -= to_offset;
}

其原理如下图所示:

pipapo_drop的功能总结如下:

  1. 现在想在一个match中删除掉一个elem在所有field中的所有rule,这些rule在每一层field中的start和count都在rulemap中;
  2. 用一个循环处理每一个field:
    1. 每一个field中循环group,再循环bucket,在每个bucket中删掉不要的rule的bit位;
    2. 每次循环中删掉该域中目标elem的所有rule;
  3. 全搞定之后resize就好了;

pipapo_commit

/**
* pipapo_commit() - Replace lookup data with current working copy
* @set: nftables API set representation
*
* While at it, check if we should perform garbage collection on the working
* copy before committing it for lookup, and don't replace the table if the
* working copy doesn't have pending changes.
*
* We also need to create a new working copy for subsequent insertions and
* deletions.
*/
static void pipapo_commit(const struct nft_set *set)
{
struct nft_pipapo *priv = nft_set_priv(set);
struct nft_pipapo_match *new_clone, *old;

if (time_after_eq(jiffies, priv->last_gc + nft_set_gc_interval(set)))
pipapo_gc(set, priv->clone);

if (!priv->dirty)
return;

new_clone = pipapo_clone(priv->clone);
if (IS_ERR(new_clone))
return;

priv->dirty = false;

old = rcu_access_pointer(priv->match);
rcu_assign_pointer(priv->match, priv->clone);
if (old)
call_rcu(&old->rcu, pipapo_reclaim_match);

priv->clone = new_clone;
}

/**
* pipapo_reclaim_match - RCU callback to free fields from old matching data
* @rcu: RCU head
*/
static void pipapo_reclaim_match(struct rcu_head *rcu)
{
struct nft_pipapo_match *m;
int i;

m = container_of(rcu, struct nft_pipapo_match, rcu);

for_each_possible_cpu(i)
kfree(*per_cpu_ptr(m->scratch, i));

#ifdef NFT_PIPAPO_ALIGN
free_percpu(m->scratch_aligned);
#endif
free_percpu(m->scratch);

pipapo_free_fields(m);

kfree(m);
}

总结

参考

https://196082.github.io/2024/09/03/nftables-CVEs1/


文章作者: q1ming
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 q1ming !
  目录