알고리즘 공부/C++

[P4] 백준 11962번 Counting Haybales C++ 세그먼트 트리, 느리게 갱신되는 세그먼트 트리

마달랭 2025. 3. 31. 10:45

리뷰

 

https://www.acmicpc.net/problem/11962

구간 업데이트 + 최소값 및 구간합 세그먼트 트리 문제

 

 

전역 변수

  • N : 배열 크기의 최대값을 저장할 상수 변수
  • n : 배열의 크기를 저장할 변수
  • q : 쿼리의 개수를 저장할 변수
  • lst : 배열의 초기값을 저장할 배열
  • MS : 세그먼트 트리의 최소값 M, 구간 합 S를 정의할 구조체
  • tree : 세그먼트 트리 정보를 저장할 MS타입 배열
  • lazy : 세그먼트 트리의 느린 업데이트 처리를 위한 배열

 

함수

1. build

void build(int node, int s, int e) {
	if (s == e) tree[node] = { lst[s], lst[s] };
	else {
		int mid = (s + e) / 2;
		build(node * 2, s, mid);
		build(node * 2 + 1, mid + 1, e);
		tree[node].S = tree[node * 2].S + tree[node * 2 + 1].S;
		tree[node].M = min(tree[node * 2].M, tree[node * 2 + 1].M);
	}
}

 

세그먼트 트리의 구간 최소 및 구간 합 정보를 초기화할 함수

 

2. propagate

void propagate(int node, int s, int e) {
	if (lazy[node]) {
		tree[node].S += (e - s + 1) * lazy[node];
		tree[node].M += lazy[node];
		if (s != e) {
			lazy[node * 2] += lazy[node];
			lazy[node * 2 + 1] += lazy[node];
		}
		lazy[node] = 0;
	}
}

 

lazy에 업데이트할 것이 있다면 수행하고, 자식 노드에 전파하기 위한 함수

 

3. update

void update(int node, int s, int e, int L, int R, ll val) {
	propagate(node, s, e);
	if (R < s || e < L) return;
	if (L <= s && e <= R) {
		lazy[node] += val;
		propagate(node, s, e);
		return;
	}
	int mid = (s + e) / 2;
	update(node * 2, s, mid, L, R, val);
	update(node * 2 + 1, mid + 1, e, L, R, val);
	tree[node].S = tree[node * 2].S + tree[node * 2 + 1].S;
	tree[node].M = min(tree[node * 2].M, tree[node * 2 + 1].M);
}

 

세그먼트 트리 업데이트를 진행하기 위한 함수

 

4. minQuery

ll minQuery(int node, int s, int e, int L, int R) {
	propagate(node, s, e);
	if (R < s || e < L) return 2e9;
	if (L <= s && e <= R) return tree[node].M;
	int mid = (s + e) / 2;
	ll left = minQuery(node * 2, s, mid, L, R);
	ll right = minQuery(node * 2 + 1, mid + 1, e, L, R);
	return min(left, right);
}

 

세그먼트 트리의 구간 최소값을 구하기 위한 함수

 

5. sumQuery

ll sumQuery(int node, int s, int e, int L, int R) {
	propagate(node, s, e);
	if (R < s || e < L) return 0;
	if (L <= s && e <= R) return tree[node].S;
	int mid = (s + e) / 2;
	ll left = sumQuery(node * 2, s, mid, L, R);
	ll right = sumQuery(node * 2 + 1, mid + 1, e, L, R);
	return left + right;
}

 

세그먼트 트리의 구간합을 구하기 위한 함수

 

문제풀이

  1. n, q값을 입력 받고, lst배열에 n개의 요소를 입력 받는다.
  2. build함수를 통해 세그먼트 트리의 초기화를 진행해 준다.
  3. q개의 쿼리문을 입력 받고, 매 쿼리마다 변수 op, l, r에 값을 입력 받아준다.
  4. op가 'M'일 경우 minQuery함수를 통해 l~r범위의 최소값을 구해 출력해 준다.
  5. op가 'S'일 경우 minQuery함수를 통해 l~r범위의 구간합을 구해 출력해 준다.
  6. op가 'P'일 경우 변수 v에 값을 입력 받고, l~r범위에 v만큼 값을 더해준다.

 

트러블 슈팅

없음

 

 

참고 사항

  • 쿼리가 10만개에 val입력값이 10만으로 int범위를 벗어날 것을 염두해 long long타입으로 설정하였다.

 

정답 코드

#include<iostream>
#define ll long long
using namespace std;

const int N = 2e5 + 1;
int n, q;
int lst[N];
struct MS {
	ll M, S;
};
MS tree[N * 4];
ll lazy[N * 4];

void build(int node, int s, int e) {
	if (s == e) tree[node] = { lst[s], lst[s] };
	else {
		int mid = (s + e) / 2;
		build(node * 2, s, mid);
		build(node * 2 + 1, mid + 1, e);
		tree[node].S = tree[node * 2].S + tree[node * 2 + 1].S;
		tree[node].M = min(tree[node * 2].M, tree[node * 2 + 1].M);
	}
}

void propagate(int node, int s, int e) {
	if (lazy[node]) {
		tree[node].S += (e - s + 1) * lazy[node];
		tree[node].M += lazy[node];
		if (s != e) {
			lazy[node * 2] += lazy[node];
			lazy[node * 2 + 1] += lazy[node];
		}
		lazy[node] = 0;
	}
}

void update(int node, int s, int e, int L, int R, ll val) {
	propagate(node, s, e);
	if (R < s || e < L) return;
	if (L <= s && e <= R) {
		lazy[node] += val;
		propagate(node, s, e);
		return;
	}
	int mid = (s + e) / 2;
	update(node * 2, s, mid, L, R, val);
	update(node * 2 + 1, mid + 1, e, L, R, val);
	tree[node].S = tree[node * 2].S + tree[node * 2 + 1].S;
	tree[node].M = min(tree[node * 2].M, tree[node * 2 + 1].M);
}

ll minQuery(int node, int s, int e, int L, int R) {
	propagate(node, s, e);
	if (R < s || e < L) return 2e9;
	if (L <= s && e <= R) return tree[node].M;
	int mid = (s + e) / 2;
	ll left = minQuery(node * 2, s, mid, L, R);
	ll right = minQuery(node * 2 + 1, mid + 1, e, L, R);
	return min(left, right);
}

ll sumQuery(int node, int s, int e, int L, int R) {
	propagate(node, s, e);
	if (R < s || e < L) return 0;
	if (L <= s && e <= R) return tree[node].S;
	int mid = (s + e) / 2;
	ll left = sumQuery(node * 2, s, mid, L, R);
	ll right = sumQuery(node * 2 + 1, mid + 1, e, L, R);
	return left + right;
}

int main() {
	ios::sync_with_stdio(0);
	cin.tie(0);
	cout.tie(0);

	cin >> n >> q;
	for (int i = 1; i <= n; ++i) cin >> lst[i];
	build(1, 1, n);

	while (q--) {
		char op; int l, r;
		cin >> op >> l >> r;
		if (op == 'M') cout << minQuery(1, 1, n, l, r) << "\n";
		else if (op == 'S') cout << sumQuery(1, 1, n, l, r) << "\n";
		else {
			ll v; cin >> v;
			update(1, 1, n, l, r, v);
		}
	}
}
728x90