import datetime
import pymongo
import pytest
import urllib

from helpers.client import QueryRuntimeException
from helpers.cluster import ClickHouseCluster
from helpers.config_cluster import mongo_pass
from helpers.test_tools import TSV


@pytest.fixture(scope="module")
def started_cluster(request):
    try:
        cluster = ClickHouseCluster(__file__)
        cluster.add_instance(
            "node",
            with_mongo=True,
            main_configs=["configs/named_collections.xml"],
            user_configs=["configs/users.xml"],
        )
        cluster.start()
        yield cluster
    finally:
        cluster.shutdown()


def get_mongo_connection(started_cluster, secure=False, with_credentials=True):
    if secure:
        return pymongo.MongoClient(
            f"mongodb://root:{urllib.parse.quote_plus(mongo_pass)}@localhost:{started_cluster.mongo_secure_port}/?tls=true&tlsAllowInvalidCertificates=true&tlsAllowInvalidHostnames=true"
        )
    if with_credentials:
        return pymongo.MongoClient(
            f"mongodb://root:{urllib.parse.quote_plus(mongo_pass)}@localhost:{started_cluster.mongo_port}"
        )

    return pymongo.MongoClient(
        "mongodb://localhost:{}".format(started_cluster.mongo_no_cred_port)
    )


def test_simple_select(started_cluster):
    mongo_connection = get_mongo_connection(started_cluster)
    db = mongo_connection["test_simple_select"]
    try:
        db.command("createUser", "root", pwd=mongo_pass, roles=["readWrite"])
    except pymongo.errors.OperationFailure:
        pass
    simple_mongo_table = db["simple_table"]
    data = []
    for i in range(0, 100):
        data.append({"key": i, "data": hex(i * i)})
    simple_mongo_table.insert_many(data)

    node = started_cluster.instances["node"]
    assert (
        node.query(
            f"SELECT COUNT() FROM mongodb('mongo1:27017', 'test_simple_select', 'simple_table', 'root', '{mongo_pass}', structure='key UInt64, data String')"
        )
        == "100\n"
    )
    assert (
        node.query(
            f"SELECT sum(key) FROM mongodb('mongo1:27017', 'test_simple_select', 'simple_table', 'root', '{mongo_pass}', structure='key UInt64, data String')"
        )
        == str(sum(range(0, 100))) + "\n"
    )
    assert (
        node.query(
            f"SELECT sum(key) FROM mongodb('mongo1:27017', 'test_simple_select', 'simple_table', 'root', '{mongo_pass}', 'key UInt64, data String')"
        )
        == str(sum(range(0, 100))) + "\n"
    )

    assert (
        node.query(
            f"SELECT data FROM mongodb('mongo1:27017', 'test_simple_select', 'simple_table', 'root', '{mongo_pass}', structure='key UInt64, data String') WHERE key = 42"
        )
        == hex(42 * 42) + "\n"
    )
    simple_mongo_table.drop()


def test_simple_select_uri(started_cluster):
    mongo_connection = get_mongo_connection(started_cluster)
    db = mongo_connection["test_simple_select_uri"]
    try:
        db.command("createUser", "root", pwd=mongo_pass, roles=["readWrite"])
    except pymongo.errors.OperationFailure:
        pass
    simple_mongo_table = db["simple_table"]
    data = []
    for i in range(0, 100):
        data.append({"key": i, "data": hex(i * i)})
    simple_mongo_table.insert_many(data)

    node = started_cluster.instances["node"]
    assert (
        node.query(
            f"SELECT COUNT() FROM mongodb('mongodb://root:{urllib.parse.quote_plus(mongo_pass)}@mongo1:27017/test_simple_select_uri', 'simple_table', structure='key UInt64, data String')"
        )
        == "100\n"
    )
    assert (
        node.query(
            f"SELECT sum(key) FROM mongodb('mongodb://root:{urllib.parse.quote_plus(mongo_pass)}@mongo1:27017/test_simple_select_uri', 'simple_table', structure='key UInt64, data String')"
        )
        == str(sum(range(0, 100))) + "\n"
    )
    assert (
        node.query(
            f"SELECT sum(key) FROM mongodb('mongodb://root:{urllib.parse.quote_plus(mongo_pass)}@mongo1:27017/test_simple_select_uri', 'simple_table', 'key UInt64, data String')"
        )
        == str(sum(range(0, 100))) + "\n"
    )

    assert (
        node.query(
            f"SELECT data FROM mongodb('mongodb://root:{urllib.parse.quote_plus(mongo_pass)}@mongo1:27017/test_simple_select_uri', 'simple_table', structure='key UInt64, data String') WHERE key = 42"
        )
        == hex(42 * 42) + "\n"
    )
    simple_mongo_table.drop()


