-
[알고리즘] 세그먼트 트리 ( Segment Tree )IT 발자취.../알고리즘 2019. 8. 3. 15:07
요약 : 주어진 쿼리에 대해 빠르게 응답하기 위해 만들어진 자료구조
예제 문제 : https://www.acmicpc.net/problem/2042 ( 구간 합 구하기 )
예제 문제 코드 : https://gintrie.tistory.com/32
참고 블로그 : https://www.crocus.co.kr/648
예로 1 2 3 4 5라는 배열 arr 이 있다. ( 배열 인덱스가 1부터 시작한다고 가정 )
2번째 부터 5번째의 구간의 수를 더한다.
arr[2] + arr[3] + arr[4] + arr[5]를 구하는 쿼리가 있다.
가장 쉽게 푸는 방법은 모든 경우의 수에서 모든 배열의 수를 다 더해주는 것이다.
지금 당장은 2 + 3 + 4 + 5로 간단히 해결할 수 있지만, arr[3]을 6으로 변경하고 arr[2] + arr[3] + arr[4] + arr[5]를 구하는 쿼리가 또 오면 다시 arr[2] + arr[3] + arr[4] + arr[5]를 해야한다.
두 번째로, A, B 구간의 합을 구할 때, 배열의 모든 합을 저장하는 sum[] 함수를 이용해서 sum[idx] = sum[ idx - 1] + arr[idx]를 만들어 sum [B] - sum[A-1]을 이용해서 구하는 방법이다.
하지만 중간에 값 arr[3] 이 변경 될 경우 변경된 인덱스부터 마지막 인덱스까지 특정 값을 더해줘야한다.
위 두가지 방법을 사용할 경우 한번 실행하는데 걸리는 시간은
수를 바꾸는데 O(1), 수를 더하는데 O(N)이 걸리고, M번 수행하면 O(MN + M) -> O(MN)의 시간이 걸린다.
세그먼트 트리를 배우는 이유
세그먼트 트리에서는 수를 바꾸는 과정과 수를 더하는 과정이
수를 바꾸는 과정 = O(logN)
수를 더하는 과정 = O(logN)으로 변하게 된다.
예를 들어 M = 100, N = 2^20이라고 친다.
O(MN) = 100*2^20 = 10,000,000 ( 대략 )
O(MlogN) = 100 *20 = 2000으로 데이터와 반복 수행이 잦아 질수록 시간 복잡도 차이가 커진다.
( 세그먼트 트리는 대부분 완전 이진 트리이다.)
세그먼트 알아보기
백준 알고리즘에서의 예제를 그림으로 표현해 봤습니다.
빨간색 칠해진 리프 노드가 실제 처음 받아온 데이터를 의미합니다.
리프 노드가 모인 부모 노드 x-y로 되어 있는것은 x부터 y까지의 합의 범위를 나타 내는 것 입니다.
즉, arr[1] = 1, arr[2] = 2 라면, Tree의 [ 1 노드에는 1 ], [ 2 노드에서 2 ], [ 1-2 노드에는 3 ]이 들어갑니다.
결국 ROOT 노드 1-5는 처음 받아온 데이터의 총 합을 의미합니다.
세그먼트 트리 전체 크기 구하기
처음 세그먼트 트리를 생성하는 함수 (init함수)를 만들기 전에 우선 세그먼트 트리의 전체 크기 (배열사이즈)를 알아야 합니다.
세그먼트 트리는 완전 이진 트리이므로 자식이 생길때 마다 노드가 2개씩 증가하는 것을 알 수 있습니다.
세그먼트 트리에서 N = 5 라고 하는 말은 리프 노드의 갯수가 위의 그림과 같이 5개가 있다는 것을 알 수 있습니다.
먼저 트리의 전체 크기를 구하려면 트리의 높이(h)를 먼저 구해야 합니다.
완전 이진 트리의 특성상, ( h >= 1) , 2^(h-1) < N <= 2^h 이 성립하므로, 모든 항에 log2를 해주면
h-1 < log2N <= h 이기 때문에 올림을 해주면 높이 h 값을 구할 수 있습니다.
즉, N = 5 일 경우 h = 3입니다.
int h = (int) Math.ceil(Math.log(N) / Math.log(2)) // log2(N) int segmentSize = (int) Math.pow(2, h + 1);
h = 3 일 경우의 포화 이진 트리의 크기를 트리의 전체 크기로 하도록 합니다.
segmentSize를 구할 경우 등비 수열의 합을 이용해서 구할 수 있습니다.
트리의 노드의 갯수는 첫째 항 a = 1이고, 공비 r= 2인 등비 수열입니다.
1 + 2 + 4 + 8 + ... + 2^n ( n = 1부터 시작 )
즉, 높이가 h 인 포화 이진 트리의 전체 크기는 등비수열의 합 공식에 따라
(2 ^ (h+1) ) - 1로 구할 수 있지만, 트리 노드의 인덱스를 1부터 시작 하기 때문에 전체 트리 사이즈를 2^(h+1)로 초기화 합니다.
( 말이 복잡해 졌으나, 세그먼트 트리 노드의 인덱스가 0부터 시작하면, 이해하는데 헤깔릴 수 있으니 tree[0]은 버리고 tree[1]부터 값을 쌓기 시작 )
세그먼트 트리 Node에 값이 쌓이는 원리
세그먼트 트리를 형성할 때 ROOT 노드를 1로 생각한다.
ROOT 노드의 왼쪽 노드는 2번 노드가 되고,
ROOT 노드의 오른쪽 노드는 3번 노드가 됩니다.
그리고 2번 노드의 왼쪽 노드는 4번 오른쪽 노드는 5번이 될 것입니다.
또한 3번 노드의 왼쪽 노드는 6번 오른쪽 노드는 7번이 될 것입니다.
그림에서처럼 배열로 트리를 만들게 된다.
그 이유는 세그먼트 트리는 Full binary tree에 가깝기에 배열에 모든 값들이 꽉꽉차서 올 가능성이 매우 높기 때문에 포인터보다는 배열을 이용하게 된다.
현재 노드의 번호가 node 일경우
- 노드의 왼쪽 자식의 배열 번호 : node * 2
- 노드의 오른쪽 자식의 배열 번호 : (node * 2) +1
세그먼트 트리 만드는 방법
사용될 배열
- tree 배열 : 새그먼트 트리가 만들어지는 배열
- nums 배열 : 처음에 입력받아 생성된 배열
1. 초기화 과정 (init)
초기화란 트리를 생성하는 과정이다. ( 가장 초기의 트리 )
초기화가 끝나면 구간 합 트리가 형성 된다.
num[5] = {1, 2, 3, 4, 5}일 경우, 다음과 같은 구간 합 세그먼트 트리가 생성된다.
private static long init (int[] nums, long[] tree, int node, int start, int end) { // 리프 노드 if(start == end) return tree[node] = nums[start]; // 중간을 기준으로 좌 / 우 노드로 이동 int mid = (start + end) / 2; // 좌 / 우 노드로 이동하고 // 연산이 끝나고 좌 노드 + 우 노드를 반환하여 구간의 합을 구한다. return tree[node] = init(nums, tree, node * 2, start, mid) + init(nums, tree, node*2 + 1, mid + 1, end); }
init 함수 해석
private static long init (int[] nums, long[] tree, int node, int start, int end)
parameter : 순서대로 ( 입력 배열, tree배열, 노드번호, 노드의 시작 번호, 노드의 끝 번호 )
if(start == end) return tree[node] = nums[start];
start == end 일때 tree[node] = nums[start]
start == end 인 케이스에 대해서 알아보자.
start != end 일 경우
return tree[node] = init(nums, tree, node * 2, start, mid) + init(nums, tree, node*2 + 1, mid + 1, end);
return tree[node] = init(a) + init(b)로 들어가게 됩니다.
이때, a, b의 인자를 확인하면, node * 2, (node * 2) +1이 전달되게 됩니다.
위 그림에서 설명한 것과 같이 이 과정은 노드에서 왼쪽, 오른쪽 노드로 분리되는 과정입니다.
왼쪽 자식은 start ~ mid
오른쪽 자식은 mid + 1 ~ end 로 보냅니다.
예시
nums[5] = {1, 2, 3, 4, 5} 라고 가정.
제일 첫 init()은
init(nums, tree, 1, 1, 5) 을 호출하게되고,
왼쪽, 오른쪽 차례대로
왼쪽 : init ( nums, tree, 1 * 2, 1, 3 ) // init(nums, tree, node * 2, start, end)
오른쪽 : init (nums, tree, (1 * 2) +1, 4, 5) // init(nums, tree, node * 2 + 1, start, end)
2. 갱신 과정 (update)
특정 인덱스의 값을 변경하고자 할때 update를 사용한다.
private static void update (long[] tree, int node, int start, int end, int index, long diff) { // index 가 start와 end 사이에 있지 않다면 계산할 필요가 없다. if (!(start <= index && index <= end)) return; tree[node] += diff; if(start != end) { int mid = (start+end)/2; update(tree, node*2, start, mid, index, diff); update(tree, node*2 + 1, mid + 1, end, index, diff); } }
기존의 nums 배열에서는 nums[3] 의 값을, 위의 세그먼트 트리 기준으로는 5번째 노드의 값을 3에서 6으로 변경합니다.
Node는 1번부터 시작하고, start는 1, end는 5, index는 3, diff는 6 -3 을한 3입니다.
update(tree, 1, 1, 5, 3, 3);
if(!(start<=index && index<=end)) return;
1<= 3 <= 5 이므로 index 가 start와 end 사이에 포함되므로 return 되지 않고 넘어간다.
tree[node] += diff; if(start != end) { int mid = (start+end)/2; update(tree, node*2, start, mid, index, diff); update(tree, node*2 + 1, mid + 1, end, index, diff); }
tree[node] += diff;
tree[1] += 3;
리프 노드에서 하나의 원소가 3->6으로 3의 값의 변화가 있었으므로, 전체 합도 3이 증가한다.
start != end 이니
좌 우 노드로 이동한다.
- update(tree, 1*2, 1, 3, 3, 3); // 좌측 노드
- update(tree, 1*2 + 1, 4, 5, 3, 3); // 우측 노드
이렇게 이동하다보면 다음과 같이 동작한다.
3. 합 과정 (sum)
합을 구한다는 말은 부분의 합을 구하는 방법을 의미한다.
만약, 위의 트리에서 2~4까지의 구간의 합을 구하고 싶다면 어떻게 될까??
- 2, 3번 treeNode에서 해당 값을 바로 얻을 수 없기 때문에, 9번 트리 노드, 5번 트리 노드와 6번 트리 노드의 값을 더해야 한다.
합을 구하는 과정은 4가지 경우로 나눌 수 있다.
여기서 left, right는 합을 구하고자 하는 구간의 시작과 끝점
1. [left, right]와 [start, end]가 겹치지 않는 경우
-> 구간 합을 구하고자 하는 범위와 상관이 없다.
* if(left > end || right < start)
2. [left, right]가 [start, end]를 완전히 포함하는 경우
-> 구하고자 하는 구간 합 구간에 포함되는 경우
* if(left<=start && end < right)
3. [start, end]가 [left, right]를 완전히 포함하는 경우
-> 구하고자 하는 구간 합 범위보다는 크게 있지만, 그 내부에 구하고자 하는 구간 합 범위가 있는 경우
* return sum(tree, node*2, start, mid, left, right) + sum(tree, node*2 + 1, mid + 1, end, left, right)
4. [left, right]와 [start, end]가 겹쳐져 있는 경우 (1,2,3을 제외한 나머지 경우)
-> left<=start<=right<=end 같은 방식인 경우
* return sum(tree, node*2, start, mid, left, right) + sum(tree, node*2 + 1, mid + 1, end, left, right)
private static long sum(long[] tree, int node, int start, int end, int left, int right) { if(left > end || right < start) return 0; if(left <= start && end <= right) return tree[node]; int mid = (start+end)/2; return sum(tree, node * 2, start, mid, left, right) + sum(tree, node*2 + 1, mid+1, end, left, right); }
예제 문제로 나와있는 2 ~ 5까지의 구간의 합을 구해보기로 한다.
* 즉, left : 2, right : 5
루트 노드부터 탐색을 시작한다.
left : 2, end : 5, right : 5, start : 1
if(left > end || right < start) return 0;
조건에 만족하지 않기 때문에 다음으로 넘어간다.
if(left <= start && end <= right) return tree[node];
if(2 <= 1 && 5 <= 5) 이니 조건을 만족하지 않아 다음 단계로 넘어 간다.
int mid = (start+end)/2; return sum(tree, node * 2, start, mid, left, right) + sum(tree, node*2 + 1, mid+1, end, left, right);
2번 노드 (nums 1~3번 배열의 구간합을 가진 노드)와 3번 노드(nums 4~5번 배열의 구간합을 가진 노드)를 탐색한다.
3번 노드부터 보면 left : 2, right: 5, start : 4, end : 5
if(left > end || right < start) return 0;
if( 2 > 5 || 5 < 4) 조건을 만족시키지 않으므로 다음으로 넘어간다.
if(left <= start && end <= right) return tree[node];
if( 2<= 4 && 5<=5) 조건을 충족시키므로 tree[3] 의 값을 반환한다.
2번 노드를 보면 left : 2 right: 5 start : 1, end : 3
if(left > end || right < start) return 0;
if( 2 > 3 || 5 < 1 ) 조건을 만족시키지 않으므로 다음으로 넘어간다.
if(left <= start && end <= right) return tree[node];
if( 2 <= 1 && 3 < = 5) 조건을 만족하지 않으므로 다음으로 넘어간다.
int mid = (start+end)/2; return sum(tree, node * 2, start, mid, left, right) + sum(tree, node*2 + 1, mid+1, end, left, right);
2 ~ 5번 ( 배열의 시작이 0부터 시작이라 그림에서 -1된 값으로 봐야함) 구간의 합을 다음과 같이 구할 수 있다.
'IT 발자취... > 알고리즘' 카테고리의 다른 글
[자료구조] 완전 이진 트리를 배열로 만들 경우 크기 계산 (0) 2019.08.03 [알고리즘] 구간 합 구하기 ( 백준 2042 ) (0) 2019.08.03 [알고리즘] 다익스트라 알고리즘 (최단 경로 알고리즘) - 이론편 (0) 2019.07.24 [알고리즘] 큐빙 (0) 2018.12.12 [TDD] 피보나치수열 (0) 2018.12.09 댓글