Algorithm/Algorithm

[Python] 세그먼트 트리, Segment Tree 구조 이해 (파이썬 코드)

DongKeun2 2024. 10. 17. 17:58

 오늘은 구간합을 구할 때 효율적인 세그먼트 트리에 대해 알아보겠다. 1차원 배열로 표현이 가능해서 생각보다 간단하게 구현할 수 있고, 단계별로 이해하면 크게 어렵지 않게 이해할 수 있다.

 

모든 알고리즘이나 자료구조는 코드를 암기하는 것 보다 구조와 동작을 이해하는 것이 더 중요하다. 세그먼트 트리를 한 단계씩 파헤쳐보자.

 

세그먼트 트리 (Segment Tree)?

 일단 세그먼트 트리를 먼저 알아보자.

세그먼트 트리는 특정 구간의 합, 곱과 같은 연산을 포함하여 최댓값, 최솟값을 구하는 용도로 사용하는 자료구조이다. 비슷한 자료구조로 구간 합과 비슷하지만 시간복잡도도 우수하고 업데이트에도 용이하다.

 

세그먼트 트리를 이해하기 쉽게 몇 가지 특징을 먼저 알아보자. 그 다음 세그먼트 트리로 구간합을 구해보자.

특징

- 전 이진트리 (Full Binary Tree) 구조를 가진다.

- 부모 노드는 자식 노드들의 합이다.

- 1차원 배열로 표현이 가능하다.

- 구하고자 하는 배열의 모든 수는 리프 노드에 속한다.

- 트리의 총 노드의 수는 리프 노드의 수 * 2 -1이다.

- 리프 노드의 개수는 최대 2N개이다.

 

구성과 의미

 어떻게 세그먼트 트리가 구성되는 지 직접 만들어보는 것이 이해하는데 큰 도움이 되기 때문에 간단히 생김새에 대해서 그림과 함께 설명을 해보자. 자료 구조를 학습하는데 정확한 구조는 알고 가는 게 좋을 것 같다.

 

만약 [3, 4, 5, 7, 1]과 같은 배열에 대해 세그먼트 트리를 구현한다면 다음과 같은 순서를 따라 그릴 수 있다.

 

 일단 배열을 리프노드로 풀어서 둔 뒤 적절하게 나눠준다. (전 이진 트리 형태를 만족해야 한다.)

반을 나눈 뒤.홀수라서 하나가 남는다면 왼쪽에 붙여준다.

나뉜 부분에 대해서도 2개 이하로 묶일 때까지 재귀적으로 수행하면 된다.

 

현재는 5개이므로 3개 2개로 나누어

- 3개짜리는 2개 1개로 나눈 뒤 2개를 묶어주고

- 2개짜리는 묶어준다.

 

이렇게 묶는 것을 코드에서 반복문 또는 재귀로 구현하기 때문에 확실하게 이해하고 넘어가는 것이 좋다.

 

 

 

그럼 2 - 1 - 2 형태로 묶이고 2개는 더해서 쌓아주면 된다.

3과 4를 더해서 7, 7과 1을 더해서 8을 쌓아주고, 남은 5는 그대로 둔다.

양쪽 끝에서 수행하는 하는 이유는 범위를 적절히 분리하기 위해서이다.

다음에는 7과 5를 더해 12를 쌓아준 뒤, 마지막으로 12와 8을 더해준다.

이 모양이 배열 3,4,5,7,1의 세그먼트 트리 모양이다.

 

이렇게 만들면 세그먼트 트리의 중요한 특징 한 가지를 살펴볼 수 있다.

배열의 인덱스를 1부터 시작한다면, 리프노드를 제외하고 새로 만든 노드들은 배열의 어떤 구간의 합으로 표현될 수 있다.

- 20: 1~5번 인덱스 구간합

- 12: 1~3번 인덱스 구간합

-  8: 4~5번 인덱스 구간합

- 7: 1~2번 인덱스  구간합

 

그리고 이런 구간은 해당 노드의 좌 우 끝 리프노드의 인덱스와 일치한다.