def test_complex_data_type(started_cluster):
    mongo_connection = get_mongo_connection(started_cluster)
    db = mongo_connection["test_complex_data_type"]
    try:
        db.command("createUser", "root", pwd=mongo_pass, roles=["readWrite"])
    except pymongo.errors.OperationFailure:
        pass
    incomplete_mongo_table = db["complex_table"]
    data = []
    for i in range(0, 100):
        data.append({"key": i, "data": hex(i * i), "dict": {"a": i, "b": str(i)}})
    incomplete_mongo_table.insert_many(data)

    node = started_cluster.instances["node"]

    assert (
        node.query(
            f"""
            SELECT COUNT()
            FROM mongodb('mongo1:27017',
                         'test_complex_data_type',
                         'complex_table',
                         'root',
                         '{mongo_pass}',
                         structure='key UInt64, data String, dict Map(UInt64, String)')"""
        )
        == "100\n"
    )
    assert (
        node.query(
            f"""
            SELECT sum(key)
            FROM mongodb('mongo1:27017',
                         'test_complex_data_type',
                         'complex_table',
                         'root',
                         '{mongo_pass}',
                         structure='key UInt64, data String, dict Map(UInt64, String)')"""
        )
        == str(sum(range(0, 100))) + "\n"
    )

    assert (
        node.query(
            f"""
            SELECT data
            FROM mongodb('mongo1:27017',
                         'test_complex_data_type',
                         'complex_table',
                         'root',
                         '{mongo_pass}',
                         structure='key UInt64, data String, dict Map(UInt64, String)')
            WHERE key = 42
            """
        )
        == hex(42 * 42) + "\n"
    )
    incomplete_mongo_table.drop()


def test_incorrect_data_type(started_cluster):
    mongo_connection = get_mongo_connection(started_cluster)
    db = mongo_connection["test_incorrect_data_type"]
    try:
        db.command("createUser", "root", pwd=mongo_pass, roles=["readWrite"])
    except pymongo.errors.OperationFailure:
        pass
    strange_mongo_table = db["strange_table"]
    data = []
    for i in range(0, 100):
        data.append({"key": i, "data": hex(i * i), "aaaa": "Hello"})
    strange_mongo_table.insert_many(data)

    node = started_cluster.instances["node"]

    with pytest.raises(QueryRuntimeException):
        node.query(
            f"SELECT aaaa FROM mongodb('mongo1:27017', 'test_incorrect_data_type', 'strange_table', 'root', '{mongo_pass}', structure='key UInt64, data String')"
        )

    strange_mongo_table.drop()


