2

给定一个具有 A 类和 B 类节点的节点权重的二部图,如下所示:

在此处输入图像描述

我想输出由以下启发式定义的 B 型节点的有序列表:

  1. 对于类型 B 的每个节点,我们将与该节点具有边的类型 A 的节点权重相加,并将总和乘以它自己的节点权重以获得节点值。
  2. 然后我们从类型 B 中选择具有最高值的节点并将其附加到输出集 S。
  3. 我们从类型 B 中删除选定的节点,并从类型 A 中删除它有一条边的所有节点。
  4. 返回第 1 步,直到类型 B 中的任何节点都与类型 A 中的节点有一条边。
  5. 将类型 B 的任何剩余节点按其节点权重的顺序附加到输出集。

下图显示了一个示例:

在此处输入图像描述

对于此示例,输出集将是:(Y, Z, X)

天真的过程将是简单地遍历这个算法,但假设二分图很大,我正在寻找找到输出集的最有效方法。请注意,我只需要 B 型节点的有序列表作为输出,而不需要中间计算值(例如 50、15、2)

4

2 回答 2

1

这是 Dave 在评论中建议的算法的进一步改进。它最大限度地减少了需要重新计算节点值的次数。

  1. 运行第 1 步,将生成的 B 节点按 val 放置在最大堆中
  2. 检查顶部节点是否有任何邻居被删除。如果是,重新计算并重新插入堆。如果否,则添加到输出并删除邻居。
  3. 重复2直到所有B都输出

我已经基于我的PathFinder graph class在 C++ 中实现了这个算法。该代码在具有半个 a 和半个 b 节点的 100 万个节点图上运行,每个 b 节点连接到两个随机 a 节点,需要 1 秒。

这是代码

void cPathFinder::karup()
    {
        raven::set::cRunWatch aWatcher("karup");
        std::cout << "karup on " << nodeCount() << " node graph\n";
        std::vector<int> output;

        // calculate initial values of B nodes
        std::multimap<int, int> mapValueNode;
        for (auto &b : nodes())
        {
            if (b.second.myName[0] != 'b')
                continue;
            int value = 0;
            for (auto a : b.second.myLink)
            {
                value += node(a.first).myCost;
            }
            value *= b.second.myCost;
            mapValueNode.insert(std::make_pair(value, b.first));
        }

        // while not all B nodes output
        while (mapValueNode.size())
        {
            raven::set::cRunWatch aWatcher("select");

            // select node with highest value
            auto remove_it = --mapValueNode.end();
            int remove = remove_it->second;

            if (!remove_it->first)
            {
                /** all remaining nodes have zero value
                 * all the links from B nodes to A nodes have been removed
                 * output remaining nodes in order of decreasing node weight
                 */
                raven::set::cRunWatch aWatcher("Bunlinked");
                std::multimap<int, int> mapNodeValueNode;
                for (auto &nv : mapValueNode)
                {
                   mapNodeValueNode.insert( 
                       std::make_pair( 
                           node(nv.second).myCost,
                           nv.second ));
                }
                for( auto& nv : mapNodeValueNode )
                {
                    myPath.push_back( nv.second );
                }
                break;
            }

            bool OK = true;
            int value = 0;
            {
                raven::set::cRunWatch aWatcher("check");

                // check that no nodes providing value have been removed

                // std::cout << "checking neighbors of " << name(remove) << "\n";

                auto &vl = node(remove).myLink;
                for (auto it = vl.begin(); it != vl.end();)
                {
                    if (!myG.count(it->first))
                    {
                        // A neighbour has been removed
                        OK = false;
                        it = vl.erase(it);
                    }
                    else
                    {
                        // A neighbour remains
                        value += node(it->first).myCost;
                        it++;
                    }
                }
            }

            if (OK)
            {
                raven::set::cRunWatch aWatcher("store");
                // we have a node whose values is highest and valid

                // store result
                output.push_back(remove);

                // remove neighbour A nodes
                auto &ls = node(remove).myLink;
                for (auto &l : ls)
                {
                    myG.erase(l.first);
                }
                // remove the B node
                // std::cout << "remove " << name( remove ) << "\n";
                mapValueNode.erase(remove_it);
            }
            else
            {
                // replace old value with new
                raven::set::cRunWatch aWatcher("replace");
                value *= node(remove).myCost;
                mapValueNode.erase(remove_it);
                mapValueNode.insert(std::make_pair(value, remove));
            }
        }
    }

