added streaming support

This commit is contained in:
leach
2025-08-19 00:20:29 -04:00
parent 6d0592bda5
commit 0faecbf657
7 changed files with 969 additions and 30 deletions

View File

@@ -6,6 +6,8 @@ use crate::core::{
ChatClient, Session,
};
use crate::utils::{Display, InputHandler, SessionAction};
use std::future::Future;
use std::pin::Pin;
pub struct ChatCLI {
session: Session,
@@ -82,8 +84,6 @@ impl ChatCLI {
self.session.add_user_message(message.to_string());
self.session.save()?;
let spinner = self.display.show_spinner("Thinking");
// Clone data needed for the API call before getting mutable client reference
let model = self.session.model.clone();
let messages = self.session.messages.clone();
@@ -91,27 +91,75 @@ impl ChatCLI {
let enable_reasoning_summary = self.session.enable_reasoning_summary;
let reasoning_effort = self.session.reasoning_effort.clone();
let client = self.get_client()?;
// Check if we should use streaming before getting client
let should_use_streaming = {
let client = self.get_client()?;
client.supports_streaming()
};
match client
.chat_completion(
&model,
&messages,
enable_web_search,
enable_reasoning_summary,
&reasoning_effort,
)
.await
{
Ok(response) => {
spinner.finish("Done");
self.display.print_assistant_response(&response);
self.session.add_assistant_message(response);
self.session.save()?;
if should_use_streaming {
print!("{} ", console::style("🤖").magenta());
use std::io::{self, Write};
io::stdout().flush().ok();
let stream_callback = {
use crate::core::StreamCallback;
Box::new(move |chunk: &str| {
print!("{}", chunk);
use std::io::{self, Write};
io::stdout().flush().ok();
Box::pin(async move {}) as Pin<Box<dyn Future<Output = ()> + Send>>
}) as StreamCallback
};
let client = self.get_client()?;
match client
.chat_completion_stream(
&model,
&messages,
enable_web_search,
enable_reasoning_summary,
&reasoning_effort,
stream_callback,
)
.await
{
Ok(response) => {
println!(); // Add newline after streaming
self.session.add_assistant_message(response);
self.session.save()?;
}
Err(e) => {
println!(); // Add newline after failed streaming
self.display.print_error(&format!("Streaming failed: {}", e));
return Err(e);
}
}
Err(e) => {
spinner.finish_with_error("Failed");
return Err(e);
} else {
// Fallback to non-streaming
let spinner = self.display.show_spinner("Thinking");
let client = self.get_client()?;
match client
.chat_completion(
&model,
&messages,
enable_web_search,
enable_reasoning_summary,
&reasoning_effort,
)
.await
{
Ok(response) => {
spinner.finish("Done");
self.display.print_assistant_response(&response);
self.session.add_assistant_message(response);
self.session.save()?;
}
Err(e) => {
spinner.finish_with_error("Failed");
return Err(e);
}
}
}
@@ -150,6 +198,9 @@ impl ChatCLI {
"/tools" => {
self.tools_manager().await?;
}
"/history" => {
self.handle_history_command(&parts)?;
}
_ => {
self.display.print_error(&format!("Unknown command: {} (see /help)", parts[0]));
}
@@ -391,4 +442,52 @@ impl ChatCLI {
Ok(())
}
fn handle_history_command(&mut self, parts: &[&str]) -> Result<()> {
let mut filter_role: Option<&str> = None;
let mut limit: Option<usize> = None;
// Parse parameters
for &part in parts.iter().skip(1) {
match part {
"user" | "assistant" => filter_role = Some(part),
_ => {
if let Ok(num) = part.parse::<usize>() {
limit = Some(num);
} else {
self.display.print_error(&format!("Invalid parameter: {}", part));
self.display.print_info("Usage: /history [user|assistant] [number]");
return Ok(());
}
}
}
}
// Filter messages (skip system prompt at index 0)
let mut messages: Vec<(usize, &crate::core::Message)> = self.session.messages
.iter()
.enumerate()
.skip(1) // Skip system prompt
.collect();
// Apply role filter
if let Some(role) = filter_role {
messages.retain(|(_, msg)| msg.role == role);
}
// Apply limit
if let Some(limit_count) = limit {
let start_index = messages.len().saturating_sub(limit_count);
messages = messages[start_index..].to_vec();
}
if messages.is_empty() {
self.display.print_info("No messages to display");
return Ok(());
}
// Format and display
self.display.print_conversation_history(&messages);
Ok(())
}
}

View File

@@ -4,10 +4,15 @@ use serde::Deserialize;
use serde_json::{json, Value};
use std::env;
use std::time::Duration;
use futures::stream::StreamExt;
use std::future::Future;
use std::pin::Pin;
use crate::config::Config;
use super::{provider::Provider, session::Message};
pub type StreamCallback = Box<dyn Fn(&str) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync>;
#[derive(Debug)]
pub enum ChatClient {
OpenAI(OpenAIClient),
@@ -32,6 +37,30 @@ impl ChatClient {
}
}
}
pub async fn chat_completion_stream(
&self,
model: &str,
messages: &[Message],
enable_web_search: bool,
enable_reasoning_summary: bool,
reasoning_effort: &str,
stream_callback: StreamCallback,
) -> Result<String> {
match self {
ChatClient::OpenAI(client) => {
client.chat_completion_stream(model, messages, enable_web_search, enable_reasoning_summary, reasoning_effort, stream_callback).await
}
ChatClient::Anthropic(_) => {
// Fallback to non-streaming for Anthropic
self.chat_completion(model, messages, enable_web_search, enable_reasoning_summary, reasoning_effort).await
}
}
}
pub fn supports_streaming(&self) -> bool {
matches!(self, ChatClient::OpenAI(_))
}
pub fn supports_feature(&self, feature: &str) -> bool {
match self {
@@ -96,6 +125,40 @@ struct FunctionCall {
arguments: String,
}
// Streaming response structures
#[derive(Deserialize, Debug)]
struct StreamingChoice {
delta: StreamingDelta,
#[allow(dead_code)]
finish_reason: Option<String>,
}
#[derive(Deserialize, Debug)]
struct StreamingDelta {
content: Option<String>,
tool_calls: Option<Vec<StreamingToolCall>>,
}
#[derive(Deserialize, Debug)]
struct StreamingToolCall {
index: usize,
id: Option<String>,
#[serde(rename = "type")]
tool_type: Option<String>,
function: Option<StreamingFunction>,
}
#[derive(Deserialize, Debug)]
struct StreamingFunction {
name: Option<String>,
arguments: Option<String>,
}
#[derive(Deserialize, Debug)]
struct StreamingResponse {
choices: Vec<StreamingChoice>,
}
// Responses API structures
#[derive(Deserialize)]
struct ResponsesApiResponse {
@@ -428,6 +491,263 @@ impl OpenAIClient {
Ok(final_content)
}
pub async fn chat_completion_stream(
&self,
model: &str,
messages: &[Message],
enable_web_search: bool,
_enable_reasoning_summary: bool,
reasoning_effort: &str,
stream_callback: StreamCallback,
) -> Result<String> {
// Use Responses API for GPT-5 with web search, otherwise use Chat Completions API
if enable_web_search && model.starts_with("gpt-5") {
return self.responses_api_stream(model, messages, reasoning_effort, stream_callback).await;
}
let url = format!("{}/chat/completions", self.base_url);
let mut payload = json!({
"model": model,
"messages": Self::convert_messages(messages),
"stream": true
});
// Add tools if web search is enabled
if enable_web_search {
payload["tools"] = json!([{
"type": "function",
"function": {
"name": "web_search",
"description": "Search the web for current information on any topic",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The search query to find relevant information"
}
},
"required": ["query"]
}
}
}]);
payload["tool_choice"] = json!("auto");
}
// Add reasoning effort for GPT-5 models
if model.starts_with("gpt-5") && ["low", "medium", "high"].contains(&reasoning_effort) {
payload["reasoning_effort"] = json!(reasoning_effort);
}
let response = self
.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&payload)
.send()
.await
.context("Failed to send request to OpenAI API")?;
if !response.status().is_success() {
let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
return Err(anyhow::anyhow!("OpenAI API error: {}", error_text));
}
let mut full_response = String::new();
let mut tool_calls_buffer: std::collections::HashMap<usize, (String, String, String)> = std::collections::HashMap::new();
let mut stream = response.bytes_stream();
while let Some(chunk) = stream.next().await {
let chunk = chunk.context("Failed to read chunk from stream")?;
let chunk_str = std::str::from_utf8(&chunk)
.context("Failed to parse chunk as UTF-8")?;
// Temporary debug output
// Parse server-sent events
for line in chunk_str.lines() {
if line.starts_with("data: ") {
let data = &line[6..];
if data == "[DONE]" {
break;
}
// Skip empty data lines
if data.trim().is_empty() {
continue;
}
match serde_json::from_str::<StreamingResponse>(data) {
Ok(streaming_response) => {
if let Some(choice) = streaming_response.choices.first() {
// Handle streaming content
if let Some(content) = &choice.delta.content {
if !content.is_empty() {
full_response.push_str(content);
stream_callback(content).await;
}
}
// Handle streaming tool calls
if let Some(tool_calls) = &choice.delta.tool_calls {
for tool_call in tool_calls {
let entry = tool_calls_buffer.entry(tool_call.index).or_insert((String::new(), String::new(), String::new()));
if let Some(id) = &tool_call.id {
entry.0.push_str(id);
}
if let Some(function) = &tool_call.function {
if let Some(name) = &function.name {
entry.1.push_str(name);
}
if let Some(args) = &function.arguments {
entry.2.push_str(args);
}
}
}
}
}
}
Err(_) => {
continue;
}
}
} else if !line.trim().is_empty() {
}
}
}
// Process any tool calls that were collected
if !tool_calls_buffer.is_empty() && enable_web_search {
for (_, (_id, name, args)) in tool_calls_buffer {
if name == "web_search" {
if let Ok(parsed_args) = serde_json::from_str::<serde_json::Value>(&args) {
if let Some(query) = parsed_args.get("query").and_then(|q| q.as_str()) {
let tool_message = format!(
"\n\n[Web Search Request: \"{}\"]\nNote: Web search functionality is not implemented in this CLI. The AI wanted to search for: {}",
query, query
);
full_response.push_str(&tool_message);
stream_callback(&tool_message).await;
}
}
}
}
}
Ok(full_response)
}
async fn responses_api_stream(
&self,
model: &str,
messages: &[Message],
reasoning_effort: &str,
stream_callback: StreamCallback,
) -> Result<String> {
let url = format!("{}/responses", self.base_url);
// Convert messages to input text (simple approach for now)
let input_text = messages
.iter()
.filter(|msg| msg.role != "system")
.map(|msg| msg.content.as_str())
.collect::<Vec<_>>()
.join("\n");
let mut payload = json!({
"model": model,
"tools": [{"type": "web_search_preview"}],
"input": input_text,
"stream": true
});
// Add reasoning effort for GPT-5 models
if ["low", "medium", "high"].contains(&reasoning_effort) {
payload["reasoning"] = json!({
"effort": reasoning_effort
});
}
let response = self
.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&payload)
.send()
.await
.context("Failed to send request to OpenAI Responses API")?;
if !response.status().is_success() {
let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
return Err(anyhow::anyhow!("OpenAI Responses API error: {}", error_text));
}
let mut full_response = String::new();
let mut stream = response.bytes_stream();
while let Some(chunk) = stream.next().await {
let chunk = chunk.context("Failed to read chunk from Responses API stream")?;
let chunk_str = std::str::from_utf8(&chunk)
.context("Failed to parse Responses API chunk as UTF-8")?;
// Parse server-sent events for Responses API
for line in chunk_str.lines() {
if line.starts_with("data: ") {
let data = &line[6..];
if data == "[DONE]" {
break;
}
// Skip empty data lines
if data.trim().is_empty() {
continue;
}
// Try to parse streaming events
match serde_json::from_str::<serde_json::Value>(data) {
Ok(response_data) => {
// Check for streaming delta events
if let Some(event_type) = response_data.get("type").and_then(|t| t.as_str()) {
if event_type == "response.output_text.delta" {
if let Some(delta) = response_data.get("delta").and_then(|d| d.as_str()) {
if !delta.is_empty() {
full_response.push_str(delta);
stream_callback(delta).await;
}
}
}
}
}
Err(_) => {
continue;
}
}
} else if !line.trim().is_empty() {
}
}
}
if full_response.is_empty() {
return Err(anyhow::anyhow!("No content found in Responses API stream response"));
}
Ok(full_response)
}
pub fn supports_feature(&self, feature: &str) -> bool {
match feature {
"web_search" | "reasoning_summary" | "reasoning_effort" => true,

View File

@@ -2,6 +2,6 @@ pub mod session;
pub mod client;
pub mod provider;
pub use session::Session;
pub use client::{ChatClient, create_client};
pub use session::{Session, Message};
pub use client::{ChatClient, create_client, StreamCallback};
pub use provider::get_provider_for_model;

View File

@@ -1,14 +1,23 @@
use console::{style, Term};
use std::io::{self, Write};
use syntect::easy::HighlightLines;
use syntect::parsing::SyntaxSet;
use syntect::highlighting::{ThemeSet, Style};
use syntect::util::{as_24_bit_terminal_escaped, LinesWithEndings};
use regex::Regex;
pub struct Display {
term: Term,
syntax_set: SyntaxSet,
theme_set: ThemeSet,
}
impl Display {
pub fn new() -> Self {
Self {
term: Term::stdout(),
syntax_set: SyntaxSet::load_defaults_newlines(),
theme_set: ThemeSet::load_defaults(),
}
}
@@ -41,7 +50,115 @@ impl Display {
}
pub fn print_assistant_response(&self, content: &str) {
println!("{} {}", style("🤖").magenta(), content);
print!("{} ", style("🤖").magenta());
self.print_formatted_content_with_pagination(content);
}
fn print_formatted_content_with_pagination(&self, content: &str) {
let lines: Vec<&str> = content.lines().collect();
let terminal_height = self.term.size().0 as usize;
let lines_per_page = terminal_height.saturating_sub(5); // Leave space for prompts
if lines.len() <= lines_per_page {
// Short content, no pagination needed
self.print_formatted_content(content);
return;
}
let mut current_line = 0;
while current_line < lines.len() {
let end_line = (current_line + lines_per_page).min(lines.len());
let page_content = lines[current_line..end_line].join("\n");
self.print_formatted_content(&page_content);
if end_line < lines.len() {
print!("\n{} ", style("Press Enter to continue, 'q' to finish...").dim());
io::stdout().flush().ok();
let mut input = String::new();
if io::stdin().read_line(&mut input).is_ok() {
if input.trim().to_lowercase() == "q" {
println!("{}", style("(response truncated)").dim());
break;
}
}
}
current_line = end_line;
}
}
fn print_formatted_content(&self, content: &str) {
// Regex to match code blocks with optional language specifier
let code_block_regex = Regex::new(r"```(\w+)?\n([\s\S]*?)\n```").unwrap();
let mut last_end = 0;
// Process code blocks
for captures in code_block_regex.captures_iter(content) {
let full_match = captures.get(0).unwrap();
let lang = captures.get(1).map(|m| m.as_str()).unwrap_or("text");
let code = captures.get(2).unwrap().as_str();
// Print text before code block
let before_text = &content[last_end..full_match.start()];
self.print_text_with_inline_code(before_text);
// Print code block with syntax highlighting
self.print_code_block(code, lang);
last_end = full_match.end();
}
// Print remaining text after last code block
let remaining_text = &content[last_end..];
self.print_text_with_inline_code(remaining_text);
println!(); // Add newline at the end
}
fn print_text_with_inline_code(&self, text: &str) {
let inline_code_regex = Regex::new(r"`([^`]+)`").unwrap();
let mut last_end = 0;
for captures in inline_code_regex.captures_iter(text) {
let full_match = captures.get(0).unwrap();
let code = captures.get(1).unwrap().as_str();
// Print text before inline code
print!("{}", &text[last_end..full_match.start()]);
// Print inline code with background highlighting
print!("{}", style(code).on_black().white());
last_end = full_match.end();
}
// Print remaining text
print!("{}", &text[last_end..]);
}
fn print_code_block(&self, code: &str, language: &str) {
// Find the appropriate syntax definition
let syntax = self.syntax_set.find_syntax_by_extension(language)
.or_else(|| self.syntax_set.find_syntax_by_name(language))
.unwrap_or_else(|| self.syntax_set.find_syntax_plain_text());
// Use a dark theme for code highlighting
let theme = &self.theme_set.themes["base16-ocean.dark"];
println!("{}", style("```").dim());
let mut highlighter = HighlightLines::new(syntax, theme);
for line in LinesWithEndings::from(code) {
let ranges: Vec<(Style, &str)> = highlighter.highlight_line(line, &self.syntax_set).unwrap();
let escaped = as_24_bit_terminal_escaped(&ranges[..], false);
print!("{}", escaped);
}
println!("{}", style("```").dim());
}
pub fn print_command_result(&self, message: &str) {
@@ -74,8 +191,15 @@ Available Commands:
/new <session_name> - Create a new session
/switch - Interactive session manager (switch/delete)
/clear - Clear current conversation
/history [user|assistant] [number] - View conversation history
/tools - Interactive tool and feature manager
Input Features:
Multi-line input - End your message with '\' to enter multi-line mode
Keyboard shortcuts - Ctrl+C: Cancel, Ctrl+D: Exit, Arrow keys: History
Syntax highlighting - Code blocks are automatically highlighted
Pagination - Long responses are paginated (press Enter or 'q')
Environment Variables:
OPENAI_API_KEY - Required for OpenAI models
ANTHROPIC_API_KEY - Required for Anthropic models
@@ -96,6 +220,41 @@ Supported Models:
io::stdout().flush().ok();
SpinnerHandle::new()
}
pub fn print_conversation_history(&self, messages: &[(usize, &crate::core::Message)]) {
println!("{}", style("📜 Conversation History").bold().cyan());
println!("{}", style("".repeat(50)).dim());
let mut content = String::new();
for (original_index, message) in messages {
let role_icon = match message.role.as_str() {
"user" => "👤",
"assistant" => "🤖",
_ => "💬",
};
let role_name = match message.role.as_str() {
"user" => "User",
"assistant" => "Assistant",
_ => &message.role,
};
content.push_str(&format!(
"\n{} {} {} [Message #{}]\n{}\n",
style("").dim(),
role_icon,
style(role_name).bold(),
original_index,
style("".repeat(40)).dim()
));
content.push_str(&message.content);
content.push_str("\n\n");
}
self.print_formatted_content_with_pagination(&content);
}
}
impl Default for Display {

View File

@@ -1,6 +1,6 @@
use anyhow::Result;
use dialoguer::{theme::ColorfulTheme, Select};
use rustyline::{error::ReadlineError, DefaultEditor};
use rustyline::{error::ReadlineError, DefaultEditor, KeyEvent, Cmd};
pub struct InputHandler {
editor: DefaultEditor,
@@ -8,8 +8,13 @@ pub struct InputHandler {
impl InputHandler {
pub fn new() -> Result<Self> {
// Use a simpler configuration approach
let mut editor = DefaultEditor::new()?;
// Configure key bindings for better UX
editor.bind_sequence(KeyEvent::ctrl('C'), Cmd::Interrupt);
editor.bind_sequence(KeyEvent::ctrl('D'), Cmd::EndOfFile);
// Try to load history file
let history_file = dirs::home_dir()
.map(|home| home.join(".chat_cli_history"))
@@ -23,8 +28,13 @@ impl InputHandler {
pub fn read_line(&mut self, prompt: &str) -> Result<Option<String>> {
match self.editor.readline(prompt) {
Ok(line) => {
let _ = self.editor.add_history_entry(&line);
Ok(Some(line))
// Check if user wants to enter multi-line mode
if line.trim().ends_with("\\") {
self.read_multiline_input(line.trim_end_matches("\\").to_string())
} else {
let _ = self.editor.add_history_entry(&line);
Ok(Some(line))
}
}
Err(ReadlineError::Interrupted) => {
println!("^C");
@@ -38,6 +48,34 @@ impl InputHandler {
}
}
pub fn read_multiline_input(&mut self, initial_line: String) -> Result<Option<String>> {
let mut lines = vec![initial_line];
println!("Multi-line mode: Type your message. End with a line containing only '.' or press Ctrl+D");
loop {
match self.editor.readline("... ") {
Ok(line) => {
if line.trim() == "." {
break;
}
lines.push(line);
}
Err(ReadlineError::Interrupted) => {
println!("^C");
return Ok(None);
}
Err(ReadlineError::Eof) => {
break;
}
Err(err) => return Err(anyhow::anyhow!("Error reading input: {}", err)),
}
}
let full_message = lines.join("\n");
let _ = self.editor.add_history_entry(&full_message);
Ok(Some(full_message))
}
pub fn save_history(&mut self) -> Result<()> {
let history_file = dirs::home_dir()
.map(|home| home.join(".chat_cli_history"))