from app.controller.schema.project_shema import FileAddModel
from app.database.repos.base_repo import SqlRepo
from app.database.sql.models import FileDb


class FileRepo(SqlRepo):
    def __init__(self, app_session):
        super().__init__(app_session)

    def get_all(self) -> list[FileDb]:
        return self._sql_session.query(FileDb).all()

    def get_by_id(self, file_id: int) -> FileDb:
        return self._sql_session.query(FileDb).filter(FileDb.file_id == file_id).first()

    def add(self, file: FileAddModel, project_id: int) -> FileDb:
        file_db = FileDb(
            file=file.file,
            tag=file.tag,
            project_id=project_id,
        )
        self._sql_session.add(file_db)
        self._sql_session.commit()
        self._sql_session.refresh(file_db)
        return file_db

    def add_audit(self, file: FileAddModel, audit_action_id: int) -> FileDb:
        file_db = FileDb(
            file=file.file,
            tag=file.tag,
            audit_action_id=audit_action_id,
        )
        self._sql_session.add(file_db)
        self._sql_session.commit()
        self._sql_session.refresh(file_db)
        return file_db

    def add_action(self, file: FileAddModel, action_id: int) -> FileDb:
        file_db = FileDb(
            file=file.file,
            tag=file.tag,
            action_id=action_id,
        )
        self._sql_session.add(file_db)
        self._sql_session.commit()
        self._sql_session.refresh(file_db)
        return file_db

    def get_by_project_id(self, project_id: int) -> list[FileDb]:
        return (
            self._sql_session.query(FileDb)
            .filter(FileDb.project_id == project_id)
            .all()
        )

    def delete_by_project_id(self, project_id: int) -> None:
        self._sql_session.query(FileDb).filter(FileDb.project_id == project_id).delete()
        self._sql_session.commit()
        return None

    def get_by_action_id(self, action_id: int) -> list[FileDb]:
        return (
            self._sql_session.query(FileDb).filter(FileDb.action_id == action_id).all()
        )

    def get_by_audit_action_id(self, audit_action_id: int, tag: str) -> list[FileDb]:
        return (
            self._sql_session.query(FileDb)
            .filter(FileDb.audit_action_id == audit_action_id)
            .filter(FileDb.tag == tag)
            .all()
        )