以下是计时结果

karup on 1000000 node graph
raven::set::cRunWatch code timing profile
Calls           Mean (secs)     Total           Scope
       1        1.16767 1.16767 karup
  581457        1.37921e-06     0.801951        select
  581456        4.71585e-07     0.274206        check
  564546        3.04042e-07     0.171646        replace
       1        0.153269        0.153269        Bunlinked
   16910        8.10422e-06     0.137042        store
于 2021-06-15T13:03:38.367 回答
1

我在 C++ 中提供了一个基本上类似于 @ravenspoint 的想法的解决方案。它维护一个堆,每次取值最高的B节点。在这里,我使用priority_queue而不是set导致第一个比第二个快得多。


#include <chrono>
#include <iostream>
#include <queue>
#include <vector>

int nA, nB;
std::vector<int> A, B, sum;
std::vector<std::vector<int>> adjA, adjB;
inline std::vector<int> solve() {
    struct Node {
        // We store the value of the node `x` WHEN IT IS INSERTED
        // Modifying the value of the node `x` (sum) won't affect this Node basically
        int x, val;

        Node(int x): x(x), val(sum[x] * B[x]) {}

        bool operator<(const Node &t) const { return val == t.val? (B[x] < B[t.x]): (val < t.val); }
    };

    std::priority_queue<Node> q;
    std::vector<bool> delA(nA, false), delB(nB, false);
    std::vector<int> ret; ret.reserve(nB);

    for (int x = 0; x < nA; ++x)
        for (int y : adjA[x]) sum[y] += A[x];
    for (int y = 0; y < nB; ++y) q.emplace(y);
    while (!q.empty()) {
        const Node node = q.top(); q.pop();
        const int y = node.x;
        if (sum[y] * B[y] != node.val || delB[y]) // This means this Node is obsolete
            continue;
        delB[y] = true;
        ret.push_back(y);
        for (int x : adjB[y]) {
            if (delA[x]) continue;
            delA[x] = true;
            for (int ny : adjA[x]) {
                if (delB[ny]) continue;
                sum[ny] -= A[x];
                // This happens at most `m` time
                q.emplace(ny);
            }
        }
    }

    return ret;
}
int main() {
    std::cout << "Number of nodes in type A: "; std::cin >> nA;
    A.resize(nA); adjA.resize(nA);
    std::cout << "Weights of nodes in type A: ";
    for (int &v : A) std::cin >> v;

    std::cout << "Number of nodes in type B: "; std::cin >> nB;
    B.resize(nB); adjB.resize(nB); sum.resize(nB, 0);
    std::cout << "Weights of nodes in type B: ";
    for (int &v : B) std::cin >> v;

    int m;
    std::cout << "Number of edges: "; std::cin >> m;
    std::cout << "Edges: " << std::endl;
    for (int i = 0; i < m; ++i) {
        int x, y; std::cin >> x >> y;
        --x; --y;
        adjA[x].push_back(y);
        adjB[y].push_back(x);
    }

    auto st_time = std::chrono::steady_clock::now();
    auto ret = solve();
    auto en_time = std::chrono::steady_clock::now();
    std::cout << "Answer:";
    for (int v : ret) std::cout << ' ' << (v + 1);
    std::cout << std::endl;

    std::cout << "Took "
        << std::chrono::duration_cast<std::chrono::milliseconds>(en_time - st_time).count()
        << "ms" << std::endl;
}

我随机生成了几批数据 where nA = nB = 1e6, m = 2e6,程序在我的电脑上总能在不到 800ms 的时间内产生答案(不考虑 IO 时间,启用 O2)。该解决方案的时间复杂度是O((m+n)log m)emplace调用以来的最n+m长时间。

对不起我的英语不好。随时指出我的错别字和错误。

于 2021-06-18T05:30:56.067 回答