返回
Featured image of post 主席树

主席树

简介

主席树是一种数据结构,用来解决区间第k小(大)问题。要问为什么叫主席树?因为这是一个叫黄嘉泰的dalao发明的 名字缩写跟我国某位主席一样,Orz
时间复杂度是Onlogn,十分优秀 空间复杂度倒是有点大

前置姿势

权值线段树

  • 众所周知,权值线段树是一种线段树 话说这不是废话吗2333
  • 它不同于一般线段树的是它的每一个结点的值代表这个区间中的数在序列中出现的总次数,正所谓“权值”
  • 它可以被用来处理区间第k小(大)问题

查询原理

此处利用二分的思想: 现有区间[l,r],不妨设区间[l,m] ( m为区间中间位置 ) 内的数的个数为n: 显然:
若 n < k ,第k小就是区间[m+1,r]的第 k - n 小 反之,则为区间[l,r]的第k小 于是可以写出伪代码

1
2
3
4
5
6
7
int que(k,l,r)//询问区间[l,r]的第k小
{
    if(l==r) return l;
    int m=(l+r)/2;
    if(cnt(l,m)>=k) return que(k,l,m);
    return que(k-cnt(l,r),m+1,r);
}

举个栗子

对于序列 1,3,2,2,0,1 建立的权值线段树如下


查询第3小:

  1. 访问区间[0,3],有6个数,> 3
  2. 访问左区间[0,1]查询第3小,有3个数,=3
  3. 访问右区间[1,1],查询第2小,返回结果->1

建树

