TOP > データ構造
スライド区間の昇順k個の和(Priority-Sum-Structure)
説明
スライドする区間の昇順(降順) $k$ 個の総和を効率良く求めるデータ構造。
priority_queue を 2 つ持てばできる。
要素の削除がある場合は priority_queue を multiset にして直接削除しても良いが, 削除用の priority_queue を用意して削除を遅延させると定数倍が軽い実装になる。
計算量
- query $O(1)$
- insert, erase $O(\log n)$ (ならし)
- set_k $((それ以前の $k$ との差) \log n)$
実装例
- $\mathrm{MinimumSum}(k)$: 昇順 $k$ 個に指定
- $\mathrm{MaximumSum}(k)$: 降順 $k$ 個に指定
- $\mathrm{insert}(x)$: 要素 $x$ を追加する
- $\mathrm{erase}(x)$: 要素 $x$ を削除する
- $\mathrm{query}()$: 上位 $k$ 個の和(要素数が $k$ に満たないとき, 要素すべての和) を返す
- $\mathrm{set\_k}()$: $k$ を指定しなおす
- $\mathrm{get\_k}()$: $k$ を返す
- $\mathrm{size}()$: 全体の要素数を返す
template< typename T, typename Compare = less< T >, typename RCompare = greater< T > >
struct PrioritySumStructure {
size_t k;
T sum;
priority_queue< T, vector< T >, Compare > in, d_in;
priority_queue< T, vector< T >, RCompare > out, d_out;
PrioritySumStructure(int k) : k(k), sum(0) {}
void modify() {
while(in.size() - d_in.size() < k && !out.empty()) {
auto p = out.top();
out.pop();
if(!d_out.empty() && p == d_out.top()) {
d_out.pop();
} else {
sum += p;
in.emplace(p);
}
}
while(in.size() - d_in.size() > k) {
auto p = in.top();
in.pop();
if(!d_in.empty() && p == d_in.top()) {
d_in.pop();
} else {
sum -= p;
out.emplace(p);
}
}
while(!d_in.empty() && in.top() == d_in.top()) {
in.pop();
d_in.pop();
}
}
T query() const {
return sum;
}
void insert(T x) {
in.emplace(x);
sum += x;
modify();
}
void erase(T x) {
assert(size());
if(!in.empty() && in.top() == x) {
sum -= x;
in.pop();
} else if(!in.empty() && RCompare()(in.top(), x)) {
sum -= x;
d_in.emplace(x);
} else {
d_out.emplace(x);
}
modify();
}
void set_k(size_t kk) {
k = kk;
modify();
}
size_t get_k() const {
return k;
}
size_t size() const {
return in.size() + out.size() - d_in.size() - d_out.size();
}
};
template< typename T >
using MaximumSum = PrioritySumStructure< T, greater< T >, less< T > >;
template< typename T >
using MinimumSum = PrioritySumStructure< T, less< T >, greater< T > >;
検証
AtCoder Grand Contest 034 C - Tests
int main() {
int N, X;
cin >> N >> X;
int64 loss = 0;
vector< int64 > A(N), B(N), C(N);
vector< int > ord;
for(int i = 0; i < N; i++) {
cin >> A[i] >> B[i] >> C[i];
loss += B[i] * A[i];
ord.emplace_back(i);
}
sort(begin(ord), end(ord), [&](int a, int b) {
return B[a] > B[b];
});
auto get_cost = [&](int idx, int64 sum) {
int64 add = 0;
add += B[idx] * min(sum, A[idx]);
add += C[idx] * max(0LL, sum - A[idx]);
return add;
};
MaximumSum< int64 > tap(0);
for(int i = 0; i < N; i++) {
tap.insert(get_cost(i, X));
}
auto check = [&](int64 sum) {
int64 need = sum / X;
tap.set_k(need);
for(int i : ord) {
tap.erase(get_cost(i, X));
bool f = tap.query() + get_cost(i, sum % X) >= loss;
tap.insert(get_cost(i, X));
if(f) return true;
}
return false;
};
int64 ok = 1LL * N * X, ng = -1;
while(ok - ng > 1) {
auto mid = (ok + ng) / 2;
if(check(mid)) ok = mid;
else ng = mid;
}
cout << ok << endl;
}