| 1 | /*
|
| 2 | * Souffle - A Datalog Compiler
|
| 3 | * Copyright (c) 2017, The Souffle Developers. All rights reserved
|
| 4 | * Licensed under the Universal Permissive License v 1.0 as shown at:
|
| 5 | * - https://opensource.org/licenses/UPL
|
| 6 | * - <souffle root>/licenses/SOUFFLE-UPL.txt
|
| 7 | */
|
| 8 |
|
| 9 | /************************************************************************
|
| 10 | *
|
| 11 | * @file ExplainProvenanceImpl.h
|
| 12 | *
|
| 13 | * Implementation of abstract class in ExplainProvenance.h for guided Impl provenance
|
| 14 | *
|
| 15 | ***********************************************************************/
|
| 16 |
|
| 17 | #pragma once
|
| 18 |
|
| 19 | #include "souffle/BinaryConstraintOps.h"
|
| 20 | #include "souffle/RamTypes.h"
|
| 21 | #include "souffle/SouffleInterface.h"
|
| 22 | #include "souffle/SymbolTable.h"
|
| 23 | #include "souffle/provenance/ExplainProvenance.h"
|
| 24 | #include "souffle/provenance/ExplainTree.h"
|
| 25 | #include "souffle/utility/ContainerUtil.h"
|
| 26 | #include "souffle/utility/MiscUtil.h"
|
| 27 | #include "souffle/utility/StreamUtil.h"
|
| 28 | #include "souffle/utility/StringUtil.h"
|
| 29 | #include <algorithm>
|
| 30 | #include <cassert>
|
| 31 | #include <chrono>
|
| 32 | #include <cstdio>
|
| 33 | #include <iostream>
|
| 34 | #include <map>
|
| 35 | #include <memory>
|
| 36 | #include <regex>
|
| 37 | #include <sstream>
|
| 38 | #include <string>
|
| 39 | #include <tuple>
|
| 40 | #include <type_traits>
|
| 41 | #include <utility>
|
| 42 | #include <vector>
|
| 43 |
|
| 44 | namespace souffle {
|
| 45 |
|
| 46 | using namespace stream_write_qualified_char_as_number;
|
| 47 |
|
| 48 | class ExplainProvenanceImpl : public ExplainProvenance {
|
| 49 | using arity_type = Relation::arity_type;
|
| 50 |
|
| 51 | public:
|
| 52 | ExplainProvenanceImpl(SouffleProgram& prog) : ExplainProvenance(prog) {
|
| 53 | setup();
|
| 54 | }
|
| 55 |
|
| 56 | void setup() override {
|
| 57 | // for each clause, store a mapping from the head relation name to body relation names
|
| 58 | for (auto& rel : prog.getAllRelations()) {
|
| 59 | std::string name = rel->getName();
|
| 60 |
|
| 61 | // only process info relations
|
| 62 | if (name.find("@info") == std::string::npos) {
|
| 63 | continue;
|
| 64 | }
|
| 65 |
|
| 66 | // find all the info tuples
|
| 67 | for (auto& tuple : *rel) {
|
| 68 | std::vector<std::string> bodyLiterals;
|
| 69 |
|
| 70 | // first field is rule number
|
| 71 | RamDomain ruleNum;
|
| 72 | tuple >> ruleNum;
|
| 73 |
|
| 74 | // middle fields are body literals
|
| 75 | for (std::size_t i = 1; i + 1 < rel->getArity(); i++) {
|
| 76 | std::string bodyLit;
|
| 77 | tuple >> bodyLit;
|
| 78 | bodyLiterals.push_back(bodyLit);
|
| 79 | }
|
| 80 |
|
| 81 | // last field is the rule itself
|
| 82 | std::string rule;
|
| 83 | tuple >> rule;
|
| 84 |
|
| 85 | std::string relName = name.substr(0, name.find(".@info"));
|
| 86 | info.insert({std::make_pair(relName, ruleNum), bodyLiterals});
|
| 87 | rules.insert({std::make_pair(relName, ruleNum), rule});
|
| 88 | }
|
| 89 | }
|
| 90 | }
|
| 91 |
|
| 92 | Own<TreeNode> explain(std::string relName, std::vector<RamDomain> tuple, int ruleNum, int levelNum,
|
| 93 | std::size_t depthLimit) {
|
| 94 | std::stringstream joinedArgs;
|
| 95 | joinedArgs << join(decodeArguments(relName, tuple), ", ");
|
| 96 | auto joinedArgsStr = joinedArgs.str();
|
| 97 |
|
| 98 | // if fact
|
| 99 | if (levelNum == 0) {
|
| 100 | return mk<LeafNode>(relName + "(" + joinedArgsStr + ")");
|
| 101 | }
|
| 102 |
|
| 103 | assert(contains(info, std::make_pair(relName, ruleNum)) && "invalid rule for tuple");
|
| 104 |
|
| 105 | // if depth limit exceeded
|
| 106 | if (depthLimit <= 1) {
|
| 107 | tuple.push_back(ruleNum);
|
| 108 | tuple.push_back(levelNum);
|
| 109 |
|
| 110 | // find if subproof exists already
|
| 111 | std::size_t idx = 0;
|
| 112 | auto it = std::find(subproofs.begin(), subproofs.end(), tuple);
|
| 113 | if (it != subproofs.end()) {
|
| 114 | idx = it - subproofs.begin();
|
| 115 | } else {
|
| 116 | subproofs.push_back(tuple);
|
| 117 | idx = subproofs.size() - 1;
|
| 118 | }
|
| 119 |
|
| 120 | return mk<LeafNode>("subproof " + relName + "(" + std::to_string(idx) + ")");
|
| 121 | }
|
| 122 |
|
| 123 | tuple.push_back(levelNum);
|
| 124 |
|
| 125 | auto internalNode =
|
| 126 | mk<InnerNode>(relName + "(" + joinedArgsStr + ")", "(R" + std::to_string(ruleNum) + ")");
|
| 127 |
|
| 128 | // make return vector pointer
|
| 129 | std::vector<RamDomain> ret;
|
| 130 |
|
| 131 | // execute subroutine to get subproofs
|
| 132 | prog.executeSubroutine(relName + "_" + std::to_string(ruleNum) + "_subproof", tuple, ret);
|
| 133 |
|
| 134 | // recursively get nodes for subproofs
|
| 135 | std::size_t tupleCurInd = 0;
|
| 136 | auto bodyRelations = info.at(std::make_pair(relName, ruleNum));
|
| 137 |
|
| 138 | // start from begin + 1 because the first element represents the head atom
|
| 139 | for (auto it = bodyRelations.begin() + 1; it < bodyRelations.end(); it++) {
|
| 140 | std::string bodyLiteral = *it;
|
| 141 | // split bodyLiteral since it contains relation name plus arguments
|
| 142 | std::string bodyRel = splitString(bodyLiteral, ',')[0];
|
| 143 |
|
| 144 | // check whether the current atom is a constraint
|
| 145 | assert(bodyRel.size() > 0 && "body of a relation should have positive length");
|
| 146 | bool isConstraint = contains(constraintList, bodyRel);
|
| 147 |
|
| 148 | // handle negated atom names
|
| 149 | auto bodyRelAtomName = bodyRel;
|
| 150 | if (bodyRel[0] == '!' && bodyRel != "!=") {
|
| 151 | bodyRelAtomName = bodyRel.substr(1);
|
| 152 | }
|
| 153 |
|
| 154 | // traverse subroutine return
|
| 155 | std::size_t arity;
|
| 156 | std::size_t auxiliaryArity;
|
| 157 | if (isConstraint) {
|
| 158 | // we only handle binary constraints, and assume arity is 4 to account for hidden provenance
|
| 159 | // annotations
|
| 160 | arity = 4;
|
| 161 | auxiliaryArity = 2;
|
| 162 | } else {
|
| 163 | arity = prog.getRelation(bodyRelAtomName)->getArity();
|
| 164 | auxiliaryArity = prog.getRelation(bodyRelAtomName)->getAuxiliaryArity();
|
| 165 | }
|
| 166 | auto tupleEnd = tupleCurInd + arity;
|
| 167 |
|
| 168 | // store current tuple
|
| 169 | std::vector<RamDomain> subproofTuple;
|
| 170 |
|
| 171 | for (; tupleCurInd < tupleEnd - auxiliaryArity; tupleCurInd++) {
|
| 172 | subproofTuple.push_back(ret[tupleCurInd]);
|
| 173 | }
|
| 174 |
|
| 175 | int subproofRuleNum = ret[tupleCurInd];
|
| 176 | int subproofLevelNum = ret[tupleCurInd + 1];
|
| 177 |
|
| 178 | tupleCurInd += 2;
|
| 179 |
|
| 180 | // for a negation, display the corresponding tuple and do not recurse
|
| 181 | if (bodyRel[0] == '!' && bodyRel != "!=") {
|
| 182 | std::stringstream joinedTuple;
|
| 183 | joinedTuple << join(decodeArguments(bodyRelAtomName, subproofTuple), ", ");
|
| 184 | auto joinedTupleStr = joinedTuple.str();
|
| 185 | internalNode->add_child(mk<LeafNode>(bodyRel + "(" + joinedTupleStr + ")"));
|
| 186 | internalNode->setSize(internalNode->getSize() + 1);
|
| 187 | // for a binary constraint, display the corresponding values and do not recurse
|
| 188 | } else if (isConstraint) {
|
| 189 | std::stringstream joinedConstraint;
|
| 190 |
|
| 191 | // FIXME: We need type info in order to figure out how to print arguments.
|
| 192 | BinaryConstraintOp rawBinOp = toBinaryConstraintOp(bodyRel);
|
| 193 | if (isOrderedBinaryConstraintOp(rawBinOp)) {
|
| 194 | joinedConstraint << subproofTuple[0] << " " << bodyRel << " " << subproofTuple[1];
|
| 195 | } else {
|
| 196 | joinedConstraint << bodyRel << "(\"" << symTable.decode(subproofTuple[0]) << "\", \""
|
| 197 | << symTable.decode(subproofTuple[1]) << "\")";
|
| 198 | }
|
| 199 |
|
| 200 | internalNode->add_child(mk<LeafNode>(joinedConstraint.str()));
|
| 201 | internalNode->setSize(internalNode->getSize() + 1);
|
| 202 | // otherwise, for a normal tuple, recurse
|
| 203 | } else {
|
| 204 | auto child =
|
| 205 | explain(bodyRel, subproofTuple, subproofRuleNum, subproofLevelNum, depthLimit - 1);
|
| 206 | internalNode->setSize(internalNode->getSize() + child->getSize());
|
| 207 | internalNode->add_child(std::move(child));
|
| 208 | }
|
| 209 |
|
| 210 | tupleCurInd = tupleEnd;
|
| 211 | }
|
| 212 |
|
| 213 | return internalNode;
|
| 214 | }
|
| 215 |
|
| 216 | Own<TreeNode> explain(
|
| 217 | std::string relName, std::vector<std::string> args, std::size_t depthLimit) override {
|
| 218 | auto tuple = argsToNums(relName, args);
|
| 219 | if (tuple.empty()) {
|
| 220 | return mk<LeafNode>("Relation not found");
|
| 221 | }
|
| 222 |
|
| 223 | std::tuple<int, int> tupleInfo = findTuple(relName, tuple);
|
| 224 |
|
| 225 | int ruleNum = std::get<0>(tupleInfo);
|
| 226 | int levelNum = std::get<1>(tupleInfo);
|
| 227 |
|
| 228 | if (ruleNum < 0 || levelNum == -1) {
|
| 229 | return mk<LeafNode>("Tuple not found");
|
| 230 | }
|
| 231 |
|
| 232 | return explain(relName, tuple, ruleNum, levelNum, depthLimit);
|
| 233 | }
|
| 234 |
|
| 235 | Own<TreeNode> explainSubproof(
|
| 236 | std::string relName, RamDomain subproofNum, std::size_t depthLimit) override {
|
| 237 | if (subproofNum >= (int)subproofs.size()) {
|
| 238 | return mk<LeafNode>("Subproof not found");
|
| 239 | }
|
| 240 |
|
| 241 | auto tup = subproofs[subproofNum];
|
| 242 |
|
| 243 | auto rel = prog.getRelation(relName);
|
| 244 |
|
| 245 | assert(rel->getAuxiliaryArity() == 2 && "unexpected auxiliary arity in provenance context");
|
| 246 |
|
| 247 | RamDomain ruleNum;
|
| 248 | ruleNum = tup[rel->getArity() - 2];
|
| 249 |
|
| 250 | RamDomain levelNum;
|
| 251 | levelNum = tup[rel->getArity() - 1];
|
| 252 |
|
| 253 | tup.erase(tup.begin() + rel->getArity() - 2, tup.end());
|
| 254 |
|
| 255 | return explain(relName, tup, ruleNum, levelNum, depthLimit);
|
| 256 | }
|
| 257 |
|
| 258 | std::vector<std::string> explainNegationGetVariables(
|
| 259 | std::string relName, std::vector<std::string> args, std::size_t ruleNum) override {
|
| 260 | std::vector<std::string> variables;
|
| 261 |
|
| 262 | // check that the tuple actually doesn't exist
|
| 263 | std::tuple<int, int> foundTuple = findTuple(relName, argsToNums(relName, args));
|
| 264 | if (std::get<0>(foundTuple) != -1 || std::get<1>(foundTuple) != -1) {
|
| 265 | // return a sentinel value
|
| 266 | return std::vector<std::string>({"@"});
|
| 267 | }
|
| 268 |
|
| 269 | // atom meta information stored for the current rule
|
| 270 | auto atoms = info[std::make_pair(relName, ruleNum)];
|
| 271 |
|
| 272 | // the info stores the set of atoms, if there is only 1 atom, then it must be the head, so it must be
|
| 273 | // a fact
|
| 274 | if (atoms.size() <= 1) {
|
| 275 | return std::vector<std::string>({"@fact"});
|
| 276 | }
|
| 277 |
|
| 278 | // atoms[0] represents variables in the head atom
|
| 279 | auto headVariables = splitString(atoms[0], ',');
|
| 280 |
|
| 281 | auto isVariable = [&](std::string arg) {
|
| 282 | return !(isNumber(arg.c_str()) || arg[0] == '\"' || arg == "_");
|
| 283 | };
|
| 284 |
|
| 285 | // check that head variable bindings make sense, i.e. for a head like a(x, x), make sure both x are
|
| 286 | // the same value
|
| 287 | std::map<std::string, std::string> headVariableMapping;
|
| 288 | for (std::size_t i = 0; i < headVariables.size(); i++) {
|
| 289 | if (!isVariable(headVariables[i])) {
|
| 290 | continue;
|
| 291 | }
|
| 292 |
|
| 293 | if (headVariableMapping.find(headVariables[i]) == headVariableMapping.end()) {
|
| 294 | headVariableMapping[headVariables[i]] = args[i];
|
| 295 | } else {
|
| 296 | if (headVariableMapping[headVariables[i]] != args[i]) {
|
| 297 | return std::vector<std::string>({"@non_matching"});
|
| 298 | }
|
| 299 | }
|
| 300 | }
|
| 301 |
|
| 302 | // get body variables
|
| 303 | std::vector<std::string> uniqueBodyVariables;
|
| 304 | for (auto it = atoms.begin() + 1; it < atoms.end(); it++) {
|
| 305 | auto atomRepresentation = splitString(*it, ',');
|
| 306 |
|
| 307 | // atomRepresentation.begin() + 1 because the first element is the relation name of the atom
|
| 308 | // which is not relevant for finding variables
|
| 309 | for (auto atomIt = atomRepresentation.begin() + 1; atomIt < atomRepresentation.end(); atomIt++) {
|
| 310 | if (!isVariable(*atomIt)) {
|
| 311 | continue;
|
| 312 | }
|
| 313 |
|
| 314 | if (!contains(uniqueBodyVariables, *atomIt) && !contains(headVariables, *atomIt)) {
|
| 315 | uniqueBodyVariables.push_back(*atomIt);
|
| 316 | }
|
| 317 | }
|
| 318 | }
|
| 319 |
|
| 320 | return uniqueBodyVariables;
|
| 321 | }
|
| 322 |
|
| 323 | Own<TreeNode> explainNegation(std::string relName, std::size_t ruleNum,
|
| 324 | const std::vector<std::string>& tuple,
|
| 325 | std::map<std::string, std::string>& bodyVariables) override {
|
| 326 | // construct a vector of unique variables that occur in the rule
|
| 327 | std::vector<std::string> uniqueVariables;
|
| 328 |
|
| 329 | // we also need to know the type of each variable
|
| 330 | std::map<std::string, char> variableTypes;
|
| 331 |
|
| 332 | // atom meta information stored for the current rule
|
| 333 | auto atoms = info.at(std::make_pair(relName, ruleNum));
|
| 334 |
|
| 335 | // atoms[0] represents variables in the head atom
|
| 336 | auto headVariables = splitString(atoms[0], ',');
|
| 337 |
|
| 338 | uniqueVariables.insert(uniqueVariables.end(), headVariables.begin(), headVariables.end());
|
| 339 |
|
| 340 | auto isVariable = [&](std::string arg) {
|
| 341 | return !(isNumber(arg.c_str()) || arg[0] == '\"' || arg == "_");
|
| 342 | };
|
| 343 |
|
| 344 | // get body variables
|
| 345 | for (auto it = atoms.begin() + 1; it < atoms.end(); it++) {
|
| 346 | auto atomRepresentation = splitString(*it, ',');
|
| 347 |
|
| 348 | // atomRepresentation.begin() + 1 because the first element is the relation name of the atom
|
| 349 | // which is not relevant for finding variables
|
| 350 | for (auto atomIt = atomRepresentation.begin() + 1; atomIt < atomRepresentation.end(); atomIt++) {
|
| 351 | if (!contains(uniqueVariables, *atomIt) && !contains(headVariables, *atomIt)) {
|
| 352 | // ignore non-variables
|
| 353 | if (!isVariable(*atomIt)) {
|
| 354 | continue;
|
| 355 | }
|
| 356 |
|
| 357 | uniqueVariables.push_back(*atomIt);
|
| 358 |
|
| 359 | if (!contains(constraintList, atomRepresentation[0])) {
|
| 360 | // store type of variable
|
| 361 | auto currentRel = prog.getRelation(atomRepresentation[0]);
|
| 362 | assert(currentRel != nullptr &&
|
| 363 | ("relation " + atomRepresentation[0] + " doesn't exist").c_str());
|
| 364 | variableTypes[*atomIt] =
|
| 365 | *currentRel->getAttrType(atomIt - atomRepresentation.begin() - 1);
|
| 366 | } else if (atomIt->find("agg_") != std::string::npos) {
|
| 367 | variableTypes[*atomIt] = 'i';
|
| 368 | }
|
| 369 | }
|
| 370 | }
|
| 371 | }
|
| 372 |
|
| 373 | std::vector<RamDomain> args;
|
| 374 |
|
| 375 | std::size_t varCounter = 0;
|
| 376 |
|
| 377 | // construct arguments to pass in to the subroutine
|
| 378 | // - this contains the variable bindings selected by the user
|
| 379 |
|
| 380 | // add number representation of tuple
|
| 381 | auto tupleNums = argsToNums(relName, tuple);
|
| 382 | args.insert(args.end(), tupleNums.begin(), tupleNums.end());
|
| 383 | varCounter += tuple.size();
|
| 384 |
|
| 385 | while (varCounter < uniqueVariables.size()) {
|
| 386 | auto var = uniqueVariables[varCounter];
|
| 387 | auto varValue = bodyVariables[var];
|
| 388 | if (variableTypes[var] == 's') {
|
| 389 | if (varValue.size() >= 2 && varValue[0] == '"' && varValue[varValue.size() - 1] == '"') {
|
| 390 | auto originalStr = varValue.substr(1, varValue.size() - 2);
|
| 391 | args.push_back(symTable.encode(originalStr));
|
| 392 | } else {
|
| 393 | // assume no quotation marks
|
| 394 | args.push_back(symTable.encode(varValue));
|
| 395 | }
|
| 396 | } else {
|
| 397 | args.push_back(std::stoi(varValue));
|
| 398 | }
|
| 399 |
|
| 400 | varCounter++;
|
| 401 | }
|
| 402 |
|
| 403 | // set up return and error vectors for subroutine calling
|
| 404 | std::vector<RamDomain> ret;
|
| 405 |
|
| 406 | // execute subroutine to get subproofs
|
| 407 | prog.executeSubroutine(relName + "_" + std::to_string(ruleNum) + "_negation_subproof", args, ret);
|
| 408 |
|
| 409 | // ensure the subroutine returns the correct number of results
|
| 410 | assert(ret.size() == atoms.size() - 1);
|
| 411 |
|
| 412 | // construct tree nodes
|
| 413 | std::stringstream joinedArgsStr;
|
| 414 | joinedArgsStr << join(tuple, ",");
|
| 415 | auto internalNode = mk<InnerNode>(
|
| 416 | relName + "(" + joinedArgsStr.str() + ")", "(R" + std::to_string(ruleNum) + ")");
|
| 417 |
|
| 418 | // store the head tuple in bodyVariables so we can print
|
| 419 | for (std::size_t i = 0; i < headVariables.size(); i++) {
|
| 420 | bodyVariables[headVariables[i]] = tuple[i];
|
| 421 | }
|
| 422 |
|
| 423 | // traverse return vector and construct child nodes
|
| 424 | // making sure we display existent and non-existent tuples correctly
|
| 425 | int literalCounter = 1;
|
| 426 | for (RamDomain returnCounter : ret) {
|
| 427 | // check what the next contained atom is
|
| 428 | bool atomExists = true;
|
| 429 | if (returnCounter == 0) {
|
| 430 | atomExists = false;
|
| 431 | }
|
| 432 |
|
| 433 | // get the relation of the current atom
|
| 434 | auto atomRepresentation = splitString(atoms[literalCounter], ',');
|
| 435 | std::string bodyRel = atomRepresentation[0];
|
| 436 |
|
| 437 | // check whether the current atom is a constraint
|
| 438 | bool isConstraint = contains(constraintList, bodyRel);
|
| 439 |
|
| 440 | // handle negated atom names
|
| 441 | auto bodyRelAtomName = bodyRel;
|
| 442 | if (bodyRel[0] == '!' && bodyRel != "!=") {
|
| 443 | bodyRelAtomName = bodyRel.substr(1);
|
| 444 | }
|
| 445 |
|
| 446 | // construct a label for a node containing a literal (either constraint or atom)
|
| 447 | std::stringstream childLabel;
|
| 448 | if (isConstraint) {
|
| 449 | // for a binary constraint, display the corresponding values and do not recurse
|
| 450 | assert(atomRepresentation.size() == 3 && "not a binary constraint");
|
| 451 |
|
| 452 | childLabel << bodyVariables[atomRepresentation[1]] << " " << bodyRel << " "
|
| 453 | << bodyVariables[atomRepresentation[2]];
|
| 454 | } else {
|
| 455 | childLabel << bodyRel << "(";
|
| 456 | for (std::size_t i = 1; i < atomRepresentation.size(); i++) {
|
| 457 | // if it's a non-variable, print either _ for unnamed, or constant value
|
| 458 | if (!isVariable(atomRepresentation[i])) {
|
| 459 | childLabel << atomRepresentation[i];
|
| 460 | } else {
|
| 461 | childLabel << bodyVariables[atomRepresentation[i]];
|
| 462 | }
|
| 463 | if (i < atomRepresentation.size() - 1) {
|
| 464 | childLabel << ", ";
|
| 465 | }
|
| 466 | }
|
| 467 | childLabel << ")";
|
| 468 | }
|
| 469 |
|
| 470 | // build a marker for existence of body atoms
|
| 471 | if (atomExists) {
|
| 472 | childLabel << " ✓";
|
| 473 | } else {
|
| 474 | childLabel << " x";
|
| 475 | }
|
| 476 |
|
| 477 | internalNode->add_child(mk<LeafNode>(childLabel.str()));
|
| 478 | internalNode->setSize(internalNode->getSize() + 1);
|
| 479 |
|
| 480 | literalCounter++;
|
| 481 | }
|
| 482 |
|
| 483 | return internalNode;
|
| 484 | }
|
| 485 |
|
| 486 | std::string getRule(std::string relName, std::size_t ruleNum) override {
|
| 487 | auto key = make_pair(relName, ruleNum);
|
| 488 |
|
| 489 | auto rule = rules.find(key);
|
| 490 | if (rule == rules.end()) {
|
| 491 | return "Rule not found";
|
| 492 | } else {
|
| 493 | return rule->second;
|
| 494 | }
|
| 495 | }
|
| 496 |
|
| 497 | std::vector<std::string> getRules(const std::string& relName) override {
|
| 498 | std::vector<std::string> relRules;
|
| 499 | // go through all rules
|
| 500 | for (auto& rule : rules) {
|
| 501 | if (rule.first.first == relName) {
|
| 502 | relRules.push_back(rule.second);
|
| 503 | }
|
| 504 | }
|
| 505 |
|
| 506 | return relRules;
|
| 507 | }
|
| 508 |
|
| 509 | std::string measureRelation(std::string relName) override {
|
| 510 | auto rel = prog.getRelation(relName);
|
| 511 |
|
| 512 | if (rel == nullptr) {
|
| 513 | return "No relation found\n";
|
| 514 | }
|
| 515 |
|
| 516 | auto size = rel->size();
|
| 517 | int skip = size / 10;
|
| 518 |
|
| 519 | if (skip == 0) {
|
| 520 | skip = 1;
|
| 521 | }
|
| 522 |
|
| 523 | std::stringstream ss;
|
| 524 |
|
| 525 | auto before_time = std::chrono::high_resolution_clock::now();
|
| 526 |
|
| 527 | int numTuples = 0;
|
| 528 | int proc = 0;
|
| 529 | for (auto& tuple : *rel) {
|
| 530 | auto tupleStart = std::chrono::high_resolution_clock::now();
|
| 531 |
|
| 532 | if (numTuples % skip != 0) {
|
| 533 | numTuples++;
|
| 534 | continue;
|
| 535 | }
|
| 536 |
|
| 537 | std::vector<RamDomain> currentTuple;
|
| 538 | for (arity_type i = 0; i < rel->getPrimaryArity(); i++) {
|
| 539 | RamDomain n;
|
| 540 | if (*rel->getAttrType(i) == 's') {
|
| 541 | std::string s;
|
| 542 | tuple >> s;
|
| 543 | n = lookupExisting(s);
|
| 544 | } else if (*rel->getAttrType(i) == 'f') {
|
| 545 | RamFloat element;
|
| 546 | tuple >> element;
|
| 547 | n = ramBitCast(element);
|
| 548 | } else if (*rel->getAttrType(i) == 'u') {
|
| 549 | RamUnsigned element;
|
| 550 | tuple >> element;
|
| 551 | n = ramBitCast(element);
|
| 552 | } else {
|
| 553 | tuple >> n;
|
| 554 | }
|
| 555 |
|
| 556 | currentTuple.push_back(n);
|
| 557 | }
|
| 558 |
|
| 559 | RamDomain ruleNum;
|
| 560 | tuple >> ruleNum;
|
| 561 |
|
| 562 | RamDomain levelNum;
|
| 563 | tuple >> levelNum;
|
| 564 |
|
| 565 | std::cout << "Tuples expanded: "
|
| 566 | << explain(relName, currentTuple, ruleNum, levelNum, 10000)->getSize();
|
| 567 |
|
| 568 | numTuples++;
|
| 569 | proc++;
|
| 570 |
|
| 571 | auto tupleEnd = std::chrono::high_resolution_clock::now();
|
| 572 | auto tupleDuration =
|
| 573 | std::chrono::duration_cast<std::chrono::duration<double>>(tupleEnd - tupleStart);
|
| 574 |
|
| 575 | std::cout << ", Time: " << tupleDuration.count() << "\n";
|
| 576 | }
|
| 577 |
|
| 578 | auto after_time = std::chrono::high_resolution_clock::now();
|
| 579 | auto duration = std::chrono::duration_cast<std::chrono::duration<double>>(after_time - before_time);
|
| 580 |
|
| 581 | ss << "total: " << proc << " ";
|
| 582 | ss << duration.count() << std::endl;
|
| 583 |
|
| 584 | return ss.str();
|
| 585 | }
|
| 586 |
|
| 587 | void printRulesJSON(std::ostream& os) override {
|
| 588 | os << "\"rules\": [\n";
|
| 589 | bool first = true;
|
| 590 | for (auto const& cur : rules) {
|
| 591 | if (first) {
|
| 592 | first = false;
|
| 593 | } else {
|
| 594 | os << ",\n";
|
| 595 | }
|
| 596 | os << "\t{ \"rule-number\": \"(R" << cur.first.second << ")\", \"rule\": \""
|
| 597 | << stringify(cur.second) << "\"}";
|
| 598 | }
|
| 599 | os << "\n]\n";
|
| 600 | }
|
| 601 |
|
| 602 | void queryProcess(const std::vector<std::pair<std::string, std::vector<std::string>>>& rels) override {
|
| 603 | std::regex varRegex("[a-zA-Z_][a-zA-Z_0-9]*", std::regex_constants::extended);
|
| 604 | std::regex symbolRegex("\"([^\"]*)\"", std::regex_constants::extended);
|
| 605 | std::regex numberRegex("[0-9]+", std::regex_constants::extended);
|
| 606 |
|
| 607 | std::smatch argsMatcher;
|
| 608 |
|
| 609 | // map for variable name and corresponding equivalence class
|
| 610 | std::map<std::string, Equivalence> nameToEquivalence;
|
| 611 |
|
| 612 | // const constraints that solution must satisfy
|
| 613 | ConstConstraint constConstraints;
|
| 614 |
|
| 615 | // relations of tuples containing variables
|
| 616 | std::vector<Relation*> varRels;
|
| 617 |
|
| 618 | // counter for adding element to varRels
|
| 619 | std::size_t idx = 0;
|
| 620 |
|
| 621 | // parse arguments in each relation Tuple
|
| 622 | for (const auto& rel : rels) {
|
| 623 | Relation* relation = prog.getRelation(rel.first);
|
| 624 | // number/symbol index for constant arguments in tuple
|
| 625 | std::vector<RamDomain> constTuple;
|
| 626 | // relation does not exist
|
| 627 | if (relation == nullptr) {
|
| 628 | std::cout << "Relation <" << rel.first << "> does not exist" << std::endl;
|
| 629 | return;
|
| 630 | }
|
| 631 | // arity error
|
| 632 | if (relation->getPrimaryArity() != rel.second.size()) {
|
| 633 | std::cout << "<" + rel.first << "> has arity of " << relation->getPrimaryArity() << std::endl;
|
| 634 | return;
|
| 635 | }
|
| 636 |
|
| 637 | // check if args contain variable
|
| 638 | bool containVar = false;
|
| 639 | for (std::size_t j = 0; j < rel.second.size(); ++j) {
|
| 640 | // arg is a variable
|
| 641 | if (std::regex_match(rel.second[j], argsMatcher, varRegex)) {
|
| 642 | containVar = true;
|
| 643 | auto nameToEquivalenceIter = nameToEquivalence.find(argsMatcher[0]);
|
| 644 | // if variable has not shown up before, create an equivalence class for add it to
|
| 645 | // nameToEquivalence map, otherwise add its indices to corresponding equivalence class
|
| 646 | if (nameToEquivalenceIter == nameToEquivalence.end()) {
|
| 647 | nameToEquivalence.insert(
|
| 648 | {argsMatcher[0], Equivalence(*(relation->getAttrType(j)), argsMatcher[0],
|
| 649 | std::make_pair(idx, j))});
|
| 650 | } else {
|
| 651 | nameToEquivalenceIter->second.push_back(std::make_pair(idx, j));
|
| 652 | }
|
| 653 | continue;
|
| 654 | }
|
| 655 |
|
| 656 | RamDomain rd;
|
| 657 | switch (*(relation->getAttrType(j))) {
|
| 658 | case 's':
|
| 659 | if (!std::regex_match(rel.second[j], argsMatcher, symbolRegex)) {
|
| 660 | std::cout << argsMatcher.str(0) << " does not match type defined in relation"
|
| 661 | << std::endl;
|
| 662 | return;
|
| 663 | }
|
| 664 | rd = prog.getSymbolTable().encode(argsMatcher[1]);
|
| 665 | break;
|
| 666 | case 'f':
|
| 667 | if (!canBeParsedAsRamFloat(rel.second[j])) {
|
| 668 | std::cout << rel.second[j] << " does not match type defined in relation"
|
| 669 | << std::endl;
|
| 670 | return;
|
| 671 | }
|
| 672 | rd = ramBitCast(RamFloatFromString(rel.second[j]));
|
| 673 | break;
|
| 674 | case 'i':
|
| 675 | if (!canBeParsedAsRamSigned(rel.second[j])) {
|
| 676 | std::cout << rel.second[j] << " does not match type defined in relation"
|
| 677 | << std::endl;
|
| 678 | return;
|
| 679 | }
|
| 680 | rd = ramBitCast(RamSignedFromString(rel.second[j]));
|
| 681 | break;
|
| 682 | case 'u':
|
| 683 | if (!canBeParsedAsRamUnsigned(rel.second[j])) {
|
| 684 | std::cout << rel.second[j] << " does not match type defined in relation"
|
| 685 | << std::endl;
|
| 686 | return;
|
| 687 | }
|
| 688 | rd = ramBitCast(RamUnsignedFromString(rel.second[j]));
|
| 689 | break;
|
| 690 | default: continue;
|
| 691 | }
|
| 692 |
|
| 693 | constConstraints.push_back(std::make_pair(std::make_pair(idx, j), rd));
|
| 694 | if (!containVar) {
|
| 695 | constTuple.push_back(rd);
|
| 696 | }
|
| 697 | }
|
| 698 |
|
| 699 | // if tuple does not contain any variable, check if existence of the tuple
|
| 700 | if (!containVar) {
|
| 701 | bool tupleExist = containsTuple(relation, constTuple);
|
| 702 |
|
| 703 | // if relation contains this tuple, remove all related constraints
|
| 704 | if (tupleExist) {
|
| 705 | constConstraints.getConstraints().erase(constConstraints.getConstraints().end() -
|
| 706 | relation->getArity() +
|
| 707 | relation->getAuxiliaryArity(),
|
| 708 | constConstraints.getConstraints().end());
|
| 709 | // otherwise, there is no solution for given query
|
| 710 | } else {
|
| 711 | std::cout << "false." << std::endl;
|
| 712 | std::cout << "Tuple " << rel.first << "(";
|
| 713 | for (std::size_t l = 0; l < rel.second.size() - 1; ++l) {
|
| 714 | std::cout << rel.second[l] << ", ";
|
| 715 | }
|
| 716 | std::cout << rel.second.back() << ") does not exist" << std::endl;
|
| 717 | return;
|
| 718 | }
|
| 719 | } else {
|
| 720 | varRels.push_back(relation);
|
| 721 | ++idx;
|
| 722 | }
|
| 723 | }
|
| 724 |
|
| 725 | // if varRels size is 0, all given tuples only contain constant args and exist, no variable to
|
| 726 | // decode, Output true and return
|
| 727 | if (varRels.size() == 0) {
|
| 728 | std::cout << "true." << std::endl;
|
| 729 | return;
|
| 730 | }
|
| 731 |
|
| 732 | // find solution for parameterised query
|
| 733 | findQuerySolution(varRels, nameToEquivalence, constConstraints);
|
| 734 | }
|
| 735 |
|
| 736 | private:
|
| 737 | std::map<std::pair<std::string, std::size_t>, std::vector<std::string>> info;
|
| 738 | std::map<std::pair<std::string, std::size_t>, std::string> rules;
|
| 739 | std::vector<std::vector<RamDomain>> subproofs;
|
| 740 | std::vector<std::string> constraintList = {
|
| 741 | "=", "!=", "<", "<=", ">=", ">", "match", "contains", "not_match", "not_contains"};
|
| 742 |
|
| 743 | RamDomain lookupExisting(const std::string& symbol) {
|
| 744 | auto Res = symTable.findOrInsert(symbol);
|
| 745 | if (Res.second) {
|
| 746 | fatal("Error string did not exist before call to `SymbolTable::findOrInsert`: `%s`", symbol);
|
| 747 | }
|
| 748 | return Res.first;
|
| 749 | }
|
| 750 |
|
| 751 | std::tuple<int, int> findTuple(const std::string& relName, std::vector<RamDomain> tup) {
|
| 752 | auto rel = prog.getRelation(relName);
|
| 753 |
|
| 754 | if (rel == nullptr) {
|
| 755 | return std::make_tuple(-1, -1);
|
| 756 | }
|
| 757 |
|
| 758 | // find correct tuple
|
| 759 | for (auto& tuple : *rel) {
|
| 760 | bool match = true;
|
| 761 | std::vector<RamDomain> currentTuple;
|
| 762 |
|
| 763 | for (arity_type i = 0; i < rel->getPrimaryArity(); i++) {
|
| 764 | RamDomain n;
|
| 765 | if (*rel->getAttrType(i) == 's') {
|
| 766 | std::string s;
|
| 767 | tuple >> s;
|
| 768 | n = lookupExisting(s);
|
| 769 | } else if (*rel->getAttrType(i) == 'f') {
|
| 770 | RamFloat element;
|
| 771 | tuple >> element;
|
| 772 | n = ramBitCast(element);
|
| 773 | } else if (*rel->getAttrType(i) == 'u') {
|
| 774 | RamUnsigned element;
|
| 775 | tuple >> element;
|
| 776 | n = ramBitCast(element);
|
| 777 | } else {
|
| 778 | tuple >> n;
|
| 779 | }
|
| 780 |
|
| 781 | currentTuple.push_back(n);
|
| 782 |
|
| 783 | if (n != tup[i]) {
|
| 784 | match = false;
|
| 785 | break;
|
| 786 | }
|
| 787 | }
|
| 788 |
|
| 789 | if (match) {
|
| 790 | RamDomain ruleNum;
|
| 791 | tuple >> ruleNum;
|
| 792 |
|
| 793 | RamDomain levelNum;
|
| 794 | tuple >> levelNum;
|
| 795 |
|
| 796 | return std::make_tuple(ruleNum, levelNum);
|
| 797 | }
|
| 798 | }
|
| 799 |
|
| 800 | // if no tuple exists
|
| 801 | return std::make_tuple(-1, -1);
|
| 802 | }
|
| 803 |
|
| 804 | /*
|
| 805 | * Find solution for parameterised query satisfying constant constraints and equivalence constraints
|
| 806 | * @param varRels, reference to vector of relation of tuple contains at least one variable in its
|
| 807 | * arguments
|
| 808 | * @param nameToEquivalence, reference to variable name and corresponding equivalence class
|
| 809 | * @param constConstraints, reference to const constraints must be satisfied
|
| 810 | * */
|
| 811 | void findQuerySolution(const std::vector<Relation*>& varRels,
|
| 812 | const std::map<std::string, Equivalence>& nameToEquivalence,
|
| 813 | const ConstConstraint& constConstraints) {
|
| 814 | // vector of iterators for relations in varRels
|
| 815 | std::vector<Relation::iterator> varRelationIterators;
|
| 816 | for (auto relation : varRels) {
|
| 817 | varRelationIterators.push_back(relation->begin());
|
| 818 | }
|
| 819 |
|
| 820 | std::size_t solutionCount = 0;
|
| 821 | std::stringstream solution;
|
| 822 |
|
| 823 | // iterate through the vector of iterators to find solution
|
| 824 | while (true) {
|
| 825 | bool isSolution = true;
|
| 826 |
|
| 827 | // vector contains the tuples the iterators currently points to
|
| 828 | std::vector<tuple> element;
|
| 829 | for (auto it : varRelationIterators) {
|
| 830 | element.push_back(*it);
|
| 831 | }
|
| 832 | // check if tuple satisfies variable equivalence
|
| 833 | for (auto var : nameToEquivalence) {
|
| 834 | if (!var.second.verify(element)) {
|
| 835 | isSolution = false;
|
| 836 | break;
|
| 837 | }
|
| 838 | }
|
| 839 | if (isSolution) {
|
| 840 | // check if tuple satisfies constant constraints
|
| 841 | isSolution = constConstraints.verify(element);
|
| 842 | }
|
| 843 |
|
| 844 | if (isSolution) {
|
| 845 | // print previous solution (if any)
|
| 846 | if (solutionCount != 0) {
|
| 847 | std::cout << solution.str() << std::endl;
|
| 848 | }
|
| 849 | solution.str(std::string()); // reset solution and process
|
| 850 |
|
| 851 | std::size_t c = 0;
|
| 852 | for (auto&& var : nameToEquivalence) {
|
| 853 | auto idx = var.second.getFirstIdx();
|
| 854 | auto raw = element[idx.first][idx.second];
|
| 855 |
|
| 856 | solution << var.second.getSymbol() << " = ";
|
| 857 | switch (var.second.getType()) {
|
| 858 | case 'i': solution << ramBitCast<RamSigned>(raw); break;
|
| 859 | case 'f': solution << ramBitCast<RamFloat>(raw); break;
|
| 860 | case 'u': solution << ramBitCast<RamUnsigned>(raw); break;
|
| 861 | case 's': solution << prog.getSymbolTable().decode(raw); break;
|
| 862 | default: fatal("invalid type: `%c`", var.second.getType());
|
| 863 | }
|
| 864 |
|
| 865 | if (++c < nameToEquivalence.size()) {
|
| 866 | solution << ", ";
|
| 867 | }
|
| 868 | }
|
| 869 |
|
| 870 | solutionCount++;
|
| 871 | // query has more than one solution; query whether to find next solution or stop
|
| 872 | if (1 < solutionCount) {
|
| 873 | for (std::string input; getline(std::cin, input);) {
|
| 874 | if (input == ";") break; // print next solution?
|
| 875 | if (input == ".") return; // break from query?
|
| 876 |
|
| 877 | std::cout << "use ; to find next solution, use . to break from current query\n";
|
| 878 | }
|
| 879 | }
|
| 880 | }
|
| 881 |
|
| 882 | // increment the iterators
|
| 883 | std::size_t i = varRels.size() - 1;
|
| 884 | bool terminate = true;
|
| 885 | for (auto it = varRelationIterators.rbegin(); it != varRelationIterators.rend(); ++it) {
|
| 886 | if ((++(*it)) != varRels[i]->end()) {
|
| 887 | terminate = false;
|
| 888 | break;
|
| 889 | } else {
|
| 890 | (*it) = varRels[i]->begin();
|
| 891 | --i;
|
| 892 | }
|
| 893 | }
|
| 894 |
|
| 895 | if (terminate) {
|
| 896 | // if there is no solution, output false
|
| 897 | if (solutionCount == 0) {
|
| 898 | std::cout << "false." << std::endl;
|
| 899 | // otherwise print the last solution
|
| 900 | } else {
|
| 901 | std::cout << solution.str() << "." << std::endl;
|
| 902 | }
|
| 903 | break;
|
| 904 | }
|
| 905 | }
|
| 906 | }
|
| 907 |
|
| 908 | // check if constTuple exists in relation
|
| 909 | bool containsTuple(Relation* relation, const std::vector<RamDomain>& constTuple) {
|
| 910 | bool tupleExist = false;
|
| 911 | for (auto it = relation->begin(); it != relation->end(); ++it) {
|
| 912 | bool eq = true;
|
| 913 | for (std::size_t j = 0; j < constTuple.size(); ++j) {
|
| 914 | if (constTuple[j] != (*it)[j]) {
|
| 915 | eq = false;
|
| 916 | break;
|
| 917 | }
|
| 918 | }
|
| 919 | if (eq) {
|
| 920 | tupleExist = true;
|
| 921 | break;
|
| 922 | }
|
| 923 | }
|
| 924 | return tupleExist;
|
| 925 | }
|
| 926 | };
|
| 927 |
|
| 928 | } // end of namespace souffle
|