알고리즘 공부/C++

[P5] 백준 10090번 Counting Inversions C++ 세그먼트 트리

마달랭 2025. 3. 26. 10:06

리뷰

 

https://www.acmicpc.net/problem/10090

세그먼트 트리를 활용하여 배열 내 자신의 뒷 숫자 중 자기보다 작은 숫자의 개수를 구하는 문제

 

 

전역 변수

  • N : 배열의 최대 길이를 저장할 상수 변수
  • n : 배열의 길이를 저장할 변수
  • tree : 세그먼트 트리 정보를 저장할 배열

 

함수

1. build

void build(int node, int s, int e)

 

세그먼트 트리의 초기화를 진행할 함수

  1. 매개 변수로 노드 정보 node, 탐색 구간 s, e를 전달 받는다.
  2. 기저 조건으로 리프노드에 도달한 경우 현재 노드의 값을 1로 저장해 준다.
  3. 좌, 우 자식 노드로 재귀를 진행하고, 재귀를 빠져나오며 현재 노드를 두 노드의 합으로 저장해 준다.

 

2. update

void update(int node, int s, int e, int idx)

 

세그먼트 트리의 업데이트를 진행할 함수

  1. 매개 변수로 노드 정보 node, 탐색 구간 s, e, 업데이트할 인덱스 idx를 전달 받는다.
  2. 기저 조건으로 리프노드에 도달한 경우 현재 노드의 값을 0으로 변경해 준다.
  3. 변수 mid에 탐색 구간을 반으로 나눈 값을 저장해 준다.
  4. idx가 mid이하라면 왼쪽, mid보다 크다면 오른쪽 자식 노드로 재귀를 진행해 준다.
  5. 재귀를 빠져나오며 현재 노드의 값을 좌, 우 자식 노드의 합으로 저장해 준다.

 

3. query

int query(int node, int s, int e, int L, int R)

 

세그먼트 트리의 구간 합을 구하는 함수

  1. 매개 변수로  노드 정보 node, 탐색 구간 s, e, 합을 구할 구간 L, R을 전달 받는다.
  2. 첫 번째 기저 조건으로 합을 구할 구간이 탐색 구간을 벗어난 경우 0을 리턴해 준다.
  3. 두 번째 기저 조건으로 합을 구할 구간이 탐색 구간과 일치하면 현재 노드에 저장된 값을 리턴해 준다.
  4. 좌, 우 자식 노드로 재귀를 진행하고 재귀를 빠져나온 뒤 두 값의 합을 더한 값을 리턴해 준다.

 

문제풀이

  1. n값을 입력 받고, build 함수를 통해 세그먼트 트리 정보를 초기화 해준다.
  2. long long타입 변수 ans를 0으로 초기화 해준다.
  3. n개의 숫자를 입력 받고, 매 숫자가 입력 될 때 마다 update함수를 통해 a인덱스를 0으로 변경해 준다.
  4. ans에 query함수를 통해 세그먼트 트리 내 1 ~ a까지의 누적합을 더해준다.
  5. ans에 저장된 값을 출력해 준다.

 

트러블 슈팅

없음

 

 

참고 사항

  • 1 ~ a - 1까지 구간합을 구해주어도 상관없다, 현재는 update를 먼저해서 그냥 1 ~ a까지의 구간합을 구했다.

 

정답 코드

#include<iostream>
using namespace std;

const int N = 1000001;
int n;
int tree[N * 4];

void build(int node, int s, int e) {
	if (s == e) tree[node] = 1;
	else {
		int mid = (s + e) / 2;
		build(node * 2, s, mid);
		build(node * 2 + 1, mid + 1, e);
		tree[node] = tree[node * 2] + tree[node * 2 + 1];
	}
}

void update(int node, int s, int e, int idx) {
	if (s == e) tree[node] = 0;
	else {
		int mid = (s + e) / 2;
		if (idx <= mid) update(node * 2, s, mid, idx);
		else update(node * 2 + 1, mid + 1, e, idx);
		tree[node] = tree[node * 2] + tree[node * 2 + 1];
	}
}

int query(int node, int s, int e, int L, int R) {
	if (R < s || e < L) return 0;
	if (L <= s && e <= R) return tree[node];
	int mid = (s + e) / 2;
	int left = query(node * 2, s, mid, L, R);
	int right = query(node * 2 + 1, mid + 1, e, L, R);
	return left + right;
}

int main() {
	ios::sync_with_stdio(0);
	cin.tie(0);

	cin >> n;
	build(1, 1, n);

	long long ans = 0;
	for (int i = 1; i <= n; ++i) {
		int a; cin >> a;
		update(1, 1, n, a);
		ans += query(1, 1, n, 1, a);
	}
	cout << ans;
}
728x90