def test_secure_connection(started_cluster):
    mongo_connection = get_mongo_connection(started_cluster, secure=True)
    db = mongo_connection["test_secure_connection"]
    try:
        db.command("createUser", "root", pwd=mongo_pass, roles=["readWrite"])
    except pymongo.errors.OperationFailure:
        pass
    simple_mongo_table = db["simple_table"]
    data = []
    for i in range(0, 100):
        data.append({"key": i, "data": hex(i * i)})
    simple_mongo_table.insert_many(data)

    node = started_cluster.instances["node"]

    assert (
        node.query(
            f"""SELECT COUNT()
               FROM mongodb('mongo_secure:27017',
                            'test_secure_connection',
                            'simple_table',
                            'root',
                            '{mongo_pass}',
                            structure='key UInt64, data String',
                            options='tls=true&tlsAllowInvalidCertificates=true&tlsAllowInvalidHostnames=true')"""
        )
        == "100\n"
    )
    assert (
        node.query(
            f"""SELECT sum(key)
               FROM mongodb('mongo_secure:27017',
                            'test_secure_connection',
                            'simple_table',
                            'root',
                            '{mongo_pass}',
                            structure='key UInt64, data String',
                            options='tls=true&tlsAllowInvalidCertificates=true&tlsAllowInvalidHostnames=true')"""
        )
        == str(sum(range(0, 100))) + "\n"
    )
    assert (
        node.query(
            f"""SELECT sum(key)
               FROM mongodb('mongo_secure:27017',
                            'test_secure_connection',
                            'simple_table',
                            'root',
                            '{mongo_pass}',
                            'key UInt64, data String',
                            'tls=true&tlsAllowInvalidCertificates=true&tlsAllowInvalidHostnames=true')"""
        )
        == str(sum(range(0, 100))) + "\n"
    )

    assert (
        node.query(
            f"""SELECT data
               FROM mongodb('mongo_secure:27017',
                            'test_secure_connection',
                            'simple_table',
                            'root',
                            '{mongo_pass}',
                            'key UInt64, data String',
                            'tls=true&tlsAllowInvalidCertificates=true&tlsAllowInvalidHostnames=true')
               WHERE key = 42"""
        )
        == hex(42 * 42) + "\n"
    )
    simple_mongo_table.drop()


def test_secure_connection_with_validation(started_cluster):
    mongo_connection = get_mongo_connection(started_cluster, secure=True)
    db = mongo_connection["test_secure_connection_with_validation"]
    try:
        db.command("createUser", "root", pwd=mongo_pass, roles=["readWrite"])
    except pymongo.errors.OperationFailure:
        pass
    simple_mongo_table = db["simple_table"]
    data = []
    for i in range(0, 100):
        data.append({"key": i, "data": hex(i * i)})
    simple_mongo_table.insert_many(data)

    node = started_cluster.instances["node"]
    with pytest.raises(QueryRuntimeException):
        node.query(
            f"""SELECT COUNT() FROM mongodb('mongo_secure:27017',
                   'test_secure_connection_with_validation',
                   'simple_table',
                   'root',
                   '{mongo_pass}',
                   structure='key UInt64, data String',
                   options='tls=true')"""
        )

    simple_mongo_table.drop()


def test_secure_connection_uri(started_cluster):
    mongo_connection = get_mongo_connection(started_cluster, secure=True)
    db = mongo_connection["test_secure_connection_uri"]
    try:
        db.command("createUser", "root", pwd=mongo_pass, roles=["readWrite"])
    except pymongo.errors.OperationFailure:
        pass
    simple_mongo_table = db["simple_table"]
    data = []
    for i in range(0, 100):
        data.append({"key": i, "data": hex(i * i)})
    simple_mongo_table.insert_many(data)

    node = started_cluster.instances["node"]

    assert (
        node.query(
            f"""SELECT COUNT()
               FROM mongodb('mongodb://root:{urllib.parse.quote_plus(mongo_pass)}@mongo_secure:27017/test_secure_connection_uri?tls=true&tlsAllowInvalidCertificates=true&tlsAllowInvalidHostnames=true',
                            'simple_table',
                            'key UInt64, data String')"""
        )
        == "100\n"
    )
    assert (
        node.query(
            f"""SELECT sum(key)
               FROM mongodb('mongodb://root:{urllib.parse.quote_plus(mongo_pass)}@mongo_secure:27017/test_secure_connection_uri?tls=true&tlsAllowInvalidCertificates=true&tlsAllowInvalidHostnames=true',
                            'simple_table',
                            'key UInt64, data String')"""
        )
        == str(sum(range(0, 100))) + "\n"
    )
    assert (
        node.query(
            f"""SELECT sum(key)
               FROM mongodb('mongodb://root:{urllib.parse.quote_plus(mongo_pass)}@mongo_secure:27017/test_secure_connection_uri?tls=true&tlsAllowInvalidCertificates=true&tlsAllowInvalidHostnames=true',
                            'simple_table',
                            'key UInt64, data String')"""
        )
        == str(sum(range(0, 100))) + "\n"
    )

    assert (
        node.query(
            f"""SELECT data
               FROM mongodb('mongodb://root:{urllib.parse.quote_plus(mongo_pass)}@mongo_secure:27017/test_secure_connection_uri?tls=true&tlsAllowInvalidCertificates=true&tlsAllowInvalidHostnames=true',
                            'simple_table',
                            'key UInt64, data String')
               WHERE key = 42"""
        )
        == hex(42 * 42) + "\n"
    )
    simple_mongo_table.drop()


