#include <random>
#include <string>

#include <sdsl/nearest_neighbour_dictionary.hpp>
#include <sdsl/nn_dict_dynamic.hpp>
#include <sdsl/util.hpp>

#include <gtest/gtest.h>

namespace
{

std::string temp_dir;

// The fixture for testing class nn_dict_dynamic.
class nn_dict_dynamic_test : public ::testing::Test
{
protected:
    nn_dict_dynamic_test()
    {}

    virtual ~nn_dict_dynamic_test()
    {}

    virtual void SetUp()
    {}

    virtual void TearDown()
    {}
};

void compare_bv_and_nndd(sdsl::bit_vector const & bv, sdsl::nn_dict_dynamic const & nndd)
{
    sdsl::nearest_neighbour_dictionary<32> exp(bv);
    uint64_t first_one = exp.select(1);
    uint64_t last_one = exp.select(exp.rank(exp.size()));
    for (uint64_t i = 0; i < first_one; ++i)
    {
        ASSERT_EQ(exp.size(), nndd.prev(i));
        ASSERT_EQ(exp.next(i), nndd.next(i));
    }
    for (uint64_t i = first_one; i <= last_one; ++i)
    {
        ASSERT_EQ(exp.prev(i), nndd.prev(i));
        ASSERT_EQ(exp.next(i), nndd.next(i));
    }
    for (uint64_t i = last_one + 1; i < exp.size(); ++i)
    {
        ASSERT_EQ(exp.prev(i), nndd.prev(i));
        ASSERT_EQ(exp.size(), nndd.next(i));
    }
}

//! Test Constructors
TEST_F(nn_dict_dynamic_test, constructors)
{
    static_assert(sdsl::util::is_regular<sdsl::nn_dict_dynamic>::value, "Type is not regular");
    uint64_t testsize = 100000;
    sdsl::bit_vector bv(testsize, 0);
    sdsl::nn_dict_dynamic nndd(testsize);

    // Fill nndd
    std::mt19937_64 rng;
    std::uniform_int_distribution<uint64_t> distribution(0, testsize - 1);
    auto dice = bind(distribution, rng);
    for (uint64_t i = 0; i < testsize / 4; ++i)
    {
        uint64_t value = dice();
        if (bv[value])
        {
            bv[value] = 0;
            nndd[value] = 0;
        }
        else
        {
            bv[value] = 1;
            nndd[value] = 1;
        }
    }

    // Copy-constructor
    sdsl::nn_dict_dynamic nndd2(nndd);
    compare_bv_and_nndd(bv, nndd2);

    // Move-constructor
    sdsl::nn_dict_dynamic nndd3(std::move(nndd2));
    compare_bv_and_nndd(bv, nndd3);
    ASSERT_EQ((uint64_t)0, nndd2.size());

    // Copy-Assign
    sdsl::nn_dict_dynamic nndd4;
    nndd4 = nndd3;
    compare_bv_and_nndd(bv, nndd4);

    // Move-Assign
    sdsl::nn_dict_dynamic nndd5;
    nndd5 = std::move(nndd4);
    compare_bv_and_nndd(bv, nndd5);
    ASSERT_EQ((uint64_t)0, nndd4.size());
}

//! Test Operations next and prev
TEST_F(nn_dict_dynamic_test, next_and_prev)
{
    uint64_t testsize = 100000;
    sdsl::bit_vector bv(testsize, 0);
    sdsl::nn_dict_dynamic nndd(testsize);
    for (uint64_t ones = 1; ones < testsize; ones *= 2)
    {
        std::mt19937_64 rng(ones);
        std::uniform_int_distribution<uint64_t> distribution(0, testsize - 1);
        auto dice = bind(distribution, rng);
        for (uint64_t i = 0; i < ones; ++i)
        {
            uint64_t value = dice();
            if (bv[value])
            {
                bv[value] = 0;
                nndd[value] = 0;
            }
            else
            {
                bv[value] = 1;
                nndd[value] = 1;
            }
        }
        bv[testsize / 4] = 1;
        nndd[testsize / 4] = 1;
        bv[3 * testsize / 4] = 1;
        nndd[3 * testsize / 4] = 1;
        compare_bv_and_nndd(bv, nndd);
    }
}

//! Test Serialize and Load
TEST_F(nn_dict_dynamic_test, serialize_and_load)
{
    std::string file_name = temp_dir + "/nn_dict_dynamic";
    uint64_t testsize = 100000;
    sdsl::bit_vector bv(testsize, 0);
    {
        std::mt19937_64 rng;
        std::uniform_int_distribution<uint64_t> distribution(0, testsize - 1);
        auto dice = bind(distribution, rng);
        sdsl::nn_dict_dynamic nndd(testsize);
        for (uint64_t i = 0; i < testsize / 4; ++i)
        {
            uint64_t value = dice();
            if (bv[value])
            {
                bv[value] = 0;
                nndd[value] = 0;
            }
            else
            {
                bv[value] = 1;
                nndd[value] = 1;
            }
        }
        sdsl::store_to_file(nndd, file_name);
    }
    {
        sdsl::nn_dict_dynamic nndd(0);
        sdsl::load_from_file(nndd, file_name);
        compare_bv_and_nndd(bv, nndd);
    }
    sdsl::remove(file_name);
}

#if SDSL_HAS_CEREAL
template <typename in_archive_t, typename out_archive_t>
void do_serialisation(sdsl::nn_dict_dynamic const & l, std::string const & temp_file)
{
    {
        std::ofstream os{temp_file, std::ios::binary};
        out_archive_t oarchive{os};
        oarchive(l);
    }

    {
        sdsl::nn_dict_dynamic in_l(0);
        std::ifstream is{temp_file, std::ios::binary};
        in_archive_t iarchive{is};
        iarchive(in_l);
        EXPECT_EQ(l, in_l);
    }
}

TEST_F(nn_dict_dynamic_test, cereal)
{
    if (temp_dir != "@/")
    {
        std::string file_name = temp_dir + "/nn_dict_dynamic";
        uint64_t testsize = 100000;
        sdsl::nn_dict_dynamic nndd(testsize);
        sdsl::bit_vector bv(testsize, 0);
        {
            std::mt19937_64 rng;
            std::uniform_int_distribution<uint64_t> distribution(0, testsize - 1);
            auto dice = bind(distribution, rng);
            for (uint64_t i = 0; i < testsize / 4; ++i)
            {
                uint64_t value = dice();
                if (bv[value])
                {
                    bv[value] = 0;
                    nndd[value] = 0;
                }
                else
                {
                    bv[value] = 1;
                    nndd[value] = 1;
                }
            }
            sdsl::store_to_file(nndd, file_name);
        }

        do_serialisation<cereal::BinaryInputArchive, cereal::BinaryOutputArchive>(nndd, file_name);
        do_serialisation<cereal::PortableBinaryInputArchive, cereal::PortableBinaryOutputArchive>(nndd, file_name);
        do_serialisation<cereal::JSONInputArchive, cereal::JSONOutputArchive>(nndd, file_name);
        do_serialisation<cereal::XMLInputArchive, cereal::XMLOutputArchive>(nndd, file_name);

        sdsl::remove(file_name);
    }
}
#endif // SDSL_HAS_CEREAL

} // namespace

int main(int argc, char ** argv)
{
    ::testing::InitGoogleTest(&argc, argv);
    if (argc < 2)
    {
        // LCOV_EXCL_START
        std::cout << "Usage: " << argv[0] << " tmp_dir" << std::endl;
        return 1;
        // LCOV_EXCL_STOP
    }
    temp_dir = argv[1];
    return RUN_ALL_TESTS();
}
