Source code for redisvl.extensions.router.schema

from enum import Enum
from typing import Dict, List, Optional

from pydantic.v1 import BaseModel, Field, validator

from redisvl.extensions.constants import ROUTE_VECTOR_FIELD_NAME
from redisvl.schema import IndexSchema


[docs] class Route(BaseModel): """Model representing a routing path with associated metadata and thresholds.""" name: str """The name of the route.""" references: List[str] """List of reference phrases for the route.""" metadata: Dict[str, str] = Field(default={}) """Metadata associated with the route.""" distance_threshold: Optional[float] = Field(default=None) """Distance threshold for matching the route.""" @validator("name") def name_must_not_be_empty(cls, v): if not v or not v.strip(): raise ValueError("Route name must not be empty") return v @validator("references") def references_must_not_be_empty(cls, v): if not v: raise ValueError("References must not be empty") if any(not ref.strip() for ref in v): raise ValueError("All references must be non-empty strings") return v @validator("distance_threshold") def distance_threshold_must_be_positive(cls, v): if v is not None and v <= 0: raise ValueError("Route distance threshold must be greater than zero") return v
[docs] class RouteMatch(BaseModel): """Model representing a matched route with distance information.""" name: Optional[str] = None """The matched route name.""" distance: Optional[float] = Field(default=None) """The vector distance between the statement and the matched route."""
[docs] class DistanceAggregationMethod(Enum): """Enumeration for distance aggregation methods.""" avg = "avg" """Compute the average of the vector distances.""" min = "min" """Compute the minimum of the vector distances.""" sum = "sum" """Compute the sum of the vector distances."""
[docs] class RoutingConfig(BaseModel): """Configuration for routing behavior.""" distance_threshold: float = Field(default=0.5) """The threshold for semantic distance.""" max_k: int = Field(default=1) """The maximum number of top matches to return.""" aggregation_method: DistanceAggregationMethod = Field( default=DistanceAggregationMethod.avg ) """Aggregation method to use to classify queries.""" @validator("max_k") def max_k_must_be_positive(cls, v): if v <= 0: raise ValueError("max_k must be a positive integer") return v @validator("distance_threshold") def distance_threshold_must_be_valid(cls, v): if v <= 0 or v > 1: raise ValueError("distance_threshold must be between 0 and 1") return v
class SemanticRouterIndexSchema(IndexSchema): """Customized index schema for SemanticRouter.""" @classmethod def from_params(cls, name: str, vector_dims: int, dtype: str): """Create an index schema based on router name and vector dimensions. Args: name (str): The name of the index. vector_dims (int): The dimensions of the vectors. Returns: SemanticRouterIndexSchema: The constructed index schema. """ return cls( index={"name": name, "prefix": name}, # type: ignore fields=[ # type: ignore {"name": "route_name", "type": "tag"}, {"name": "reference", "type": "text"}, { "name": ROUTE_VECTOR_FIELD_NAME, "type": "vector", "attrs": { "algorithm": "flat", "dims": vector_dims, "distance_metric": "cosine", "datatype": dtype, }, }, ], )