Home BOJ 11654 아스키 코드
Post
Cancel

BOJ 11654 아스키 코드

세그먼트 트리 (구간 합 트리)


개념

이진 트리의 형태로 여러 개의 데이터가 존재할 때 특정 구간의 합을 구하는데 사용하는 자료구조이다.

구간 (a,b)에 대한 구간 합 연산 시 O(lgN), i번째 수를 j로 바꿀 시 O(lgN)의 시간 복잡도를 갖는다.

구현

세그먼트 트리 인덱스는 1부터 시작하는데, 이는 세그먼트 트리를 재귀적으로 구성할 때 자식 노드의 인덱스를 쉽게 계산하기 위함이다. (현재 인덱스 * 2 = 왼쪽 자식의 인덱스)

구간 합을 구하고 구간 합 트리를 갱신하기 위해서는 아래 3가지 함수를 구현하면 된다.

1
2
3
- 트리를 생성하는 init 함수
- 데이터 값을 업데이트하는 update 함수
- 구간 합을 구하는 sum 함수
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#include<iostream>
#include<vector>
using namespace std;

// 세그먼트 트리 생성
int init(vector<int> &a, vector<int> &tree, int node, int start, int end) {
    if (start == end) {	// 리프 노드인 경우
        return tree[node] = a[start];
    } else {
    	int mid = (start + end) / 2;
        // 부모 노드는 왼쪽과 오른쪽 자식의 구간 합을 값으로 갖는다.
        return tree[node] = init(a, tree, node*2, start, mid) + init(a, tree, node*2+1, mid+1, end);
    }
}

// 배열 값 갱신
void update(vector<int> &tree, int node, int start, int end, int index, int diff) {
    if (index < start || index > end) return;	// 갱신 구간에 포함되지 않는 경우
    tree[node] = tree[node] + diff;				// diff 만큼 보정
    if (start != end) {							// 리프 노드가 아닌 경우
    	int mid = (start + end) / 2;
        update(tree,node*2, start, mid, index, diff);	// 왼쪽 자식 업데이트
        update(tree,node*2+1, mid+1, end, index, diff);	// 오른쪽 자식 업데이트
    }
}
int sum(vector<int> &tree, int node, int start, int end, int left, int right) {
    if (left > end || right < start) {	// [left,right]와 [start,end]가 겹치지 않는 경우
        return 0;
    }
    if (left <= start && end <= right) {	// [left,right]가 [start,end]를 완전히 포함하는 경우
        return tree[node];
    }
    
    // [start,end]가 [left,right]를 완전히 포함하는 경우
    // 혹은 [left,right]와 [start,end]가 겹쳐져 있는 경우
    int mid = (start + end) / 2;
    return sum(tree, node*2, start, mid, left, right) + sum(tree, node*2+1, mid+1, end, left, right);
}

int main() {
    int n, m, k;		// n: 데이터 갯수, m: 데이터 변경이 일어나는 횟수, k: 구간 합을 구하는 횟수
    cin >> n >> m >>k;
    vector<int> a(n);
    int h = (int)ceil(log2(n));	// 트리의 높이는 노드의 개수에 로그를 취한 것과 같다
    int tree_size = (1 << (h+1));
    vector<int> tree(tree_size);
    m += k;
    for (int i=0; i<n; i++) {
    	cin >> a[i];
    }
    init(a, tree, 1, 0, n-1);
    while (m--) {
        int t1,t2,t3;
        cin >> t1 >> t2 >> t3;	// t1이 1인 쿼리의 경우, 데이터 변경. t1이 2인 경우, 구간 합 구하기.
        if (t1 == 1) {
            t2-=1;
            int diff = t3-a[t2];
            a[t2] = t3;
            update(tree, 1, 0, n-1, t2, diff);
        } else if (t1 == 2) {
            cout << sum(tree, 1, 0, n-1, t2-1, t3-1));
        }
    }
    return 0;
}

참고: https://www.acmicpc.net/blog/view/9

This post is licensed under CC BY 4.0 by the author.