当然是递归建树啦ψ(`∇´)ψ
此处同普通线段树,只需统计区间个数即可
此处使用暴力统计 其实是懒得打

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
int ls[100],rs[100],cnt=0,v[100];
int l[100],r[100];
int num[100];
int n;
void update(int p)
{
    v[p]=v[ls[p]]+v[rs[p]];
}
void build(int p,int l,int r)
{
    l[p]=l;r[p]=r;
    if(l==r)
    {
        for(int i=0;i<n;i++)
        {
            if(num[i]>=l&&num[i]<=r) v[p]++;
        }
        return;
    }
    int m=l+r>>1;
    ls[p]=++cnt;
    rs[p]=++cnt;
    build(ls[p],l,m);
    build(rs[p],m+1,r);
    update(p);
}

查询

1
2
3
4
5
6
7
int que(int p,int k)
{
    if(l[p]==r[p]) return l;
    int m=l[p]+r[p]>>1;
    if(v[p]<k) return que(rs[p],k-v[p]);
    return que(ls[p],k);
}

思想

权值线段树有一个缺点,只能查询整个区间的第k小,对于子区间就无能为力了
那是不是对于所有子区间都能搞出一个权值线段树就好了呢?

前缀和

这里就可以用前缀和的思想:
若用相同范围(即相同的l,r)对任意子区间建树,所有的权值线段树结构相同,这意味着两棵树可以相减(即每一个等位节点相减)!
于是我们可以依次对[0,0],[0,1],[0,2]…[0,n]建树,如需得到[l,r]的树,只需将[0,r]和[0,l-1]两棵线段树相减即可!
查询可以机智地使用宏定义

1
2
#define lv (value[ls[r]]-value[ls[l]])//相减后得到的线段树左儿子的值
#define rv (value[rs[r]]-value[rs[l]])//相减后得到的线段树右儿子的值

空间优化

然而这里有一个问题,每一个区间建立线段树,空间复杂度太大,无法接受
怎么优化呢?
实际上,下一棵线段树就相当于上一棵线段树进行单点更新,即[0,n]相对于[0,n-1]只是更新了[num[n],num[n]]的值而已。
由于主席树不是完全二叉树,需要开数组保存左右儿子的编号,这里利用一下:如果更新左儿子,就把右儿子编号赋为上一棵树该节点右儿子的编号,反之也一样。这样每次建树只需新开一条链就行了 有没有感觉特别机智
更新

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
void updata(int last/*上一棵线段树等位结点编号*/, int p, int v)//建立下一棵线段树,相当于线段树的单点更新,只需保存一条链而已
{
	ln[p] = ln[last]; rn[p] = rn[last];//这棵线段树和上一棵一样,直接赋值
	value[p] = value[last] + 1;//明显v包含在区间里,直接加一即可
	if (ln[p] == rn[p]) return;//叶节点,不用继续更新,直接返回即可
	if (v <= (ln[p] + rn[p] >> 1))//包含在左儿子的区间中
	{
		rs[p] = rs[last];//右子树(即右儿子)和上一棵线段树的一样
		ls[p] = ++cnt;//累计节点编号
		//本来更新节点的值是写在这里的,但这样根节点就更新不了了,于是写到了前面
		updata(ls[last], ls[p], v);//建立左子树
	}
	else//反之亦然同理
	{
		ls[p] = ls[last];
		rs[p] = ++cnt;
		updata(rs[last], rs[p], v);
	}
}

循环建树

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
void build()//建立主席树
{
	root[0] = cnt;//0号根节点为空树,便于相减
	bud(0, n2 - 1, cnt);//先建一发空树
	for (int i = 0; i < n; i++)//[0,n)和(0,n]皆可,只不过是root[i-1]和root[i+1]的区别
	{
		root[i + 1] = ++cnt;//累计根节点
		updata(root[i], root[i + 1], num[i]);//更新这棵线段树
	}
}

离散化

然而还有一个问题,权值线段树的空间复杂度是与数据范围相关的,所以此处需要离散化。所谓离散化就是把一堆范围很大的数映射到一段很小的区间里
例如对于序列:12 , 76 , 123 , 67 , 1 , 20
排序后:1 , 12 , 20 , 67 , 76 , 123
令:1->1 , 12->2 , 20->3 , 67->4 , 76->5 , 123->6
原序列就变成了:2 , 5 , 6 , 4 , 1 , 3
至于怎么搞映射?这不是map最擅长的吗( ̄︶ ̄)↗ 

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
map <int, int> po/*由原值映射到离散化后的值*/, re/*由离散化后的值映射到原值*/;
sort(tmp, tmp + n);//先来一发排序
//巨坑!!tmp+n是排序区间末区间端点的下一个值的地址!!!
//我也不知道STL为什么要这样反正我就这样爆了两个点
n2 = unique(tmp, tmp + n) - tmp;//去重也一样!!!(STL的东西都是这个尿性)
for (int i = 0; i < n2; i++)//记录离散化后的数值
{
	po[tmp[i]] = i;
	re[i] = tmp[i];
}
for (int i = 0; i < n; i++)//将原值变为离散化后的值
{
	num[i] = po[num[i]];
}

输出答案时也别忘了搞回去

1
cout << re[ans] << '\n';//输出答案,一定要离散化的数值搞回来!!!小心爆零!!!

完整代码

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
/*
主席树 20191210
皮:震惊,cin和cout,竟然没有没有TLE???数据竟然没有爆int???!!!
*/
#include<iostream>
#include<map>
#include<algorithm>
#include<cstdio>
using namespace std;
int ls[20000000]/*保存节点左儿子*/, rs[20000000]/*保存节点右儿子*/,
root[200010]/*保存第一个元素到第n个数对应权值线段树的根节点*/,
value[20000000]/*保存权值线段树节点的值*/, 
ln[20000000]/*保存节点区间左端点*/, rn[20000000]/*保存节点区间右端点*/;
int cnt/*累计节点编号*/, n/*数的个数*/;
int n2/*排序去重后数的个数(离散化巨坑)*/;
int num[200010]/*保存序列*/, tmp[200010]/*便于离散化*/;
void bud(int l, int r, int p)//建立空树(便于线段树相减),同普通线段树,只是没有保存节点权值而已
{
	ln[p] = l; rn[p] = r;//显然
	if (l != r)//巨坑,否则死循环!!!
	{
		ls[p] = ++cnt;//累计子节点
		bud(l, (l + r >> 1), cnt);//递归建立左子树
		rs[p] = ++cnt;//累计子节点
		bud((l + r >> 1) + 1, r, cnt);//递归建立右子树
	}
}
void updata(int last/*上一棵线段树等位结点编号*/, int p, int v)//建立下一棵线段树,相当于线段树的单点更新,只需保存一条链而已
{
	ln[p] = ln[last]; rn[p] = rn[last];//这棵线段树和上一棵一样,直接赋值
	value[p] = value[last] + 1;//明显v包含在区间里,直接加一即可
	if (ln[p] == rn[p]) return;//叶节点,不用继续更新,直接返回即可
	if (v <= (ln[p] + rn[p] >> 1))//包含在左儿子的区间中
	{
		rs[p] = rs[last];//右子树(即右儿子)和上一棵线段树的一样
		ls[p] = ++cnt;//累计节点编号
		//本来更新节点的值是写在这里的,但这样根节点就更新不了了,于是写到了前面
		updata(ls[last], ls[p], v);//建立左子树
	}
	else//反之亦然同理
	{
		ls[p] = ls[last];
		rs[p] = ++cnt;
		updata(rs[last], rs[p], v);
	}
}
void build()//建立主席树
{
	root[0] = cnt;//0号根节点为空树,便于相减
	bud(0, n2 - 1, cnt);//先建一发空树
	for (int i = 0; i < n; i++)//[0,n)和(0,n]皆可,只不过是root[i-1]和root[i+1]的区别
	{
		root[i + 1] = ++cnt;//累计根节点
		updata(root[i], root[i + 1], num[i]);//更新这棵线段树
	}
}
//这类宏定义一定要加括号,否则死的很惨!!!!!!
//我才不会告诉你我就是因为这个第一次交爆零了...
//还有,宏定义这种东西作用范围是当前文件...写在哪里都一样...
#define lv (value[ls[r]]-value[ls[l]])//相减后得到的线段树左儿子的值
#define rv (value[rs[r]]-value[rs[l]])//相减后得到的线段树右儿子的值
int ans;//落谷猥琐巨坑,直接返回会MLE!!!被迫改void并使用全局变量保存答案
void query(int k, int l, int r)//这里的l和r指的不是数列的区间,而是l棵线段树和第r棵线段树的根节点
{
	//直接写成(其实就是)权值的查询就好
	if (ln[l] == rn[l])//叶节点,直接记录答案
	{
		ans = ln[l];
		return;
	}
	if (k <= lv)//左儿子包含的数的个数大于等于待查询的k
		query(k, ls[l], ls[r]);//向左儿子查询第k小即可
	else//左儿子包含的数的个数小于待查询的k
		query(k - lv, rs[l], rs[r]);//整段区间的第k小即为右区间的第(k-左区间包含数的个数)小
}
map <int, int> po/*由原值映射到离散化后的值*/, re/*由离散化后的值映射到原值*/;
int main()
{
	cin >> n;//输入,显然
	int m;
	cin >> m;
	for (int i = 0; i < n; i++)//输入并保存副本用于离散化
	{
		cin >> num[i];
		tmp[i] = num[i];
	}
	//---------------------------------------离散化--------------------------------------------
	sort(tmp, tmp + n);//先来一发排序
	//巨坑!!tmp+n是排序区间末区间端点的下一个值的地址!!!
	//我也不知道STL为什么要这样反正我就这样爆了两个点
	n2 = unique(tmp, tmp + n) - tmp;//去重也一样!!!(STL的东西都是这个尿性)
	for (int i = 0; i < n2; i++)//记录离散化后的数值
	{
		po[tmp[i]] = i;
		re[i] = tmp[i];
	}
	for (int i = 0; i < n; i++)//将原值变为离散化后的值
	{
		num[i] = po[num[i]];
	}
	//-----------------------------------------------------------------------------------------
	build();//建立主席树
	int l, r, k;
	while (m--)
	{
		cin >> l >> r >> k;
		l--;
		query(k, root[l], root[r]);//一定要传递根节点的编号!!!!!!
		cout << re[ans] << '\n';//输出答案,一定要离散化的数值搞回来!!!小心爆零!!!
	}
	return 0;
}
Built with Hugo
Theme Stack designed by Jimmy