def test_no_credentials(started_cluster):
    mongo_connection = get_mongo_connection(started_cluster, with_credentials=False)
    db = mongo_connection["test_no_credentials"]
    simple_mongo_table = db["simple_table"]
    data = []
    for i in range(0, 100):
        data.append({"key": i, "data": hex(i * i)})
    simple_mongo_table.insert_many(data)

    node = started_cluster.instances["node"]
    assert (
        node.query(
            "SELECT count() FROM mongodb('mongo_no_cred:27017', 'test_no_credentials', 'simple_table', '', '', structure='key UInt64, data String')"
        )
        == "100\n"
    )
    simple_mongo_table.drop()


def test_auth_source(started_cluster):
    mongo_connection = get_mongo_connection(started_cluster, with_credentials=False)
    admin_db = mongo_connection["admin"]
    try:
        admin_db.command(
            "createUser",
            "root",
            pwd=mongo_pass,
            roles=[{"role": "userAdminAnyDatabase", "db": "admin"}, "readWriteAnyDatabase"],
        )
    except pymongo.errors.OperationFailure:
        pass
    simple_mongo_table = admin_db["simple_table"]
    data = []
    for i in range(0, 50):
        data.append({"key": i, "data": hex(i * i)})
    simple_mongo_table.insert_many(data)
    db = mongo_connection["test_auth_source"]
    simple_mongo_table = db["simple_table"]
    data = []
    for i in range(0, 100):
        data.append({"key": i, "data": hex(i * i)})
    simple_mongo_table.insert_many(data)

    node = started_cluster.instances["node"]
    with pytest.raises(QueryRuntimeException):
        node.query(
            f"SELECT count() FROM mongodb('mongo_no_cred:27017', 'test_auth_source', 'simple_table', 'root', '{mongo_pass}', structure='key UInt64, data String')"
        )

    assert (
        node.query(
            f"SELECT count() FROM mongodb('mongo_no_cred:27017', 'test_auth_source', 'simple_table', 'root', '{mongo_pass}', structure='key UInt64, data String', options='authSource=admin')"
        )
        == "100\n"
    )

    simple_mongo_table.drop()


def test_missing_columns(started_cluster):
    mongo_connection = get_mongo_connection(started_cluster)
    db = mongo_connection["test_missing_columns"]
    try:
        db.command("createUser", "root", pwd=mongo_pass, roles=["readWrite"])
    except pymongo.errors.OperationFailure:
        pass
    simple_mongo_table = db["simple_table"]
    data = []
    for i in range(0, 10):
        data.append({"key": i, "data": hex(i * i)})
    for i in range(0, 10):
        data.append({"key": i})
    simple_mongo_table.insert_many(data)

    node = started_cluster.instances["node"]
    result = node.query(
        f"SELECT count() FROM mongodb('mongo1:27017', 'test_missing_columns', 'simple_table', 'root', '{mongo_pass}', structure='key UInt64, data Nullable(String)') WHERE isNull(data)"
    )
    assert result == "10\n"
    simple_mongo_table.drop()


