Centralize provider capabilities

This commit is contained in:
Christopher 2025-08-25 01:27:09 -04:00
parent 1ac7914646
commit 4618f273f3
2 changed files with 200 additions and 98 deletions

View File

@ -93,7 +93,6 @@ impl ChatCLI {
self.session.add_user_message(message.to_string());
self.session.save()?;
// Clone data needed for the API call before getting mutable client reference
let model = self.session.model.clone();
let messages = self.session.messages.clone();
@ -125,10 +124,8 @@ impl ChatCLI {
}) as StreamCallback
};
let client = self.get_client()?.clone();
let response = client
.chat_completion_stream(
&self.session.model,
&self.session.messages,
@ -159,7 +156,6 @@ impl ChatCLI {
let client = self.get_client()?.clone();
let response = client
.chat_completion(
&self.session.model,
&self.session.messages,
@ -460,27 +456,33 @@ impl ChatCLI {
// Check model compatibility
let model = self.session.model.clone();
let provider = get_provider_for_model(&model);
let reasoning_enabled = self.session.enable_reasoning_summary;
let capabilities = provider.capabilities();
// Show compatibility warnings based on provider
match provider {
crate::core::provider::Provider::Anthropic => {
if reasoning_enabled {
self.display.print_warning(
"Reasoning summaries are not supported by Anthropic models",
);
}
if self.session.enable_extended_thinking {
// Extended thinking is supported by Anthropic models
}
// Web search is now supported by Anthropic models
}
crate::core::provider::Provider::OpenAI => {
if self.session.enable_extended_thinking {
self.display
.print_warning("Extended thinking is not supported by OpenAI models");
}
// OpenAI models generally support other features
if self.session.enable_reasoning_summary && !capabilities.supports_reasoning_summaries {
self.display.print_warning(&format!(
"Reasoning summaries are not supported by {} models",
format!("{:?}", provider)
));
}
if self.session.enable_extended_thinking && !capabilities.supports_extended_thinking {
self.display.print_warning(&format!(
"Extended thinking is not supported by {} models",
format!("{:?}", provider)
));
}
if self.session.enable_web_search && !capabilities.supports_web_search {
self.display.print_warning(&format!(
"Web search is not supported by {} models",
format!("{:?}", provider)
));
}
if let Some(min) = capabilities.min_thinking_budget {
if self.session.thinking_budget_tokens < min {
self.display.print_warning(&format!(
"Minimum thinking budget is {} tokens for {} models",
min,
format!("{:?}", provider)
));
}
}
@ -548,15 +550,12 @@ impl ChatCLI {
.print_command_result(&format!("Extended thinking {}", state));
let provider = get_provider_for_model(&self.session.model);
match provider {
crate::core::provider::Provider::OpenAI => {
self.display.print_warning(
"Extended thinking is not supported by OpenAI models",
);
}
crate::core::provider::Provider::Anthropic => {
// Supported
}
let capabilities = provider.capabilities();
if !capabilities.supports_extended_thinking {
self.display.print_warning(&format!(
"Extended thinking is not supported by {} models",
format!("{:?}", provider)
));
}
}
Some("Set Thinking Budget") => {
@ -575,16 +574,19 @@ impl ChatCLI {
));
let provider = get_provider_for_model(&self.session.model);
match provider {
crate::core::provider::Provider::OpenAI => {
self.display.print_warning(
"Extended thinking is not supported by OpenAI models",
);
}
crate::core::provider::Provider::Anthropic => {
if budget < 1024 {
self.display.print_warning("Minimum thinking budget is 1024 tokens for Anthropic models");
}
let capabilities = provider.capabilities();
if !capabilities.supports_extended_thinking {
self.display.print_warning(&format!(
"Extended thinking is not supported by {} models",
format!("{:?}", provider)
));
} else if let Some(min) = capabilities.min_thinking_budget {
if budget < min {
self.display.print_warning(&format!(
"Minimum thinking budget is {} tokens for {} models",
min,
format!("{:?}", provider)
));
}
}
}

View File

@ -6,6 +6,14 @@ pub enum Provider {
Anthropic,
}
#[derive(Debug, Clone, Copy)]
pub struct ProviderCapabilities {
pub supports_web_search: bool,
pub supports_reasoning_summaries: bool,
pub supports_extended_thinking: bool,
pub min_thinking_budget: Option<u32>,
}
pub struct ModelInfo {
pub model_id: &'static str,
pub display_name: &'static str,
@ -18,6 +26,23 @@ impl Provider {
Provider::Anthropic => "anthropic",
}
}
pub fn capabilities(&self) -> ProviderCapabilities {
match self {
Provider::OpenAI => ProviderCapabilities {
supports_web_search: true,
supports_reasoning_summaries: true,
supports_extended_thinking: false,
min_thinking_budget: None,
},
Provider::Anthropic => ProviderCapabilities {
supports_web_search: true,
supports_reasoning_summaries: false,
supports_extended_thinking: true,
min_thinking_budget: Some(1024),
},
}
}
}
pub fn get_supported_models() -> HashMap<Provider, Vec<&'static str>> {
@ -55,22 +80,63 @@ pub fn get_supported_models() -> HashMap<Provider, Vec<&'static str>> {
pub fn get_model_info_list() -> Vec<ModelInfo> {
vec![
// OpenAI models
ModelInfo { model_id: "gpt-4.1", display_name: "GPT-4.1" },
ModelInfo { model_id: "gpt-4.1-mini", display_name: "GPT-4.1 Mini" },
ModelInfo { model_id: "gpt-4o", display_name: "GPT-4o" },
ModelInfo { model_id: "gpt-5", display_name: "GPT-5" },
ModelInfo { model_id: "gpt-5-chat-latest", display_name: "GPT-5 Chat Latest" },
ModelInfo { model_id: "o1", display_name: "o1" },
ModelInfo { model_id: "o3", display_name: "o3" },
ModelInfo { model_id: "o4-mini", display_name: "o4 Mini" },
ModelInfo { model_id: "o3-mini", display_name: "o3 Mini" },
ModelInfo {
model_id: "gpt-4.1",
display_name: "GPT-4.1",
},
ModelInfo {
model_id: "gpt-4.1-mini",
display_name: "GPT-4.1 Mini",
},
ModelInfo {
model_id: "gpt-4o",
display_name: "GPT-4o",
},
ModelInfo {
model_id: "gpt-5",
display_name: "GPT-5",
},
ModelInfo {
model_id: "gpt-5-chat-latest",
display_name: "GPT-5 Chat Latest",
},
ModelInfo {
model_id: "o1",
display_name: "o1",
},
ModelInfo {
model_id: "o3",
display_name: "o3",
},
ModelInfo {
model_id: "o4-mini",
display_name: "o4 Mini",
},
ModelInfo {
model_id: "o3-mini",
display_name: "o3 Mini",
},
// Anthropic models with friendly names
ModelInfo { model_id: "claude-opus-4-1-20250805", display_name: "Claude Opus 4.1" },
ModelInfo { model_id: "claude-sonnet-4-20250514", display_name: "Claude Sonnet 4.0" },
ModelInfo { model_id: "claude-3-7-sonnet-20250219", display_name: "Claude 3.7 Sonnet" },
ModelInfo { model_id: "claude-3-5-haiku-20241022", display_name: "Claude 3.5 Haiku" },
ModelInfo { model_id: "claude-3-haiku-20240307", display_name: "Claude 3.0 Haiku" },
ModelInfo {
model_id: "claude-opus-4-1-20250805",
display_name: "Claude Opus 4.1",
},
ModelInfo {
model_id: "claude-sonnet-4-20250514",
display_name: "Claude Sonnet 4.0",
},
ModelInfo {
model_id: "claude-3-7-sonnet-20250219",
display_name: "Claude 3.7 Sonnet",
},
ModelInfo {
model_id: "claude-3-5-haiku-20241022",
display_name: "Claude 3.5 Haiku",
},
ModelInfo {
model_id: "claude-3-haiku-20240307",
display_name: "Claude 3.0 Haiku",
},
]
}
@ -156,7 +222,9 @@ mod tests {
assert!(gpt5_info.is_some());
assert_eq!(gpt5_info.unwrap().display_name, "GPT-5");
let claude_info = model_infos.iter().find(|info| info.model_id == "claude-sonnet-4-20250514");
let claude_info = model_infos
.iter()
.find(|info| info.model_id == "claude-sonnet-4-20250514");
assert!(claude_info.is_some());
assert_eq!(claude_info.unwrap().display_name, "Claude Sonnet 4.0");
}
@ -165,18 +233,30 @@ mod tests {
fn test_get_display_name_for_model() {
// Test known models
assert_eq!(get_display_name_for_model("gpt-5"), "GPT-5");
assert_eq!(get_display_name_for_model("claude-sonnet-4-20250514"), "Claude Sonnet 4.0");
assert_eq!(
get_display_name_for_model("claude-sonnet-4-20250514"),
"Claude Sonnet 4.0"
);
assert_eq!(get_display_name_for_model("o1"), "o1");
// Test unknown model (should return the model_id itself)
assert_eq!(get_display_name_for_model("unknown-model-123"), "unknown-model-123");
assert_eq!(
get_display_name_for_model("unknown-model-123"),
"unknown-model-123"
);
}
#[test]
fn test_get_model_id_from_display_name() {
// Test known display names
assert_eq!(get_model_id_from_display_name("GPT-5"), Some("gpt-5".to_string()));
assert_eq!(get_model_id_from_display_name("Claude Sonnet 4.0"), Some("claude-sonnet-4-20250514".to_string()));
assert_eq!(
get_model_id_from_display_name("GPT-5"),
Some("gpt-5".to_string())
);
assert_eq!(
get_model_id_from_display_name("Claude Sonnet 4.0"),
Some("claude-sonnet-4-20250514".to_string())
);
assert_eq!(get_model_id_from_display_name("o1"), Some("o1".to_string()));
// Test unknown display name
@ -205,12 +285,24 @@ mod tests {
assert_eq!(get_provider_for_model("gpt-4.1"), Provider::OpenAI);
// Test Anthropic models
assert_eq!(get_provider_for_model("claude-sonnet-4-20250514"), Provider::Anthropic);
assert_eq!(get_provider_for_model("claude-3-5-haiku-20241022"), Provider::Anthropic);
assert_eq!(get_provider_for_model("claude-opus-4-1-20250805"), Provider::Anthropic);
assert_eq!(
get_provider_for_model("claude-sonnet-4-20250514"),
Provider::Anthropic
);
assert_eq!(
get_provider_for_model("claude-3-5-haiku-20241022"),
Provider::Anthropic
);
assert_eq!(
get_provider_for_model("claude-opus-4-1-20250805"),
Provider::Anthropic
);
// Test unknown model (should default to OpenAI)
assert_eq!(get_provider_for_model("unknown-model-123"), Provider::OpenAI);
assert_eq!(
get_provider_for_model("unknown-model-123"),
Provider::OpenAI
);
}
#[test]
@ -264,7 +356,10 @@ mod tests {
for info in model_infos {
assert!(!info.model_id.is_empty(), "Model ID should not be empty");
assert!(!info.display_name.is_empty(), "Display name should not be empty");
assert!(
!info.display_name.is_empty(),
"Display name should not be empty"
);
// Display name should be different from model_id for most cases
// (though some might be the same like "o1")
@ -280,14 +375,19 @@ mod tests {
// All models in get_all_models should be in supported_models
for model in &all_models {
let found = supported_models.values().any(|models| models.contains(model));
let found = supported_models
.values()
.any(|models| models.contains(model));
assert!(found, "Model {} not found in supported_models", model);
}
// All models in model_infos should be in all_models
for info in &model_infos {
assert!(all_models.contains(&info.model_id),
"Model {} from model_infos not found in all_models", info.model_id);
assert!(
all_models.contains(&info.model_id),
"Model {} from model_infos not found in all_models",
info.model_id
);
}
// All models in all_models should have corresponding model_info