-
Notifications
You must be signed in to change notification settings - Fork 580
Docstring modules. #877
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Docstring modules. #877
Conversation
Please make sure all the checkboxes are checked:
|
Important Review skippedAuto reviews are disabled on base/target branches other than the default branch. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
LLM-Generated Cypher Query Injection Vulnerability @@ -1,6 +1,7 @@
from typing import Any, Optional
import logging
+import re
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.graph.networkx.adapter import NetworkXAdapter
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.infrastructure.llm.prompts import render_prompt
@@ -9,9 +10,84 @@
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
logger = logging.getLogger("NaturalLanguageRetriever")
+# Allowed Cypher clauses for READ-ONLY queries
+# ONLY allow SELECT-like queries (MATCH/OPTIONAL MATCH + RETURN, WHERE, LIMIT, ORDER BY etc.)
+# Prohibit clauses that change data: CREATE, MERGE, DELETE, SET, REMOVE, CALL (except permitted read procedures)
+_ALLOWED_READONLY_CLAUSES = [
+ "MATCH",
+ "OPTIONAL MATCH",
+ "RETURN",
+ "WHERE",
+ "ORDER BY",
+ "LIMIT",
+ "SKIP",
+ "UNWIND",
+ "WITH",
+ "DISTINCT",
+]
+# Forbidden (modifying/destructive) Cypher keywords
+_FORBIDDEN_CLAUSES = [
+ "CREATE",
+ "MERGE",
+ "DELETE",
+ "DETACH",
+ "SET",
+ "REMOVE",
+ "DROP",
+ "CALL",
+ "LOAD CSV",
+ "START",
+ "FOREACH",
+ "APOC",
+ "PROFILE",
+ "EXPLAIN",
+ "GRANT",
+ "DENY",
+ "REVOKE",
+]
+def is_cypher_query_safe(query: str) -> bool:
+ """
+ Returns True if the provided Cypher query is 'safe' (read-only); False otherwise.
+
+ We enforce that it does not contain any forbidden clauses, and only permitted starting clauses.
+ """
+ if not isinstance(query, str):
+ return False
+
+ # Make matches case-insensitive, normalize query
+ normalized_query = query.upper()
+ # Remove multi-line comments and inline comments
+ normalized_query = re.sub(r"//.*?$|/\*.*?\*/", "", normalized_query, flags=re.MULTILINE | re.DOTALL)
+ # Basic check: Look for any forbidden clause as whole word
+ for forbidden in _FORBIDDEN_CLAUSES:
+ pattern = r'\b' + re.escape(forbidden) + r'\b'
+ if re.search(pattern, normalized_query):
+ return False
+
+ # Must start with a MATCH or OPTIONAL MATCH (with optional WITH/UNWIND/RETURN statements in between)
+ allowed_starter = False
+ # Find first non-empty, non-comment line
+ statements = [line.strip() for line in normalized_query.splitlines() if line.strip()]
+ joined = " ".join(statements)
+ # Check if it starts with MATCH or OPTIONAL MATCH, possibly after WITH or UNWIND
+ match_start = re.match(r'^(WITH|UNWIND|MATCH|OPTIONAL MATCH)', joined)
+ if match_start:
+ first_clause = match_start.group(1)
+ if first_clause in ("MATCH", "OPTIONAL MATCH", "WITH", "UNWIND"):
+ allowed_starter = True
+ else:
+ # Or maybe just starts with RETURN or WITH in rare cases
+ allowed_starter = joined.startswith("RETURN") or joined.startswith("WITH")
+
+ if not allowed_starter:
+ return False
+
+ return True
+
+
class NaturalLanguageRetriever(BaseRetriever):
"""
Retriever for handling natural language search.
@@ -78,8 +154,12 @@
cypher_query = await self._generate_cypher_query(
query, edge_schemas, previous_attempts
)
+ if not is_cypher_query_safe(cypher_query):
+ logger.error("Aborting execution: LLM-generated Cypher query failed safety check. Query: %s", cypher_query)
+ raise RuntimeError("Generated Cypher query is not permitted by safety policy (must be read-only).")
+
logger.info(
f"Executing generated Cypher query (attempt {attempt + 1}): {cypher_query[:100]}..."
if len(cypher_query) > 100
else cypher_query
@@ -156,5 +236,5 @@
"""
if context is None:
context = await self.get_context(query)
- return context
+ return context
\ No newline at end of file
Explanation of FixVulnerability: Fix Approach: Side Effects:
Issues
|
LLM-to-Cypher Query Injection via Prompt Manipulation @@ -1,6 +1,7 @@
from typing import Any, Optional
import logging
+import re
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.graph.networkx.adapter import NetworkXAdapter
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.infrastructure.llm.prompts import render_prompt
@@ -48,8 +49,48 @@
"""
)
return node_schemas, edge_schemas
+ def _is_cypher_query_safe(self, cypher_query: str) -> bool:
+ """
+ Check that a Cypher query is "safe" -- i.e., does not contain modification/destructive/admin keywords.
+ Returns True if it appears safe, False otherwise.
+ """
+ # Disallow keywords associated with data mutation or schema/admin operations
+ # Case-insensitive, ignore whitespace before/after, check for each forbidden word/phrase
+ forbidden_patterns = [
+ r"\bDETACH\s+DELETE\b",
+ r"\bDELETE\b",
+ r"\bREMOVE\b",
+ r"\bSET\b",
+ r"\bMERGE\b",
+ r"\bCREATE\b",
+ r"\bDROP\b",
+ r"\bCALL\b",
+ r"\bUSE\b",
+ r"\bLOAD\s+CSV\b",
+ r"\bAPOC\.",
+ r"\bDBMS\.",
+ r"(?<!\w)SHOW\b", # For SHOW commands like SHOW DATABASES
+ r"\bALTER\b",
+ r"\bGRANT\b",
+ r"\bDENY\b",
+ r"\bREVOKE\b",
+ r"\bSTART\b", # For START TRANSACTION etc.
+ r"\bCOMMIT\b",
+ r"\bROLLBACK\b",
+ r"\bRENAME\b",
+ r"\bCONSTRAINT\b",
+ r"\bINDEX\b",
+ r"\bTRUNCATE\b",
+ r"\bAUTH\b",
+ r"\bPASSWORD\b",
+ ]
+ pattern = re.compile("|".join(forbidden_patterns), re.IGNORECASE | re.MULTILINE)
+ if pattern.search(cypher_query):
+ return False
+ return True
+
async def _generate_cypher_query(self, query: str, edge_schemas, previous_attempts=None) -> str:
"""Generate a Cypher query using LLM based on natural language query and schema information."""
llm_client = get_llm_client()
system_prompt = render_prompt(
@@ -78,8 +119,20 @@
cypher_query = await self._generate_cypher_query(
query, edge_schemas, previous_attempts
)
+ if not isinstance(cypher_query, str):
+ raise ValueError("Generated Cypher query is not a string.")
+
+ if not self._is_cypher_query_safe(cypher_query):
+ logger.error(
+ f"Rejected potentially unsafe Cypher query (attempt {attempt + 1}): {cypher_query[:100]}..."
+ if len(cypher_query) > 100
+ else f"Rejected potentially unsafe Cypher query (attempt {attempt + 1}): {cypher_query}"
+ )
+ previous_attempts += f"Query: {cypher_query} -> Rejected: Unsafe or potentially destructive query detected\n"
+ continue # Try again for the remaining attempts
+
logger.info(
f"Executing generated Cypher query (attempt {attempt + 1}): {cypher_query[:100]}..."
if len(cypher_query) > 100
else cypher_query
@@ -156,5 +209,5 @@
"""
if context is None:
context = await self.get_context(query)
- return context
+ return context
\ No newline at end of file
Explanation of FixVulnerability and Fix Summary Fix Approach
Potential Impacts
Issues
|
Unsanitized Query Input Leading to Potential Graph/Vector Database Injection @@ -1,13 +1,19 @@
import asyncio
-from typing import Any, Optional
+import re
+from typing import Any, Optional, Callable, List, Dict
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
+def _is_valid_node_id(node_id: str) -> bool:
+ if not isinstance(node_id, str):
+ return False
+ # Only allow alphanumeric, -, _, and length up to 64, no spaces
+ return re.fullmatch(r"[A-Za-z0-9_\-]{1,64}", node_id) is not None
class InsightsRetriever(BaseRetriever):
"""
Retriever for handling graph connection-based insights.
@@ -20,13 +26,28 @@
- exploration_levels
- top_k
"""
- def __init__(self, exploration_levels: int = 1, top_k: int = 5):
- """Initialize retriever with exploration levels and search parameters."""
+ def __init__(
+ self,
+ exploration_levels: int = 1,
+ top_k: int = 5,
+ allowed_nodes_callback: Optional[Callable[[str], bool]] = None,
+ ):
+ """Initialize retriever with exploration levels, search parameters, and (optional) allowed nodes callback.
+
+ allowed_nodes_callback, if provided, is called with a node_id and should return True if access is permitted.
+ """
self.exploration_levels = exploration_levels
self.top_k = top_k
+ self.allowed_nodes_callback = allowed_nodes_callback
+ def _is_allowed_node(self, node_id: str) -> bool:
+ # Check allowed_nodes_callback if provided, otherwise allow by default
+ if self.allowed_nodes_callback:
+ return self.allowed_nodes_callback(node_id)
+ return True
+
async def get_context(self, query: str) -> list:
"""
Find neighbours of a given node in the graph.
@@ -44,50 +65,92 @@
--------
- list: A list of unique connections found for the queried node.
"""
- if query is None:
+ if query is None or not isinstance(query, str):
return []
- node_id = query
+ node_id = query.strip()
+ # Input validation: enforce strict node id format
+ if not _is_valid_node_id(node_id):
+ return []
+
+ if not self._is_allowed_node(node_id):
+ return []
+
graph_engine = await get_graph_engine()
+
+ # Only validated node_id is used
exact_node = await graph_engine.extract_node(node_id)
+ node_connections = []
if exact_node is not None and "id" in exact_node:
- node_connections = await graph_engine.get_connections(str(exact_node["id"]))
+ extracted_id = str(exact_node["id"])
+ if _is_valid_node_id(extracted_id) and self._is_allowed_node(extracted_id):
+ node_connections = await graph_engine.get_connections(extracted_id)
+ else:
+ return []
else:
vector_engine = get_vector_engine()
try:
results = await asyncio.gather(
- vector_engine.search("Entity_name", query_text=query, limit=self.top_k),
- vector_engine.search("EntityType_name", query_text=query, limit=self.top_k),
+ vector_engine.search(
+ "Entity_name",
+ query_text=node_id,
+ limit=self.top_k,
+ ),
+ vector_engine.search(
+ "EntityType_name",
+ query_text=node_id,
+ limit=self.top_k,
+ ),
)
except CollectionNotFoundError as error:
raise NoDataError("No data found in the system, please add data first.") from error
- results = [*results[0], *results[1]]
- relevant_results = [result for result in results if result.score < 0.5][: self.top_k]
+ # Flatten results
+ flat_results = [*results[0], *results[1]]
+ # Filter for score and valid node IDs and allowed nodes
+ relevant_results = [
+ result
+ for result in flat_results
+ if result.score < 0.5 and _is_valid_node_id(str(result.id)) and self._is_allowed_node(str(result.id))
+ ][: self.top_k]
if len(relevant_results) == 0:
return []
+ # Only fetch connections for validated, allowed IDs
node_connections_results = await asyncio.gather(
- *[graph_engine.get_connections(result.id) for result in relevant_results]
+ *[graph_engine.get_connections(str(result.id)) for result in relevant_results]
)
- node_connections = []
for neighbours in node_connections_results:
node_connections.extend(neighbours)
- unique_node_connections_map = {}
- unique_node_connections = []
+ unique_node_connections_map: Dict[str, bool] = {}
+ unique_node_connections: List[Any] = []
for node_connection in node_connections:
- if "id" not in node_connection[0] or "id" not in node_connection[2]:
+ if (
+ not isinstance(node_connection, list)
+ or len(node_connection) < 3
+ or "id" not in node_connection[0]
+ or "id" not in node_connection[2]
+ ):
continue
- unique_id = f"{node_connection[0]['id']} {node_connection[1]['relationship_name']} {node_connection[2]['id']}"
+ left_id = str(node_connection[0]["id"])
+ right_id = str(node_connection[2]["id"])
+ if not _is_valid_node_id(left_id) or not _is_valid_node_id(right_id):
+ continue
+ if not self._is_allowed_node(left_id) or not self._is_allowed_node(right_id):
+ continue
+
+ relationship = node_connection[1].get("relationship_name") if isinstance(node_connection[1], dict) else str(node_connection[1])
+ unique_id = f"{left_id} {relationship} {right_id}"
+
if unique_id not in unique_node_connections_map:
unique_node_connections_map[unique_id] = True
unique_node_connections.append(node_connection)
@@ -113,5 +176,5 @@
based on the query.
"""
if context is None:
context = await self.get_context(query)
- return context
+ return context
\ No newline at end of file
Explanation of FixVulnerability Summary: Fix Summary:
Impact: Potential Impact:
Issues
|
Description
DCO Affirmation
I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin.