Source code for redisvl.utils.rerank.cohere

import os
from typing import Any, Dict, List, Optional, Tuple, Union

from pydantic.v1 import PrivateAttr

from redisvl.utils.rerank.base import BaseReranker


[docs] class CohereReranker(BaseReranker): """ The CohereReranker class uses Cohere's API to rerank documents based on an input query. This reranker is designed to interact with Cohere's /rerank API, requiring an API key for authentication. The key can be provided directly in the `api_config` dictionary or through the `COHERE_API_KEY` environment variable. User must obtain an API key from Cohere's website (https://dashboard.cohere.com/). Additionally, the `cohere` python client must be installed with `pip install cohere`. .. code-block:: python from redisvl.utils.rerank import CohereReranker # set up the Cohere reranker with some configuration reranker = CohereReranker(rank_by=["content"], limit=2) # rerank raw search results based on user input/query results = reranker.rank( query="your input query text here", docs=[ {"content": "document 1"}, {"content": "document 2"}, {"content": "document 3"} ] ) """ _client: Any = PrivateAttr() _aclient: Any = PrivateAttr() def __init__( self, model: str = "rerank-english-v3.0", rank_by: Optional[List[str]] = None, limit: int = 5, return_score: bool = True, api_config: Optional[Dict] = None, ) -> None: """ Initialize the CohereReranker with specified model, ranking criteria, and API configuration. Parameters: model (str): The identifier for the Cohere model used for reranking. Defaults to 'rerank-english-v3.0'. rank_by (Optional[List[str]]): Optional list of keys specifying the attributes in the documents that should be considered for ranking. None means ranking will rely on the model's default behavior. limit (int): The maximum number of results to return after reranking. Must be a positive integer. return_score (bool): Whether to return scores alongside the reranked results. api_config (Optional[Dict], optional): Dictionary containing the API key. Defaults to None. Raises: ImportError: If the cohere library is not installed. ValueError: If the API key is not provided. """ super().__init__( model=model, rank_by=rank_by, limit=limit, return_score=return_score ) self._initialize_clients(api_config) def _initialize_clients(self, api_config: Optional[Dict]): """ Setup the Cohere clients using the provided API key or an environment variable. """ # Dynamic import of the cohere module try: from cohere import AsyncClient, Client except ImportError: raise ImportError( "Cohere reranker requires the cohere library. \ Please install with `pip install cohere`" ) # Fetch the API key from api_config or environment variable api_key = ( api_config.get("api_key") if api_config else os.getenv("COHERE_API_KEY") ) if not api_key: raise ValueError( "Cohere API key is required. " "Provide it in api_config or set the COHERE_API_KEY environment variable." ) self._client = Client(api_key=api_key, client_name="redisvl") self._aclient = AsyncClient(api_key=api_key, client_name="redisvl") def _preprocess( self, query: str, docs: Union[List[Dict[str, Any]], List[str]], **kwargs ): """ Prepare and validate reranking config based on provided input and optional overrides. """ limit = kwargs.get("limit", self.limit) return_score = kwargs.get("return_score", self.return_score) max_chunks_per_doc = kwargs.get("max_chunks_per_doc") rank_by = kwargs.get("rank_by", self.rank_by) or [] rank_by = [rank_by] if isinstance(rank_by, str) else rank_by reranker_kwargs = { "model": self.model, "query": query, "top_n": limit, "documents": docs, "max_chunks_per_doc": max_chunks_per_doc, } # if we are working with list of dicts if all(isinstance(doc, dict) for doc in docs): if rank_by: reranker_kwargs["rank_fields"] = rank_by else: raise ValueError( "If reranking dictionary-like docs, " "you must provide a list of rank_by fields" ) return reranker_kwargs, return_score @staticmethod def _postprocess( docs: Union[List[Dict[str, Any]], List[str]], rankings: List[Any], ) -> Tuple[List[Any], List[float]]: """ Post-process the initial list of documents to include ranking scores, if specified. """ reranked_docs, scores = [], [] for item in rankings.results: # type: ignore scores.append(item.relevance_score) reranked_docs.append(docs[item.index]) return reranked_docs, scores
[docs] def rank( self, query: str, docs: Union[List[Dict[str, Any]], List[str]], **kwargs ) -> Union[Tuple[List[Dict[str, Any]], List[float]], List[Dict[str, Any]]]: """ Rerank documents based on the provided query using the Cohere rerank API. This method processes the user's query and the provided documents to rerank them in a manner that is potentially more relevant to the query's context. Parameters: query (str): The user's search query. docs (Union[List[Dict[str, Any]], List[str]]): The list of documents to be ranked, either as dictionaries or strings. Returns: Union[Tuple[Union[List[Dict[str, Any]], List[str]], float], List[Dict[str, Any]]]: The reranked list of documents and optionally associated scores. """ reranker_kwargs, return_score = self._preprocess(query, docs, **kwargs) rankings = self._client.rerank(**reranker_kwargs) reranked_docs, scores = self._postprocess(docs, rankings) if return_score: return reranked_docs, scores return reranked_docs
[docs] async def arank( self, query: str, docs: Union[List[Dict[str, Any]], List[str]], **kwargs ) -> Union[Tuple[List[Dict[str, Any]], List[float]], List[Dict[str, Any]]]: """ Rerank documents based on the provided query using the Cohere rerank API. This method processes the user's query and the provided documents to rerank them in a manner that is potentially more relevant to the query's context. Parameters: query (str): The user's search query. docs (Union[List[Dict[str, Any]], List[str]]): The list of documents to be ranked, either as dictionaries or strings. Returns: Union[Tuple[Union[List[Dict[str, Any]], List[str]], float], List[Dict[str, Any]]]: The reranked list of documents and optionally associated scores. """ reranker_kwargs, return_score = self._preprocess(query, docs, **kwargs) rankings = await self._aclient.rerank(**reranker_kwargs) reranked_docs, scores = self._postprocess(docs, rankings) if return_score: return reranked_docs, scores return reranked_docs