summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xai_api_module.py429
-rwxr-xr-xcli.py137
2 files changed, 377 insertions, 189 deletions
diff --git a/ai_api_module.py b/ai_api_module.py
index 3cf616b..f69527e 100755
--- a/ai_api_module.py
+++ b/ai_api_module.py
@@ -1,11 +1,9 @@
#!/usr/bin/env python
# COPYRIGHT 2025 Thomas Grothe
-import argparse
import requests
import json
import os
import sys
-import http.client
from util import *
#this is a python module to provide REST API access to some AI model
@@ -16,44 +14,26 @@ from util import *
### CONFIGURATION HERE (yes you could move this into a config file if desired, but i don't need that currently)
# Base URL and endpoints configuration
-
-# the key of the environment variable used to store your api key
api_key_env_key = 'GEMINI_API_KEY'
-
-# the api key, if you prefer to set it here
api_key = None
+urlbase = 'https://generativelanguage.googleapis.com/v1beta'
-urlbase = 'https://generativelanguage.googleapis.com/v1beta' #models/gemini-2.0-flash:generateContent'
endpoints = {
- 'chat': {
- 'simple': urlbase+'/chat/completions', # without custom assistant #use components/schemas/ChatCompletionRequest
- 'assistant': urlbase+'/assistant/--assistant_id--/chat/completions', #use components/schemas/AssistantChatCompletionRequest
- },
'model': {
- 'list': urlbase+'/models',
- },
- 'assistant': {
- 'list': urlbase+'/assistants'
+ 'list': '/models',
+ 'generate': '/models/{model}:generateContent',
+ 'count_tokens': '/models/{model}:countTokens',
},
- 'conversation': {
- 'list': urlbase+'/conversations',
- 'get': urlbase+'/conversation/' #/message_id
- },
- 'genContent': 'models/gemini-2.0-flash:generateContent?key='
+ 'file': {
+ 'list': '/files',
+ 'get': '/files/{name}',
+ 'delete': '/files/{name}',
+ }
}
-# Possible environment variables for API key
-
-#the default message context. this will be appended to or replaced as conversation goes, or if using a previous chat
-messages = [
- {
- 'role': 'system',
- 'content': '', #TODO system prompt instructions
- },
-]
-
-def setupURLParams():
- urlbase += f'?key={api_key}'
+# Conversation state management
+current_conversation = []
+system_instruction = None # System instruction separate from conversation
def get_api_key():
global api_key
@@ -64,120 +44,283 @@ def get_api_key():
return api_key
return None
-def list_models(output_mode = ''):
+def set_system_instruction(instruction):
+ """Set the system instruction for the conversation"""
+ global system_instruction
+ system_instruction = instruction
+ log(f"System instruction set: {instruction}")
+
+def list_models(output_mode=''):
+ """List available models from Gemini API"""
headers = {"Content-Type": "application/json"}
- url = endpoints['model']['list'] + f'?key={api_key}'
- print(f'calling {url}')
- response = requests.get(f"{endpoints['model']['list']}", headers=headers)
- models = []
-
- if response.status_code == 200:
- j = json.loads(response.text)
- if 'j' in output_mode:
- return j
- else:
- for d in j['data']:
- models.append(d)
- return [True, models]
- else:
- return [False, response]
+ url = f"{urlbase}{endpoints['model']['list']}?key={api_key}"
+
+ try:
+ response = requests.get(url, headers=headers)
+ response.raise_for_status()
+
+ data = response.json()
-def list_assistants(output_mode = ''):
- headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
- response = requests.get(f"{endpoints['assistant']['list']}", headers=headers)
- if response.status_code == 200:
- j = json.loads(response.text) # keys: assistants, total_assistants, scores,carousel_display_name,carousel_subtitle
if 'j' in output_mode:
- return j
- else:
- res = []
- for a in j['assistants']:
- res.append(a)
- return [True, res]
- else:
- return [False, response]
-
-def list_conversations(output_mode = ''):
- headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
- response = requests.get(f"{endpoints['conversation']['list']}", headers=headers)
- if response.status_code == 200:
- j = json.loads(response.text)
- if 'j' in output_mode:
- return j
- else:
- res = []
- for a in j.keys(): #iterate through assistants
- for cid in a: #j[]: #convo ids for each assistant
- print('todo')
- #TODO this gives only the first message, with its id, then we have to call /conversation/{m_id}
-
- return [True, res]
-
-def get_conversations(msg_id, output_mode = ''):
- headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
- response = requests.get(f"{endpoints['conversation']['get']}", headers=headers)
-
-
-def query(model_id, query, top_p, temperature, stream=False):
- """Call the model with some query"""
- headers = {"Content-Type": "application/json"} #, "Authorization": f"Bearer {api_key}"}
- url = urlbase+endpoints['genContent']+str(api_key)
- #attachments = [] #file attachments
-
- payload = { #todo options for other parameters
- "contents": [
- {
- "parts": [
- {
- "text": query
- }
- ]
- }
- ]
+ return [True, data]
+
+ models = []
+ for model in data.get('models', []):
+ # Filter to only generation-capable models
+ if 'generateContent' in model.get('supportedGenerationMethods', []):
+ models.append({
+ 'name': model.get('name', ''),
+ 'display_name': model.get('displayName', ''),
+ 'description': model.get('description', ''),
+ 'supported_methods': model.get('supportedGenerationMethods', []),
+ 'input_token_limit': model.get('inputTokenLimit', 0),
+ 'output_token_limit': model.get('outputTokenLimit', 0),
+ })
+ return [True, models]
+
+ except requests.exceptions.RequestException as e:
+ log(f"Error listing models: {e}")
+ return [False, str(e)]
+
+def count_tokens(model_id, text):
+ """Count tokens for given text"""
+ if not model_id:
+ model_id = "gemini-2.0-flash-exp"
+
+ headers = {"Content-Type": "application/json"}
+ endpoint = endpoints['model']['count_tokens'].format(model=model_id)
+ url = f"{urlbase}{endpoint}?key={api_key}"
+
+ payload = {
+ "contents": [{
+ "parts": [{"text": text}]
+ }]
}
- # payload = {
- # "model": model_id,
- # "top_p": top_p,
- # "temperature": temperature,
- # "stream": stream,
- # "messages": [{"role": "user", "content": query}], #TODO include previous conversation here
- # # components/schemas/ChatCompletionsMessage-Input
- # # content: string
- # # attachments:
- # # role: user, assistant, or system
- # # name:
- # # tool_call_id
- # # tool_calls
-
- #}
-
- if stream:
- response = requests.post(url, json=payload, headers=headers, stream=True)
- for line in response.iter_lines():
- if line:
- line = line.decode('utf-8')
- if line.startswith("data: ") and not line.startswith("data: [DONE]"):
- data = line[len("data: "):]
- try:
- content = json.loads(data)["choices"][0]["delta"]["content"]
- if content:
- log(content)
- sys.stdout.write(content)
- sys.stdout.flush()
- except (json.JSONDecodeError, KeyError, IndexError):
- pass
- print()
- else:
+
+ try:
response = requests.post(url, json=payload, headers=headers)
- if response.status_code == 200:
- response_json = response.json()
- #TODO handle message to append
- response = response_json['candidates'][0]['content']['parts'][0]['text']
- log(response)
- return response
- #return response_json["choices"][0]["message"]["content"]
+ response.raise_for_status()
+ data = response.json()
+ return [True, data.get('totalTokens', 0)]
+ except requests.exceptions.RequestException as e:
+ log(f"Error counting tokens: {e}")
+ return [False, str(e)]
+
+def query(model_id, query_text, top_p=0.9, temperature=0.7, stream=False, max_tokens=8192):
+ """Send a query to the Gemini model with conversation context"""
+ global current_conversation, system_instruction
+
+ if not model_id:
+ model_id = "gemini-2.0-flash-exp"
+
+ # Add user message to conversation
+ current_conversation.append({
+ "role": "user",
+ "parts": [{"text": query_text}]
+ })
+
+ headers = {"Content-Type": "application/json"}
+ endpoint = endpoints['model']['generate'].format(model=model_id)
+ url = f"{urlbase}{endpoint}?key={api_key}"
+
+ payload = {
+ "contents": current_conversation,
+ "generationConfig": {
+ "temperature": temperature,
+ "topP": top_p,
+ "maxOutputTokens": max_tokens,
+ }
+ }
+
+ # Add system instruction if set
+ if system_instruction:
+ payload["systemInstruction"] = {
+ "parts": [{"text": system_instruction}]
+ }
+
+ log(f"Sending query to {model_id}: {query_text}")
+
+ try:
+ if stream:
+ return _stream_response(url, payload, headers)
+ else:
+ return _non_stream_response(url, payload, headers)
+
+ except requests.exceptions.RequestException as e:
+ log(f"Error querying model: {e}")
+ print(f"Error: {e}")
+ return None
+
+def _non_stream_response(url, payload, headers):
+ """Handle non-streaming response"""
+ global current_conversation
+
+ response = requests.post(url, json=payload, headers=headers)
+ response.raise_for_status()
+
+ data = response.json()
+
+ # Check for prompt feedback (safety issues)
+ if 'promptFeedback' in data:
+ feedback = data['promptFeedback']
+ if 'blockReason' in feedback:
+ error_msg = f"Prompt blocked: {feedback['blockReason']}"
+ log(error_msg)
+ return error_msg
+
+ if 'candidates' in data and len(data['candidates']) > 0:
+ candidate = data['candidates'][0]
+
+ # Check finish reason
+ finish_reason = candidate.get('finishReason', '')
+ if finish_reason and finish_reason not in ['STOP', '']:
+ log(f"Warning: Finish reason: {finish_reason}")
+
+ if 'content' in candidate:
+ assistant_message = candidate['content']
+ current_conversation.append(assistant_message)
+
+ # Extract text from parts
+ text_parts = []
+ for part in assistant_message.get('parts', []):
+ if 'text' in part:
+ text_parts.append(part['text'])
+
+ response_text = ''.join(text_parts)
+ log(f"Response: {response_text}")
+
+ # Log token usage if available
+ if 'usageMetadata' in data:
+ metadata = data['usageMetadata']
+ log(f"Token usage - Prompt: {metadata.get('promptTokenCount', 0)}, "
+ f"Response: {metadata.get('candidatesTokenCount', 0)}, "
+ f"Total: {metadata.get('totalTokenCount', 0)}")
+
+ return response_text
+
+ log(f"Unexpected response structure: {data}")
+ return None
+
+def _stream_response(url, payload, headers):
+ """Handle streaming response"""
+ global current_conversation
+
+ # Gemini streaming uses SSE with alt=sse parameter
+ stream_url = url + "&alt=sse"
+
+ response = requests.post(stream_url, json=payload, headers=headers, stream=True)
+ response.raise_for_status()
+
+ accumulated_text = []
+
+ for line in response.iter_lines():
+ if line:
+ line = line.decode('utf-8')
+ if line.startswith("data: "):
+ data_str = line[6:]
+ try:
+ data = json.loads(data_str)
+ if 'candidates' in data:
+ for candidate in data['candidates']:
+ if 'content' in candidate:
+ for part in candidate['content'].get('parts', []):
+ if 'text' in part:
+ text = part['text']
+ accumulated_text.append(text)
+ sys.stdout.write(text)
+ sys.stdout.flush()
+ except json.JSONDecodeError:
+ pass
+
+ print() # New line after streaming
+
+ # Add assistant response to conversation
+ full_response = ''.join(accumulated_text)
+ if full_response:
+ current_conversation.append({
+ "role": "model",
+ "parts": [{"text": full_response}]
+ })
+ log(f"Streamed response: {full_response}")
+
+ return full_response
+
+def save_conversation(filepath=None):
+ """Save current conversation to file"""
+ global system_instruction
+
+ if filepath is None:
+ filepath = f"{logdir}/conversation_{tnow()}.json"
+
+ # Ensure directory exists
+ os.makedirs(os.path.dirname(filepath), exist_ok=True)
+
+ conversation_data = {
+ "system_instruction": system_instruction,
+ "messages": current_conversation,
+ "saved_at": tnow()
+ }
+
+ with open(filepath, 'w') as f:
+ json.dump(conversation_data, f, indent=2)
+
+ log(f"Conversation saved to {filepath}")
+ return filepath
+
+def load_conversation(filepath):
+ """Load conversation from file"""
+ global current_conversation, system_instruction
+
+ try:
+ with open(filepath, 'r') as f:
+ data = json.load(f)
+
+ # Handle both old and new formats
+ if isinstance(data, list):
+ # Old format: just messages
+ current_conversation = data
+ system_instruction = None
else:
- print(f"Error: {response.status_code}")
- print(response.text)
- log(response.text)
- return None
+ # New format: with metadata
+ current_conversation = data.get('messages', [])
+ system_instruction = data.get('system_instruction')
+
+ log(f"Conversation loaded from {filepath}")
+ return True
+ except Exception as e:
+ log(f"Error loading conversation: {e}")
+ return False
+
+def list_saved_conversations(directory=None):
+ """List all saved conversation files"""
+ if directory is None:
+ directory = logdir
+
+ try:
+ conversations = []
+ for filename in os.listdir(directory):
+ if filename.startswith('conversation_') and filename.endswith('.json'):
+ filepath = os.path.join(directory, filename)
+ with open(filepath, 'r') as f:
+ data = json.load(f)
+
+ conversations.append({
+ 'filename': filename,
+ 'filepath': filepath,
+ 'saved_at': data.get('saved_at', 'unknown'),
+ 'message_count': len(data.get('messages', data if isinstance(data, list) else [])),
+ })
+
+ # Sort by saved_at
+ conversations.sort(key=lambda x: x['saved_at'], reverse=True)
+ return [True, conversations]
+ except Exception as e:
+ log(f"Error listing conversations: {e}")
+ return [False, str(e)]
+
+def clear_conversation():
+ """Clear current conversation context"""
+ global current_conversation, system_instruction
+ current_conversation = []
+ system_instruction = None
+ log("Conversation cleared")
diff --git a/cli.py b/cli.py
index 843dd68..b940d55 100755
--- a/cli.py
+++ b/cli.py
@@ -3,11 +3,12 @@
import argparse
import os
import sys
+import json
import ai_api_module
#this is the CLI program used to interact with the GAI python module.
-default_model=None
+default_model = "gemini-2.0-flash-exp"
def main():
"""Main function to parse arguments and execute commands"""
@@ -15,69 +16,95 @@ def main():
parser.add_argument(
"--apikey", default=os.getenv("GEMINI_API_KEY"), help="API key for the service"
)
- parser.add_argument("--modelid", default=default_model, help=f"ID of the model to use (default={default_model}")
+ parser.add_argument("--modelid", default=default_model, help=f"ID of the model to use (default={default_model})")
parser.add_argument("--list-models", action="store_true", help="list the available models")
- parser.add_argument("--assistantid", default="default_chatbot", help="ID of the assistant")
- parser.add_argument("--list-assistants", "--list-a", "--list-assistant", action='store_true', help='list all the available assistants')
- parser.add_argument('--list-conversations', '--list-conversation', '--list-convo', '--list-c', action='store_true', help='list previous conversations')
+
+ # Conversation management
+ parser.add_argument("--new", "-n", action="store_true", help="Start a new conversation")
+ parser.add_argument("--save-conversation", metavar="FILE", help="Save conversation to file")
+ parser.add_argument("--load-conversation", metavar="FILE", help="Load conversation from file")
+ parser.add_argument("--list-conversations", action="store_true", help="List saved conversations")
+ parser.add_argument("--clear", action="store_true", help="Clear conversation history")
+ parser.add_argument("--system-instruction", metavar="TEXT", help="Set system instruction for the conversation")
+
+ # Query parameters
parser.add_argument("--query", "-q", default="", help="Query to be sent to the assistant")
parser.add_argument("--topp", type=float, default=0.9, help="Top P value for the model")
- parser.add_argument(
- "--temperature", type=float, default=0.7, help="Temperature value for the model"
- )
+ parser.add_argument("--temperature", type=float, default=0.7, help="Temperature value for the model")
+ parser.add_argument("--max-tokens", type=int, default=8192, help="Maximum output tokens")
parser.add_argument("--stream", action="store_true", help="Enable streaming response")
parser.add_argument("--no-stream", action="store_true", help="Disable streaming response")
- parser.add_argument("--output_mode", "-o", nargs=1, action="store", default='text', help="just output the json")
+ parser.add_argument("--output_mode", "-o", default='text', help="Output mode (text or json)")
+ parser.add_argument("--count-tokens", action="store_true", help="Count tokens in query without sending")
- #parser.add_argument("--file-health", action="store_true", help="")
- #parser.add_argument("--file-list", action="store_true", help="List all uploaded files")
- #parser.add_argument("--file-upload", help="Upload a file to the service")
- #parser.add_argument("--file-status", help="Check status of a file by ID")
- #parser.add_argument("--file-purpose", default="assistants", help="Purpose of the file")
- #parser.add_argument("--attach-files", nargs="*", help="File IDs to attach to the query")
-
- parser.add_argument("--new", "-n", help="force new session and new chat", action="store_true", default=False) #TODO: solidify a good approach here
-
parser.add_argument("extra_query", nargs="*")
args = parser.parse_args()
- # Try to get API key if not provided
+ # Get API key
if args.apikey:
ai_api_module.api_key = args.apikey
else:
ai_api_module.get_api_key()
if not ai_api_module.api_key:
- print("Error: API key not provided")
+ print("Error: API key not provided. Set GEMINI_API_KEY environment variable.")
sys.exit(1)
- if args.list_models:
- res = ai_api_module.list_models(args.output_mode)
- if res[0]:
- for m in res[1]:
- print(str(m))
- #print(f" ModelID: {m['id']}\n Description: {m['description']}\n Tokens: {m['max_completion_tokens']}\n Modalities: {str(m['modalities'])}")
- print()
+ # Handle system instruction
+ if args.system_instruction:
+ ai_api_module.set_system_instruction(args.system_instruction)
+ print(f"System instruction set")
+
+ # Handle conversation management
+ if args.new:
+ ai_api_module.clear_conversation()
+ print("Started new conversation")
+
+ if args.load_conversation:
+ if ai_api_module.load_conversation(args.load_conversation):
+ print(f"Loaded conversation from {args.load_conversation}")
else:
- print('error: ')
- print(str(res[1].text))
- return
-
+ print(f"Failed to load conversation from {args.load_conversation}")
+ sys.exit(1)
+
if args.list_conversations:
- res = ai_api_module.list_conversations(args.output_mode)
- print(res)
- return
-
- if args.list_assistants:
- res = ai_api_module.list_assistants(args.output_mode)
- if res[0]:
- for a in res[1]:
- print(f" ID: {a['id']}\n Name: {a['display_name']}")
- print()
+ success, result = ai_api_module.list_saved_conversations()
+ if success:
+ if not result:
+ print("No saved conversations found")
+ else:
+ print("Saved conversations:")
+ for conv in result:
+ print(f" {conv['filename']}")
+ print(f" Saved: {conv['saved_at']}")
+ print(f" Messages: {conv['message_count']}")
+ print()
else:
- print('error: ' + str(res[1]))
+ print(f"Error: {result}")
+ return
+
+ if args.clear:
+ ai_api_module.clear_conversation()
+ print("Conversation history cleared")
+ return
+
+ # List models
+ if args.list_models:
+ success, result = ai_api_module.list_models(args.output_mode)
+ if success:
+ for model in result:
+ if args.output_mode == 'json':
+ print(json.dumps(model, indent=2))
+ else:
+ print(f"Model: {model['display_name']}")
+ print(f" Name: {model['name']}")
+ print(f" Description: {model.get('description', 'N/A')}")
+ print(f" Input tokens: {model.get('input_token_limit', 'N/A')}")
+ print(f" Output tokens: {model.get('output_token_limit', 'N/A')}")
+ print()
+ else:
+ print(f'Error: {result}')
return
-
# Process query
if args.query or args.extra_query:
@@ -87,17 +114,35 @@ def main():
query += " " + " ".join(args.extra_query)
else:
query = " ".join(args.extra_query)
-
+
+ # Count tokens if requested
+ if args.count_tokens:
+ success, count = ai_api_module.count_tokens(args.modelid, query)
+ if success:
+ print(f"Token count: {count}")
+ else:
+ print(f"Error counting tokens: {count}")
+ return
+
+ # Determine streaming
+ use_stream = args.stream and not args.no_stream
+
response = ai_api_module.query(
args.modelid,
query,
args.topp,
args.temperature,
- not args.no_stream and args.stream,
+ use_stream,
+ args.max_tokens,
)
- if response and not args.stream:
+ if response and not use_stream:
print(response)
+
+ # Auto-save conversation if requested
+ if args.save_conversation:
+ filepath = ai_api_module.save_conversation(args.save_conversation)
+ print(f"\nConversation saved to {filepath}")
else:
parser.print_help()