코딩테스트

[알고리즘] 인덱스 트리

행복한 토마토 2024. 11. 5. 01:03

인덱스 트리 : 구간 합 또는 구간 최소값/최대값을 빠르게 구할 수 있는 자료구조. 세그먼트 트리를 포함하고 있는 개념이며 좀 더 간단하다.

 

인덱스 트리는 다음의 단계를 통해 만들 수 있다. 

 

1. 인덱스 트리 생성, 초기화

2. 구간 합 계산

3. 트리 업데이트

 

예제를 통해 알아보자

구간을 구하는 문제이고, 값의 변화가 잦기 때문에 인덱스 트리로 풀기에 적절한 예제이다.

https://www.acmicpc.net/problem/2042

 

//입력
5 2 2
1
2
3
4
5
1 3 6
2 2 5
1 5 2
2 3 5

//출력
17
12

 

기본 배열의 값을 리프 노드로 두고, 부모 노드는 자식 노드의 합으로 구성한다. 이 구조 덕분에 부모 노드가 자식 구간의 합을 저장하고, 구간 합을 구하거나 값을 업데이트하는 작업이 O(log N)에 가능하게 된다. 

 

1. 인덱스 트리 생성, 초기화

 

  • 리프 노드에 각 입력 값을 채워넣고, 상위 노드로 올라가며 각 부모 노드에 자식들의 합을 저장한다.
  • 이를 통해 이후 쿼리와 업데이트 연산 시, 각 구간의 합이 필요한 부모 노드에 미리 계산된 상태가 된다.
  • for문을 통해 리프 노드부터 시작하여 부모 노드로 올라가며, 배열을 한 번 순회해 트리를 완성한다.

GPT가 그려준 그래프

 

public static void main(String[] args) throws Exception {
	// 인덱스 트리 초기화
        leafCnt = 1;
        while (leafCnt < N) {
            leafCnt *= 2;
        }
        indexedTree = new long[leafCnt * 2];

        // 리프 노드에 값 저장
        for (int i = 0; i < N; i++) {
            indexedTree[leafCnt + i] = Long.parseLong(br.readLine());
        }

        // 상위 노드 초기화
        init();
}

// 트리 초기화
    public static void init() {
        for (int i = leafCnt - 1; i > 0; i--) {
            indexedTree[i] = indexedTree[i * 2] + indexedTree[i * 2 + 1];
        }
    }

 

2. 구간 합 계산

 

  • 특정 인덱스의 값을 변경하면, 변경된 값과 이전 값의 차이를 계산한다. 그리고 영향을 받는 노드들에 차이를 더해주며 갱신한다.
  • 변경된 노드로부터 루트 노드까지 올라가며 차이만큼 상위 노드들에 반영하므로, 시간 복잡도가 O(log N)이다.

 

// 값 업데이트
    public static void update(int idx, long val) {
        long diff = val - indexedTree[idx];
        while (idx > 0) {
            indexedTree[idx] += diff;
            idx >>= 1;
        }
    }

 

3. 트리 업데이트

 

  • 주어진 구간의 시작점과 끝점에 대해, 각 인덱스가 리프 노드의 자식 노드인 경우 현재 값을 결과에 더하고 부모로 이동한다.
  • 구간을 반으로 나누며 합을 구해가고, 상위 구간으로 올라가면서 전체 구간 합을 계산하므로 O(log N)의 시간이 걸린다.

 

public static void main(String[] args) throws Exception {
    	// 쿼리 처리
        for (int i = 0; i < M + K; i++) {
            st = new StringTokenizer(br.readLine());
            int a = Integer.parseInt(st.nextToken());
            int b = Integer.parseInt(st.nextToken());
            long c = Long.parseLong(st.nextToken());

            if (a == 1) {
                update(leafCnt + b - 1, c);
            } else {
                long result = query(leafCnt + b - 1, leafCnt + (int)c - 1);
                sb.append(result).append("\n");
            }
        }

    }
    
// 구간 합 구하기
    public static long query(int start, int end) {
        long result = 0;
        while (start <= end) {
            if (start % 2 == 1) result += indexedTree[start++];
            if (end % 2 == 0) result += indexedTree[end--];
            start >>= 1;
            end >>= 1;
        }
        return result;
    }

 

 

전체 코드

import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.io.BufferedWriter;
import java.io.OutputStreamWriter;
import java.util.StringTokenizer;

public class 구간합구하기 {
    private static int N, M, K;
    private static long[] indexedTree;
    private static int leafCnt;

    public static void main(String[] args) throws Exception {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out));
        StringBuilder sb = new StringBuilder();
        StringTokenizer st = new StringTokenizer(br.readLine());

        N = Integer.parseInt(st.nextToken());
        M = Integer.parseInt(st.nextToken());
        K = Integer.parseInt(st.nextToken());

        // 인덱스 트리 초기화
        leafCnt = 1;
        while (leafCnt < N) {
            leafCnt *= 2;
        }
        indexedTree = new long[leafCnt * 2];

        // 리프 노드에 값 저장
        for (int i = 0; i < N; i++) {
            indexedTree[leafCnt + i] = Long.parseLong(br.readLine());
        }

        // 상위 노드 초기화
        init();

        // 쿼리 처리
        for (int i = 0; i < M + K; i++) {
            st = new StringTokenizer(br.readLine());
            int a = Integer.parseInt(st.nextToken());
            int b = Integer.parseInt(st.nextToken());
            long c = Long.parseLong(st.nextToken());

            if (a == 1) {
                update(leafCnt + b - 1, c);
            } else {
                long result = query(leafCnt + b - 1, leafCnt + (int)c - 1);
                sb.append(result).append("\n");
            }
        }

        bw.write(sb.toString());
        bw.flush();
        bw.close();
        br.close();
    }

    // 트리 초기화
    public static void init() {
        for (int i = leafCnt - 1; i > 0; i--) {
            indexedTree[i] = indexedTree[i * 2] + indexedTree[i * 2 + 1];
        }
    }

    // 값 업데이트
    public static void update(int idx, long val) {
        long diff = val - indexedTree[idx];
        while (idx > 0) {
            indexedTree[idx] += diff;
            idx >>= 1;
        }
    }

    // 구간 합 구하기
    public static long query(int start, int end) {
        long result = 0;
        while (start <= end) {
            if (start % 2 == 1) result += indexedTree[start++];
            if (end % 2 == 0) result += indexedTree[end--];
            start >>= 1;
            end >>= 1;
        }
        return result;
    }
}