알고리즘 공부/C++

[P3] 백준 9345번 디지털 비디오 디스크(DVDs) C++ 세그먼트 트리

마달랭 2025. 3. 28. 16:55

리뷰

 

https://www.acmicpc.net/problem/9345

누적합, 최대, 최소 세그먼트 트리 모두 사용하여 AC를 받았다.

 

 

전역 변수

  • N : DVD수의 최대값을 저장할 상수 변수
  • t : 테스트 케이스의 개수를 저장할 변수
  • n : DVD의 개수를 저장할 변수
  • k : 쿼리의 개수를 저장할 변수
  • lst : DVD의 초기 위치를 저장할 배열
  • presum : DVD의 초기 위치를 기준으로 누적합을 저장할 배열
  • T : 세그먼트 트리의 누적합 SUM, 최대값 MAX, 최소값 MIN을 정의할 구조체
  • tree : T타입의 세그먼트 트리를 요소를 저장할 배열

 

함수

1. build

void build(int node, int s, int e) {
	if (s == e) tree[node] = { lst[s], lst[s], lst[s] };
	else {
		int mid = (s + e) / 2;
		build(node * 2, s, mid);
		build(node * 2 + 1, mid + 1, e);
		tree[node].SUM = tree[node * 2].SUM + tree[node * 2 + 1].SUM;
		tree[node].MAX = max(tree[node * 2].MAX, tree[node * 2 + 1].MAX);
		tree[node].MIN = min(tree[node * 2].MIN, tree[node * 2 + 1].MIN);
	}
}

 

세그먼트 트리 초기화를 위한 함수

 

2. update

void update(int node, int s, int e, int idx, int val) {
	if (s == e) tree[node] = { val, val, val };
	else {
		int mid = (s + e) / 2;
		if (idx <= mid) update(node * 2, s, mid, idx, val);
		else update(node * 2 + 1, mid + 1, e, idx, val);
		tree[node].SUM = tree[node * 2].SUM + tree[node * 2 + 1].SUM;
		tree[node].MAX = max(tree[node * 2].MAX, tree[node * 2 + 1].MAX);
		tree[node].MIN = min(tree[node * 2].MIN, tree[node * 2 + 1].MIN);
	}
}

 

세그먼트 트리 업데이트를 위한 함수

 

3. sumQuery

ll sumQuery(int node, int s, int e, int L, int R) {
	if (R < s || e < L) return 0;
	if (L <= s && e <= R) return tree[node].SUM;
	int mid = (s + e) / 2;
	ll left = sumQuery(node * 2, s, mid, L, R);
	ll right = sumQuery(node * 2 + 1, mid + 1, e, L, R);
	return left + right;
}

 

세그먼트 트리의 구간 누적합을 구하기 위한 함수

 

4. maxQuery

int maxQuery(int node, int s, int e, int L, int R) {
	if (R < s || e < L) return -1;
	if (L <= s && e <= R) return tree[node].MAX;
	int mid = (s + e) / 2;
	int left = maxQuery(node * 2, s, mid, L, R);
	int right = maxQuery(node * 2 + 1, mid + 1, e, L, R);
	return max(left, right);
}

 

세그먼트 트리의 최대값을 구하기 위한 함수

 

5. minQuery

int minQuery(int node, int s, int e, int L, int R) {
	if (R < s || e < L) return 1e5 + 1;
	if (L <= s && e <= R) return tree[node].MIN;
	int mid = (s + e) / 2;
	int left = minQuery(node * 2, s, mid, L, R);
	int right = minQuery(node * 2 + 1, mid + 1, e, L, R);
	return min(left, right);
}

 

세그먼트 트리의 최소값을 구하기 위한 함수

 

6. getVal

int getVal(int node, int s, int e, int idx) {
	if (s == e) return tree[node].SUM;
	int mid = (s + e) / 2;
	if (idx <= mid) return getVal(node * 2, s, mid, idx);
	return getVal(node * 2 + 1, mid + 1, e, idx);
}

 

세그먼트 트리의 노드 인덱스에 저장된 값을 구하기 위한 함수

 

문제풀이

  1. t값을 입력 받고, t번의 테스트 케이스를 수행해 준다.
  2. 매 테스트 케이스 마다 n, k에 값을 입력 받아준다.
  3. 1~n - 1번째 요소를 순회하며 lst배열에 자기 자신을 값으로 입력 받아주고, presum은 이전까지의 누적합 + 현재 요소값을 저장해 준다.
  4. build함수를 통해 세그먼트 트리 초기화를 진행해 준다.
  5. k개의 쿼리를 수행해 주고 매 쿼리마다 변수 op, a, b에 값을 입력 받아준다.
  6. op가 1일 경우 변수 SUM, MAX, MIN에 sumQuery, maxQuery, minQuery함수의 구간 a, b의 값을 저장해 준다.
  7. 변수 ragne에는 presum[b] - presum[a - 1]의 값을 저장해 준다, a - 1이 음수라면 0으로 적용해 준다.
  8. range와 SUM에 같고, MIN과 a가 같고, MAX와 b가 같다면 YES를 출력해 준다. 아니라면 NO를 출력한다.
  9. op가 0일 경우 변수 A, B에 getVal함수를 통해 세그먼트 트리의 a, b위치에 있는 값을 저장해 준다.
  10. update함수를 통해 a인덱스엔 값 B를, b인덱스엔 값 A로 세그먼트 트리를 업데이트해 준다.

 