def test_oid(started_cluster):
    mongo_connection = get_mongo_connection(started_cluster)
    db = mongo_connection["test_oid"]
    db.command("dropAllUsersFromDatabase")
    db.command("createUser", "root", pwd="clickhouse", roles=["readWrite"])
    oid_mongo_table = db["oid_table"]
    inserted_result = oid_mongo_table.insert_many(
        [
            {"key": "oid1"},
            {"key": "oid2"},
            {"key": "oid3"},
        ]
    )
    ids = inserted_result.inserted_ids

    node = started_cluster.instances["node"]
    table_definitions = [
        "mongodb('mongo1:27017', 'test_oid', 'oid_table', 'root', 'clickhouse', '_id String, key String')",
        "mongodb('mongodb://root:clickhouse@mongo1:27017/test_oid', 'oid_table', '_id String, key String')",
    ]

    for table_definition in table_definitions:
        assert node.query(f"SELECT COUNT() FROM {table_definition}") == "3\n"

        assert node.query(f"SELECT _id FROM {table_definition} WHERE _id = '{str(ids[0])}'") == f"{str(ids[0])}\n"
        assert node.query(f"SELECT key FROM {table_definition} WHERE _id = '{str(ids[0])}'") == "oid1\n"
        assert (node.query(f"SELECT key FROM {table_definition} WHERE _id != '{str(ids[0])}' ORDER BY key") ==
                "oid2\noid3\n")
        assert (node.query(f"SELECT key FROM {table_definition} WHERE _id in ['{ids[0]}', '{ids[1]}'] ORDER BY key")
                == "oid1\noid2\n")
        assert (node.query(f"SELECT key FROM {table_definition} WHERE _id not in ['{ids[0]}', '{ids[1]}'] ORDER BY key")
                == "oid3\n")

        with pytest.raises(QueryRuntimeException):
            node.query(f"SELECT * FROM {table_definition} WHERE _id = 'not-oid'")
        with pytest.raises(QueryRuntimeException):
            node.query(f"SELECT * FROM {table_definition} WHERE _id != 'not-oid'")
        with pytest.raises(QueryRuntimeException):
            node.query(f"SELECT * FROM {table_definition} WHERE _id = 1234567")
        with pytest.raises(QueryRuntimeException):
            node.query(f"SELECT * FROM {table_definition} WHERE _id != 1234567")
        with pytest.raises(QueryRuntimeException):
            node.query(f"SELECT key FROM {table_definition} WHERE _id in ['{ids[0]}', 'not-oid']")
        with pytest.raises(QueryRuntimeException):
            node.query(f"SELECT key FROM {table_definition} WHERE _id not in ['{ids[0]}', 'not-oid']")
        with pytest.raises(QueryRuntimeException):
            node.query(f"SELECT key FROM {table_definition} WHERE _id in ['nope', 'not-oid']")
        with pytest.raises(QueryRuntimeException):
            node.query(f"SELECT key FROM {table_definition} WHERE _id not in ['nope', 'not-oid']")

    table_definition = ("mongodb('mongo1:27017', 'test_oid', 'oid_table', 'root', 'clickhouse', '_id String, key String', "
                        "'', 'key)")
    with pytest.raises(QueryRuntimeException):
        node.query(f"SELECT * FROM {table_definition} WHERE key = 'not-oid'")

    table_definition = ("mongodb('mongodb://root:clickhouse@mongo1:27017/test_oid', 'oid_table', '_id String, key String', "
                        "'key')")
    with pytest.raises(QueryRuntimeException):
        node.query(f"SELECT * FROM {table_definition} WHERE key = 'not-oid'")


    oid_mongo_table.drop()

