import asyncio

import typing

if typing.TYPE_CHECKING:
    from travelogues_extraction.dataextractors.abstract import AbstractDataExtractor

from lxml import etree as lxmletree
import httpx
import pandas as pd
import pytest

from travelogues_extraction.getrecords.session import RecordRetriever
from travelogues_extraction.getrecords.acnumber_extractor import extract_ac_from_series
from travelogues_extraction.dataextractors.dataextractors.index import IndexSetter

dummy_data = pd.read_excel('test/dummy_data/TravelogueD19_ALMAoutput_20200712.xlsx')

@pytest.mark.asyncio
async def test_record_retriever():
    reduced_series = dummy_data['Datensatznummer'][:25]
    session= httpx.AsyncClient()
    record_retriever = RecordRetriever(extract_ac_from_series(reduced_series), session=session)
    records = [record async for record in record_retriever.generate_records()]
    assert len(records) == 25
    assert all([record.__class__ is RecordRetriever.Record for record in records])
    assert all([record.ac_number.__class__ is str for record in records])
    assert all([record.ac_number.startswith('AC') for record in records])
    assert all([record.lxmlelement.__class__ is lxmletree._Element for record in records])
    assert all([record.lxmlelement.tag.endswith('record') for record in records])
    await session.aclose()


def test_main_loop():
    try:
        loop = asyncio.get_event_loop()
    except Exception:
        loop = asyncio.new_event_loop()

    # todo this can be nicer

    async def work(ac: str, extractors: typing.List['AbstractDataExtractor'], record_retriever: RecordRetriever):
        record = await record_retriever.get_record_from_ac_number(ac)
        for extractor in extractors:
            await extractor.write(record)

    async def dummy_controller(series: pd.Series):
        async with httpx.AsyncClient() as session:
            ac_numbers = extract_ac_from_series(series)
            record_retriever = RecordRetriever(ac_numbers, session=session)
            df = pd.DataFrame([])
            indexer = IndexSetter(df)
            tasks = [asyncio.create_task(work(ac, [indexer, ], record_retriever)) for ac in ac_numbers]
            for task in tasks:
                await task
            return df

    reduced_series = dummy_data['Datensatznummer'][:25]

    df = loop.run_until_complete(dummy_controller(reduced_series))

    assert df.shape == (25, 1)