Source code for redisvl.query.query

from typing import Any, Dict, List, Optional, Union

from redis.commands.search.query import Query as RedisQuery

from redisvl.query.filter import FilterExpression
from redisvl.redis.utils import array_to_buffer


class BaseQuery(RedisQuery):
    """Base query class used to subclass many query types."""

    _params: Dict[str, Any] = {}
    _filter_expression: Union[str, FilterExpression] = FilterExpression("*")

    def __init__(self, query_string: str = "*"):
        """
        Initialize the BaseQuery class.

        Args:
            query_string (str, optional): The query string to use. Defaults to '*'.
        """
        super().__init__(query_string)

    def __str__(self) -> str:
        """Return the string representation of the query."""
        return " ".join([str(x) for x in self.get_args()])

    def _build_query_string(self) -> str:
        """Build the full Redis query string."""
        raise NotImplementedError("Must be implemented by subclasses")

    def set_filter(
        self, filter_expression: Optional[Union[str, FilterExpression]] = None
    ):
        """Set the filter expression for the query.

        Args:
            filter_expression (Optional[Union[str, FilterExpression]], optional): The filter
                expression or query string to use on the query.

        Raises:
            TypeError: If filter_expression is not a valid FilterExpression or string.
        """
        if filter_expression is None:
            # Default filter to match everything
            self._filter_expression = FilterExpression("*")
        elif isinstance(filter_expression, (FilterExpression, str)):
            self._filter_expression = filter_expression
        else:
            raise TypeError(
                "filter_expression must be of type FilterExpression or string or None"
            )

        # Reset the query string
        self._query_string = self._build_query_string()

    @property
    def filter(self) -> Union[str, FilterExpression]:
        """The filter expression for the query."""
        return self._filter_expression

    @property
    def query(self) -> "BaseQuery":
        """Return self as the query object."""
        return self

    @property
    def params(self) -> Dict[str, Any]:
        """Return the query parameters."""
        return self._params


