| 1 | /*
|
| 2 | * Souffle - A Datalog Compiler
|
| 3 | * Copyright (c) 2021, The Souffle Developers. All rights reserved
|
| 4 | * Licensed under the Universal Permissive License v 1.0 as shown at:
|
| 5 | * - https://opensource.org/licenses/UPL
|
| 6 | * - <souffle root>/licenses/SOUFFLE-UPL.txt
|
| 7 | */
|
| 8 |
|
| 9 | /************************************************************************
|
| 10 | *
|
| 11 | * @file ReadStream.h
|
| 12 | *
|
| 13 | ***********************************************************************/
|
| 14 |
|
| 15 | #pragma once
|
| 16 |
|
| 17 | #include "souffle/RamTypes.h"
|
| 18 | #include "souffle/RecordTable.h"
|
| 19 | #include "souffle/SymbolTable.h"
|
| 20 | #include "souffle/io/SerialisationStream.h"
|
| 21 | #include "souffle/utility/ContainerUtil.h"
|
| 22 | #include "souffle/utility/MiscUtil.h"
|
| 23 | #include "souffle/utility/StringUtil.h"
|
| 24 | #include "souffle/utility/json11.h"
|
| 25 | #include <cctype>
|
| 26 | #include <cstddef>
|
| 27 | #include <map>
|
| 28 | #include <memory>
|
| 29 | #include <ostream>
|
| 30 | #include <stdexcept>
|
| 31 | #include <string>
|
| 32 | #include <vector>
|
| 33 |
|
| 34 | namespace souffle {
|
| 35 |
|
| 36 | class ReadStream : public SerialisationStream<false> {
|
| 37 | protected:
|
| 38 | ReadStream(
|
| 39 | const std::map<std::string, std::string>& rwOperation, SymbolTable& symTab, RecordTable& recTab)
|
| 40 | : SerialisationStream(symTab, recTab, rwOperation) {}
|
| 41 |
|
| 42 | public:
|
| 43 | template <typename T>
|
| 44 | void readAll(T& relation) {
|
| 45 | while (const auto next = readNextTuple()) {
|
| 46 | const RamDomain* ramDomain = next.get();
|
| 47 | relation.insert(ramDomain);
|
| 48 | }
|
| 49 | }
|
| 50 |
|
| 51 | protected:
|
| 52 | /**
|
| 53 | * Read a record from a string.
|
| 54 | *
|
| 55 | * @param source - string containing a record
|
| 56 | * @param recordTypeName - record type.
|
| 57 | * @parem pos - start parsing from this position.
|
| 58 | * @param consumed - if not nullptr: number of characters read.
|
| 59 | *
|
| 60 | */
|
| 61 | RamDomain readRecord(const std::string& source, const std::string& recordTypeName, std::size_t pos = 0,
|
| 62 | std::size_t* charactersRead = nullptr) {
|
| 63 | const std::size_t initial_position = pos;
|
| 64 |
|
| 65 | // Check if record type information are present
|
| 66 | auto&& recordInfo = types["records"][recordTypeName];
|
| 67 | if (recordInfo.is_null()) {
|
| 68 | throw std::invalid_argument("Missing record type information: " + recordTypeName);
|
| 69 | }
|
| 70 |
|
| 71 | // Handle nil case
|
| 72 | consumeWhiteSpace(source, pos);
|
| 73 | if (source.substr(pos, 3) == "nil") {
|
| 74 | if (charactersRead != nullptr) {
|
| 75 | *charactersRead = 3;
|
| 76 | }
|
| 77 | return 0;
|
| 78 | }
|
| 79 |
|
| 80 | auto&& recordTypes = recordInfo["types"];
|
| 81 | const std::size_t recordArity = recordInfo["arity"].long_value();
|
| 82 |
|
| 83 | std::vector<RamDomain> recordValues(recordArity);
|
| 84 |
|
| 85 | consumeChar(source, '[', pos);
|
| 86 |
|
| 87 | for (std::size_t i = 0; i < recordArity; ++i) {
|
| 88 | const std::string& recordType = recordTypes[i].string_value();
|
| 89 | std::size_t consumed = 0;
|
| 90 |
|
| 91 | if (i > 0) {
|
| 92 | consumeChar(source, ',', pos);
|
| 93 | }
|
| 94 | consumeWhiteSpace(source, pos);
|
| 95 | switch (recordType[0]) {
|
| 96 | case 's': {
|
| 97 | recordValues[i] = symbolTable.encode(readSymbol(source, ",]", pos, &consumed));
|
| 98 | break;
|
| 99 | }
|
| 100 | case 'i': {
|
| 101 | recordValues[i] = RamSignedFromString(source.substr(pos), &consumed);
|
| 102 | break;
|
| 103 | }
|
| 104 | case 'u': {
|
| 105 | recordValues[i] = ramBitCast(RamUnsignedFromString(source.substr(pos), &consumed));
|
| 106 | break;
|
| 107 | }
|
| 108 | case 'f': {
|
| 109 | recordValues[i] = ramBitCast(RamFloatFromString(source.substr(pos), &consumed));
|
| 110 | break;
|
| 111 | }
|
| 112 | case 'r': {
|
| 113 | recordValues[i] = readRecord(source, recordType, pos, &consumed);
|
| 114 | break;
|
| 115 | }
|
| 116 | case '+': {
|
| 117 | recordValues[i] = readADT(source, recordType, pos, &consumed);
|
| 118 | break;
|
| 119 | }
|
| 120 | default: fatal("Invalid type attribute");
|
| 121 | }
|
| 122 | pos += consumed;
|
| 123 | }
|
| 124 | consumeChar(source, ']', pos);
|
| 125 |
|
| 126 | if (charactersRead != nullptr) {
|
| 127 | *charactersRead = pos - initial_position;
|
| 128 | }
|
| 129 |
|
| 130 | return recordTable.pack(recordValues.data(), recordValues.size());
|
| 131 | }
|
| 132 |
|
| 133 | RamDomain readADT(const std::string& source, const std::string& adtName, std::size_t pos = 0,
|
| 134 | std::size_t* charactersRead = nullptr) {
|
| 135 | const std::size_t initial_position = pos;
|
| 136 |
|
| 137 | // Branch will are encoded as one of the:
|
| 138 | // [branchIdx, [branchValues...]]
|
| 139 | // [branchIdx, branchValue]
|
| 140 | // branchIdx
|
| 141 | RamDomain branchIdx = -1;
|
| 142 |
|
| 143 | auto&& adtInfo = types["ADTs"][adtName];
|
| 144 | const auto& branches = adtInfo["branches"];
|
| 145 |
|
| 146 | if (adtInfo.is_null() || !branches.is_array()) {
|
| 147 | throw std::invalid_argument("Missing ADT information: " + adtName);
|
| 148 | }
|
| 149 |
|
| 150 | // Consume initial character
|
| 151 | consumeChar(source, '$', pos);
|
| 152 | std::string constructor = readQualifiedName(source, pos);
|
| 153 |
|
| 154 | json11::Json branchInfo = [&]() -> json11::Json {
|
| 155 | for (auto branch : branches.array_items()) {
|
| 156 | ++branchIdx;
|
| 157 |
|
| 158 | if (branch["name"].string_value() == constructor) {
|
| 159 | return branch;
|
| 160 | }
|
| 161 | }
|
| 162 |
|
| 163 | throw std::invalid_argument("Missing branch information: " + constructor);
|
| 164 | }();
|
| 165 |
|
| 166 | assert(branchInfo["types"].is_array());
|
| 167 | auto branchTypes = branchInfo["types"].array_items();
|
| 168 |
|
| 169 | // Handle a branch without arguments.
|
| 170 | if (branchTypes.empty()) {
|
| 171 | if (charactersRead != nullptr) {
|
| 172 | *charactersRead = pos - initial_position;
|
| 173 | }
|
| 174 |
|
| 175 | if (adtInfo["enum"].bool_value()) {
|
| 176 | return branchIdx;
|
| 177 | }
|
| 178 |
|
| 179 | RamDomain emptyArgs = recordTable.pack(toVector<RamDomain>().data(), 0);
|
| 180 | const RamDomain record[] = {branchIdx, emptyArgs};
|
| 181 | return recordTable.pack(record, 2);
|
| 182 | }
|
| 183 |
|
| 184 | consumeChar(source, '(', pos);
|
| 185 |
|
| 186 | std::vector<RamDomain> branchArgs(branchTypes.size());
|
| 187 |
|
| 188 | for (std::size_t i = 0; i < branchTypes.size(); ++i) {
|
| 189 | auto argType = branchTypes[i].string_value();
|
| 190 | assert(!argType.empty());
|
| 191 |
|
| 192 | std::size_t consumed = 0;
|
| 193 |
|
| 194 | if (i > 0) {
|
| 195 | consumeChar(source, ',', pos);
|
| 196 | }
|
| 197 | consumeWhiteSpace(source, pos);
|
| 198 |
|
| 199 | switch (argType[0]) {
|
| 200 | case 's': {
|
| 201 | branchArgs[i] = symbolTable.encode(readSymbol(source, ",)", pos, &consumed));
|
| 202 | break;
|
| 203 | }
|
| 204 | case 'i': {
|
| 205 | branchArgs[i] = RamSignedFromString(source.substr(pos), &consumed);
|
| 206 | break;
|
| 207 | }
|
| 208 | case 'u': {
|
| 209 | branchArgs[i] = ramBitCast(RamUnsignedFromString(source.substr(pos), &consumed));
|
| 210 | break;
|
| 211 | }
|
| 212 | case 'f': {
|
| 213 | branchArgs[i] = ramBitCast(RamFloatFromString(source.substr(pos), &consumed));
|
| 214 | break;
|
| 215 | }
|
| 216 | case 'r': {
|
| 217 | branchArgs[i] = readRecord(source, argType, pos, &consumed);
|
| 218 | break;
|
| 219 | }
|
| 220 | case '+': {
|
| 221 | branchArgs[i] = readADT(source, argType, pos, &consumed);
|
| 222 | break;
|
| 223 | }
|
| 224 | default: fatal("Invalid type attribute");
|
| 225 | }
|
| 226 | pos += consumed;
|
| 227 | }
|
| 228 |
|
| 229 | consumeChar(source, ')', pos);
|
| 230 |
|
| 231 | if (charactersRead != nullptr) {
|
| 232 | *charactersRead = pos - initial_position;
|
| 233 | }
|
| 234 |
|
| 235 | // Store branch either as [branch_id, [arguments]] or [branch_id, argument].
|
| 236 | RamDomain branchValue = [&]() -> RamDomain {
|
| 237 | if (branchArgs.size() != 1) {
|
| 238 | return recordTable.pack(branchArgs.data(), branchArgs.size());
|
| 239 | } else {
|
| 240 | return branchArgs[0];
|
| 241 | }
|
| 242 | }();
|
| 243 |
|
| 244 | RamDomain rec[2] = {branchIdx, branchValue};
|
| 245 | return recordTable.pack(rec, 2);
|
| 246 | }
|
| 247 |
|
| 248 | /**
|
| 249 | * Read the next alphanumeric + ('_', '?') sequence (corresponding to IDENT).
|
| 250 | * Consume preceding whitespace.
|
| 251 | * TODO (darth_tytus): use std::string_view?
|
| 252 | */
|
| 253 | std::string readQualifiedName(const std::string& source, std::size_t& pos) {
|
| 254 | consumeWhiteSpace(source, pos);
|
| 255 | if (pos >= source.length()) {
|
| 256 | throw std::invalid_argument("Unexpected end of input");
|
| 257 | }
|
| 258 |
|
| 259 | const std::size_t bgn = pos;
|
| 260 | while (pos < source.length()) {
|
| 261 | unsigned char ch = static_cast<unsigned char>(source[pos]);
|
| 262 | bool valid = std::isalnum(ch) || ch == '_' || ch == '?' || ch == '.';
|
| 263 | if (!valid) break;
|
| 264 | ++pos;
|
| 265 | }
|
| 266 |
|
| 267 | return source.substr(bgn, pos - bgn);
|
| 268 | }
|
| 269 |
|
| 270 | std::string readUntil(const std::string& source, const std::string& stopChars, const std::size_t pos,
|
| 271 | std::size_t* charactersRead) {
|
| 272 | std::size_t endOfSymbol = source.find_first_of(stopChars, pos);
|
| 273 |
|
| 274 | if (endOfSymbol == std::string::npos) {
|
| 275 | throw std::invalid_argument("Unexpected end of input");
|
| 276 | }
|
| 277 |
|
| 278 | *charactersRead = endOfSymbol - pos;
|
| 279 |
|
| 280 | return source.substr(pos, *charactersRead);
|
| 281 | }
|
| 282 |
|
| 283 | std::string readQuotedSymbol(const std::string& source, std::size_t pos, std::size_t* charactersRead) {
|
| 284 | const std::size_t start = pos;
|
| 285 | const std::size_t end = source.length();
|
| 286 |
|
| 287 | const char quoteMark = source[pos];
|
| 288 | ++pos;
|
| 289 |
|
| 290 | const std::size_t startOfSymbol = pos;
|
| 291 | std::size_t endOfSymbol = std::string::npos;
|
| 292 | bool hasEscaped = false;
|
| 293 |
|
| 294 | bool escaped = false;
|
| 295 | while (pos < end) {
|
| 296 | if (escaped) {
|
| 297 | hasEscaped = true;
|
| 298 | escaped = false;
|
| 299 | ++pos;
|
| 300 | continue;
|
| 301 | }
|
| 302 |
|
| 303 | const char c = source[pos];
|
| 304 | if (c == quoteMark) {
|
| 305 | endOfSymbol = pos;
|
| 306 | ++pos;
|
| 307 | break;
|
| 308 | }
|
| 309 | if (c == '\\') {
|
| 310 | escaped = true;
|
| 311 | }
|
| 312 | ++pos;
|
| 313 | }
|
| 314 |
|
| 315 | if (endOfSymbol == std::string::npos) {
|
| 316 | throw std::invalid_argument("Unexpected end of input");
|
| 317 | }
|
| 318 |
|
| 319 | *charactersRead = pos - start;
|
| 320 |
|
| 321 | std::size_t lengthOfSymbol = endOfSymbol - startOfSymbol;
|
| 322 |
|
| 323 | // fast handling of symbol without escape sequence
|
| 324 | if (!hasEscaped) {
|
| 325 | return source.substr(startOfSymbol, lengthOfSymbol);
|
| 326 | } else {
|
| 327 | // slow handling of symbol with escape sequence
|
| 328 | std::string symbol;
|
| 329 | symbol.reserve(lengthOfSymbol);
|
| 330 | bool escaped = false;
|
| 331 | for (std::size_t pos = startOfSymbol; pos < endOfSymbol; ++pos) {
|
| 332 | char ch = source[pos];
|
| 333 | if (escaped || ch != '\\') {
|
| 334 | symbol.push_back(ch);
|
| 335 | escaped = false;
|
| 336 | } else {
|
| 337 | escaped = true;
|
| 338 | }
|
| 339 | }
|
| 340 | return symbol;
|
| 341 | }
|
| 342 | }
|
| 343 |
|
| 344 | /**
|
| 345 | * Read the next symbol.
|
| 346 | * It is either a double-quoted symbol with backslash-escaped chars, or the
|
| 347 | * longuest sequence that do not contains any of the given stopChars.
|
| 348 | * */
|
| 349 | std::string readSymbol(const std::string& source, const std::string& stopChars, const std::size_t pos,
|
| 350 | std::size_t* charactersRead) {
|
| 351 | if (source[pos] == '"') {
|
| 352 | return readQuotedSymbol(source, pos, charactersRead);
|
| 353 | } else {
|
| 354 | return readUntil(source, stopChars, pos, charactersRead);
|
| 355 | }
|
| 356 | }
|
| 357 |
|
| 358 | /**
|
| 359 | * Read past given character, consuming any preceding whitespace.
|
| 360 | */
|
| 361 | void consumeChar(const std::string& str, char c, std::size_t& pos) {
|
| 362 | consumeWhiteSpace(str, pos);
|
| 363 | if (pos >= str.length()) {
|
| 364 | throw std::invalid_argument("Unexpected end of input");
|
| 365 | }
|
| 366 | if (str[pos] != c) {
|
| 367 | std::stringstream error;
|
| 368 | error << "Expected: \'" << c << "\', got: " << str[pos];
|
| 369 | throw std::invalid_argument(error.str());
|
| 370 | }
|
| 371 | ++pos;
|
| 372 | }
|
| 373 |
|
| 374 | /**
|
| 375 | * Advance position in the string until first non-whitespace character.
|
| 376 | */
|
| 377 | void consumeWhiteSpace(const std::string& str, std::size_t& pos) {
|
| 378 | while (pos < str.length() && std::isspace(static_cast<unsigned char>(str[pos]))) {
|
| 379 | ++pos;
|
| 380 | }
|
| 381 | }
|
| 382 |
|
| 383 | virtual Own<RamDomain[]> readNextTuple() = 0;
|
| 384 | };
|
| 385 |
|
| 386 | class ReadStreamFactory {
|
| 387 | public:
|
| 388 | virtual Own<ReadStream> getReader(
|
| 389 | const std::map<std::string, std::string>&, SymbolTable&, RecordTable&) = 0;
|
| 390 | virtual const std::string& getName() const = 0;
|
| 391 | virtual ~ReadStreamFactory() = default;
|
| 392 | };
|
| 393 |
|
| 394 | } /* namespace souffle */
|