| 1 | /* | 
| 2 | * Souffle - A Datalog Compiler | 
| 3 | * Copyright (c) 2017 The Souffle Developers. All rights reserved | 
| 4 | * Licensed under the Universal Permissive License v 1.0 as shown at: | 
| 5 | * - https://opensource.org/licenses/UPL | 
| 6 | * - <souffle root>/licenses/SOUFFLE-UPL.txt | 
| 7 | */ | 
| 8 |  | 
| 9 | /************************************************************************ | 
| 10 | * | 
| 11 | * @file UnionFind.h | 
| 12 | * | 
| 13 | * Defines a union-find data-structure | 
| 14 | * | 
| 15 | ***********************************************************************/ | 
| 16 |  | 
| 17 | #pragma once | 
| 18 |  | 
| 19 | #include "souffle/datastructure/LambdaBTree.h" | 
| 20 | #include "souffle/datastructure/PiggyList.h" | 
| 21 | #include "souffle/utility/MiscUtil.h" | 
| 22 | #include <atomic> | 
| 23 | #include <cstddef> | 
| 24 | #include <cstdint> | 
| 25 | #include <functional> | 
| 26 | #include <utility> | 
| 27 |  | 
| 28 | namespace souffle { | 
| 29 |  | 
| 30 | // branch predictor hacks | 
| 31 | #define unlikely(x) __builtin_expect((x), 0) | 
| 32 | #define likely(x) __builtin_expect((x), 1) | 
| 33 |  | 
| 34 | using rank_t = uint8_t; | 
| 35 | /* technically uint56_t, but, doesn't exist. Just be careful about storing > 2^56 elements. */ | 
| 36 | using parent_t = uint64_t; | 
| 37 |  | 
| 38 | // number of bits that the rank is | 
| 39 | constexpr uint8_t split_size = 8u; | 
| 40 |  | 
| 41 | // block_t stores parent in the upper half, rank in the lower half | 
| 42 | using block_t = uint64_t; | 
| 43 | // block_t & rank_mask extracts the rank | 
| 44 | constexpr block_t rank_mask = (1ul << split_size) - 1; | 
| 45 |  | 
| 46 | /** | 
| 47 | * Structure that emulates a Disjoint Set, i.e. a data structure that supports efficient union-find operations | 
| 48 | */ | 
| 49 | class DisjointSet { | 
| 50 | template <typename TupleType> | 
| 51 | friend class EquivalenceRelation; | 
| 52 |  | 
| 53 | PiggyList<std::atomic<block_t>> a_blocks; | 
| 54 |  | 
| 55 | public: | 
| 56 | DisjointSet() = default; | 
| 57 |  | 
| 58 | // copy ctor | 
| 59 | DisjointSet(DisjointSet& other) = delete; | 
| 60 | // move ctor | 
| 61 | DisjointSet(DisjointSet&& other) = delete; | 
| 62 |  | 
| 63 | // copy assign ctor | 
| 64 | DisjointSet& operator=(DisjointSet& ds) = delete; | 
| 65 | // move assign ctor | 
| 66 | DisjointSet& operator=(DisjointSet&& ds) = delete; | 
| 67 |  | 
| 68 | /** | 
| 69 | * Return the number of elements in this disjoint set (not the number of pairs) | 
| 70 | */ | 
| 71 | inline std::size_t size() { | 
| 72 | auto sz = a_blocks.size(); | 
| 73 | return sz; | 
| 74 | }; | 
| 75 |  | 
| 76 | /** | 
| 77 | * Yield reference to the node by its node index | 
| 78 | * @param node node to be searched | 
| 79 | * @return the parent block of the specified node | 
| 80 | */ | 
| 81 | inline std::atomic<block_t>& get(parent_t node) const { | 
| 82 | auto& ret = a_blocks.get(node); | 
| 83 | return ret; | 
| 84 | }; | 
| 85 |  | 
| 86 | /** | 
| 87 | * Equivalent to the find() function in union/find | 
| 88 | * Find the highest ancestor of the provided node - flattening as we go | 
| 89 | * @param x the node to find the parent of, whilst flattening its set-tree | 
| 90 | * @return The parent of x | 
| 91 | */ | 
| 92 | parent_t findNode(parent_t x) { | 
| 93 | // while x's parent is not itself | 
| 94 | while (x != b2p(get(x))) { | 
| 95 | block_t xState = get(x); | 
| 96 | // yield x's parent's parent | 
| 97 | parent_t newParent = b2p(get(b2p(xState))); | 
| 98 | // construct block out of the original rank and the new parent | 
| 99 | block_t newState = pr2b(newParent, b2r(xState)); | 
| 100 |  | 
| 101 | this->get(x).compare_exchange_strong(xState, newState); | 
| 102 |  | 
| 103 | x = newParent; | 
| 104 | } | 
| 105 | return x; | 
| 106 | } | 
| 107 |  | 
| 108 | private: | 
| 109 | /** | 
| 110 | * Update the root of the tree of which x is, to have y as the base instead | 
| 111 | * @param x : old root | 
| 112 | * @param oldrank : old root rank | 
| 113 | * @param y : new root | 
| 114 | * @param newrank : new root rank | 
| 115 | * @return Whether the update succeeded (fails if another root update/union has been perfomed in the | 
| 116 | * interim) | 
| 117 | */ | 
| 118 | bool updateRoot(const parent_t x, const rank_t oldrank, const parent_t y, const rank_t newrank) { | 
| 119 | block_t oldState = get(x); | 
| 120 | parent_t nextN = b2p(oldState); | 
| 121 | rank_t rankN = b2r(oldState); | 
| 122 |  | 
| 123 | if (nextN != x || rankN != oldrank) return false; | 
| 124 | // set the parent and rank of the new record | 
| 125 | block_t newVal = pr2b(y, newrank); | 
| 126 |  | 
| 127 | return this->get(x).compare_exchange_strong(oldState, newVal); | 
| 128 | } | 
| 129 |  | 
| 130 | public: | 
| 131 | /** | 
| 132 | * Clears the DisjointSet of all nodes | 
| 133 | * Invalidates all iterators | 
| 134 | */ | 
| 135 | void clear() { | 
| 136 | a_blocks.clear(); | 
| 137 | } | 
| 138 |  | 
| 139 | /** | 
| 140 | * Check whether the two indices are in the same set | 
| 141 | * @param x node to be checked | 
| 142 | * @param y node to be checked | 
| 143 | * @return where the two indices are in the same set | 
| 144 | */ | 
| 145 | bool sameSet(parent_t x, parent_t y) { | 
| 146 | while (true) { | 
| 147 | x = findNode(x); | 
| 148 | y = findNode(y); | 
| 149 | if (x == y) return true; | 
| 150 | // if x's parent is itself, they are not the same set | 
| 151 | if (b2p(get(x)) == x) return false; | 
| 152 | } | 
| 153 | } | 
| 154 |  | 
| 155 | /** | 
| 156 | * Union the two specified index nodes | 
| 157 | * @param x node to be unioned | 
| 158 | * @param y node to be unioned | 
| 159 | */ | 
| 160 | void unionNodes(parent_t x, parent_t y) { | 
| 161 | while (true) { | 
| 162 | x = findNode(x); | 
| 163 | y = findNode(y); | 
| 164 |  | 
| 165 | // no need to union if both already in same set | 
| 166 | if (x == y) return; | 
| 167 |  | 
| 168 | rank_t xrank = b2r(get(x)); | 
| 169 | rank_t yrank = b2r(get(y)); | 
| 170 |  | 
| 171 | // if x comes before y (better rank or earlier & equal node) | 
| 172 | if (xrank > yrank || ((xrank == yrank) && x > y)) { | 
| 173 | std::swap(x, y); | 
| 174 | std::swap(xrank, yrank); | 
| 175 | } | 
| 176 | // join the trees together | 
| 177 | // perhaps we can optimise the use of compare_exchange_strong here, as we're in a pessimistic loop | 
| 178 | if (!updateRoot(x, xrank, y, yrank)) { | 
| 179 | continue; | 
| 180 | } | 
| 181 | // make sure that the ranks are orderable | 
| 182 | if (xrank == yrank) { | 
| 183 | updateRoot(y, yrank, y, yrank + 1); | 
| 184 | } | 
| 185 | break; | 
| 186 | } | 
| 187 | } | 
| 188 |  | 
| 189 | /** | 
| 190 | * Create a node with its parent as itself, rank 0 | 
| 191 | * @return the newly created block | 
| 192 | */ | 
| 193 | inline block_t makeNode() { | 
| 194 | // make node and find out where we've added it | 
| 195 | std::size_t nodeDetails = a_blocks.createNode(); | 
| 196 |  | 
| 197 | a_blocks.get(nodeDetails).store(pr2b(nodeDetails, 0)); | 
| 198 |  | 
| 199 | return a_blocks.get(nodeDetails).load(); | 
| 200 | }; | 
| 201 |  | 
| 202 | /** | 
| 203 | * Extract parent from block | 
| 204 | * @param inblock the block to be masked | 
| 205 | * @return The parent_t contained in the upper half of block_t | 
| 206 | */ | 
| 207 | static inline parent_t b2p(const block_t inblock) { | 
| 208 | return (parent_t)(inblock >> split_size); | 
| 209 | }; | 
| 210 |  | 
| 211 | /** | 
| 212 | * Extract rank from block | 
| 213 | * @param inblock the block to be masked | 
| 214 | * @return the rank_t contained in the lower half of block_t | 
| 215 | */ | 
| 216 | static inline rank_t b2r(const block_t inblock) { | 
| 217 | return (rank_t)(inblock & rank_mask); | 
| 218 | }; | 
| 219 |  | 
| 220 | /** | 
| 221 | * Yield a block given parent and rank | 
| 222 | * @param parent the top half bits | 
| 223 | * @param rank the lower half bits | 
| 224 | * @return the resultant block after merge | 
| 225 | */ | 
| 226 | static inline block_t pr2b(const parent_t parent, const rank_t rank) { | 
| 227 | return (((block_t)parent) << split_size) | rank; | 
| 228 | }; | 
| 229 | }; | 
| 230 |  | 
| 231 | template <typename StorePair> | 
| 232 | struct EqrelMapComparator { | 
| 233 | int operator()(const StorePair& a, const StorePair& b) { | 
| 234 | if (a.first < b.first) { | 
| 235 | return -1; | 
| 236 | } else if (b.first < a.first) { | 
| 237 | return 1; | 
| 238 | } else { | 
| 239 | return 0; | 
| 240 | } | 
| 241 | } | 
| 242 |  | 
| 243 | bool less(const StorePair& a, const StorePair& b) { | 
| 244 | return operator()(a, b) < 0; | 
| 245 | } | 
| 246 |  | 
| 247 | bool equal(const StorePair& a, const StorePair& b) { | 
| 248 | return operator()(a, b) == 0; | 
| 249 | } | 
| 250 | }; | 
| 251 |  | 
| 252 | template <typename SparseDomain> | 
| 253 | class SparseDisjointSet { | 
| 254 | DisjointSet ds; | 
| 255 |  | 
| 256 | template <typename TupleType> | 
| 257 | friend class EquivalenceRelation; | 
| 258 |  | 
| 259 | using PairStore = std::pair<SparseDomain, parent_t>; | 
| 260 | using SparseMap = | 
| 261 | LambdaBTreeSet<PairStore, std::function<parent_t(PairStore&)>, EqrelMapComparator<PairStore>>; | 
| 262 | using DenseMap = RandomInsertPiggyList<SparseDomain>; | 
| 263 |  | 
| 264 | typename SparseMap::operation_hints last_ins; | 
| 265 |  | 
| 266 | SparseMap sparseToDenseMap; | 
| 267 | // mapping from union-find val to souffle, union-find encoded as index | 
| 268 | DenseMap denseToSparseMap; | 
| 269 |  | 
| 270 | public: | 
| 271 | /** | 
| 272 | * Retrieve dense encoding, adding it in if non-existent | 
| 273 | * @param in the sparse value | 
| 274 | * @return the corresponding dense value | 
| 275 | */ | 
| 276 | parent_t toDense(const SparseDomain in) { | 
| 277 | // insert into the mapping - if the key doesn't exist (in), the function will be called | 
| 278 | // and a dense value will be created for it | 
| 279 | PairStore p = {in, -1}; | 
| 280 | return sparseToDenseMap.insert(p, [&](PairStore& p) { | 
| 281 | parent_t c2 = DisjointSet::b2p(this->ds.makeNode()); | 
| 282 | this->denseToSparseMap.insertAt(c2, p.first); | 
| 283 | p.second = c2; | 
| 284 | return c2; | 
| 285 | }); | 
| 286 | } | 
| 287 |  | 
| 288 | public: | 
| 289 | SparseDisjointSet() = default; | 
| 290 |  | 
| 291 | // copy ctor | 
| 292 | SparseDisjointSet(SparseDisjointSet& other) = delete; | 
| 293 |  | 
| 294 | // move ctor | 
| 295 | SparseDisjointSet(SparseDisjointSet&& other) = delete; | 
| 296 |  | 
| 297 | // copy assign ctor | 
| 298 | SparseDisjointSet& operator=(SparseDisjointSet& other) = delete; | 
| 299 |  | 
| 300 | // move assign ctor | 
| 301 | SparseDisjointSet& operator=(SparseDisjointSet&& other) = delete; | 
| 302 |  | 
| 303 | /** | 
| 304 | * For the given dense value, return the associated sparse value | 
| 305 | *   Undefined behaviour if dense value not in set | 
| 306 | * @param in the supplied dense value | 
| 307 | * @return the sparse value from the denseToSparseMap | 
| 308 | */ | 
| 309 | inline const SparseDomain toSparse(const parent_t in) const { | 
| 310 | return denseToSparseMap.get(in); | 
| 311 | }; | 
| 312 |  | 
| 313 | /* a wrapper to enable checking in the sparse set - however also adds them if not already existing */ | 
| 314 | inline bool sameSet(SparseDomain x, SparseDomain y) { | 
| 315 | return ds.sameSet(toDense(x), toDense(y)); | 
| 316 | }; | 
| 317 | /* finds the node in the underlying disjoint set, adding the node if non-existent */ | 
| 318 | inline SparseDomain findNode(SparseDomain x) { | 
| 319 | return toSparse(ds.findNode(toDense(x))); | 
| 320 | }; | 
| 321 | /* union the nodes, add if not existing */ | 
| 322 | inline void unionNodes(SparseDomain x, SparseDomain y) { | 
| 323 | ds.unionNodes(toDense(x), toDense(y)); | 
| 324 | }; | 
| 325 |  | 
| 326 | inline std::size_t size() { | 
| 327 | return ds.size(); | 
| 328 | }; | 
| 329 |  | 
| 330 | /** | 
| 331 | * Remove all elements from this disjoint set | 
| 332 | */ | 
| 333 | void clear() { | 
| 334 | ds.clear(); | 
| 335 | sparseToDenseMap.clear(); | 
| 336 | denseToSparseMap.clear(); | 
| 337 | } | 
| 338 |  | 
| 339 | /* wrapper for node creation */ | 
| 340 | inline void makeNode(SparseDomain val) { | 
| 341 | // dense has the behaviour of creating if not exists. | 
| 342 | toDense(val); | 
| 343 | }; | 
| 344 |  | 
| 345 | /* whether the supplied node exists */ | 
| 346 | inline bool nodeExists(const SparseDomain val) const { | 
| 347 | return sparseToDenseMap.contains({val, -1}); | 
| 348 | }; | 
| 349 |  | 
| 350 | inline bool contains(SparseDomain v1, SparseDomain v2) { | 
| 351 | if (nodeExists(v1) && nodeExists(v2)) { | 
| 352 | return sameSet(v1, v2); | 
| 353 | } | 
| 354 | return false; | 
| 355 | } | 
| 356 | }; | 
| 357 | }  // namespace souffle |