트러블 슈팅

  1. 초기에 계속 출력 초과를 받았는데 lst, presum배열을 초기화 할때 범위를 1~n으로 한 것이 원인이었다.
  2. for문의 범위를 1~n - 1로 변경해 주니 출력 초과가 해결되었다.
  3. 83%에서 Fail을 받았다, 구간 합을 통해서만 YES, NO를 검증하는 것이 원인이었다.
  4. MAX, MIN값을 구해 구간이 정확히 연속되는지를 검증해 주어 AC를 받았다.

 

참고 사항

도움이 된 테스트 케이스

1
5 3
0 0 3
0 1 4
1 1 3

 

출력값은 NO가 나와야 한다, 구간합만 처리해 준다면 YES가 출력될 것이다.

 

정답 코드

#include<iostream>
#define ll long long
using namespace std;

const int N = 100000;
int t, n, k;
int lst[N];
ll presum[N * 4];
struct T {
	ll SUM;
	int MAX, MIN;
};
T tree[N * 4];

void build(int node, int s, int e) {
	if (s == e) tree[node] = { lst[s], lst[s], lst[s] };
	else {
		int mid = (s + e) / 2;
		build(node * 2, s, mid);
		build(node * 2 + 1, mid + 1, e);
		tree[node].SUM = tree[node * 2].SUM + tree[node * 2 + 1].SUM;
		tree[node].MAX = max(tree[node * 2].MAX, tree[node * 2 + 1].MAX);
		tree[node].MIN = min(tree[node * 2].MIN, tree[node * 2 + 1].MIN);
	}
}

void update(int node, int s, int e, int idx, int val) {
	if (s == e) tree[node] = { val, val, val };
	else {
		int mid = (s + e) / 2;
		if (idx <= mid) update(node * 2, s, mid, idx, val);
		else update(node * 2 + 1, mid + 1, e, idx, val);
		tree[node].SUM = tree[node * 2].SUM + tree[node * 2 + 1].SUM;
		tree[node].MAX = max(tree[node * 2].MAX, tree[node * 2 + 1].MAX);
		tree[node].MIN = min(tree[node * 2].MIN, tree[node * 2 + 1].MIN);
	}
}

ll sumQuery(int node, int s, int e, int L, int R) {
	if (R < s || e < L) return 0;
	if (L <= s && e <= R) return tree[node].SUM;
	int mid = (s + e) / 2;
	ll left = sumQuery(node * 2, s, mid, L, R);
	ll right = sumQuery(node * 2 + 1, mid + 1, e, L, R);
	return left + right;
}

int maxQuery(int node, int s, int e, int L, int R) {
	if (R < s || e < L) return -1;
	if (L <= s && e <= R) return tree[node].MAX;
	int mid = (s + e) / 2;
	int left = maxQuery(node * 2, s, mid, L, R);
	int right = maxQuery(node * 2 + 1, mid + 1, e, L, R);
	return max(left, right);
}

int minQuery(int node, int s, int e, int L, int R) {
	if (R < s || e < L) return 1e5 + 1;
	if (L <= s && e <= R) return tree[node].MIN;
	int mid = (s + e) / 2;
	int left = minQuery(node * 2, s, mid, L, R);
	int right = minQuery(node * 2 + 1, mid + 1, e, L, R);
	return min(left, right);
}

int getVal(int node, int s, int e, int idx) {
	if (s == e) return tree[node].SUM;
	int mid = (s + e) / 2;
	if (idx <= mid) return getVal(node * 2, s, mid, idx);
	return getVal(node * 2 + 1, mid + 1, e, idx);
}

int main() {
	ios::sync_with_stdio(0);
	cin.tie(0);
	cout.tie(0);

	cin >> t;
	while (t--) {
		cin >> n >> k;
		for (int i = 1; i < n; ++i) {
			lst[i] = i;
			presum[i] = presum[i - 1] + lst[i];
		}
		build(1, 0, n - 1);

		while (k--) {
			int op, a, b; cin >> op >> a >> b;
			if (op) {
				ll range = presum[b] - (a > 0 ? presum[a - 1] : 0);
				ll SUM = sumQuery(1, 0, n - 1, a, b);
				int MAX = maxQuery(1, 0, n - 1, a, b);
				int MIN = minQuery(1, 0, n - 1, a, b);
				//cout << SUM << " " << MAX << " " << MIN << "\n";
				if (range == SUM && MIN == a && MAX == b) cout << "YES\n";
				else cout << "NO\n";
			}
			else {
				int A = getVal(1, 0, n - 1, a);
				int B = getVal(1, 0, n - 1, b);
				update(1, 0, n - 1, a, B);
				update(1, 0, n - 1, b, A);
			}
		}
	}
}
728x90