// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

#include "parquet/arrow/fuzz_internal.h"

#include <cstdint>
#include <random>
#include <string_view>
#include <unordered_map>

#include "arrow/io/memory.h"
#include "arrow/table.h"
#include "arrow/util/base64.h"
#include "arrow/util/fuzz_internal.h"
#include "arrow/util/string.h"
#include "parquet/arrow/reader.h"
#include "parquet/bloom_filter.h"
#include "parquet/bloom_filter_reader.h"
#include "parquet/page_index.h"
#include "parquet/properties.h"

namespace parquet::fuzzing::internal {

using ::arrow::MemoryPool;
using ::arrow::Status;
using ::arrow::Table;
using ::arrow::util::SecureString;
using ::parquet::arrow::FileReader;

namespace {

constexpr std::string_view kInlineKeyPrefix = "inline:";

// See
// https://github.com/apache/parquet-testing/blob/master/data/README.md#encrypted-files
const std::unordered_map<std::string, SecureString> kTestingKeys = {
    {"kf", SecureString("0123456789012345")},
    {"kc1", SecureString("1234567890123450")},
    {"kc2", SecureString("1234567890123451")},
};

}  // namespace

EncryptionKey MakeEncryptionKey(int key_len) {
  // Keep the engine persistent to generate a different key every time
  static auto gen = []() { return std::default_random_engine(/*seed=*/42); }();

  std::uniform_int_distribution<unsigned int> chars_dist(0, 255);
  std::string key(key_len, '\x00');
  for (auto& c : key) {
    c = static_cast<uint8_t>(chars_dist(gen));
  }

  std::string key_metadata(kInlineKeyPrefix);
  key_metadata += ::arrow::util::base64_encode(key);

  return {SecureString(std::move(key)), std::move(key_metadata)};
}

class FuzzDecryptionKeyRetriever : public DecryptionKeyRetriever {
 public:
  SecureString GetKey(const std::string& key_id) override {
    // Is it one of the keys used in parquet-testing?
    auto it = kTestingKeys.find(key_id);
    if (it != kTestingKeys.end()) {
      return it->second;
    }
    // Is it a key generated by MakeEncryptionKey?
    if (key_id.starts_with(kInlineKeyPrefix)) {
      return SecureString(
          ::arrow::util::base64_decode(key_id.substr(kInlineKeyPrefix.length())));
    }
    throw ParquetException("Unknown fuzz encryption key_id");
  }
};

std::shared_ptr<DecryptionKeyRetriever> MakeKeyRetriever() {
  return std::make_shared<FuzzDecryptionKeyRetriever>();
}

namespace {

Status FuzzReadData(std::unique_ptr<FileReader> reader) {
  auto st = Status::OK();
  for (int i = 0; i < reader->num_row_groups(); ++i) {
    std::shared_ptr<Table> table;
    auto row_group_status = reader->ReadRowGroup(i, &table);
    if (row_group_status.ok()) {
      row_group_status &= table->ValidateFull();
    }
    st &= row_group_status;
  }
  return st;
}

template <typename DType>
Status FuzzReadTypedColumnIndex(const TypedColumnIndex<DType>* index) {
  index->min_values();
  index->max_values();
  return Status::OK();
}

Status FuzzReadColumnIndex(const ColumnIndex* index, const ColumnDescriptor* descr) {
  Status st;
  BEGIN_PARQUET_CATCH_EXCEPTIONS
  index->definition_level_histograms();
  index->repetition_level_histograms();
  index->null_pages();
  index->null_counts();
  index->non_null_page_indices();
  index->encoded_min_values();
  index->encoded_max_values();
  switch (descr->physical_type()) {
    case Type::BOOLEAN:
      st &= FuzzReadTypedColumnIndex(dynamic_cast<const BoolColumnIndex*>(index));
      break;
    case Type::INT32:
      st &= FuzzReadTypedColumnIndex(dynamic_cast<const Int32ColumnIndex*>(index));
      break;
    case Type::INT64:
      st &= FuzzReadTypedColumnIndex(dynamic_cast<const Int64ColumnIndex*>(index));
      break;
    case Type::INT96:
      st &= FuzzReadTypedColumnIndex(
          dynamic_cast<const TypedColumnIndex<Int96Type>*>(index));
      break;
    case Type::FLOAT:
      st &= FuzzReadTypedColumnIndex(dynamic_cast<const FloatColumnIndex*>(index));
      break;
    case Type::DOUBLE:
      st &= FuzzReadTypedColumnIndex(dynamic_cast<const DoubleColumnIndex*>(index));
      break;
    case Type::FIXED_LEN_BYTE_ARRAY:
      st &= FuzzReadTypedColumnIndex(dynamic_cast<const FLBAColumnIndex*>(index));
      break;
    case Type::BYTE_ARRAY:
      st &= FuzzReadTypedColumnIndex(dynamic_cast<const ByteArrayColumnIndex*>(index));
      break;
    case Type::UNDEFINED:
      break;
  }
  END_PARQUET_CATCH_EXCEPTIONS
  return st;
}

Status FuzzReadPageIndex(RowGroupPageIndexReader* reader, const SchemaDescriptor* schema,
                         int column) {
  Status st;
  BEGIN_PARQUET_CATCH_EXCEPTIONS
  auto offset_index = reader->GetOffsetIndex(column);
  if (offset_index) {
    offset_index->page_locations();
    offset_index->unencoded_byte_array_data_bytes();
  }
  auto col_index = reader->GetColumnIndex(column);
  if (col_index) {
    st &= FuzzReadColumnIndex(col_index.get(), schema->Column(column));
  }
  END_PARQUET_CATCH_EXCEPTIONS
  return st;
}

ReaderProperties MakeFuzzReaderProperties(MemoryPool* pool) {
  FileDecryptionProperties::Builder builder;
  builder.key_retriever(MakeKeyRetriever());
  builder.plaintext_files_allowed();
  // XXX Cannot set a AAD prefix as that would fail on files
  // that store their own ADD prefix.
  auto decryption_properties = builder.build();

  ReaderProperties properties(pool);
  properties.file_decryption_properties(decryption_properties);
  return properties;
}

}  // namespace

Status FuzzReader(const uint8_t* data, int64_t size) {
  Status st;

  auto buffer = std::make_shared<::arrow::Buffer>(data, size);
  auto file = std::make_shared<::arrow::io::BufferReader>(buffer);
  auto pool = ::arrow::internal::fuzzing_memory_pool();
  auto reader_properties = MakeFuzzReaderProperties(pool);

  std::default_random_engine rng(/*seed*/ 42);

  // Read Parquet file metadata only once, which will reduce iteration time slightly
  std::shared_ptr<FileMetaData> pq_md;
  BEGIN_PARQUET_CATCH_EXCEPTIONS {
    int num_row_groups, num_columns;
    auto pq_file_reader = ParquetFileReader::Open(file, reader_properties);
    {
      // Read some additional metadata (often lazy-decoded, such as statistics)
      pq_md = pq_file_reader->metadata();
      num_row_groups = pq_md->num_row_groups();
      num_columns = pq_md->num_columns();
      for (int i = 0; i < num_row_groups; ++i) {
        auto rg = pq_md->RowGroup(i);
        rg->sorting_columns();
        for (int j = 0; j < num_columns; ++j) {
          auto col = rg->ColumnChunk(j);
          col->encoded_statistics();
          col->statistics();
          col->geo_statistics();
          col->size_statistics();
          col->key_value_metadata();
          col->encodings();
          col->encoding_stats();
        }
      }
    }
    {
      // Read and decode bloom filters
      try {
        auto& bloom_reader = pq_file_reader->GetBloomFilterReader();
        std::uniform_int_distribution<uint64_t> hash_dist;
        for (int i = 0; i < num_row_groups; ++i) {
          auto bloom_rg = bloom_reader.RowGroup(i);
          for (int j = 0; j < num_columns; ++j) {
            std::unique_ptr<BloomFilter> bloom;
            bloom = bloom_rg->GetColumnBloomFilter(j);
            // If the column has a bloom filter, find a bunch of random hashes
            if (bloom != nullptr) {
              for (int k = 0; k < 100; ++k) {
                bloom->FindHash(hash_dist(rng));
              }
            }
          }
        }
      } catch (const ParquetException& exc) {
        // XXX we just want to ignore encrypted bloom filters and validate the
        // rest of the file; there is no better way of doing this until GH-46597
        // is done.
        // (also see GH-48334 for reading encrypted bloom filters)
        if (std::string_view(exc.what())
                .find("BloomFilter decryption is not yet supported") ==
            std::string_view::npos) {
          throw;
        }
      }
    }
    {
      // Read and decode page indexes
      auto index_reader = pq_file_reader->GetPageIndexReader();
      for (int i = 0; i < num_row_groups; ++i) {
        auto index_rg = index_reader->RowGroup(i);
        if (index_rg) {
          for (int j = 0; j < num_columns; ++j) {
            st &= FuzzReadPageIndex(index_rg.get(), pq_md->schema(), j);
          }
        }
      }
    }
  }
  END_PARQUET_CATCH_EXCEPTIONS

  // Note that very small batch sizes probably make fuzzing slower
  for (auto batch_size : std::vector<std::optional<int>>{std::nullopt, 13, 300}) {
    ArrowReaderProperties properties;
    if (batch_size) {
      properties.set_batch_size(batch_size.value());
    }

    std::unique_ptr<ParquetFileReader> pq_file_reader;
    BEGIN_PARQUET_CATCH_EXCEPTIONS
    pq_file_reader = ParquetFileReader::Open(file, reader_properties, pq_md);
    END_PARQUET_CATCH_EXCEPTIONS

    auto arrow_reader_result =
        FileReader::Make(pool, std::move(pq_file_reader), properties);
    RETURN_NOT_OK(arrow_reader_result.status());
    auto reader = std::move(*arrow_reader_result);
    st &= FuzzReadData(std::move(reader));
  }
  return st;
}

}  // namespace parquet::fuzzing::internal
