Memory updates
All checks were successful
CI / Build and push Docker image (push) Successful in 1m33s

This commit is contained in:
Tom Foster 2025-06-16 18:29:51 +01:00
parent b634377ddc
commit 0c01d65a6c
3 changed files with 272 additions and 298 deletions

View file

@ -2,132 +2,66 @@
from __future__ import annotations
from typing import Literal
from pydantic import BaseModel, Field
class Memory(BaseModel):
"""A timestamped fact that references one or more entities."""
id: str = Field(..., description="Unique identifier for this memory")
content: str = Field(..., description="The actual fact or information being stored")
entities: list[str] = Field(..., description="List of entity names this memory references")
timestamp: str = Field(..., description="ISO timestamp when this memory was created")
class Entity(BaseModel):
"""Entity model for knowledge graph."""
"""Simple entity reference."""
name: str = Field(..., description="The name of the entity")
entity_type: str = Field(..., description="The type of the entity")
observations: list[str] = Field(
..., description="An array of observation contents associated with the entity"
)
entity_type: str = Field(default="general", description="The type of the entity")
memory_count: int = Field(default=0, description="Number of memories referencing this entity")
class Relation(BaseModel):
"""Relation model for knowledge graph."""
from_: str = Field(
...,
alias="from",
description="The name of the entity where the relation starts",
)
to: str = Field(..., description="The name of the entity where the relation ends")
relation_type: str = Field(..., description="The type of the relation")
class KnowledgeGraph(BaseModel):
"""Knowledge graph containing entities and relations."""
class MemoryGraph(BaseModel):
"""Collection of memories and entities."""
memories: list[Memory]
entities: list[Entity]
relations: list[Relation]
class EntityWrapper(BaseModel):
"""Wrapper for entity data."""
class CreateMemoryRequest(BaseModel):
"""Request to create a new memory."""
type: Literal["entity"]
name: str
entity_type: str
observations: list[str]
content: str = Field(..., description="The fact or information to store")
entities: list[str] = Field(..., description="List of entity names this memory references")
class RelationWrapper(BaseModel):
"""Wrapper for relation data."""
class SearchMemoryRequest(BaseModel):
"""Request to search memories."""
type: Literal["relation"]
from_: str = Field(..., alias="from")
to: str
relation_type: str
query: str = Field(..., description="Search term to find in memory content or entity names")
limit: int = Field(default=10, description="Maximum number of memories to return")
class CreateEntitiesRequest(BaseModel):
"""Request model for creating entities."""
class GetEntityRequest(BaseModel):
"""Request to get memories for specific entities."""
entities: list[Entity] = Field(..., description="List of entities to create")
entities: list[str] = Field(..., description="List of entity names to retrieve memories for")
limit: int = Field(default=5, description="Maximum number of memories to return")
class CreateRelationsRequest(BaseModel):
"""Request model for creating relations."""
class DeleteMemoryRequest(BaseModel):
"""Request to delete specific memories."""
relations: list[Relation] = Field(
..., description="List of relations to create. All must be in active voice."
)
memory_ids: list[str] = Field(..., description="List of memory IDs to delete")
class ObservationItem(BaseModel):
"""Item for adding observations."""
class MemorySummary(BaseModel):
"""Summary statistics about stored memories."""
entity_name: str = Field(..., description="The name of the entity to add the observations to")
contents: list[str] = Field(..., description="An array of observation contents to add")
class DeletionItem(BaseModel):
"""Item for deleting observations."""
entity_name: str = Field(..., description="The name of the entity containing the observations")
observations: list[str] = Field(..., description="An array of observations to delete")
class AddObservationsRequest(BaseModel):
"""Request model for adding observations."""
observations: list[ObservationItem] = Field(
...,
description=(
"A list of observation additions, each specifying an entity and contents to add"
),
)
class DeleteObservationsRequest(BaseModel):
"""Request model for deleting observations."""
deletions: list[DeletionItem] = Field(
...,
description=(
"A list of observation deletions, each specifying an entity and observations to remove"
),
)
class DeleteEntitiesRequest(BaseModel):
"""Request model for deleting entities."""
entity_names: list[str] = Field(..., description="An array of entity names to delete")
class DeleteRelationsRequest(BaseModel):
"""Request model for deleting relations."""
relations: list[Relation] = Field(..., description="An array of relations to delete")
class SearchNodesRequest(BaseModel):
"""Request model for searching nodes."""
query: str = Field(
...,
description=(
"The search query to match against entity names, types, and observation content"
),
)
class OpenNodesRequest(BaseModel):
"""Request model for opening specific nodes."""
names: list[str] = Field(..., description="An array of entity names to retrieve")
total_memories: int = Field(..., description="Total number of memories stored")
total_entities: int = Field(..., description="Total number of unique entities")
oldest_memory: str | None = Field(..., description="ISO timestamp of oldest memory")
latest_memory: str | None = Field(..., description="ISO timestamp of latest memory")
memory_timespan_days: int | None = Field(..., description="Days between oldest and latest memory")
top_entities: list[Entity] = Field(..., description="Most frequently referenced entities")

