rustGPT/src/core/session.rs

379 lines
13 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use anyhow::{Context, Result};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::fs;
use std::path::PathBuf;
use crate::config::Config;
const SYSTEM_PROMPT: &str = "You are an AI assistant running in a terminal (CLI) environment. \
Optimise all answers for 80column readability, prefer plain text, \
ASCII art or concise bullet lists over heavy markup, and wrap code \
snippets in fenced blocks when helpful. Do not emit trailing spaces or \
control characters.";
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: String,
pub content: String,
}
#[derive(Debug, Clone)]
pub struct ConversationStats {
pub total_messages: usize,
pub user_messages: usize,
pub assistant_messages: usize,
pub total_characters: usize,
pub average_message_length: usize,
}
#[derive(Debug, Clone)]
pub struct SessionInfo {
pub name: String,
pub last_modified: DateTime<Utc>,
pub model: Option<String>,
pub message_count: Option<usize>,
pub file_size: Option<u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionData {
pub model: String,
pub messages: Vec<Message>,
pub enable_web_search: bool,
pub enable_reasoning_summary: bool,
#[serde(default = "default_reasoning_effort")]
pub reasoning_effort: String,
pub updated_at: DateTime<Utc>,
}
fn default_reasoning_effort() -> String {
"medium".to_string()
}
#[derive(Debug, Clone)]
pub struct Session {
pub name: String,
pub model: String,
pub messages: Vec<Message>,
pub enable_web_search: bool,
pub enable_reasoning_summary: bool,
pub reasoning_effort: String,
}
impl Session {
pub fn new(name: String, model: String) -> Self {
let mut session = Self {
name,
model,
messages: Vec::new(),
enable_web_search: true,
enable_reasoning_summary: false,
reasoning_effort: "medium".to_string(),
};
// Add system prompt as first message
session.messages.push(Message {
role: "system".to_string(),
content: SYSTEM_PROMPT.to_string(),
});
session
}
pub fn sessions_dir() -> Result<PathBuf> {
let config = Config::load().unwrap_or_default();
let home = dirs::home_dir().context("Could not find home directory")?;
let sessions_dir = home.join(&config.session.sessions_dir_name);
if !sessions_dir.exists() {
fs::create_dir_all(&sessions_dir)
.with_context(|| format!("Failed to create sessions directory: {:?}", sessions_dir))?;
}
Ok(sessions_dir)
}
pub fn session_path(name: &str) -> Result<PathBuf> {
let config = Config::load().unwrap_or_default();
Ok(Self::sessions_dir()?.join(format!("{}.{}", name, config.session.file_extension)))
}
pub fn save(&self) -> Result<()> {
let data = SessionData {
model: self.model.clone(),
messages: self.messages.clone(),
enable_web_search: self.enable_web_search,
enable_reasoning_summary: self.enable_reasoning_summary,
reasoning_effort: self.reasoning_effort.clone(),
updated_at: Utc::now(),
};
let path = Self::session_path(&self.name)?;
let tmp_path = path.with_extension("tmp");
let json_data = serde_json::to_string_pretty(&data)
.context("Failed to serialize session data")?;
fs::write(&tmp_path, json_data)
.with_context(|| format!("Failed to write session to {:?}", tmp_path))?;
fs::rename(&tmp_path, &path)
.with_context(|| format!("Failed to rename {:?} to {:?}", tmp_path, path))?;
Ok(())
}
pub fn load(name: &str) -> Result<Self> {
let path = Self::session_path(name)?;
if !path.exists() {
return Err(anyhow::anyhow!("Session '{}' does not exist", name));
}
let json_data = fs::read_to_string(&path)
.with_context(|| format!("Failed to read session from {:?}", path))?;
let data: SessionData = serde_json::from_str(&json_data)
.with_context(|| format!("Failed to parse session data from {:?}", path))?;
let mut session = Self {
name: name.to_string(),
model: data.model,
messages: data.messages,
enable_web_search: data.enable_web_search,
enable_reasoning_summary: data.enable_reasoning_summary,
reasoning_effort: data.reasoning_effort,
};
// Ensure system prompt is present
if session.messages.is_empty() || session.messages[0].role != "system" {
session.messages.insert(0, Message {
role: "system".to_string(),
content: SYSTEM_PROMPT.to_string(),
});
}
Ok(session)
}
pub fn add_user_message(&mut self, content: String) {
self.messages.push(Message {
role: "user".to_string(),
content,
});
self.truncate_history_if_needed();
}
pub fn add_assistant_message(&mut self, content: String) {
self.messages.push(Message {
role: "assistant".to_string(),
content,
});
self.truncate_history_if_needed();
}
/// Truncates conversation history to stay within configured limits
fn truncate_history_if_needed(&mut self) {
let config = Config::load().unwrap_or_default();
let max_history = config.limits.max_conversation_history;
// Always preserve the system prompt (first message)
if self.messages.len() > max_history + 1 {
let system_prompt = self.messages[0].clone();
let messages_to_keep = max_history;
let start_index = self.messages.len() - messages_to_keep;
// Keep the most recent messages
let mut new_messages = vec![system_prompt];
new_messages.extend_from_slice(&self.messages[start_index..]);
self.messages = new_messages;
}
}
/// Truncates individual messages that exceed reasonable length
pub fn truncate_long_messages(&mut self) {
const MAX_MESSAGE_LENGTH: usize = 10000; // 10k characters per message
const TRUNCATION_NOTICE: &str = "\n\n[Message truncated for performance...]";
for message in &mut self.messages {
if message.content.len() > MAX_MESSAGE_LENGTH {
message.content.truncate(MAX_MESSAGE_LENGTH - TRUNCATION_NOTICE.len());
message.content.push_str(TRUNCATION_NOTICE);
}
}
}
/// Gets conversation statistics
pub fn get_stats(&self) -> ConversationStats {
let total_messages = self.messages.len();
let user_messages = self.messages.iter().filter(|m| m.role == "user").count();
let assistant_messages = self.messages.iter().filter(|m| m.role == "assistant").count();
let total_chars: usize = self.messages.iter().map(|m| m.content.len()).sum();
let avg_message_length = if total_messages > 0 { total_chars / total_messages } else { 0 };
ConversationStats {
total_messages,
user_messages,
assistant_messages,
total_characters: total_chars,
average_message_length: avg_message_length,
}
}
pub fn clear_messages(&mut self) {
self.messages.clear();
// Re-add system prompt
self.messages.push(Message {
role: "system".to_string(),
content: SYSTEM_PROMPT.to_string(),
});
}
pub fn list_sessions() -> Result<Vec<(String, DateTime<Utc>)>> {
let sessions = Self::list_sessions_lazy(false)?;
Ok(sessions.into_iter().map(|s| (s.name, s.last_modified)).collect())
}
/// Lists sessions with lazy loading - only loads full data if detailed=true
pub fn list_sessions_lazy(detailed: bool) -> Result<Vec<SessionInfo>> {
let config = Config::load().unwrap_or_default();
let sessions_dir = Self::sessions_dir()?;
if !sessions_dir.exists() {
return Ok(Vec::new());
}
let mut sessions = Vec::new();
let mut count = 0;
for entry in fs::read_dir(&sessions_dir)? {
let entry = entry?;
let path = entry.path();
// Respect max sessions limit for performance
if count >= config.limits.max_sessions_to_list {
break;
}
if let Some(extension) = path.extension() {
if extension == config.session.file_extension.as_str() {
if let Some(name) = path.file_stem().and_then(|s| s.to_str()) {
let metadata = entry.metadata()?;
let modified = metadata.modified()?;
let datetime = DateTime::<Utc>::from(modified);
let file_size = metadata.len();
let (model, message_count) = if detailed {
// Only load session data if detailed info is requested
match Self::get_session_metadata(name) {
Ok((model, count)) => (Some(model), Some(count)),
Err(_) => (None, None)
}
} else {
(None, None)
};
sessions.push(SessionInfo {
name: name.to_string(),
last_modified: datetime,
model,
message_count,
file_size: Some(file_size),
});
count += 1;
}
}
}
}
sessions.sort_by(|a, b| b.last_modified.cmp(&a.last_modified)); // Sort by modification time, newest first
Ok(sessions)
}
/// Gets metadata from a session file without loading all messages
fn get_session_metadata(name: &str) -> Result<(String, usize)> {
let path = Self::session_path(name)?;
if !path.exists() {
return Err(anyhow::anyhow!("Session '{}' does not exist", name));
}
let json_data = fs::read_to_string(&path)
.with_context(|| format!("Failed to read session from {:?}", path))?;
// Parse only the fields we need for metadata
let data: SessionData = serde_json::from_str(&json_data)
.with_context(|| format!("Failed to parse session data from {:?}", path))?;
Ok((data.model, data.messages.len()))
}
/// Optimizes session data in memory by removing redundant information
pub fn optimize_for_memory(&mut self) {
// Truncate very long messages for memory efficiency
self.truncate_long_messages();
// Remove excessive whitespace from messages
for message in &mut self.messages {
message.content = message.content
.lines()
.map(|line| line.trim())
.collect::<Vec<_>>()
.join("\n")
.trim()
.to_string();
}
}
/// Checks if session needs cleanup based on size and age
pub fn needs_cleanup(&self) -> bool {
let stats = self.get_stats();
let config = Config::load().unwrap_or_default();
// Check if conversation is too long
if stats.total_messages > config.limits.max_conversation_history * 2 {
return true;
}
// Check if total character count is excessive (>1MB)
if stats.total_characters > 1_000_000 {
return true;
}
false
}
/// Performs aggressive cleanup for memory optimization
pub fn cleanup_for_memory(&mut self) {
let config = Config::load().unwrap_or_default();
let target_messages = config.limits.max_conversation_history / 2;
if self.messages.len() > target_messages + 1 {
let system_prompt = self.messages[0].clone();
let start_index = self.messages.len() - target_messages;
let mut new_messages = vec![system_prompt];
new_messages.extend_from_slice(&self.messages[start_index..]);
self.messages = new_messages;
}
self.optimize_for_memory();
}
pub fn delete_session(name: &str) -> Result<()> {
let path = Self::session_path(name)?;
if !path.exists() {
return Err(anyhow::anyhow!("Session '{}' does not exist", name));
}
fs::remove_file(&path)
.with_context(|| format!("Failed to delete session file: {:?}", path))?;
Ok(())
}
}