我在 C++ 中提供了一个基本上类似于 @ravenspoint 的想法的解决方案。它维护一个堆,每次取值最高的B节点。在这里,我使用priority_queue
而不是set
导致第一个比第二个快得多。
#include <chrono>
#include <iostream>
#include <queue>
#include <vector>
int nA, nB;
std::vector<int> A, B, sum;
std::vector<std::vector<int>> adjA, adjB;
inline std::vector<int> solve() {
struct Node {
// We store the value of the node `x` WHEN IT IS INSERTED
// Modifying the value of the node `x` (sum) won't affect this Node basically
int x, val;
Node(int x): x(x), val(sum[x] * B[x]) {}
bool operator<(const Node &t) const { return val == t.val? (B[x] < B[t.x]): (val < t.val); }
};
std::priority_queue<Node> q;
std::vector<bool> delA(nA, false), delB(nB, false);
std::vector<int> ret; ret.reserve(nB);
for (int x = 0; x < nA; ++x)
for (int y : adjA[x]) sum[y] += A[x];
for (int y = 0; y < nB; ++y) q.emplace(y);
while (!q.empty()) {
const Node node = q.top(); q.pop();
const int y = node.x;
if (sum[y] * B[y] != node.val || delB[y]) // This means this Node is obsolete
continue;
delB[y] = true;
ret.push_back(y);
for (int x : adjB[y]) {
if (delA[x]) continue;
delA[x] = true;
for (int ny : adjA[x]) {
if (delB[ny]) continue;
sum[ny] -= A[x];
// This happens at most `m` time
q.emplace(ny);
}
}
}
return ret;
}
int main() {
std::cout << "Number of nodes in type A: "; std::cin >> nA;
A.resize(nA); adjA.resize(nA);
std::cout << "Weights of nodes in type A: ";
for (int &v : A) std::cin >> v;
std::cout << "Number of nodes in type B: "; std::cin >> nB;
B.resize(nB); adjB.resize(nB); sum.resize(nB, 0);
std::cout << "Weights of nodes in type B: ";
for (int &v : B) std::cin >> v;
int m;
std::cout << "Number of edges: "; std::cin >> m;
std::cout << "Edges: " << std::endl;
for (int i = 0; i < m; ++i) {
int x, y; std::cin >> x >> y;
--x; --y;
adjA[x].push_back(y);
adjB[y].push_back(x);
}
auto st_time = std::chrono::steady_clock::now();
auto ret = solve();
auto en_time = std::chrono::steady_clock::now();
std::cout << "Answer:";
for (int v : ret) std::cout << ' ' << (v + 1);
std::cout << std::endl;
std::cout << "Took "
<< std::chrono::duration_cast<std::chrono::milliseconds>(en_time - st_time).count()
<< "ms" << std::endl;
}
我随机生成了几批数据 where nA = nB = 1e6, m = 2e6
,程序在我的电脑上总能在不到 800ms 的时间内产生答案(不考虑 IO 时间,启用 O2)。该解决方案的时间复杂度是O((m+n)log m)
自emplace
调用以来的最n+m
长时间。
对不起我的英语不好。随时指出我的错别字和错误。