ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • [알고리즘] 세그먼트 트리 ( Segment Tree )
    IT 발자취.../알고리즘 2019. 8. 3. 15:07

    요약 : 주어진 쿼리에 대해 빠르게 응답하기 위해 만들어진 자료구조

    예제 문제 : https://www.acmicpc.net/problem/2042 ( 구간 합 구하기 ) 

    예제 문제 코드 : https://gintrie.tistory.com/32

    참고 블로그 : https://www.crocus.co.kr/648

     

    예로  1 2 3 4 5라는 배열 arr 이 있다. ( 배열 인덱스가 1부터 시작한다고 가정 )

    2번째 부터 5번째의 구간의 수를 더한다.

    arr[2] + arr[3] + arr[4] + arr[5]를 구하는 쿼리가 있다.

     

    가장 쉽게 푸는 방법은 모든 경우의 수에서 모든 배열의 수를 다 더해주는 것이다.

    지금 당장은 2 + 3 + 4 + 5로 간단히 해결할 수 있지만, arr[3]을 6으로 변경하고 arr[2] + arr[3] + arr[4] + arr[5]를 구하는 쿼리가 또 오면 다시   arr[2] + arr[3] + arr[4] + arr[5]를 해야한다.

     

    두 번째로, A, B 구간의 합을 구할 때, 배열의 모든 합을 저장하는 sum[] 함수를 이용해서 sum[idx] = sum[ idx - 1] + arr[idx]를 만들어 sum [B] - sum[A-1]을 이용해서 구하는 방법이다.

    하지만 중간에 값 arr[3] 이 변경 될 경우 변경된 인덱스부터 마지막 인덱스까지 특정 값을 더해줘야한다.

     

    위 두가지 방법을 사용할 경우 한번 실행하는데 걸리는 시간은

    수를 바꾸는데 O(1), 수를 더하는데 O(N)이 걸리고, M번 수행하면 O(MN + M) -> O(MN)의 시간이 걸린다.

     

    세그먼트 트리를 배우는 이유

    세그먼트 트리에서는 수를 바꾸는 과정과 수를 더하는 과정이

    수를 바꾸는 과정 = O(logN)

    수를 더하는 과정 = O(logN)으로 변하게 된다.

     

    예를 들어 M = 100, N = 2^20이라고 친다.

     

    O(MN) = 100*2^20 = 10,000,000 ( 대략 )

    O(MlogN) = 100 *20 = 2000으로 데이터와 반복 수행이 잦아 질수록 시간 복잡도 차이가 커진다.

     

    ( 세그먼트 트리는 대부분 완전 이진 트리이다.)

     

    세그먼트 알아보기

     

    백준 알고리즘에서의 예제를 그림으로 표현해 봤습니다.

    N = 5, 세그먼트 트리

    빨간색 칠해진 리프 노드가 실제 처음 받아온 데이터를 의미합니다.

    리프 노드가 모인 부모 노드 x-y로 되어 있는것은 x부터 y까지의 합의 범위를 나타 내는 것 입니다.

     

    즉, arr[1] = 1, arr[2] = 2 라면, Tree의 [ 1 노드에는 1 ], [ 2 노드에서 2 ], [ 1-2 노드에는 3 ]이 들어갑니다.

     

    결국 ROOT 노드 1-5는 처음 받아온 데이터의 총 합을 의미합니다.

     

    세그먼트 트리 전체 크기 구하기

    처음 세그먼트 트리를 생성하는 함수 (init함수)를 만들기 전에 우선 세그먼트 트리의 전체 크기 (배열사이즈)를 알아야 합니다.

    N=5 일 때, 트리의 전체 크기 구하기

    세그먼트 트리는 완전 이진 트리이므로 자식이 생길때 마다 노드가 2개씩 증가하는 것을 알 수 있습니다.

    세그먼트 트리에서 N = 5 라고 하는 말은 리프 노드의 갯수가 위의 그림과 같이 5개가 있다는 것을 알 수 있습니다.

     

    먼저 트리의 전체 크기를 구하려면 트리의 높이(h)를 먼저 구해야 합니다.

    완전 이진 트리의 특성상, ( h >= 1) , 2^(h-1) < N <= 2^h 이 성립하므로, 모든 항에 log2를 해주면

    h-1 < log2N <= h 이기 때문에 올림을 해주면 높이 h 값을 구할 수 있습니다.

     

    즉, N = 5 일 경우 h = 3입니다.

    int h = (int) Math.ceil(Math.log(N) / Math.log(2)) // log2(N)
    int segmentSize = (int) Math.pow(2, h + 1);

    h = 3 일 경우의 포화 이진 트리의 크기를 트리의 전체 크기로 하도록 합니다.

    segmentSize를 구할 경우 등비 수열의 합을 이용해서 구할 수 있습니다.

    등비수열의 합 ( 네이버 검색 )

    트리의 노드의 갯수는 첫째 항 a = 1이고, 공비 r= 2인 등비 수열입니다.

    1 + 2 + 4 + 8 + ... + 2^n ( n = 1부터 시작 )

    즉, 높이가 h 인 포화 이진 트리의 전체 크기는 등비수열의 합 공식에 따라

    (2 ^ (h+1) ) - 1로 구할 수 있지만, 트리 노드의 인덱스를 1부터 시작 하기 때문에 전체 트리 사이즈를 2^(h+1)로 초기화 합니다. 

     

    ( 말이 복잡해 졌으나, 세그먼트 트리 노드의 인덱스가 0부터 시작하면, 이해하는데 헤깔릴 수 있으니 tree[0]은 버리고 tree[1]부터 값을 쌓기 시작 )

     

    세그먼트 트리 Node에 값이 쌓이는 원리

    세그먼트 트리 구성 원리

    세그먼트 트리를 형성할 때 ROOT 노드를 1로 생각한다.

    ROOT 노드왼쪽 노드는 2번 노드가 되고,

    ROOT 노드오른쪽 노드는 3번 노드가 됩니다.

    그리고 2번 노드의 왼쪽 노드는 4번 오른쪽 노드는 5번이 될 것입니다.

    또한 3번 노드의 왼쪽 노드는 6번 오른쪽 노드는 7번이 될 것입니다.

     

    그림에서처럼 배열로 트리를 만들게 된다.

    그 이유는 세그먼트 트리는 Full binary tree에 가깝기에 배열에 모든 값들이 꽉꽉차서 올 가능성이 매우 높기 때문에 포인터보다는 배열을 이용하게 된다.

     

    현재 노드의 번호가 node 일경우

    - 노드의 왼쪽 자식의 배열 번호 : node * 2

    - 노드의 오른쪽 자식의 배열 번호 : (node * 2) +1

     

    세그먼트 트리 만드는 방법

    사용될 배열

    - tree 배열   : 새그먼트 트리가 만들어지는 배열

    - nums 배열 : 처음에 입력받아 생성된 배열

     

    1. 초기화 과정 (init)

    초기화란 트리를 생성하는 과정이다. ( 가장 초기의 트리 )

    초기화가 끝나면 구간 합 트리가 형성 된다.

    N = 5, {1, 2, 3, 4, 5} 일 경우 새그먼트 트리

    num[5] = {1, 2, 3, 4, 5}일 경우, 다음과 같은 구간 합 세그먼트 트리가 생성된다.

    private static long init (int[] nums, long[] tree, int node, int start, int end) {
        // 리프 노드
        if(start == end)
            return tree[node] = nums[start];
        
        // 중간을 기준으로 좌 / 우 노드로 이동
        int mid = (start + end) / 2;
        
        // 좌 / 우 노드로 이동하고
        // 연산이 끝나고 좌 노드 + 우 노드를 반환하여 구간의 합을 구한다.
        return tree[node] = init(nums, tree, node * 2, start, mid) + init(nums, tree, node*2 + 1, mid + 1, end);
    }

    init 함수 해석

    private static long init (int[] nums, long[] tree, int node, int start, int end)

    parameter : 순서대로 ( 입력 배열, tree배열, 노드번호, 노드의 시작 번호, 노드의 끝 번호 )

    if(start == end)
            return tree[node] = nums[start];

    start == end 일때 tree[node] = nums[start]

    start == end 인 케이스에 대해서 알아보자.

     

    start != end 일 경우

    return tree[node] = init(nums, tree, node * 2, start, mid) + init(nums, tree, node*2 + 1, mid + 1, end);

    return tree[node] = init(a) + init(b)로 들어가게 됩니다.

    이때, a, b의 인자를 확인하면, node * 2, (node * 2) +1이 전달되게 됩니다.

     

    위 그림에서 설명한 것과 같이 이 과정은 노드에서 왼쪽, 오른쪽 노드로 분리되는 과정입니다.

    왼쪽 자식은 start ~ mid

    오른쪽 자식은 mid + 1 ~ end 로 보냅니다.

     

    예시

    nums[5] = {1, 2, 3, 4, 5} 라고 가정.

     

    제일 첫 init()은

    init(nums, tree, 1, 1, 5) 을 호출하게되고,

    왼쪽, 오른쪽 차례대로

    왼쪽 : init ( nums, tree, 1 * 2, 1, 3 )          // init(nums, tree, node * 2, start, end)

    오른쪽 : init (nums, tree, (1 * 2) +1, 4, 5)   // init(nums, tree, node * 2 + 1, start, end)

    세그먼트 트리 초기화 과정

    2. 갱신 과정 (update)

    특정 인덱스의 값을 변경하고자 할때 update를 사용한다.

    private static void update (long[] tree, int node, int start, int end, int index, long diff) {
        // index 가 start와 end 사이에 있지 않다면 계산할 필요가 없다.
        if (!(start <= index && index <= end))
            return;
        tree[node] += diff;
        if(start != end) {
            int mid = (start+end)/2;
            update(tree, node*2, start, mid, index, diff);
            update(tree, node*2 + 1, mid + 1, end, index, diff);
        }
    }

    값 변경하기 (update)

    기존의 nums 배열에서는 nums[3] 의 값을, 위의 세그먼트 트리 기준으로는 5번째 노드의 값을 3에서 6으로 변경합니다.

    Node는 1번부터 시작하고, start는 1, end는 5, index는 3, diff는 6 -3 을한 3입니다.

    update(tree, 1, 1, 5, 3, 3);
    if(!(start<=index && index<=end))
        return;

    1<= 3 <= 5 이므로 index 가 start와 end 사이에 포함되므로 return 되지 않고 넘어간다.

     

    tree[node] += diff;
    if(start != end) {
        int mid = (start+end)/2;
        update(tree, node*2, start, mid, index, diff);
        update(tree, node*2 + 1, mid + 1, end, index, diff);
    }

    tree[node] += diff;

    tree[1] += 3;

    리프 노드에서 하나의 원소가 3->6으로 3의 값의 변화가 있었으므로, 전체 합도 3이 증가한다.

    start != end 이니

    좌 우 노드로 이동한다.

    - update(tree, 1*2, 1, 3, 3, 3);      // 좌측 노드

    - update(tree, 1*2 + 1, 4, 5, 3, 3); // 우측 노드

     

    이렇게 이동하다보면 다음과 같이 동작한다.

    update 완료 후 모습

    3. 합 과정 (sum)

    합을 구한다는 말은 부분의 합을 구하는 방법을 의미한다.

    만약, 위의 트리에서 2~4까지의 구간의 합을 구하고 싶다면 어떻게 될까??

    - 2, 3번 treeNode에서 해당 값을 바로 얻을 수 없기 때문에, 9번 트리 노드, 5번 트리 노드와 6번 트리 노드의 값을 더해야 한다.

    합을 구하는 과정은 4가지 경우로 나눌 수 있다.

    여기서 left, right는 합을 구하고자 하는 구간의 시작과 끝점

     

    1. [left, right]와 [start, end]가 겹치지 않는 경우

    -> 구간 합을 구하고자 하는 범위와 상관이 없다.

    * if(left > end || right < start)

     

    2. [left, right]가 [start, end]를 완전히 포함하는 경우

    -> 구하고자 하는 구간 합 구간에 포함되는 경우

    * if(left<=start && end < right)

     

    3. [start, end]가 [left, right]를 완전히 포함하는 경우

    -> 구하고자 하는 구간 합 범위보다는 크게 있지만, 그 내부에 구하고자 하는 구간 합 범위가 있는 경우

     * return sum(tree, node*2, start, mid, left, right) + sum(tree, node*2 + 1, mid + 1, end, left, right)

     

    4. [left, right]와 [start, end]가 겹쳐져 있는 경우 (1,2,3을 제외한 나머지 경우)

    -> left<=start<=right<=end 같은 방식인 경우 

    * return sum(tree, node*2, start, mid, left, right) + sum(tree, node*2 + 1, mid + 1, end, left, right)

     

    private static long sum(long[] tree, int node, int start, int end, int left, int right) {
        if(left > end || right < start)
            return 0;
        if(left <= start && end <= right)
            return tree[node];
        int mid = (start+end)/2;
        return sum(tree, node * 2, start, mid, left, right) + sum(tree, node*2 + 1, mid+1, end, left, right);
    }

    nums[2]가 6으로 변경된 후 세그먼트 트리

    예제 문제로 나와있는 2 ~ 5까지의 구간의 합을 구해보기로 한다.

    * 즉, left : 2, right : 5

     

    루트 노드부터 탐색을 시작한다.

    left : 2, end : 5, right : 5, start : 1

    if(left > end || right < start)
            return 0;

    조건에 만족하지 않기 때문에 다음으로 넘어간다.

    if(left <= start && end <= right)
            return tree[node];

    if(2 <= 1 && 5 <= 5) 이니 조건을 만족하지 않아 다음 단계로 넘어 간다.

    int mid = (start+end)/2;
    return sum(tree, node * 2, start, mid, left, right) + sum(tree, node*2 + 1, mid+1, end, left, right);

    구간의 합 구하기

    2번 노드 (nums 1~3번 배열의 구간합을 가진 노드)와 3번 노드(nums 4~5번 배열의 구간합을 가진 노드)를 탐색한다.

    3번 노드부터 보면 left : 2, right: 5, start : 4, end : 5

    if(left > end || right < start)
            return 0;

    if( 2 > 5 || 5 < 4) 조건을 만족시키지 않으므로 다음으로 넘어간다.

    if(left <= start && end <= right)
            return tree[node];

    if( 2<= 4 && 5<=5) 조건을 충족시키므로 tree[3] 의 값을 반환한다.

     

    2번 노드를 보면 left : 2 right: 5 start : 1, end : 3

    if(left > end || right < start)
            return 0;

    if( 2 > 3 || 5 < 1 ) 조건을 만족시키지 않으므로 다음으로 넘어간다.

    if(left <= start && end <= right)
            return tree[node];

    if( 2 <= 1 && 3 < = 5) 조건을 만족하지 않으므로 다음으로 넘어간다.

    int mid = (start+end)/2;
    return sum(tree, node * 2, start, mid, left, right) + sum(tree, node*2 + 1, mid+1, end, left, right);

    최종 결과

    2 ~ 5번 ( 배열의 시작이 0부터 시작이라 그림에서 -1된 값으로 봐야함) 구간의 합을 다음과 같이 구할 수 있다.

    댓글

Designed by Gintire