예를 들어 합이 12인 노드의 경우를 살펴보면

- 왼쪽 자식을 타고 내려가면 index1

- 오른쪽 자식을 타고 내려가면 index3

그래서 1에서 3까지의 구간합을 나타낸다.

 

구간합 구하기

이렇게 세그먼트 트리를 만들었다면, 구간합의 모든 경우의 수를 더 빠르게 구할 수 있다.

 

방법은 이러하다.

루트 노드부터 시작한다. 구하고자 하는 구간이 노드가 포함하는 구간과 일치하는 지를 살펴본다.

1 . 구간이 정확하게 일치하거나 구해야 할 범위가 현재 노드를 전부 포함하고 있는 경우에는 해당 노드의 값을 리턴한다.

    (e.g. 1~5까지의 구간합을 구하고자 할 때 1~4를 포함하는 노드의 값을 리턴한다. 자손은 이 범위를 벗어날 수 없다.)

2. 정확하게 일치하지 않지만 노드의 구간이 구하고자 하는 구간을 일부라도 포함하면 자식 노드로 내려가 반복한다.

    (e.g. 2~4까지의 구간합을 구하고자 할 때 루트노드는 1~5로 더 큰 범위이다. 자식 노드로 내려가서 확인한다.)

3. 겹치는 구간이 없으면 포함시키지 않고 바로 리턴한다.

 

가벼운 예시로 살펴보자.

 

만약 1~3까지의 구간의 합을 구한다고 생각해보자. 루트 노드는 1~5까지의 합이므로 1~3사이를 포함하고 있다.

그렇다면 자식으로 내려가서 다시 확인한다.

- 왼쪽 자식의 경우, 1~3의 구간의 합을 나타내고 있고 2~4에 걸쳐있다. 뒤쪽 숫자가 겹치니 자식으로 다시 내린다.

    - 다시 왼쪽 자식의 경우 1~2의 구간합이고 

- 오른쪽 자식의 경우, 4~5의 구간의 합을 나타내고 있고 2~4에 걸쳐있으므로 자식으로 내려준다.  

    - 4번과 5번 리프 노드이므로 4번만 합에 포함시켜준다.

 

 

세그먼트 트리 파이썬 코드

코드로 세그먼트 트리를 해석하고 풀기 위해 다음과 같이 1차원 배열로 나타낸다

INDEX 0 1 2 3 4 5 6 7 8 9
ARRAY 0 20 12 8 7 5 7 1 3 4

 

 진하게 표시한 0은 부모 자식 관계를 계산하기 편하게 하기 위한 더미데이터이다. 루트의 index가 1부터 시작해야 계산이 편하다.

 

이렇게 풀어서 나타내면 어떤 노드의 인덱스가 N이라면 자식은 2N, 2N+1이다. 

반대로 인덱스 M 노드의 부모의 인덱스는 M//2로 표현된다. 

 

예를 들어, 2번 인덱스는 4번과 5번 인덱스를 자식으로 가진다. (12 = 7 + 5)

 

세그먼트 트리 구하기

일단 배열이 주어졌을 때 세그먼트 트리를 만들어보자.

루트 노드를 시작으로 좌, 우 노드의 합을 저장하는 방식으로 재귀를 통해 구현한다.

앞서 살펴봤던 세그먼트 트리의 특징들을 잘 생각하면서 코드를 본다면 쉽게 이해할 수 있을 것이다.

 

# node: 구하고자 하는 노드의 인덱스
# start: 현재 노드의 좌측 끝 인덱스 (포함하는 범위의 좌측 기준)
# end: 현재 노드의 우측 끝 인덱스 (포함하는 범위의 우측 기준)

arr # 길이 n의 1차원 숫자 배열 (e.g. [3,4,5,7,1])
tree = [0] * (4*n) # 최대범위, 리프노드의 수를 2N으로 가정, 총 노드 수 4N-1개