def test_datetime_condition(started_cluster):
    mongo_connection = get_mongo_connection(started_cluster, with_credentials=False)
    db = mongo_connection["test_datetime_condition"]
    datetime_mongo_table = db["datetime_mongo_table"]
    data = []

    data.append({"key": 0, "timestamp": datetime.datetime(2025, 1, 11, 0, 0, 0)})
    data.append({"key": 1, "timestamp": datetime.datetime(2025, 1, 15, 0, 0, 0)})
    data.append({"key": 2, "timestamp": datetime.datetime(2025, 1, 20, 0, 0, 0)})

    datetime_mongo_table.insert_many(data)

    node = started_cluster.instances["node"]
    table_func = "mongodb('mongo_no_cred:27017', 'test_datetime_condition', 'datetime_mongo_table', '', '', structure='key UInt64, timestamp DateTime')"

    assert TSV(node.query(f"SELECT count(), any(toTypeName(timestamp)) FROM {table_func}")) == TSV("3\tDateTime\n")

    assert TSV(node.query(f"SELECT count() FROM {table_func} WHERE timestamp = '2025-01-11 00:00:00'")) == TSV("1\n")
    assert TSV(node.query(f"SELECT count() FROM {table_func} WHERE timestamp = toDateTime('2025-01-11 00:00:00')" )) == TSV("1\n")

    assert TSV(node.query(f"SELECT count() FROM {table_func} WHERE timestamp > '2025-01-11 00:00:00'")) == TSV("2\n")
    assert TSV(node.query(f"SELECT count() FROM {table_func} WHERE timestamp > toDateTime('2025-01-11 00:00:00')")) == TSV("2\n")

    assert TSV(node.query(f"SELECT count() FROM {table_func} WHERE timestamp > '2025-01-11 00:00:00' AND key = 2")) == TSV("1\n")
    assert TSV(node.query(f"SELECT count() FROM {table_func} WHERE timestamp > toDateTime('2025-01-11 00:00:00') AND key = 2")) == TSV("1\n")

    assert TSV(node.query(f"SELECT count() FROM {table_func} WHERE timestamp > '2025-01-11 00:00:00' OR key = 0")) == TSV("3\n")
    assert TSV(node.query(f"SELECT count() FROM {table_func} WHERE timestamp > toDateTime('2025-01-11 00:00:00') OR key = 0")) == TSV("3\n")

    assert TSV(node.query(f"SELECT count() FROM {table_func} WHERE timestamp IN '2025-01-11 00:00:00'")) == TSV("1\n")
    assert TSV(node.query(f"SELECT count() FROM {table_func} WHERE timestamp IN ('2025-01-11 00:00:00', '2025-01-15 00:00:00', '2025-01-30 00:00:00')")) == TSV("2\n")
    assert TSV(node.query(f"SELECT count() FROM {table_func} WHERE timestamp IN ['2025-01-11 00:00:00', '2025-01-15 00:00:00', '2025-01-30 00:00:00']")) == TSV("2\n")

    assert TSV(node.query(f"SELECT count() FROM {table_func} WHERE '2025-01-11 00:00:00' = timestamp")) == TSV("1\n")
    assert TSV(node.query(f"SELECT count() FROM {table_func} WHERE toDateTime('2025-01-11 00:00:00') = timestamp" )) == TSV("1\n")

    assert TSV(node.query(f"SELECT count() FROM {table_func} WHERE '2025-01-11 00:00:00' < timestamp")) == TSV("2\n")
    assert TSV(node.query(f"SELECT count() FROM {table_func} WHERE toDateTime('2025-01-11 00:00:00') < timestamp")) == TSV("2\n")

    datetime_mongo_table.drop()


def test_limit(started_cluster):
    mongo_connection = get_mongo_connection(started_cluster, with_credentials=False)
    db = mongo_connection["test_limit"]
    group_by_limit_mongo_table = db["group_by_limit_mongo_table"]
    data = []

    for i in range(0, 100):
        data.append({"key": i % 10, "value": i})

    group_by_limit_mongo_table.insert_many(data)

    node = started_cluster.instances["node"]
    table_func = "mongodb('mongo_no_cred:27017', 'test_limit', 'group_by_limit_mongo_table', '', '', structure='key UInt64, value UInt64')"

    assert TSV(node.query(f"SELECT sum(value) FROM {table_func} GROUP BY key ORDER BY key LIMIT 1 SETTINGS mongodb_throw_on_unsupported_query = 0")) == TSV("450\n")
    assert TSV(node.query(f"SELECT value FROM {table_func} ORDER BY value LIMIT 5 SETTINGS mongodb_throw_on_unsupported_query = 1")) == TSV("0\n1\n2\n3\n4\n")
    assert TSV(node.query(f"SELECT value FROM {table_func} ORDER BY value LIMIT 5 OFFSET 5 SETTINGS mongodb_throw_on_unsupported_query = 1")) == TSV("5\n6\n7\n8\n9\n")

    group_by_limit_mongo_table.drop()
