세그먼트 트리 (구간 합 트리)
개념
이진 트리의 형태로 여러 개의 데이터가 존재할 때 특정 구간의 합을 구하는데 사용하는 자료구조이다.
구간 (a,b)에 대한 구간 합 연산 시 O(lgN), i번째 수를 j로 바꿀 시 O(lgN)의 시간 복잡도를 갖는다.
구현
세그먼트 트리 인덱스는 1부터 시작하는데, 이는 세그먼트 트리를 재귀적으로 구성할 때 자식 노드의 인덱스를 쉽게 계산하기 위함이다. (현재 인덱스 * 2 = 왼쪽 자식의 인덱스)
구간 합을 구하고 구간 합 트리를 갱신하기 위해서는 아래 3가지 함수를 구현하면 된다.
1
2
3
- 트리를 생성하는 init 함수
- 데이터 값을 업데이트하는 update 함수
- 구간 합을 구하는 sum 함수
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#include<iostream>
#include<vector>
using namespace std;
// 세그먼트 트리 생성
int init(vector<int> &a, vector<int> &tree, int node, int start, int end) {
if (start == end) { // 리프 노드인 경우
return tree[node] = a[start];
} else {
int mid = (start + end) / 2;
// 부모 노드는 왼쪽과 오른쪽 자식의 구간 합을 값으로 갖는다.
return tree[node] = init(a, tree, node*2, start, mid) + init(a, tree, node*2+1, mid+1, end);
}
}
// 배열 값 갱신
void update(vector<int> &tree, int node, int start, int end, int index, int diff) {
if (index < start || index > end) return; // 갱신 구간에 포함되지 않는 경우
tree[node] = tree[node] + diff; // diff 만큼 보정
if (start != end) { // 리프 노드가 아닌 경우
int mid = (start + end) / 2;
update(tree,node*2, start, mid, index, diff); // 왼쪽 자식 업데이트
update(tree,node*2+1, mid+1, end, index, diff); // 오른쪽 자식 업데이트
}
}
int sum(vector<int> &tree, int node, int start, int end, int left, int right) {
if (left > end || right < start) { // [left,right]와 [start,end]가 겹치지 않는 경우
return 0;
}
if (left <= start && end <= right) { // [left,right]가 [start,end]를 완전히 포함하는 경우
return tree[node];
}
// [start,end]가 [left,right]를 완전히 포함하는 경우
// 혹은 [left,right]와 [start,end]가 겹쳐져 있는 경우
int mid = (start + end) / 2;
return sum(tree, node*2, start, mid, left, right) + sum(tree, node*2+1, mid+1, end, left, right);
}
int main() {
int n, m, k; // n: 데이터 갯수, m: 데이터 변경이 일어나는 횟수, k: 구간 합을 구하는 횟수
cin >> n >> m >>k;
vector<int> a(n);
int h = (int)ceil(log2(n)); // 트리의 높이는 노드의 개수에 로그를 취한 것과 같다
int tree_size = (1 << (h+1));
vector<int> tree(tree_size);
m += k;
for (int i=0; i<n; i++) {
cin >> a[i];
}
init(a, tree, 1, 0, n-1);
while (m--) {
int t1,t2,t3;
cin >> t1 >> t2 >> t3; // t1이 1인 쿼리의 경우, 데이터 변경. t1이 2인 경우, 구간 합 구하기.
if (t1 == 1) {
t2-=1;
int diff = t3-a[t2];
a[t2] = t3;
update(tree, 1, 0, n-1, t2, diff);
} else if (t1 == 2) {
cout << sum(tree, 1, 0, n-1, t2-1, t3-1));
}
}
return 0;
}
참고: https://www.acmicpc.net/blog/view/9