# topdown으로 1번 노드를 채우기 위해 재귀
def init(node, start, end):
	# 범위가 1이면 리프 노드이다. 저장 후 리턴
    if start == end:
    	tree[node] = arr[start]
    
    # 리프노드가 아니라면 좌우 자식의 합을 저장 
    else:
    	tree[node] = init(node*2, start, (start+end)//2) + init(node*2 + 1, (start+end)//2 + 1, end)
   
	return tree[node]

init(1, 0, n-1)

print(tree)
# [0, 20, 12, 8, 7, 5, 7, 1, 3, 4]

이렇게 구현하여 아까 원리를 살펴보며 그려봤던 세그먼트 트리와 같은 값을 구할 수 있다.

 

구간합 구하기

 이제 만들어 둔 세그먼트 트리로 구간 합을 구해보자.

앞서 봤던 배열은 너무 짧으니까 뒤에 수를 몇 개 추가해서

arr = [3, 4, 5, 7, 1, 6, 6, 3] 으로 세그먼트 트리를 만들면 tree = [0, 35, 19, 16, 7, 12, 7, 9, 3, 4, 5, 7, 1, 6, 6, 3]이다.

 

여기서 3번째(5)부터 7번째(6)까지 구간합을 구해보자.

arr = [3, 4, 5, 7, 1, 6, 6, 3]
tree = [0, 35, 19, 16, 7, 12, 7, 9, 3, 4, 5, 7, 1, 6, 6, 3]
left, right = 2, 6 # 구하고자 하는 범위
n = len(arr)


# 세그먼트 트리로 구간합을 구하는 함수
# node: 현재 노드

# start, end: 현재 노드가 담당하는 범위
def find_prefix_sum(node, start, end):
	# 겹치지 않는다면 종료
	if end < left or right < start:
    	return 0
        
    # 노드의 범위가 완전히 포함되고 있다면 전부 더해준다.
    if left <= start and end <= right:
    	return tree[node]
    
    # 조금이라도 걸친다면 좌우 자식으로 내려가서 반복한다.
    return find_prefix_sum(node*2, start, (start+end)//2) + find_prefix_sum(node*2 + 1, (start+end)//2 + 1, end)
    

print(find_prefix_sum(1, 0, n-1)) # 25

 

5+7+1+6+6 = 25으로 결과가 잘 나오는 것을 볼 수 있다.

 

업데이트

 이제 배열 중 어떤 수가 다른 수로 바뀌었다고 가정하자. 구간합 자료구조는 업데이트를 하려면 이후의 모든 값을 새로 구해줘야 한다. 하지만 세그먼트 트리는 해당 수가 포함된 노드들만 건드려주면 된다.

 

 방금 봤던 배열에서 3번째 5를 10으로 바꾼다고 생각하자. 어떤 구간을 구할 때 이 5를 포함했던 구간들을 모두 5만큼 더해주면 올바른 계산이 될 것이다.

arr = [3, 4, 5, 7, 1, 6, 6, 3]
n = len(arr)
tree = [0, 35, 19, 16, 7, 12, 7, 9, 3, 4, 5, 7, 1, 6, 6, 3]

target = 2 
diff = 10 - 5 

# node: 현재 노드
# start, end: 현재 노드가 포함하는 범위
# target: 업데이트 할 index
# diff: target을 포함하는 노드에 가감해줄 상수
def tree_update(node, start, end):

    # 포함하지 않는 경우에는 종료
    if target < start or end < target: return
    
    # 포함하는 경우에는 diff만큼 업데이트
    tree[node] += diff
    
    # 리프노드라면 종료
    if start == end: return

    # 좌, 우 자식도 확인
    tree_update(node*2, start, (start+end)//2)
    tree_update(node*2 + 1, (start+end)//2 + 1, end)

tree_update(1, 0, n-1)
print(tree) # [0, 40, 24, 16, 7, 17, 7, 9, 3, 4, 10, 7, 1, 6, 6, 3]

 

 

 원리를 이해하고 세그먼트 트리를 만들고, 구간합을 구하고, 업데이트를 하는 정도만 숙지하고 있자.