[docs] class FilterQuery(BaseQuery): def __init__( self, filter_expression: Optional[Union[str, FilterExpression]] = None, return_fields: Optional[List[str]] = None, num_results: int = 10, dialect: int = 2, sort_by: Optional[str] = None, in_order: bool = False, params: Optional[Dict[str, Any]] = None, ): """A query for running a filtered search with a filter expression. Args: filter_expression (Optional[Union[str, FilterExpression]]): The optional filter expression to query with. Defaults to '*'. return_fields (Optional[List[str]], optional): The fields to return. num_results (Optional[int], optional): The number of results to return. Defaults to 10. dialect (int, optional): The query dialect. Defaults to 2. sort_by (Optional[str], optional): The field to order the results by. Defaults to None. in_order (bool, optional): Requires the terms in the field to have the same order as the terms in the query filter. Defaults to False. params (Optional[Dict[str, Any]], optional): The parameters for the query. Defaults to None. Raises: TypeError: If filter_expression is not of type redisvl.query.FilterExpression """ self.set_filter(filter_expression) if params: self._params = params self._num_results = num_results # Initialize the base query with the full query string constructed from the filter expression query_string = self._build_query_string() super().__init__(query_string) # Handle query settings if return_fields: self.return_fields(*return_fields) self.paging(0, self._num_results).dialect(dialect) if sort_by: self.sort_by(sort_by) if in_order: self.in_order() def _build_query_string(self) -> str: """Build the full query string based on the filter and other components.""" if isinstance(self._filter_expression, FilterExpression): return str(self._filter_expression) return self._filter_expression
[docs] class CountQuery(BaseQuery): def __init__( self, filter_expression: Optional[Union[str, FilterExpression]] = None, dialect: int = 2, params: Optional[Dict[str, Any]] = None, ): """A query for a simple count operation provided some filter expression. Args: filter_expression (Optional[Union[str, FilterExpression]]): The filter expression to query with. Defaults to None. params (Optional[Dict[str, Any]], optional): The parameters for the query. Defaults to None. Raises: TypeError: If filter_expression is not of type redisvl.query.FilterExpression .. code-block:: python from redisvl.query import CountQuery from redisvl.query.filter import Tag t = Tag("brand") == "Nike" query = CountQuery(filter_expression=t) count = index.query(query) """ self.set_filter(filter_expression) if params: self._params = params # Initialize the base query with the full query string constructed from the filter expression query_string = self._build_query_string() super().__init__(query_string) # Query specific modifications self.no_content().paging(0, 0).dialect(dialect) def _build_query_string(self) -> str: """Build the full query string based on the filter and other components.""" if isinstance(self._filter_expression, FilterExpression): return str(self._filter_expression) return self._filter_expression
class BaseVectorQuery: DISTANCE_ID: str = "vector_distance" VECTOR_PARAM: str = "vector"
[docs] class VectorQuery(BaseVectorQuery, BaseQuery): def __init__( self, vector: Union[List[float], bytes], vector_field_name: str, return_fields: Optional[List[str]] = None, filter_expression: Optional[Union[str, FilterExpression]] = None, dtype: str = "float32", num_results: int = 10, return_score: bool = True, dialect: int = 2, sort_by: Optional[str] = None, in_order: bool = False, ): """A query for running a vector search along with an optional filter expression. Args: vector (List[float]): The vector to perform the vector search with. vector_field_name (str): The name of the vector field to search against in the database. return_fields (List[str]): The declared fields to return with search results. filter_expression (Union[str, FilterExpression], optional): A filter to apply along with the vector search. Defaults to None. dtype (str, optional): The dtype of the vector. Defaults to "float32". num_results (int, optional): The top k results to return from the vector search. Defaults to 10. return_score (bool, optional): Whether to return the vector distance. Defaults to True. dialect (int, optional): The RediSearch query dialect. Defaults to 2. sort_by (Optional[str]): The field to order the results by. Defaults to None. Results will be ordered by vector distance. in_order (bool): Requires the terms in the field to have the same order as the terms in the query filter, regardless of the offsets between them. Defaults to False. Raises: TypeError: If filter_expression is not of type redisvl.query.FilterExpression Note: Learn more about vector queries in Redis: https://redis.io/docs/interact/search-and-query/search/vectors/#knn-search """ self._vector = vector self._vector_field_name = vector_field_name self._dtype = dtype self._num_results = num_results self.set_filter(filter_expression) query_string = self._build_query_string() super().__init__(query_string) # Handle query modifiers if return_fields: self.return_fields(*return_fields) self.paging(0, self._num_results).dialect(dialect) if return_score: self.return_fields(self.DISTANCE_ID) if sort_by: self.sort_by(sort_by) else: self.sort_by(self.DISTANCE_ID) if in_order: self.in_order() def _build_query_string(self) -> str: """Build the full query string for vector search with optional filtering.""" filter_expression = self._filter_expression if isinstance(filter_expression, FilterExpression): filter_expression = str(filter_expression) return f"{filter_expression}=>[KNN {self._num_results} @{self._vector_field_name} ${self.VECTOR_PARAM} AS {self.DISTANCE_ID}]" @property def params(self) -> Dict[str, Any]: """Return the parameters for the query. Returns: Dict[str, Any]: The parameters for the query. """ if isinstance(self._vector, bytes): vector = self._vector else: vector = array_to_buffer(self._vector, dtype=self._dtype) return {self.VECTOR_PARAM: vector}
[docs] class VectorRangeQuery(BaseVectorQuery, BaseQuery): DISTANCE_THRESHOLD_PARAM: str = "distance_threshold" def __init__( self, vector: Union[List[float], bytes], vector_field_name: str, return_fields: Optional[List[str]] = None, filter_expression: Optional[Union[str, FilterExpression]] = None, dtype: str = "float32", distance_threshold: float = 0.2, num_results: int = 10, return_score: bool = True, dialect: int = 2, sort_by: Optional[str] = None, in_order: bool = False, ): """A query for running a filtered vector search based on semantic distance threshold. Args: vector (List[float]): The vector to perform the range query with. vector_field_name (str): The name of the vector field to search against in the database. return_fields (List[str]): The declared fields to return with search results. filter_expression (Union[str, FilterExpression], optional): A filter to apply along with the range query. Defaults to None. dtype (str, optional): The dtype of the vector. Defaults to "float32". distance_threshold (str, float): The threshold for vector distance. A smaller threshold indicates a stricter semantic search. Defaults to 0.2. num_results (int): The MAX number of results to return. Defaults to 10. return_score (bool, optional): Whether to return the vector distance. Defaults to True. dialect (int, optional): The RediSearch query dialect. Defaults to 2. sort_by (Optional[str]): The field to order the results by. Defaults to None. Results will be ordered by vector distance. in_order (bool): Requires the terms in the field to have the same order as the terms in the query filter, regardless of the offsets between them. Defaults to False. Raises: TypeError: If filter_expression is not of type redisvl.query.FilterExpression Note: Learn more about vector range queries: https://redis.io/docs/interact/search-and-query/search/vectors/#range-query """ self._vector = vector self._vector_field_name = vector_field_name self._dtype = dtype self._num_results = num_results self.set_distance_threshold(distance_threshold) self.set_filter(filter_expression) query_string = self._build_query_string() super().__init__(query_string) # Handle query modifiers if return_fields: self.return_fields(*return_fields) self.paging(0, self._num_results).dialect(dialect) if return_score: self.return_fields(self.DISTANCE_ID) if sort_by: self.sort_by(sort_by) else: self.sort_by(self.DISTANCE_ID) if in_order: self.in_order() def _build_query_string(self) -> str: """Build the full query string for vector range queries with optional filtering""" base_query = f"@{self._vector_field_name}:[VECTOR_RANGE ${self.DISTANCE_THRESHOLD_PARAM} ${self.VECTOR_PARAM}]" filter_expression = self._filter_expression if isinstance(filter_expression, FilterExpression): filter_expression = str(filter_expression) if filter_expression == "*": return f"{base_query}=>{{$yield_distance_as: {self.DISTANCE_ID}}}" return f"({base_query}=>{{$yield_distance_as: {self.DISTANCE_ID}}} {filter_expression})"
[docs] def set_distance_threshold(self, distance_threshold: float): """Set the distance threshold for the query. Args: distance_threshold (float): vector distance """ if not isinstance(distance_threshold, (float, int)): raise TypeError("distance_threshold must be of type int or float") self._distance_threshold = distance_threshold
@property def distance_threshold(self) -> float: """Return the distance threshold for the query. Returns: float: The distance threshold for the query. """ return self._distance_threshold @property def params(self) -> Dict[str, Any]: """Return the parameters for the query. Returns: Dict[str, Any]: The parameters for the query. """ if isinstance(self._vector, bytes): vector_param = self._vector else: vector_param = array_to_buffer(self._vector, dtype=self._dtype) return { self.VECTOR_PARAM: vector_param, self.DISTANCE_THRESHOLD_PARAM: self._distance_threshold, }
class RangeQuery(VectorRangeQuery): # keep for backwards compatibility pass