from typing import List

from app.controller.schema.species_schema import SpecieInfo
from app.database.repos.base_repo import SqlRepo
from app.database.sql.models import SpeciesInfoDb, SpeciesDb


class SpeciesRepo(SqlRepo):
    def __init__(self, app_session):
        super().__init__(app_session)

    def get_species_info(self) -> dict[str, SpecieInfo]:
        species_info = {}
        for specie in self._sql_session.query(SpeciesInfoDb).all():
            species_info[str(specie.key)] = SpecieInfo(
                key=specie.key,
                name=specie.name,
                year1=specie.year1,
                year2=specie.year2,
                year3=specie.year3,
                year4=specie.year4,
                year5=specie.year5,
                year6=specie.year6,
                year7=specie.year7,
                year8=specie.year8,
                year9=specie.year9,
                year10=specie.year10,
                ndvi_index=specie.ndvi_index,
            )
        return species_info

    def update_specie_info(self, specie_key, specie_info: SpecieInfo):
        specie_db = self._sql_session.query(SpeciesInfoDb).filter(SpeciesInfoDb.key == specie_key).first()
        if specie_db:
            specie_db.year1 = specie_info.year1
            specie_db.year2 = specie_info.year2
            specie_db.year3 = specie_info.year3
            specie_db.year4 = specie_info.year4
            specie_db.year5 = specie_info.year5
            specie_db.year6 = specie_info.year6
            specie_db.year7 = specie_info.year7
            specie_db.year8 = specie_info.year8
            specie_db.year9 = specie_info.year9
            specie_db.year10 = specie_info.year10
            specie_db.ndvi_index = specie_info.ndvi_index
            self._sql_session.commit()
        else:
            new_specie = SpeciesInfoDb(
                key=specie_info.key,
                name=specie_info.name,
                year1=specie_info.year1,
                year2=specie_info.year2,
                year3=specie_info.year3,
                year4=specie_info.year4,
                year5=specie_info.year5,
                year6=specie_info.year6,
                year7=specie_info.year7,
                year8=specie_info.year8,
                year9=specie_info.year9,
                year10=specie_info.year10,
                ndvi_index=specie_info.ndvi_index,
            )
            self._sql_session.add(new_specie)
            self._sql_session.commit()
        return specie_db

    def get_species_for_action(self, action_id: int) -> List[SpeciesDb]:
        return self._sql_session.query(SpeciesDb).filter(SpeciesDb.action_id == action_id).all()
