구간 트리(segment tree)
저장된 자료들을 전처리해서 질의에 빠르게 대답할 수 있도록 하는 것이다.
예를 들어, 일차원 배열의 특정 구간에 대한 최소치를 구하는 문제에서, 주어진 배열을 구간 트리로 만들어 최소치를 찾는 연산 속도를 최소화하는 것이다. 이 구간 트리의 루트는 항상 배열 전체이다.
위 트리에서 각 노드는 각 구간에 대한 최소치 값을 저장하게 된다.
이 예제는 구간 최소 쿼리 (range minimum query, RMQ) 라고 불린다. 이런 꽉 찬 이진트리의 형태는 배열로 표현하는 것이 메모리를 더 절약할 수 있다. 노드 i
의 왼쪽 자손과 오른쪽 자손을 2i
와 2i + 1
로 표현하면, 구간 트리에 들어갈 정보를 일차원 배열로 간단하게 표현할 수 있다.
struct RMQ {
int n;
vector<int> rangeMin;
RMQ(const vector<int>& array) {
n = array.size();
rangeMin.resize(n * 4);
init(array, 0, n - 1, 1);
}
// 현재 구간을 두 개로 나눠, 두 구간의 최소치 중 더 작은 값을 선택한다.
// node 노드가 array[left, right] 배열을 표현한다.
// node를 루트로 하는 서브 트리를 초기화하고, 이 구간의 최소치를 반환
int init(const vector<int>& array, int left, int right, int node) {
if (left == right)
return rangeMin[node] = array[left];
int mid = (left + right) / 2;
int leftMin = init(array, left, mid, node * 2);
int rightMin = init(array, mid + 1, right, node * 2 + 1);
return rangeMin[node] = min(leftMin, rightMin);
}
};
주어진 배열의 길이가 n이라고 했을 때, 구간 트리 노드의 수는 n보다 큰 2의 거듭제곱에서 2를 곱해야 한다.
ex) n = 6인 경우, 가까운 거듭제곱 8의 두 배인 16을 구간트리의 크기로 해야한다. 계산이 복잡하니 절충안으로 4n의 크기를 사용한다
query
함수 : node가 표현하는 범위 [nodeLeft, nodeRight] 와 우리가 최소치를 찾기 원하는 구간 [left, right] 의 교집합의 최소 원소를 반환한다.
int query(int left, int right, int node, int nodeLeft, int nodeRight) {
// 노드가 표현하는 구간과 관련없으면 무시
if (right < nodeLeft || nodeRight < left) return INT_MAX;
// (nodeLeft, nodeRight) 안에 (left, right)가 속해있는 경우
if (left <= nodeLeft && nodeRight <= right)
return rangeMin[node];
int mid = (nodeLeft + nodeRight) / 2;
return min(query(left, right, node * 2, nodeLeft, mid),
query(left, right, node * 2 + 1, mid + 1, nodeRight));
}
int query(int left, int right) {
return query(left, right, 1, 0, n - 1);
}
update
함수 : 원래 배열의 index 위치의 값이 newValue로 바뀌었을 때, node를 루트로 하는 구간트리를 갱신하는 함수
int update(int index, int newValue, int node, int nodeLeft, int nodeRight) {
// index가 노드가 표현하는 구간과 상관없다면 무시한다.
if (index < nodeLeft || nodeRight < index)
return rangeMin[node];
if (nodeLeft == nodeRight) return rangeMin[node] = newValue;
int mid = (nodeLeft + nodeRight) / 2;
return rangeMin[node] = min(
update(index, newValue, node * 2, nodeLeft, mid),
update(index, newValue, node * 2 + 1, mid + 1, nodeRight)
);
}
int update(int index, int newValue) {
return update(index, newValue, 1, 0, n - 1);
}