알고리즘/알고리즘

[Alogorithm] 세그먼트 트리( Segment Tree ) 구간합 구하기 ( + BOJ2042 ) With Python,JAVA

IT록흐 2023. 6. 26. 10:01
반응형
 

[Alogorithm] 세그먼트 트리 ( Segment Tree )

프로그래밍 분야에서는 시간-공간 트레이드-오프(Trade-off) 현상이 자주 일어난다. 처리시간을 줄이면 처리공간이 늘어나고 처리공간이 줄어들면 처리시간이 늘어나는 현상이다. 세그먼트 트리(S

lordofkangs.tistory.com

 

 

지난 포스팅에서 세그먼트 트리를 알아보았다. 

 

세그먼트 트리는 수열의 변경이 잦은 경우 구간합을 빠른속도로 구할 수 있는 알고리즘이다. 세그먼트 트리는 재귀호출로 구현되며, 파라미터로 구간과 구간합이 저장된 노드의 위치를 받는다. 이번 포스팅에서는 원하는 구간의 합을 구하는 로직과 세그먼트 트리를 UPDATE하는 로직을 알아보겠다.

 

 

원하는 구간의 합 구하기 

 

 

 

 

세그먼트 트리는 각 노드에 구간의 합이 들어가 있다. 수열에 12개가 저장되어 있다고 가정하면 각 노드의 구간정보는 아래와 같다.

 

1번 노드 : 전구간

2번 노드 : 1번 - 6번 

3번 노드 : 7번- 12번

4번 노드 : 1번 - 3번

.

.

.

 

만약 5번- 11번 까지의 구간합을 구하고 싶다고 가정해보자. 1-4번 구간합을 가진 트리는 제외한다. 12번 구간합을 가진 트리도 제외한다. 5번-11번 사이의 구간에 존재하는 트리는 구간합에 포함한다.

 

def sum_tree(node,start,end,fl,fr) :
  if fl > end or fr < start : return 0 #탐색구간 밖에 있는 경우
  if fl <= start and fr >= end : return tree[node] # 탐색구간 안에 있는 경우
  #탐색구간에 걸쳐있는 경우
  mid = ( start + end ) // 2
  return sum_tree(2*node,start,mid,fl,fr) + sum_tree(2*node+1,mid+1,end,fl,fr)

 

함수의 파라미터를 보자. start,end는 구간이고 node는 [ start-end ] 구간합이 저장된 인덱스이다. 그리고 fl,fr은 탐색구간이다. 탐색구간 안에 start,end가 포함되어 있는지 여부를 조건문으로 판단한다. 구간합이 탐색구간 밖에 있다면 제외하고 안에 있으면 포함한다. 그리고 걸쳐있다며 재귀호출로 분할한다.

 

그렇게 탐색구간 안 쪽에 있는 구간합만 결과를 return 하여 합을 구할 수 있다. 

 

 

세그먼트 트리 UPDATE 하기

 

만약 수열이 변경된다면 세그먼트 트리도 변경되어야 한다. 수열의 idx번째 수가 변경되었다면 idx를 구간합으로 포함하고 있는 세그먼트 트리의 노드는 모두 변경되어야 한다. 5번째 수가 2에서 5로 변경되었다면 5번째 구간을 포함하고 있는 각 노드는 구간합에 +3을 해야 한다. 

 

def update(node,start,end,idx,diff) :
  if start > idx or end < idx : return  # idx를 포함하지 않은 구간 제외
  # idx를 포함하는 구간인 경우
  tree[node] += diff # 차이(diff) 더하기
  if start != end : # 리프노드가 아닌 경우, 탐색
    mid = (start + end) // 2
    update(2*node,start,mid,idx,diff)
    update(2*node+1,mid+1,end,idx,diff)

 

변경된 수의 위치를 idx라고 한다면 idx를 포함하는 구간은 모두 차이(diff)만큼 구간합에 더해주어야 한다. 그리고 리프노드가 아닌 경우, 분할하여 자식노드도 변경사항을 UPDATE할 수 있도록 한다.

 

 

그럼 이를 문제로 풀어보자. 

 

문제풀이

 

2042번: 구간 합 구하기

첫째 줄에 수의 개수 N(1 ≤ N ≤ 1,000,000)과 M(1 ≤ M ≤ 10,000), K(1 ≤ K ≤ 10,000) 가 주어진다. M은 수의 변경이 일어나는 횟수이고, K는 구간의 합을 구하는 횟수이다. 그리고 둘째 줄부터 N+1번째 줄

www.acmicpc.net

 

코드

#BOJ2042 구간 합 구하기
import sys
input = sys.stdin.readline

n,m,k = map(int,input().split())
arr = [0]*(n+1) # 수열
for i in range(1,n+1) :
  arr[i] = int(input())
tree = [0]*(4*n) # 세그먼트 트리

#세그먼트 트리 초기화
def init(node,start,end) :
  if start == end : #리프노드인 경우
    tree[node] = arr[start]
  else :  #리프노드가 아닌 경우
    mid = ( start + end ) // 2
    tree[node] = init(2*node,start,mid) + init(2*node+1,mid+1,end)
  return tree[node]

#세그먼트 트리 구간합 탐색하기
def sum_tree(node,start,end,fl,fr) :
  if fl > end or fr < start : return 0
  if fl <= start and fr >= end : return tree[node]
  mid = ( start + end ) // 2
  return sum_tree(2*node,start,mid,fl,fr) + sum_tree(2*node+1,mid+1,end,fl,fr)

#세그먼트 트리 UPDATE 하기
def update(node,start,end,idx,diff) :
  if start > idx or end < idx : return 
  tree[node] += diff
  if start != end :
    mid = (start + end) // 2
    update(2*node,start,mid,idx,diff)
    update(2*node+1,mid+1,end,idx,diff)

#문제풀이 시작
init(1,0,len(arr)-1) # 세그먼트 트리 초기화
for _ in range(m+k) :
  a,b,c = map(int,input().split())
  # 수정
  if a == 1 :
    diff = c - arr[b] # diff, 변경정도
    arr[b] = c # 수열에 변경내용 반영
    update(1,0,len(arr)-1,b,diff)  # 세그먼트리 변경내용 반영
  #구간합출력
  elif a == 2 :
    print(sum_tree(1,0,len(arr)-1,b,c)) # 구간합 탐색 및 출력

 

JAVA

 

JAVA는 개발자가 자료형을 제어해야 하므로 문제가 더 까다로웠다. 수열과 tree를 long타입 배열로 만들어야 하고 a=1일때, c도 long타입으로 받아야 한다.

 

import java.io.*;
import java.util.StringTokenizer;

//BOJ2042 구간 합 구하기
public class Main {
    static long[] arr;
    static long[] tree;
    static StringTokenizer st;

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out));
        StringBuilder sb = new StringBuilder();

        st = new StringTokenizer(br.readLine());

        int n = Integer.parseInt(st.nextToken());
        int m = Integer.parseInt(st.nextToken());
        int k = Integer.parseInt(st.nextToken());
        arr = new long[n+1];
        tree = new long[4*n+1];

        for(int i =1;i<n+1;i++){
            st = new StringTokenizer(br.readLine());
            arr[i] = Long.parseLong(st.nextToken());
        }

        init_tree(1,n,1);
        int z = m+k;
        while(z-- > 0){
            st = new StringTokenizer(br.readLine());
            int a = Integer.parseInt(st.nextToken());
            int b = Integer.parseInt(st.nextToken());
            long c = Long.parseLong(st.nextToken());

            //수정
            if(a==1){
                long diff = c - arr[b];
                arr[b] = c;
                update_tree(1,n,1,b,diff);
            }
            //구간합
            else if(a==2){
                sb.append(sum_tree(1,n,1,b,(int)c)).append("\n");
            }
        }

        bw.write(sb.toString());
        bw.flush();
        bw.close();
        br.close();
    }
    public static long init_tree(int start,int end,int node){
        if( start == end ){
            tree[node] = arr[start];
        }
        else {
            int mid = (start + end) / 2;
            tree[node] = init_tree(start, mid, 2 * node) + init_tree(mid + 1, end, 2 * node + 1);
        }
        return tree[node];
    }

    public static void update_tree(int start,int end,int node,int idx,long diff){
        if(idx < start || end < idx) return;
        tree[node] += diff;
        if ( start != end ){
            int mid = (start+end)/2;
            update_tree(start,mid,2*node,idx,diff);
            update_tree(mid+1,end,2*node+1,idx,diff);
        }
    }

    public static long sum_tree(int start,int end, int node,int fs, int fe){
        if ( fe < start || end < fs ) return 0;
        if ( fs <= start && end <= fe ) return tree[node];

        int mid = (start+end)/2;
        return sum_tree(start,mid,2*node,fs,fe) + sum_tree(mid+1,end,2*node+1,fs,fe);
    }
}

 

반응형