#include <Interpreters/BloomFilter.h>
#include <city.h>
#include <Columns/ColumnArray.h>
#include <Columns/ColumnNullable.h>
#include <Columns/ColumnLowCardinality.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypeLowCardinality.h>
#include <libdivide.h>


namespace DB
{
namespace ErrorCodes
{
    extern const int BAD_ARGUMENTS;
}

static constexpr UInt64 SEED_GEN_A = 845897321;
static constexpr UInt64 SEED_GEN_B = 217728422;

static constexpr UInt64 MAX_BLOOM_FILTER_SIZE = 1 << 30;


BloomFilterParameters::BloomFilterParameters(size_t filter_size_, size_t filter_hashes_, size_t seed_)
    : filter_size(filter_size_), filter_hashes(filter_hashes_), seed(seed_)
{
    if (filter_size == 0)
        throw Exception(ErrorCodes::BAD_ARGUMENTS, "The size of bloom filter cannot be zero");
    if (filter_hashes == 0)
        throw Exception(ErrorCodes::BAD_ARGUMENTS, "The number of hash functions for bloom filter cannot be zero");
    if (filter_size > MAX_BLOOM_FILTER_SIZE)
        throw Exception(ErrorCodes::BAD_ARGUMENTS, "The size of bloom filter cannot be more than {}", MAX_BLOOM_FILTER_SIZE);
}


BloomFilter::BloomFilter(const BloomFilterParameters & params)
    : BloomFilter(params.filter_size, params.filter_hashes, params.seed)
{
}

BloomFilter::BloomFilter(size_t size_, size_t hashes_, size_t seed_)
    : size(size_), hashes(hashes_), seed(seed_), words((size + sizeof(UnderType) - 1) / sizeof(UnderType)),
      modulus(8 * size_), divider(modulus), filter(words, 0)
{
    chassert(size != 0);
    chassert(hashes != 0);
}

void BloomFilter::resize(size_t size_)
{
    size = size_;
    words = ((size + sizeof(UnderType) - 1) / sizeof(UnderType));
    modulus = 8 * size;
    divider = libdivide::divider<size_t, libdivide::BRANCHFREE>(modulus);
    filter.resize(words);
}

bool BloomFilter::find(const char * data, size_t len)
{
    size_t hash1 = CityHash_v1_0_2::CityHash64WithSeed(data, len, seed);
    size_t hash2 = CityHash_v1_0_2::CityHash64WithSeed(data, len, SEED_GEN_A * seed + SEED_GEN_B);

    size_t acc = hash1;
    for (size_t i = 0; i < hashes; ++i)
    {
        /// It accumulates in the loop as follows:
        /// pos = (hash1 + hash2 * i + i * i) % (8 * size)
        size_t pos = fastMod(acc + i * i);
        if (!(filter[pos / word_bits] & (1ULL << (pos % word_bits))))
            return false;
        acc += hash2;
    }
    return true;
}

void BloomFilter::add(const char * data, size_t len)
{
    size_t hash1 = CityHash_v1_0_2::CityHash64WithSeed(data, len, seed);
    size_t hash2 = CityHash_v1_0_2::CityHash64WithSeed(data, len, SEED_GEN_A * seed + SEED_GEN_B);

    size_t acc = hash1;
    for (size_t i = 0; i < hashes; ++i)
    {
        /// It accumulates in the loop as follows:
        /// pos = (hash1 + hash2 * i + i * i) % (8 * size)
        size_t pos = fastMod(acc + i * i);
        filter[pos / word_bits] |= (1ULL << (pos % word_bits));
        acc += hash2;
    }
}

void BloomFilter::clear()
{
    filter.assign(words, 0);
}

bool BloomFilter::contains(const BloomFilter & bf)
{
    for (size_t i = 0; i < words; ++i)
    {
        if ((filter[i] & bf.filter[i]) != bf.filter[i])
            return false;
    }
    return true;
}

UInt64 BloomFilter::isEmpty() const
{
    for (size_t i = 0; i < words; ++i)
        if (filter[i] != 0)
            return false;
    return true;
}

size_t BloomFilter::memoryUsageBytes() const
{
    return filter.capacity() * sizeof(UnderType);
}

bool operator== (const BloomFilter & a, const BloomFilter & b)
{
    for (size_t i = 0; i < a.words; ++i)
        if (a.filter[i] != b.filter[i])
            return false;
    return true;
}

void BloomFilter::addHashWithSeed(const UInt64 & hash, const UInt64 & hash_seed)
{
    size_t pos = fastMod(CityHash_v1_0_2::Hash128to64(CityHash_v1_0_2::uint128(hash, hash_seed)));
    filter[pos / word_bits] |= (1ULL << (pos % word_bits));
}

bool BloomFilter::findHashWithSeed(const UInt64 & hash, const UInt64 & hash_seed)
{
    size_t pos = fastMod(CityHash_v1_0_2::Hash128to64(CityHash_v1_0_2::uint128(hash, hash_seed)));
    return bool(filter[pos / word_bits] & (1ULL << (pos % word_bits)));
}

DataTypePtr BloomFilter::getPrimitiveType(const DataTypePtr & data_type)
{
    if (const auto * array_type = typeid_cast<const DataTypeArray *>(data_type.get()))
    {
        if (!typeid_cast<const DataTypeArray *>(array_type->getNestedType().get()))
            return getPrimitiveType(array_type->getNestedType());
        throw Exception(ErrorCodes::BAD_ARGUMENTS, "Unexpected type {} of bloom filter index.", data_type->getName());
    }

    if (const auto * nullable_type = typeid_cast<const DataTypeNullable *>(data_type.get()))
        return getPrimitiveType(nullable_type->getNestedType());

    if (const auto * low_cardinality_type = typeid_cast<const DataTypeLowCardinality *>(data_type.get()))
        return getPrimitiveType(low_cardinality_type->getDictionaryType());

    return data_type;
}

ColumnPtr BloomFilter::getPrimitiveColumn(const ColumnPtr & column)
{
    if (const auto * array_col = typeid_cast<const ColumnArray *>(column.get()))
        return getPrimitiveColumn(array_col->getDataPtr());

    if (const auto * nullable_col = typeid_cast<const ColumnNullable *>(column.get()))
        return getPrimitiveColumn(nullable_col->getNestedColumnPtr());

    if (const auto * low_cardinality_col = typeid_cast<const ColumnLowCardinality *>(column.get()))
        return getPrimitiveColumn(low_cardinality_col->convertToFullColumnIfLowCardinality());

    return column;
}

}
