Source code for pygeodes.utils.query

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""This module contains the Query class"""
# -----------------------------------------------------------------------------
# Copyright (c) 2024, CNES
#
# REFERENCES:
# https://cnes.fr/
# -----------------------------------------------------------------------------

# stdlib imports -------------------------------------------------------
from typing import List
import warnings
from time import perf_counter
from tqdm import tqdm

# third-party imports -----------------------------------------------
from whoosh.filedb.filestore import RamStorage
from whoosh.fields import TEXT, Schema, STORED
from whoosh.qparser import MultifieldParser, OrGroup

# local imports ---------------------------------------------------
from pygeodes.utils.consts import (
    KNOWN_COLLECTION_REQUESTABLE_ARGUMENTS,
    KNOWN_ITEM_REQUESTABLE_ARGUMENTS,
    REQUESTABLE_ARGS_FILEPATH,
)
from pygeodes.utils.logger import logger
from pygeodes.utils.io import load_json


[docs]def get_requestable_args(): return load_json(REQUESTABLE_ARGS_FILEPATH)
[docs]class Argument:
[docs] def __init__(self, name: str): self.name = name self.queries = {}
[docs] def eq(self, value): self.queries["eq"] = value
[docs] def lte(self, value): self.queries["lte"] = value
[docs] def gte(self, value): self.queries["gte"] = value
[docs] def contains(self, value): self.queries["contains"] = value
[docs] def is_in(self, value: list): if not type(value) is list: raise Exception( f"is_in argument must be a list type, not {type(value)}" ) self.queries["in"] = value
[docs] def to_dict(self): return self.queries
[docs]class Query:
[docs] def __init__(self): self.args: List[Argument] = []
[docs] def add(self, argument: Argument): self.args.append(argument)
[docs] def check(self): names = [arg.name for arg in self.args] for name in names: if (count := names.count(name)) > 1: raise Exception( f"Argument {name} appears {count} times in your query" )
[docs] def check_for_collection(self): self.check() for arg in self.args: if arg.name not in KNOWN_COLLECTION_REQUESTABLE_ARGUMENTS: raise Exception( f"Argument {arg.name} cannot be queried in collections" )
[docs] def check_for_item(self): self.check() for arg in self.args: if arg.name not in KNOWN_ITEM_REQUESTABLE_ARGUMENTS: raise Exception( f"Argument {arg.name} cannot be queried in items" )
[docs] def to_dict(self): dico = {} for arg in self.args: dico[arg.name] = arg.to_dict()
[docs]def full_text_search_in_jsons( jsons: List[dict], search_term: str, key_field: str, fields_to_index: set, return_index: bool = False, ): from pygeodes.utils.formatting import ( get_from_dico_path, ) # to avoid circular import begin = perf_counter() # verifications for json_obj in jsons: assert ( get_from_dico_path(key_field, json_obj) is not None ), f"{key_field=} has value None" for field_to_index in fields_to_index: assert ( get_from_dico_path(field_to_index, json_obj) is not None ), f"{field_to_index=} has value None {json_obj=}" dico = { get_from_dico_path(key_field, json_obj): json_obj for json_obj in jsons } with warnings.catch_warnings(): # because sometimes whoosh raises a unexpected warning schema_components = {field: TEXT for field in fields_to_index} if key_field in fields_to_index: schema_components[key_field] = TEXT(stored=True) else: schema_components[key_field] = STORED schema = Schema(**schema_components) ix = RamStorage().create_index(schema) writer = ix.writer() for json_obj in tqdm(dico.values(), "Indexing"): to_add = { field: str(get_from_dico_path(field, json_obj)) for field in fields_to_index } to_add[key_field] = str(get_from_dico_path(key_field, json_obj)) writer.add_document(**to_add) writer.commit() query = MultifieldParser( [field for field in fields_to_index], ix.schema, group=OrGroup ) with ix.searcher() as searcher: query = query.parse(search_term) results = searcher.search(query, terms=True) print( f"Matched terms for {search_term=} : {results.matched_terms()}" ) ids = [result.get(key_field) for result in results] res = [dico.get(_id) for _id in ids] end = perf_counter() logger.debug( f"Proceeded to full-text search on {len(jsons)} objects in {end - begin} seconds" ) if return_index: return res, ix else: return res