summaryrefslogtreecommitdiff
path: root/ai_api_module.py
diff options
context:
space:
mode:
Diffstat (limited to 'ai_api_module.py')
-rwxr-xr-xai_api_module.py429
1 files changed, 286 insertions, 143 deletions
diff --git a/ai_api_module.py b/ai_api_module.py
index 3cf616b..f69527e 100755
--- a/ai_api_module.py
+++ b/ai_api_module.py
@@ -1,11 +1,9 @@
#!/usr/bin/env python
# COPYRIGHT 2025 Thomas Grothe
-import argparse
import requests
import json
import os
import sys
-import http.client
from util import *
#this is a python module to provide REST API access to some AI model
@@ -16,44 +14,26 @@ from util import *
### CONFIGURATION HERE (yes you could move this into a config file if desired, but i don't need that currently)
# Base URL and endpoints configuration
-
-# the key of the environment variable used to store your api key
api_key_env_key = 'GEMINI_API_KEY'
-
-# the api key, if you prefer to set it here
api_key = None
+urlbase = 'https://generativelanguage.googleapis.com/v1beta'
-urlbase = 'https://generativelanguage.googleapis.com/v1beta' #models/gemini-2.0-flash:generateContent'
endpoints = {
- 'chat': {
- 'simple': urlbase+'/chat/completions', # without custom assistant #use components/schemas/ChatCompletionRequest
- 'assistant': urlbase+'/assistant/--assistant_id--/chat/completions', #use components/schemas/AssistantChatCompletionRequest
- },
'model': {
- 'list': urlbase+'/models',
- },
- 'assistant': {
- 'list': urlbase+'/assistants'
+ 'list': '/models',
+ 'generate': '/models/{model}:generateContent',
+ 'count_tokens': '/models/{model}:countTokens',
},
- 'conversation': {
- 'list': urlbase+'/conversations',
- 'get': urlbase+'/conversation/' #/message_id
- },
- 'genContent': 'models/gemini-2.0-flash:generateContent?key='
+ 'file': {
+ 'list': '/files',
+ 'get': '/files/{name}',
+ 'delete': '/files/{name}',
+ }
}
-# Possible environment variables for API key
-
-#the default message context. this will be appended to or replaced as conversation goes, or if using a previous chat
-messages = [
- {
- 'role': 'system',
- 'content': '', #TODO system prompt instructions
- },
-]
-
-def setupURLParams():
- urlbase += f'?key={api_key}'
+# Conversation state management
+current_conversation = []
+system_instruction = None # System instruction separate from conversation
def get_api_key():
global api_key
@@ -64,120 +44,283 @@ def get_api_key():
return api_key
return None
-def list_models(output_mode = ''):
+def set_system_instruction(instruction):
+ """Set the system instruction for the conversation"""
+ global system_instruction
+ system_instruction = instruction
+ log(f"System instruction set: {instruction}")
+
+def list_models(output_mode=''):
+ """List available models from Gemini API"""
headers = {"Content-Type": "application/json"}
- url = endpoints['model']['list'] + f'?key={api_key}'
- print(f'calling {url}')
- response = requests.get(f"{endpoints['model']['list']}", headers=headers)
- models = []
-
- if response.status_code == 200:
- j = json.loads(response.text)
- if 'j' in output_mode:
- return j
- else:
- for d in j['data']:
- models.append(d)
- return [True, models]
- else:
- return [False, response]
+ url = f"{urlbase}{endpoints['model']['list']}?key={api_key}"
+
+ try:
+ response = requests.get(url, headers=headers)
+ response.raise_for_status()
+
+ data = response.json()
-def list_assistants(output_mode = ''):
- headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
- response = requests.get(f"{endpoints['assistant']['list']}", headers=headers)
- if response.status_code == 200:
- j = json.loads(response.text) # keys: assistants, total_assistants, scores,carousel_display_name,carousel_subtitle
if 'j' in output_mode:
- return j
- else:
- res = []
- for a in j['assistants']:
- res.append(a)
- return [True, res]
- else:
- return [False, response]
-
-def list_conversations(output_mode = ''):
- headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
- response = requests.get(f"{endpoints['conversation']['list']}", headers=headers)
- if response.status_code == 200:
- j = json.loads(response.text)
- if 'j' in output_mode:
- return j
- else:
- res = []
- for a in j.keys(): #iterate through assistants
- for cid in a: #j[]: #convo ids for each assistant
- print('todo')
- #TODO this gives only the first message, with its id, then we have to call /conversation/{m_id}
-
- return [True, res]
-
-def get_conversations(msg_id, output_mode = ''):
- headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
- response = requests.get(f"{endpoints['conversation']['get']}", headers=headers)
-
-
-def query(model_id, query, top_p, temperature, stream=False):
- """Call the model with some query"""
- headers = {"Content-Type": "application/json"} #, "Authorization": f"Bearer {api_key}"}
- url = urlbase+endpoints['genContent']+str(api_key)
- #attachments = [] #file attachments
-
- payload = { #todo options for other parameters
- "contents": [
- {
- "parts": [
- {
- "text": query
- }
- ]
- }
- ]
+ return [True, data]
+
+ models = []
+ for model in data.get('models', []):
+ # Filter to only generation-capable models
+ if 'generateContent' in model.get('supportedGenerationMethods', []):
+ models.append({
+ 'name': model.get('name', ''),
+ 'display_name': model.get('displayName', ''),
+ 'description': model.get('description', ''),
+ 'supported_methods': model.get('supportedGenerationMethods', []),
+ 'input_token_limit': model.get('inputTokenLimit', 0),
+ 'output_token_limit': model.get('outputTokenLimit', 0),
+ })
+ return [True, models]
+
+ except requests.exceptions.RequestException as e:
+ log(f"Error listing models: {e}")
+ return [False, str(e)]
+
+def count_tokens(model_id, text):
+ """Count tokens for given text"""
+ if not model_id:
+ model_id = "gemini-2.0-flash-exp"
+
+ headers = {"Content-Type": "application/json"}
+ endpoint = endpoints['model']['count_tokens'].format(model=model_id)
+ url = f"{urlbase}{endpoint}?key={api_key}"
+
+ payload = {
+ "contents": [{
+ "parts": [{"text": text}]
+ }]
}
- # payload = {
- # "model": model_id,
- # "top_p": top_p,
- # "temperature": temperature,
- # "stream": stream,
- # "messages": [{"role": "user", "content": query}], #TODO include previous conversation here
- # # components/schemas/ChatCompletionsMessage-Input
- # # content: string
- # # attachments:
- # # role: user, assistant, or system
- # # name:
- # # tool_call_id
- # # tool_calls
-
- #}
-
- if stream:
- response = requests.post(url, json=payload, headers=headers, stream=True)
- for line in response.iter_lines():
- if line:
- line = line.decode('utf-8')
- if line.startswith("data: ") and not line.startswith("data: [DONE]"):
- data = line[len("data: "):]
- try:
- content = json.loads(data)["choices"][0]["delta"]["content"]
- if content:
- log(content)
- sys.stdout.write(content)
- sys.stdout.flush()
- except (json.JSONDecodeError, KeyError, IndexError):
- pass
- print()
- else:
+
+ try:
response = requests.post(url, json=payload, headers=headers)
- if response.status_code == 200:
- response_json = response.json()
- #TODO handle message to append
- response = response_json['candidates'][0]['content']['parts'][0]['text']
- log(response)
- return response
- #return response_json["choices"][0]["message"]["content"]
+ response.raise_for_status()
+ data = response.json()
+ return [True, data.get('totalTokens', 0)]
+ except requests.exceptions.RequestException as e:
+ log(f"Error counting tokens: {e}")
+ return [False, str(e)]
+
+def query(model_id, query_text, top_p=0.9, temperature=0.7, stream=False, max_tokens=8192):
+ """Send a query to the Gemini model with conversation context"""
+ global current_conversation, system_instruction
+
+ if not model_id:
+ model_id = "gemini-2.0-flash-exp"
+
+ # Add user message to conversation
+ current_conversation.append({
+ "role": "user",
+ "parts": [{"text": query_text}]
+ })
+
+ headers = {"Content-Type": "application/json"}
+ endpoint = endpoints['model']['generate'].format(model=model_id)
+ url = f"{urlbase}{endpoint}?key={api_key}"
+
+ payload = {
+ "contents": current_conversation,
+ "generationConfig": {
+ "temperature": temperature,
+ "topP": top_p,
+ "maxOutputTokens": max_tokens,
+ }
+ }
+
+ # Add system instruction if set
+ if system_instruction:
+ payload["systemInstruction"] = {
+ "parts": [{"text": system_instruction}]
+ }
+
+ log(f"Sending query to {model_id}: {query_text}")
+
+ try:
+ if stream:
+ return _stream_response(url, payload, headers)
+ else:
+ return _non_stream_response(url, payload, headers)
+
+ except requests.exceptions.RequestException as e:
+ log(f"Error querying model: {e}")
+ print(f"Error: {e}")
+ return None
+
+def _non_stream_response(url, payload, headers):
+ """Handle non-streaming response"""
+ global current_conversation
+
+ response = requests.post(url, json=payload, headers=headers)
+ response.raise_for_status()
+
+ data = response.json()
+
+ # Check for prompt feedback (safety issues)
+ if 'promptFeedback' in data:
+ feedback = data['promptFeedback']
+ if 'blockReason' in feedback:
+ error_msg = f"Prompt blocked: {feedback['blockReason']}"
+ log(error_msg)
+ return error_msg
+
+ if 'candidates' in data and len(data['candidates']) > 0:
+ candidate = data['candidates'][0]
+
+ # Check finish reason
+ finish_reason = candidate.get('finishReason', '')
+ if finish_reason and finish_reason not in ['STOP', '']:
+ log(f"Warning: Finish reason: {finish_reason}")
+
+ if 'content' in candidate:
+ assistant_message = candidate['content']
+ current_conversation.append(assistant_message)
+
+ # Extract text from parts
+ text_parts = []
+ for part in assistant_message.get('parts', []):
+ if 'text' in part:
+ text_parts.append(part['text'])
+
+ response_text = ''.join(text_parts)
+ log(f"Response: {response_text}")
+
+ # Log token usage if available
+ if 'usageMetadata' in data:
+ metadata = data['usageMetadata']
+ log(f"Token usage - Prompt: {metadata.get('promptTokenCount', 0)}, "
+ f"Response: {metadata.get('candidatesTokenCount', 0)}, "
+ f"Total: {metadata.get('totalTokenCount', 0)}")
+
+ return response_text
+
+ log(f"Unexpected response structure: {data}")
+ return None
+
+def _stream_response(url, payload, headers):
+ """Handle streaming response"""
+ global current_conversation
+
+ # Gemini streaming uses SSE with alt=sse parameter
+ stream_url = url + "&alt=sse"
+
+ response = requests.post(stream_url, json=payload, headers=headers, stream=True)
+ response.raise_for_status()
+
+ accumulated_text = []
+
+ for line in response.iter_lines():
+ if line:
+ line = line.decode('utf-8')
+ if line.startswith("data: "):
+ data_str = line[6:]
+ try:
+ data = json.loads(data_str)
+ if 'candidates' in data:
+ for candidate in data['candidates']:
+ if 'content' in candidate:
+ for part in candidate['content'].get('parts', []):
+ if 'text' in part:
+ text = part['text']
+ accumulated_text.append(text)
+ sys.stdout.write(text)
+ sys.stdout.flush()
+ except json.JSONDecodeError:
+ pass
+
+ print() # New line after streaming
+
+ # Add assistant response to conversation
+ full_response = ''.join(accumulated_text)
+ if full_response:
+ current_conversation.append({
+ "role": "model",
+ "parts": [{"text": full_response}]
+ })
+ log(f"Streamed response: {full_response}")
+
+ return full_response
+
+def save_conversation(filepath=None):
+ """Save current conversation to file"""
+ global system_instruction
+
+ if filepath is None:
+ filepath = f"{logdir}/conversation_{tnow()}.json"
+
+ # Ensure directory exists
+ os.makedirs(os.path.dirname(filepath), exist_ok=True)
+
+ conversation_data = {
+ "system_instruction": system_instruction,
+ "messages": current_conversation,
+ "saved_at": tnow()
+ }
+
+ with open(filepath, 'w') as f:
+ json.dump(conversation_data, f, indent=2)
+
+ log(f"Conversation saved to {filepath}")
+ return filepath
+
+def load_conversation(filepath):
+ """Load conversation from file"""
+ global current_conversation, system_instruction
+
+ try:
+ with open(filepath, 'r') as f:
+ data = json.load(f)
+
+ # Handle both old and new formats
+ if isinstance(data, list):
+ # Old format: just messages
+ current_conversation = data
+ system_instruction = None
else:
- print(f"Error: {response.status_code}")
- print(response.text)
- log(response.text)
- return None
+ # New format: with metadata
+ current_conversation = data.get('messages', [])
+ system_instruction = data.get('system_instruction')
+
+ log(f"Conversation loaded from {filepath}")
+ return True
+ except Exception as e:
+ log(f"Error loading conversation: {e}")
+ return False
+
+def list_saved_conversations(directory=None):
+ """List all saved conversation files"""
+ if directory is None:
+ directory = logdir
+
+ try:
+ conversations = []
+ for filename in os.listdir(directory):
+ if filename.startswith('conversation_') and filename.endswith('.json'):
+ filepath = os.path.join(directory, filename)
+ with open(filepath, 'r') as f:
+ data = json.load(f)
+
+ conversations.append({
+ 'filename': filename,
+ 'filepath': filepath,
+ 'saved_at': data.get('saved_at', 'unknown'),
+ 'message_count': len(data.get('messages', data if isinstance(data, list) else [])),
+ })
+
+ # Sort by saved_at
+ conversations.sort(key=lambda x: x['saved_at'], reverse=True)
+ return [True, conversations]
+ except Exception as e:
+ log(f"Error listing conversations: {e}")
+ return [False, str(e)]
+
+def clear_conversation():
+ """Clear current conversation context"""
+ global current_conversation, system_instruction
+ current_conversation = []
+ system_instruction = None
+ log("Conversation cleared")