Memory updates
All checks were successful
CI / Build and push Docker image (push) Successful in 1m33s
All checks were successful
CI / Build and push Docker image (push) Successful in 1m33s
This commit is contained in:
parent
b634377ddc
commit
0c01d65a6c
3 changed files with 272 additions and 298 deletions
|
@ -2,132 +2,66 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Memory(BaseModel):
|
||||
"""A timestamped fact that references one or more entities."""
|
||||
|
||||
id: str = Field(..., description="Unique identifier for this memory")
|
||||
content: str = Field(..., description="The actual fact or information being stored")
|
||||
entities: list[str] = Field(..., description="List of entity names this memory references")
|
||||
timestamp: str = Field(..., description="ISO timestamp when this memory was created")
|
||||
|
||||
|
||||
class Entity(BaseModel):
|
||||
"""Entity model for knowledge graph."""
|
||||
"""Simple entity reference."""
|
||||
|
||||
name: str = Field(..., description="The name of the entity")
|
||||
entity_type: str = Field(..., description="The type of the entity")
|
||||
observations: list[str] = Field(
|
||||
..., description="An array of observation contents associated with the entity"
|
||||
)
|
||||
entity_type: str = Field(default="general", description="The type of the entity")
|
||||
memory_count: int = Field(default=0, description="Number of memories referencing this entity")
|
||||
|
||||
|
||||
class Relation(BaseModel):
|
||||
"""Relation model for knowledge graph."""
|
||||
|
||||
from_: str = Field(
|
||||
...,
|
||||
alias="from",
|
||||
description="The name of the entity where the relation starts",
|
||||
)
|
||||
to: str = Field(..., description="The name of the entity where the relation ends")
|
||||
relation_type: str = Field(..., description="The type of the relation")
|
||||
|
||||
|
||||
class KnowledgeGraph(BaseModel):
|
||||
"""Knowledge graph containing entities and relations."""
|
||||
class MemoryGraph(BaseModel):
|
||||
"""Collection of memories and entities."""
|
||||
|
||||
memories: list[Memory]
|
||||
entities: list[Entity]
|
||||
relations: list[Relation]
|
||||
|
||||
|
||||
class EntityWrapper(BaseModel):
|
||||
"""Wrapper for entity data."""
|
||||
class CreateMemoryRequest(BaseModel):
|
||||
"""Request to create a new memory."""
|
||||
|
||||
type: Literal["entity"]
|
||||
name: str
|
||||
entity_type: str
|
||||
observations: list[str]
|
||||
content: str = Field(..., description="The fact or information to store")
|
||||
entities: list[str] = Field(..., description="List of entity names this memory references")
|
||||
|
||||
|
||||
class RelationWrapper(BaseModel):
|
||||
"""Wrapper for relation data."""
|
||||
class SearchMemoryRequest(BaseModel):
|
||||
"""Request to search memories."""
|
||||
|
||||
type: Literal["relation"]
|
||||
from_: str = Field(..., alias="from")
|
||||
to: str
|
||||
relation_type: str
|
||||
query: str = Field(..., description="Search term to find in memory content or entity names")
|
||||
limit: int = Field(default=10, description="Maximum number of memories to return")
|
||||
|
||||
|
||||
class CreateEntitiesRequest(BaseModel):
|
||||
"""Request model for creating entities."""
|
||||
class GetEntityRequest(BaseModel):
|
||||
"""Request to get memories for specific entities."""
|
||||
|
||||
entities: list[Entity] = Field(..., description="List of entities to create")
|
||||
entities: list[str] = Field(..., description="List of entity names to retrieve memories for")
|
||||
limit: int = Field(default=5, description="Maximum number of memories to return")
|
||||
|
||||
|
||||
class CreateRelationsRequest(BaseModel):
|
||||
"""Request model for creating relations."""
|
||||
class DeleteMemoryRequest(BaseModel):
|
||||
"""Request to delete specific memories."""
|
||||
|
||||
relations: list[Relation] = Field(
|
||||
..., description="List of relations to create. All must be in active voice."
|
||||
)
|
||||
memory_ids: list[str] = Field(..., description="List of memory IDs to delete")
|
||||
|
||||
|
||||
class ObservationItem(BaseModel):
|
||||
"""Item for adding observations."""
|
||||
class MemorySummary(BaseModel):
|
||||
"""Summary statistics about stored memories."""
|
||||
|
||||
entity_name: str = Field(..., description="The name of the entity to add the observations to")
|
||||
contents: list[str] = Field(..., description="An array of observation contents to add")
|
||||
|
||||
|
||||
class DeletionItem(BaseModel):
|
||||
"""Item for deleting observations."""
|
||||
|
||||
entity_name: str = Field(..., description="The name of the entity containing the observations")
|
||||
observations: list[str] = Field(..., description="An array of observations to delete")
|
||||
|
||||
|
||||
class AddObservationsRequest(BaseModel):
|
||||
"""Request model for adding observations."""
|
||||
|
||||
observations: list[ObservationItem] = Field(
|
||||
...,
|
||||
description=(
|
||||
"A list of observation additions, each specifying an entity and contents to add"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class DeleteObservationsRequest(BaseModel):
|
||||
"""Request model for deleting observations."""
|
||||
|
||||
deletions: list[DeletionItem] = Field(
|
||||
...,
|
||||
description=(
|
||||
"A list of observation deletions, each specifying an entity and observations to remove"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class DeleteEntitiesRequest(BaseModel):
|
||||
"""Request model for deleting entities."""
|
||||
|
||||
entity_names: list[str] = Field(..., description="An array of entity names to delete")
|
||||
|
||||
|
||||
class DeleteRelationsRequest(BaseModel):
|
||||
"""Request model for deleting relations."""
|
||||
|
||||
relations: list[Relation] = Field(..., description="An array of relations to delete")
|
||||
|
||||
|
||||
class SearchNodesRequest(BaseModel):
|
||||
"""Request model for searching nodes."""
|
||||
|
||||
query: str = Field(
|
||||
...,
|
||||
description=(
|
||||
"The search query to match against entity names, types, and observation content"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class OpenNodesRequest(BaseModel):
|
||||
"""Request model for opening specific nodes."""
|
||||
|
||||
names: list[str] = Field(..., description="An array of entity names to retrieve")
|
||||
total_memories: int = Field(..., description="Total number of memories stored")
|
||||
total_entities: int = Field(..., description="Total number of unique entities")
|
||||
oldest_memory: str | None = Field(..., description="ISO timestamp of oldest memory")
|
||||
latest_memory: str | None = Field(..., description="ISO timestamp of latest memory")
|
||||
memory_timespan_days: int | None = Field(..., description="Days between oldest and latest memory")
|
||||
top_entities: list[Entity] = Field(..., description="Most frequently referenced entities")
|
|
@ -2,243 +2,278 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from openapi_mcp_server.tools.base import BaseTool
|
||||
|
||||
from .models import (
|
||||
AddObservationsRequest,
|
||||
CreateEntitiesRequest,
|
||||
CreateRelationsRequest,
|
||||
DeleteEntitiesRequest,
|
||||
DeleteObservationsRequest,
|
||||
DeleteRelationsRequest,
|
||||
CreateMemoryRequest,
|
||||
DeleteMemoryRequest,
|
||||
Entity,
|
||||
KnowledgeGraph,
|
||||
OpenNodesRequest,
|
||||
Relation,
|
||||
SearchNodesRequest,
|
||||
GetEntityRequest,
|
||||
Memory,
|
||||
MemoryGraph,
|
||||
MemorySummary,
|
||||
SearchMemoryRequest,
|
||||
)
|
||||
from .storage import (
|
||||
generate_memory_id,
|
||||
get_current_timestamp,
|
||||
read_memory_graph,
|
||||
save_memory_graph,
|
||||
)
|
||||
from .storage import read_graph_file, save_graph
|
||||
|
||||
|
||||
class MemoryTool(BaseTool):
|
||||
"""Knowledge graph memory system tool."""
|
||||
"""Simplified memory system for storing timestamped facts."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the memory tool."""
|
||||
super().__init__(
|
||||
name="memory",
|
||||
description="A structured knowledge graph memory system",
|
||||
description="A simple memory system for storing timestamped facts about entities",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_entities(req: CreateEntitiesRequest) -> list[Entity]:
|
||||
"""Create multiple entities in the graph.
|
||||
def create_memory(req: CreateMemoryRequest) -> Memory:
|
||||
"""Store a new memory/fact.
|
||||
|
||||
Returns:
|
||||
list: List of newly created entities.
|
||||
Memory: The newly created memory with auto-generated timestamp and ID.
|
||||
"""
|
||||
graph = read_graph_file()
|
||||
existing_names = {e.name for e in graph.entities}
|
||||
new_entities = [e for e in req.entities if e.name not in existing_names]
|
||||
graph.entities.extend(new_entities)
|
||||
save_graph(graph)
|
||||
return new_entities
|
||||
graph = read_memory_graph()
|
||||
|
||||
# Create new memory with auto-generated timestamp and ID
|
||||
memory = Memory(
|
||||
id=generate_memory_id(),
|
||||
content=req.content,
|
||||
entities=req.entities,
|
||||
timestamp=get_current_timestamp(),
|
||||
)
|
||||
|
||||
graph.memories.append(memory)
|
||||
|
||||
# Update entity counts and ensure entities exist
|
||||
entity_dict = {e.name: e for e in graph.entities}
|
||||
for entity_name in req.entities:
|
||||
if entity_name in entity_dict:
|
||||
entity_dict[entity_name].memory_count += 1
|
||||
else:
|
||||
# Create new entity
|
||||
new_entity = Entity(name=entity_name, entity_type="general", memory_count=1)
|
||||
graph.entities.append(new_entity)
|
||||
entity_dict[entity_name] = new_entity
|
||||
|
||||
save_memory_graph(graph)
|
||||
return memory
|
||||
|
||||
@staticmethod
|
||||
def create_relations(req: CreateRelationsRequest) -> list[Relation]:
|
||||
"""Create multiple relations between entities.
|
||||
def get_all_memories(limit: int = 20) -> MemoryGraph:
|
||||
"""Get all memories and entities.
|
||||
|
||||
Returns:
|
||||
list: List of newly created relations.
|
||||
MemoryGraph: All stored memories and entities, sorted by timestamp (newest first).
|
||||
"""
|
||||
graph = read_graph_file()
|
||||
existing = {(r.from_, r.to, r.relation_type) for r in graph.relations}
|
||||
new = [r for r in req.relations if (r.from_, r.to, r.relation_type) not in existing]
|
||||
graph.relations.extend(new)
|
||||
save_graph(graph)
|
||||
return new
|
||||
graph = read_memory_graph()
|
||||
# Sort memories by timestamp (newest first)
|
||||
graph.memories.sort(key=lambda m: m.timestamp, reverse=True)
|
||||
if limit and len(graph.memories) > limit:
|
||||
graph.memories = graph.memories[:limit]
|
||||
return graph
|
||||
|
||||
@staticmethod
|
||||
def add_observations(req: AddObservationsRequest) -> list[dict[str, str | list[str]]]:
|
||||
"""Add new observations to existing entities.
|
||||
def search_memories(req: SearchMemoryRequest) -> MemoryGraph:
|
||||
"""Search memories by content or entity names.
|
||||
|
||||
Returns:
|
||||
list: List of entities with their added observations.
|
||||
|
||||
Raises:
|
||||
HTTPException: If entity is not found.
|
||||
MemoryGraph: Filtered memories matching the search query.
|
||||
"""
|
||||
graph = read_graph_file()
|
||||
results = []
|
||||
graph = read_memory_graph()
|
||||
query = req.query.lower()
|
||||
|
||||
for obs in req.observations:
|
||||
name = obs.entity_name.lower()
|
||||
contents = obs.contents
|
||||
entity = next((e for e in graph.entities if e.name == name), None)
|
||||
if not entity:
|
||||
raise HTTPException(status_code=404, detail=f"Entity {name} not found")
|
||||
added = [c for c in contents if c not in entity.observations]
|
||||
entity.observations.extend(added)
|
||||
results.append({"entity_name": name, "added_observations": added})
|
||||
matching_memories = []
|
||||
for memory in graph.memories:
|
||||
# Search in content
|
||||
if query in memory.content.lower():
|
||||
matching_memories.append(memory)
|
||||
continue
|
||||
# Search in entity names
|
||||
if any(query in entity.lower() for entity in memory.entities):
|
||||
matching_memories.append(memory)
|
||||
|
||||
save_graph(graph)
|
||||
return results
|
||||
# Sort by timestamp (newest first) and apply limit
|
||||
matching_memories.sort(key=lambda m: m.timestamp, reverse=True)
|
||||
if len(matching_memories) > req.limit:
|
||||
matching_memories = matching_memories[: req.limit]
|
||||
|
||||
# Get entities referenced in matching memories
|
||||
referenced_entities = set()
|
||||
for memory in matching_memories:
|
||||
referenced_entities.update(memory.entities)
|
||||
|
||||
matching_entities = [e for e in graph.entities if e.name in referenced_entities]
|
||||
|
||||
return MemoryGraph(memories=matching_memories, entities=matching_entities)
|
||||
|
||||
@staticmethod
|
||||
def delete_entities(req: DeleteEntitiesRequest) -> dict[str, str]:
|
||||
"""Delete entities and associated relations.
|
||||
def get_entity_memories(req: GetEntityRequest) -> MemoryGraph:
|
||||
"""Get memories for specific entities.
|
||||
|
||||
Returns:
|
||||
dict: Success message indicating entities were deleted.
|
||||
MemoryGraph: Memories that reference the specified entities.
|
||||
"""
|
||||
graph = read_graph_file()
|
||||
graph.entities = [e for e in graph.entities if e.name not in req.entity_names]
|
||||
graph.relations = [
|
||||
r
|
||||
for r in graph.relations
|
||||
if r.from_ not in req.entity_names and r.to not in req.entity_names
|
||||
graph = read_memory_graph()
|
||||
|
||||
# Check if memory references any of the requested entities
|
||||
matching_memories = [
|
||||
memory
|
||||
for memory in graph.memories
|
||||
if any(entity in memory.entities for entity in req.entities)
|
||||
]
|
||||
save_graph(graph)
|
||||
return {"message": "Entities deleted successfully"}
|
||||
|
||||
# Sort by timestamp (newest first) and apply limit
|
||||
matching_memories.sort(key=lambda m: m.timestamp, reverse=True)
|
||||
if len(matching_memories) > req.limit:
|
||||
matching_memories = matching_memories[: req.limit]
|
||||
|
||||
# Get the requested entities
|
||||
matching_entities = [e for e in graph.entities if e.name in req.entities]
|
||||
|
||||
return MemoryGraph(memories=matching_memories, entities=matching_entities)
|
||||
|
||||
@staticmethod
|
||||
def delete_observations(req: DeleteObservationsRequest) -> dict[str, str]:
|
||||
"""Delete specific observations from entities.
|
||||
def delete_memories(req: DeleteMemoryRequest) -> dict[str, str]:
|
||||
"""Delete specific memories by ID.
|
||||
|
||||
Returns:
|
||||
dict: Success message indicating observations were deleted.
|
||||
dict: Success message with count of deleted memories.
|
||||
"""
|
||||
graph = read_graph_file()
|
||||
graph = read_memory_graph()
|
||||
original_count = len(graph.memories)
|
||||
|
||||
for deletion in req.deletions:
|
||||
name = deletion.entity_name.lower()
|
||||
to_delete = deletion.observations
|
||||
entity = next((e for e in graph.entities if e.name == name), None)
|
||||
if entity:
|
||||
entity.observations = [obs for obs in entity.observations if obs not in to_delete]
|
||||
|
||||
save_graph(graph)
|
||||
return {"message": "Observations deleted successfully"}
|
||||
|
||||
@staticmethod
|
||||
def delete_relations(req: DeleteRelationsRequest) -> dict[str, str]:
|
||||
"""Delete relations from the graph.
|
||||
|
||||
Returns:
|
||||
dict: Success message indicating relations were deleted.
|
||||
"""
|
||||
graph = read_graph_file()
|
||||
del_set = {(r.from_, r.to, r.relation_type) for r in req.relations}
|
||||
graph.relations = [
|
||||
r for r in graph.relations if (r.from_, r.to, r.relation_type) not in del_set
|
||||
# Remove memories and track which entities were affected
|
||||
affected_entities = set()
|
||||
graph.memories = [
|
||||
m
|
||||
for m in graph.memories
|
||||
if m.id not in req.memory_ids or affected_entities.update(m.entities)
|
||||
]
|
||||
save_graph(graph)
|
||||
return {"message": "Relations deleted successfully"}
|
||||
|
||||
@staticmethod
|
||||
def read_graph() -> KnowledgeGraph:
|
||||
"""Read entire knowledge graph.
|
||||
deleted_count = original_count - len(graph.memories)
|
||||
|
||||
Returns:
|
||||
KnowledgeGraph: The complete knowledge graph.
|
||||
"""
|
||||
return read_graph_file()
|
||||
# Recalculate entity memory counts
|
||||
entity_counts = {}
|
||||
for memory in graph.memories:
|
||||
for entity_name in memory.entities:
|
||||
entity_counts[entity_name] = entity_counts.get(entity_name, 0) + 1
|
||||
|
||||
@staticmethod
|
||||
def search_nodes(req: SearchNodesRequest) -> KnowledgeGraph:
|
||||
"""Search for nodes by keyword.
|
||||
|
||||
Returns:
|
||||
KnowledgeGraph: Filtered knowledge graph containing matching entities.
|
||||
"""
|
||||
graph = read_graph_file()
|
||||
entities = [
|
||||
e
|
||||
# Update entity counts and remove entities with zero memories
|
||||
graph.entities = [
|
||||
Entity(
|
||||
name=e.name, entity_type=e.entity_type, memory_count=entity_counts.get(e.name, 0)
|
||||
)
|
||||
for e in graph.entities
|
||||
if req.query.lower() in e.name.lower()
|
||||
or req.query.lower() in e.entity_type.lower()
|
||||
or any(req.query.lower() in o.lower() for o in e.observations)
|
||||
if entity_counts.get(e.name, 0) > 0
|
||||
]
|
||||
names = {e.name for e in entities}
|
||||
relations = [r for r in graph.relations if r.from_ in names and r.to in names]
|
||||
|
||||
return KnowledgeGraph(entities=entities, relations=relations)
|
||||
save_memory_graph(graph)
|
||||
return {"message": f"Deleted {deleted_count} memories"}
|
||||
|
||||
@staticmethod
|
||||
def open_nodes(req: OpenNodesRequest) -> KnowledgeGraph:
|
||||
"""Open specific nodes by name.
|
||||
def get_summary() -> MemorySummary:
|
||||
"""Get summary statistics about stored memories.
|
||||
|
||||
Returns:
|
||||
KnowledgeGraph: Knowledge graph containing requested entities and their relations.
|
||||
MemorySummary: Statistics about memories and entities.
|
||||
"""
|
||||
graph = read_graph_file()
|
||||
entities = [e for e in graph.entities if e.name in req.names]
|
||||
names = {e.name for e in entities}
|
||||
relations = [r for r in graph.relations if r.from_ in names and r.to in names]
|
||||
return KnowledgeGraph(entities=entities, relations=relations)
|
||||
graph = read_memory_graph()
|
||||
|
||||
if not graph.memories:
|
||||
return MemorySummary(
|
||||
total_memories=0,
|
||||
total_entities=0,
|
||||
oldest_memory=None,
|
||||
latest_memory=None,
|
||||
memory_timespan_days=None,
|
||||
top_entities=[],
|
||||
)
|
||||
|
||||
# Sort memories by timestamp
|
||||
sorted_memories = sorted(graph.memories, key=lambda m: m.timestamp)
|
||||
oldest = sorted_memories[0].timestamp
|
||||
latest = sorted_memories[-1].timestamp
|
||||
|
||||
# Calculate timespan
|
||||
try:
|
||||
oldest_dt = datetime.fromisoformat(oldest)
|
||||
latest_dt = datetime.fromisoformat(latest)
|
||||
timespan_days = (latest_dt - oldest_dt).days
|
||||
except ValueError:
|
||||
timespan_days = None
|
||||
|
||||
# Get top entities by memory count
|
||||
top_entities = sorted(graph.entities, key=lambda e: e.memory_count, reverse=True)[:10]
|
||||
|
||||
return MemorySummary(
|
||||
total_memories=len(graph.memories),
|
||||
total_entities=len(graph.entities),
|
||||
oldest_memory=oldest,
|
||||
latest_memory=latest,
|
||||
memory_timespan_days=timespan_days,
|
||||
top_entities=top_entities,
|
||||
)
|
||||
|
||||
def get_router(self) -> APIRouter:
|
||||
"""Return the FastAPI router for memory tool endpoints."""
|
||||
router = APIRouter()
|
||||
|
||||
router.add_api_route(
|
||||
"/create_entities",
|
||||
MemoryTool.create_entities,
|
||||
"/create",
|
||||
MemoryTool.create_memory,
|
||||
methods=["POST"],
|
||||
summary="Store new entities (people, places, concepts) with their properties and observations",
|
||||
response_model=Memory,
|
||||
summary="STORE MEMORY: Save a new fact/memory about one or more entities. Timestamp is auto-generated.",
|
||||
)
|
||||
|
||||
router.add_api_route(
|
||||
"/create_relations",
|
||||
MemoryTool.create_relations,
|
||||
methods=["POST"],
|
||||
summary="Connect entities with relationships (works_at, lives_in, knows, etc.)",
|
||||
)
|
||||
router.add_api_route(
|
||||
"/add_observations",
|
||||
MemoryTool.add_observations,
|
||||
methods=["POST"],
|
||||
summary="Add new facts or observations to entities you've already stored",
|
||||
)
|
||||
router.add_api_route(
|
||||
"/delete_entities",
|
||||
MemoryTool.delete_entities,
|
||||
methods=["POST"],
|
||||
summary="Remove entities and all their connections from memory",
|
||||
)
|
||||
router.add_api_route(
|
||||
"/delete_observations",
|
||||
MemoryTool.delete_observations,
|
||||
methods=["POST"],
|
||||
summary="Remove specific facts or observations from entities",
|
||||
)
|
||||
router.add_api_route(
|
||||
"/delete_relations",
|
||||
MemoryTool.delete_relations,
|
||||
methods=["POST"],
|
||||
summary="Remove specific relationships between entities",
|
||||
)
|
||||
router.add_api_route(
|
||||
"/read_graph",
|
||||
MemoryTool.read_graph,
|
||||
"/all",
|
||||
MemoryTool.get_all_memories,
|
||||
methods=["GET"],
|
||||
response_model=KnowledgeGraph,
|
||||
summary="Retrieve all stored entities and relationships from memory",
|
||||
response_model=MemoryGraph,
|
||||
summary="GET ALL MEMORIES: Retrieve all stored memories and entities, sorted by timestamp (newest first).",
|
||||
)
|
||||
|
||||
router.add_api_route(
|
||||
"/search_nodes",
|
||||
MemoryTool.search_nodes,
|
||||
"/search",
|
||||
MemoryTool.search_memories,
|
||||
methods=["POST"],
|
||||
response_model=KnowledgeGraph,
|
||||
summary="Find entities by searching names, types, or observations",
|
||||
response_model=MemoryGraph,
|
||||
summary="SEARCH MEMORIES: Find memories by searching content or entity names. Returns matching memories sorted by time.",
|
||||
)
|
||||
|
||||
router.add_api_route(
|
||||
"/open_nodes",
|
||||
MemoryTool.open_nodes,
|
||||
"/entity",
|
||||
MemoryTool.get_entity_memories,
|
||||
methods=["POST"],
|
||||
response_model=KnowledgeGraph,
|
||||
summary="Retrieve specific entities and their connections by exact name",
|
||||
response_model=MemoryGraph,
|
||||
summary="GET ENTITY MEMORIES: Retrieve all memories that reference specific entities, sorted by timestamp.",
|
||||
)
|
||||
|
||||
router.add_api_route(
|
||||
"/delete",
|
||||
MemoryTool.delete_memories,
|
||||
methods=["POST"],
|
||||
summary="DELETE MEMORIES: Remove specific memories by their IDs. Entity counts are automatically updated.",
|
||||
)
|
||||
|
||||
router.add_api_route(
|
||||
"/stats",
|
||||
MemoryTool.get_summary,
|
||||
methods=["GET"],
|
||||
response_model=MemorySummary,
|
||||
summary="MEMORY STATS: Get summary statistics - total memories, entities, timespan, and top entities by frequency.",
|
||||
)
|
||||
|
||||
return router
|
||||
|
|
|
@ -4,11 +4,12 @@ from __future__ import annotations
|
|||
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from .models import Entity, KnowledgeGraph, Relation
|
||||
from .models import Entity, Memory, MemoryGraph
|
||||
|
||||
MEMORY_FILE_PATH_ENV = os.getenv("MEMORY_FILE_PATH", "memory.json")
|
||||
MEMORY_FILE_PATH = Path(
|
||||
|
@ -18,58 +19,62 @@ MEMORY_FILE_PATH = Path(
|
|||
)
|
||||
|
||||
|
||||
def read_graph_file() -> KnowledgeGraph:
|
||||
"""Read the knowledge graph from file.
|
||||
def read_memory_graph() -> MemoryGraph:
|
||||
"""Read the memory graph from file.
|
||||
|
||||
Returns:
|
||||
KnowledgeGraph: The knowledge graph loaded from storage.
|
||||
MemoryGraph: The memory graph loaded from storage.
|
||||
|
||||
Raises:
|
||||
HTTPException: If the memory file is not found or permission is denied.
|
||||
"""
|
||||
if not MEMORY_FILE_PATH.exists():
|
||||
return KnowledgeGraph(entities=[], relations=[])
|
||||
return MemoryGraph(memories=[], entities=[])
|
||||
|
||||
try:
|
||||
with MEMORY_FILE_PATH.open(encoding="utf-8") as f:
|
||||
lines = [line for line in f if line.strip()]
|
||||
entities = []
|
||||
relations = []
|
||||
|
||||
for line in lines:
|
||||
item = json.loads(line)
|
||||
if item["type"] == "entity":
|
||||
entities.append(
|
||||
Entity(
|
||||
name=item["name"],
|
||||
entity_type=item["entity_type"],
|
||||
observations=item["observations"],
|
||||
)
|
||||
)
|
||||
elif item["type"] == "relation":
|
||||
relations.append(Relation(**item))
|
||||
|
||||
return KnowledgeGraph(entities=entities, relations=relations)
|
||||
data = json.load(f)
|
||||
|
||||
memories = [Memory(**m) for m in data.get("memories", [])]
|
||||
entities = [Entity(**e) for e in data.get("entities", [])]
|
||||
|
||||
return MemoryGraph(memories=memories, entities=entities)
|
||||
except PermissionError as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Permission denied when reading memory file: {MEMORY_FILE_PATH}",
|
||||
) from e
|
||||
except json.JSONDecodeError:
|
||||
# Handle legacy format or corrupted file
|
||||
return MemoryGraph(memories=[], entities=[])
|
||||
|
||||
|
||||
def save_graph(graph: KnowledgeGraph) -> None:
|
||||
"""Save the knowledge graph to file.
|
||||
def save_memory_graph(graph: MemoryGraph) -> None:
|
||||
"""Save the memory graph to file.
|
||||
|
||||
Raises:
|
||||
HTTPException: If the memory file is not found or permission is denied.
|
||||
"""
|
||||
lines = [json.dumps({"type": "entity", **e.model_dump()}) for e in graph.entities] + [
|
||||
json.dumps({"type": "relation", **r.model_dump(by_alias=True)}) for r in graph.relations
|
||||
]
|
||||
data = {
|
||||
"memories": [m.model_dump() for m in graph.memories],
|
||||
"entities": [e.model_dump() for e in graph.entities]
|
||||
}
|
||||
|
||||
try:
|
||||
MEMORY_FILE_PATH.write_text("\n".join(lines), encoding="utf-8")
|
||||
with MEMORY_FILE_PATH.open("w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
except PermissionError as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Permission denied when writing to memory file: {MEMORY_FILE_PATH}",
|
||||
) from e
|
||||
|
||||
|
||||
def generate_memory_id() -> str:
|
||||
"""Generate a unique ID for a memory."""
|
||||
return f"mem_{datetime.utcnow().strftime('%Y%m%d_%H%M%S_%f')}"
|
||||
|
||||
|
||||
def get_current_timestamp() -> str:
|
||||
"""Get current UTC timestamp in ISO format."""
|
||||
return datetime.utcnow().isoformat() + "Z"
|
Loading…
Add table
Add a link
Reference in a new issue