import re
from typing import List, Tuple
from models.vulnerability import Vulnerability

class VulnerabilityValidator:
    version_pattern = r'\b\d+(\.\d+)*\b'

    @staticmethod
    def validate(search_terms: List[str], vulnerability: Vulnerability) -> bool:
        search_terms = [term.lower() for term in search_terms]
                    
        all_text = (
            vulnerability.id.lower() +
            " " + vulnerability.title.lower() +
            " " + vulnerability.url.lower().replace("-"," ") +
            " " + vulnerability.description.lower() +
            " " + " ".join(vulnerability.vulnerable_components).lower() +
            " " + " ".join(vulnerability.reference_urls).lower() +
            " " + " ".join(vulnerability.weaknesses).lower()
        )

        return all(term in all_text for term in search_terms)

    @staticmethod
    def extract_versions_from_vulnerability(vulnerability: Vulnerability) -> List[str]:
        versions = []

        versions += VulnerabilityValidator.extract_version(vulnerability.title)

        versions += VulnerabilityValidator.extract_version(vulnerability.description)

        versions += VulnerabilityValidator.extract_version(" ".join(vulnerability.tags))

        versions += VulnerabilityValidator.extract_version(" ".join(vulnerability.vulnerable_components))

        versions += VulnerabilityValidator.extract_version(" ".join(vulnerability.reference_urls))

        if vulnerability.weaknesses:
            versions += VulnerabilityValidator.extract_version(" ".join(vulnerability.weaknesses))

        return versions

    @staticmethod
    def extract_version(text: str) -> List[str]:
        return re.findall(VulnerabilityValidator.version_pattern, text)

    @staticmethod
    def normalize_version(version: str) -> List[int]:
        return list(map(int, version.split('.')))

    @staticmethod
    def is_version_in_range(version: str, min_version: str = None, max_version: str = None) -> bool:
        normalized_version = VulnerabilityValidator.normalize_version(version)
        normalized_min = VulnerabilityValidator.normalize_version(min_version) if min_version else None
        normalized_max = VulnerabilityValidator.normalize_version(max_version) if max_version else None

        if normalized_min and normalized_version < normalized_min:
            return False
        if normalized_max and normalized_version > normalized_max:
            return False

        return True

    @staticmethod
    def parse_version_terms(search_terms: List[str]) -> Tuple[str, str]:
        min_version = None
        max_version = None

        for term in search_terms:
            if ">=" in term:
                min_version = term.split(">=")[1].strip()
            elif ">" in term:
                min_version = term.split(">")[1].strip()
            elif "<=" in term:
                max_version = term.split("<=")[1].strip()
            elif "<" in term:
                max_version = term.split("<")[1].strip()

        return min_version, max_version

    @staticmethod
    def validate_with_versions(search_terms: List[str], vulnerability: Vulnerability) -> bool:
        non_version_terms = [term for term in search_terms if not any(op in term for op in [">=", ">", "<=", "<"])]
        if not VulnerabilityValidator.validate(non_version_terms, vulnerability):
            return False

        min_version, max_version = VulnerabilityValidator.parse_version_terms(search_terms)

        extracted_versions = VulnerabilityValidator.extract_versions_from_vulnerability(vulnerability)

        if min_version or max_version:
            if not extracted_versions:
                return False

            for version in extracted_versions:
                if min_version and not VulnerabilityValidator.is_version_in_range(version, min_version, None):
                    return False
                if max_version and not VulnerabilityValidator.is_version_in_range(version, None, max_version):
                    return False

        return True