View file

@ -2,243 +2,278 @@
from __future__ import annotations
from fastapi import APIRouter, HTTPException
from datetime import datetime
from fastapi import APIRouter
from openapi_mcp_server.tools.base import BaseTool
from .models import (
AddObservationsRequest,
CreateEntitiesRequest,
CreateRelationsRequest,
DeleteEntitiesRequest,
DeleteObservationsRequest,
DeleteRelationsRequest,
CreateMemoryRequest,
DeleteMemoryRequest,
Entity,
KnowledgeGraph,
OpenNodesRequest,
Relation,
SearchNodesRequest,
GetEntityRequest,
Memory,
MemoryGraph,
MemorySummary,
SearchMemoryRequest,
)
from .storage import (
generate_memory_id,
get_current_timestamp,
read_memory_graph,
save_memory_graph,
)
from .storage import read_graph_file, save_graph
class MemoryTool(BaseTool):
"""Knowledge graph memory system tool."""
"""Simplified memory system for storing timestamped facts."""
def __init__(self) -> None:
"""Initialize the memory tool."""
super().__init__(
name="memory",
description="A structured knowledge graph memory system",
description="A simple memory system for storing timestamped facts about entities",
)
@staticmethod
def create_entities(req: CreateEntitiesRequest) -> list[Entity]:
"""Create multiple entities in the graph.
def create_memory(req: CreateMemoryRequest) -> Memory:
"""Store a new memory/fact.
Returns:
list: List of newly created entities.
Memory: The newly created memory with auto-generated timestamp and ID.
"""
graph = read_graph_file()
existing_names = {e.name for e in graph.entities}
new_entities = [e for e in req.entities if e.name not in existing_names]
graph.entities.extend(new_entities)
save_graph(graph)
return new_entities
graph = read_memory_graph()
# Create new memory with auto-generated timestamp and ID
memory = Memory(
id=generate_memory_id(),
content=req.content,
entities=req.entities,
timestamp=get_current_timestamp(),
)
graph.memories.append(memory)
# Update entity counts and ensure entities exist
entity_dict = {e.name: e for e in graph.entities}
for entity_name in req.entities:
if entity_name in entity_dict:
entity_dict[entity_name].memory_count += 1
else:
# Create new entity
new_entity = Entity(name=entity_name, entity_type="general", memory_count=1)
graph.entities.append(new_entity)
entity_dict[entity_name] = new_entity
save_memory_graph(graph)
return memory
@staticmethod
def create_relations(req: CreateRelationsRequest) -> list[Relation]:
"""Create multiple relations between entities.
def get_all_memories(limit: int = 20) -> MemoryGraph:
"""Get all memories and entities.
Returns:
list: List of newly created relations.
MemoryGraph: All stored memories and entities, sorted by timestamp (newest first).
"""
graph = read_graph_file()
existing = {(r.from_, r.to, r.relation_type) for r in graph.relations}
new = [r for r in req.relations if (r.from_, r.to, r.relation_type) not in existing]
graph.relations.extend(new)
save_graph(graph)
return new
graph = read_memory_graph()
# Sort memories by timestamp (newest first)
graph.memories.sort(key=lambda m: m.timestamp, reverse=True)
if limit and len(graph.memories) > limit:
graph.memories = graph.memories[:limit]
return graph
@staticmethod
def add_observations(req: AddObservationsRequest) -> list[dict[str, str | list[str]]]:
"""Add new observations to existing entities.
def search_memories(req: SearchMemoryRequest) -> MemoryGraph:
"""Search memories by content or entity names.
Returns:
list: List of entities with their added observations.
Raises:
HTTPException: If entity is not found.
MemoryGraph: Filtered memories matching the search query.
"""
graph = read_graph_file()
results = []
graph = read_memory_graph()
query = req.query.lower()
for obs in req.observations:
name = obs.entity_name.lower()
contents = obs.contents
entity = next((e for e in graph.entities if e.name == name), None)
if not entity:
raise HTTPException(status_code=404, detail=f"Entity {name} not found")
added = [c for c in contents if c not in entity.observations]
entity.observations.extend(added)
results.append({"entity_name": name, "added_observations": added})
matching_memories = []
for memory in graph.memories:
# Search in content
if query in memory.content.lower():
matching_memories.append(memory)
continue
# Search in entity names
if any(query in entity.lower() for entity in memory.entities):
matching_memories.append(memory)
save_graph(graph)
return results
# Sort by timestamp (newest first) and apply limit
matching_memories.sort(key=lambda m: m.timestamp, reverse=True)
if len(matching_memories) > req.limit:
matching_memories = matching_memories[: req.limit]
# Get entities referenced in matching memories
referenced_entities = set()
for memory in matching_memories:
referenced_entities.update(memory.entities)
matching_entities = [e for e in graph.entities if e.name in referenced_entities]
return MemoryGraph(memories=matching_memories, entities=matching_entities)
@staticmethod
def delete_entities(req: DeleteEntitiesRequest) -> dict[str, str]:
"""Delete entities and associated relations.
def get_entity_memories(req: GetEntityRequest) -> MemoryGraph:
"""Get memories for specific entities.
Returns:
dict: Success message indicating entities were deleted.
MemoryGraph: Memories that reference the specified entities.
"""
graph = read_graph_file()
graph.entities = [e for e in graph.entities if e.name not in req.entity_names]
graph.relations = [
r
for r in graph.relations
if r.from_ not in req.entity_names and r.to not in req.entity_names
graph = read_memory_graph()
# Check if memory references any of the requested entities
matching_memories = [
memory
for memory in graph.memories
if any(entity in memory.entities for entity in req.entities)
]
save_graph(graph)
return {"message": "Entities deleted successfully"}
# Sort by timestamp (newest first) and apply limit
matching_memories.sort(key=lambda m: m.timestamp, reverse=True)
if len(matching_memories) > req.limit:
matching_memories = matching_memories[: req.limit]
# Get the requested entities
matching_entities = [e for e in graph.entities if e.name in req.entities]
return MemoryGraph(memories=matching_memories, entities=matching_entities)
@staticmethod
def delete_observations(req: DeleteObservationsRequest) -> dict[str, str]:
"""Delete specific observations from entities.
def delete_memories(req: DeleteMemoryRequest) -> dict[str, str]:
"""Delete specific memories by ID.
Returns:
dict: Success message indicating observations were deleted.
dict: Success message with count of deleted memories.
"""
graph = read_graph_file()
graph = read_memory_graph()
original_count = len(graph.memories)
for deletion in req.deletions:
name = deletion.entity_name.lower()
to_delete = deletion.observations
entity = next((e for e in graph.entities if e.name == name), None)
if entity:
entity.observations = [obs for obs in entity.observations if obs not in to_delete]
save_graph(graph)
return {"message": "Observations deleted successfully"}
@staticmethod
def delete_relations(req: DeleteRelationsRequest) -> dict[str, str]:
"""Delete relations from the graph.
Returns:
dict: Success message indicating relations were deleted.
"""
graph = read_graph_file()
del_set = {(r.from_, r.to, r.relation_type) for r in req.relations}
graph.relations = [
r for r in graph.relations if (r.from_, r.to, r.relation_type) not in del_set
# Remove memories and track which entities were affected
affected_entities = set()
graph.memories = [
m
for m in graph.memories
if m.id not in req.memory_ids or affected_entities.update(m.entities)
]
save_graph(graph)
return {"message": "Relations deleted successfully"}
@staticmethod
def read_graph() -> KnowledgeGraph:
"""Read entire knowledge graph.
deleted_count = original_count - len(graph.memories)
Returns:
KnowledgeGraph: The complete knowledge graph.
"""
return read_graph_file()
# Recalculate entity memory counts
entity_counts = {}
for memory in graph.memories:
for entity_name in memory.entities:
entity_counts[entity_name] = entity_counts.get(entity_name, 0) + 1
@staticmethod
def search_nodes(req: SearchNodesRequest) -> KnowledgeGraph:
"""Search for nodes by keyword.
Returns:
KnowledgeGraph: Filtered knowledge graph containing matching entities.
"""
graph = read_graph_file()
entities = [
e
# Update entity counts and remove entities with zero memories
graph.entities = [
Entity(
name=e.name, entity_type=e.entity_type, memory_count=entity_counts.get(e.name, 0)
)
for e in graph.entities
if req.query.lower() in e.name.lower()
or req.query.lower() in e.entity_type.lower()
or any(req.query.lower() in o.lower() for o in e.observations)
if entity_counts.get(e.name, 0) > 0
]
names = {e.name for e in entities}
relations = [r for r in graph.relations if r.from_ in names and r.to in names]
return KnowledgeGraph(entities=entities, relations=relations)
save_memory_graph(graph)
return {"message": f"Deleted {deleted_count} memories"}
@staticmethod
def open_nodes(req: OpenNodesRequest) -> KnowledgeGraph:
"""Open specific nodes by name.
def get_summary() -> MemorySummary:
"""Get summary statistics about stored memories.
Returns:
KnowledgeGraph: Knowledge graph containing requested entities and their relations.
MemorySummary: Statistics about memories and entities.
"""
graph = read_graph_file()
entities = [e for e in graph.entities if e.name in req.names]
names = {e.name for e in entities}
relations = [r for r in graph.relations if r.from_ in names and r.to in names]
return KnowledgeGraph(entities=entities, relations=relations)
graph = read_memory_graph()
if not graph.memories:
return MemorySummary(
total_memories=0,
total_entities=0,
oldest_memory=None,
latest_memory=None,
memory_timespan_days=None,
top_entities=[],
)
# Sort memories by timestamp
sorted_memories = sorted(graph.memories, key=lambda m: m.timestamp)
oldest = sorted_memories[0].timestamp
latest = sorted_memories[-1].timestamp
# Calculate timespan
try:
oldest_dt = datetime.fromisoformat(oldest)
latest_dt = datetime.fromisoformat(latest)
timespan_days = (latest_dt - oldest_dt).days
except ValueError:
timespan_days = None
# Get top entities by memory count
top_entities = sorted(graph.entities, key=lambda e: e.memory_count, reverse=True)[:10]
return MemorySummary(
total_memories=len(graph.memories),
total_entities=len(graph.entities),
oldest_memory=oldest,
latest_memory=latest,
memory_timespan_days=timespan_days,
top_entities=top_entities,
)
def get_router(self) -> APIRouter:
"""Return the FastAPI router for memory tool endpoints."""
router = APIRouter()
router.add_api_route(
"/create_entities",
MemoryTool.create_entities,
"/create",
MemoryTool.create_memory,
methods=["POST"],
summary="Store new entities (people, places, concepts) with their properties and observations",
response_model=Memory,
summary="STORE MEMORY: Save a new fact/memory about one or more entities. Timestamp is auto-generated.",
)
router.add_api_route(
"/create_relations",
MemoryTool.create_relations,
methods=["POST"],
summary="Connect entities with relationships (works_at, lives_in, knows, etc.)",
)
router.add_api_route(
"/add_observations",
MemoryTool.add_observations,
methods=["POST"],
summary="Add new facts or observations to entities you've already stored",
)
router.add_api_route(
"/delete_entities",
MemoryTool.delete_entities,
methods=["POST"],
summary="Remove entities and all their connections from memory",
)
router.add_api_route(
"/delete_observations",
MemoryTool.delete_observations,
methods=["POST"],
summary="Remove specific facts or observations from entities",
)
router.add_api_route(
"/delete_relations",
MemoryTool.delete_relations,
methods=["POST"],
summary="Remove specific relationships between entities",
)
router.add_api_route(
"/read_graph",
MemoryTool.read_graph,
"/all",
MemoryTool.get_all_memories,
methods=["GET"],
response_model=KnowledgeGraph,
summary="Retrieve all stored entities and relationships from memory",
response_model=MemoryGraph,
summary="GET ALL MEMORIES: Retrieve all stored memories and entities, sorted by timestamp (newest first).",
)
router.add_api_route(
"/search_nodes",
MemoryTool.search_nodes,
"/search",
MemoryTool.search_memories,
methods=["POST"],
response_model=KnowledgeGraph,
summary="Find entities by searching names, types, or observations",
response_model=MemoryGraph,
summary="SEARCH MEMORIES: Find memories by searching content or entity names. Returns matching memories sorted by time.",
)
router.add_api_route(
"/open_nodes",
MemoryTool.open_nodes,
"/entity",
MemoryTool.get_entity_memories,
methods=["POST"],
response_model=KnowledgeGraph,
summary="Retrieve specific entities and their connections by exact name",
response_model=MemoryGraph,
summary="GET ENTITY MEMORIES: Retrieve all memories that reference specific entities, sorted by timestamp.",
)
router.add_api_route(
"/delete",
MemoryTool.delete_memories,
methods=["POST"],
summary="DELETE MEMORIES: Remove specific memories by their IDs. Entity counts are automatically updated.",
)
router.add_api_route(
"/stats",
MemoryTool.get_summary,
methods=["GET"],
response_model=MemorySummary,
summary="MEMORY STATS: Get summary statistics - total memories, entities, timespan, and top entities by frequency.",
)
return router

View file

@ -4,11 +4,12 @@ from __future__ import annotations
import json
import os
from datetime import datetime
from pathlib import Path
from fastapi import HTTPException
from .models import Entity, KnowledgeGraph, Relation
from .models import Entity, Memory, MemoryGraph
MEMORY_FILE_PATH_ENV = os.getenv("MEMORY_FILE_PATH", "memory.json")
MEMORY_FILE_PATH = Path(
@ -18,58 +19,62 @@ MEMORY_FILE_PATH = Path(
)
def read_graph_file() -> KnowledgeGraph:
"""Read the knowledge graph from file.
def read_memory_graph() -> MemoryGraph:
"""Read the memory graph from file.
Returns:
KnowledgeGraph: The knowledge graph loaded from storage.
MemoryGraph: The memory graph loaded from storage.
Raises:
HTTPException: If the memory file is not found or permission is denied.
"""
if not MEMORY_FILE_PATH.exists():
return KnowledgeGraph(entities=[], relations=[])
return MemoryGraph(memories=[], entities=[])
try:
with MEMORY_FILE_PATH.open(encoding="utf-8") as f:
lines = [line for line in f if line.strip()]
entities = []
relations = []
for line in lines:
item = json.loads(line)
if item["type"] == "entity":
entities.append(
Entity(
name=item["name"],
entity_type=item["entity_type"],
observations=item["observations"],
)
)
elif item["type"] == "relation":
relations.append(Relation(**item))
return KnowledgeGraph(entities=entities, relations=relations)
data = json.load(f)
memories = [Memory(**m) for m in data.get("memories", [])]
entities = [Entity(**e) for e in data.get("entities", [])]
return MemoryGraph(memories=memories, entities=entities)
except PermissionError as e:
raise HTTPException(
status_code=500,
detail=f"Permission denied when reading memory file: {MEMORY_FILE_PATH}",
) from e
except json.JSONDecodeError:
# Handle legacy format or corrupted file
return MemoryGraph(memories=[], entities=[])
def save_graph(graph: KnowledgeGraph) -> None:
"""Save the knowledge graph to file.
def save_memory_graph(graph: MemoryGraph) -> None:
"""Save the memory graph to file.
Raises:
HTTPException: If the memory file is not found or permission is denied.
"""
lines = [json.dumps({"type": "entity", **e.model_dump()}) for e in graph.entities] + [
json.dumps({"type": "relation", **r.model_dump(by_alias=True)}) for r in graph.relations
]
data = {
"memories": [m.model_dump() for m in graph.memories],
"entities": [e.model_dump() for e in graph.entities]
}
try:
MEMORY_FILE_PATH.write_text("\n".join(lines), encoding="utf-8")
with MEMORY_FILE_PATH.open("w", encoding="utf-8") as f:
json.dump(data, f, indent=2)
except PermissionError as e:
raise HTTPException(
status_code=500,
detail=f"Permission denied when writing to memory file: {MEMORY_FILE_PATH}",
) from e
def generate_memory_id() -> str:
"""Generate a unique ID for a memory."""
return f"mem_{datetime.utcnow().strftime('%Y%m%d_%H%M%S_%f')}"
def get_current_timestamp() -> str:
"""Get current UTC timestamp in ISO format."""
return datetime.utcnow().isoformat() + "Z"