diff options
| -rwxr-xr-x | ai_api_module.py | 429 | ||||
| -rwxr-xr-x | cli.py | 137 |
2 files changed, 377 insertions, 189 deletions
diff --git a/ai_api_module.py b/ai_api_module.py index 3cf616b..f69527e 100755 --- a/ai_api_module.py +++ b/ai_api_module.py @@ -1,11 +1,9 @@ #!/usr/bin/env python # COPYRIGHT 2025 Thomas Grothe -import argparse import requests import json import os import sys -import http.client from util import * #this is a python module to provide REST API access to some AI model @@ -16,44 +14,26 @@ from util import * ### CONFIGURATION HERE (yes you could move this into a config file if desired, but i don't need that currently) # Base URL and endpoints configuration - -# the key of the environment variable used to store your api key api_key_env_key = 'GEMINI_API_KEY' - -# the api key, if you prefer to set it here api_key = None +urlbase = 'https://generativelanguage.googleapis.com/v1beta' -urlbase = 'https://generativelanguage.googleapis.com/v1beta' #models/gemini-2.0-flash:generateContent' endpoints = { - 'chat': { - 'simple': urlbase+'/chat/completions', # without custom assistant #use components/schemas/ChatCompletionRequest - 'assistant': urlbase+'/assistant/--assistant_id--/chat/completions', #use components/schemas/AssistantChatCompletionRequest - }, 'model': { - 'list': urlbase+'/models', - }, - 'assistant': { - 'list': urlbase+'/assistants' + 'list': '/models', + 'generate': '/models/{model}:generateContent', + 'count_tokens': '/models/{model}:countTokens', }, - 'conversation': { - 'list': urlbase+'/conversations', - 'get': urlbase+'/conversation/' #/message_id - }, - 'genContent': 'models/gemini-2.0-flash:generateContent?key=' + 'file': { + 'list': '/files', + 'get': '/files/{name}', + 'delete': '/files/{name}', + } } -# Possible environment variables for API key - -#the default message context. this will be appended to or replaced as conversation goes, or if using a previous chat -messages = [ - { - 'role': 'system', - 'content': '', #TODO system prompt instructions - }, -] - -def setupURLParams(): - urlbase += f'?key={api_key}' +# Conversation state management +current_conversation = [] +system_instruction = None # System instruction separate from conversation def get_api_key(): global api_key @@ -64,120 +44,283 @@ def get_api_key(): return api_key return None -def list_models(output_mode = ''): +def set_system_instruction(instruction): + """Set the system instruction for the conversation""" + global system_instruction + system_instruction = instruction + log(f"System instruction set: {instruction}") + +def list_models(output_mode=''): + """List available models from Gemini API""" headers = {"Content-Type": "application/json"} - url = endpoints['model']['list'] + f'?key={api_key}' - print(f'calling {url}') - response = requests.get(f"{endpoints['model']['list']}", headers=headers) - models = [] - - if response.status_code == 200: - j = json.loads(response.text) - if 'j' in output_mode: - return j - else: - for d in j['data']: - models.append(d) - return [True, models] - else: - return [False, response] + url = f"{urlbase}{endpoints['model']['list']}?key={api_key}" + + try: + response = requests.get(url, headers=headers) + response.raise_for_status() + + data = response.json() -def list_assistants(output_mode = ''): - headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} - response = requests.get(f"{endpoints['assistant']['list']}", headers=headers) - if response.status_code == 200: - j = json.loads(response.text) # keys: assistants, total_assistants, scores,carousel_display_name,carousel_subtitle if 'j' in output_mode: - return j - else: - res = [] - for a in j['assistants']: - res.append(a) - return [True, res] - else: - return [False, response] - -def list_conversations(output_mode = ''): - headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} - response = requests.get(f"{endpoints['conversation']['list']}", headers=headers) - if response.status_code == 200: - j = json.loads(response.text) - if 'j' in output_mode: - return j - else: - res = [] - for a in j.keys(): #iterate through assistants - for cid in a: #j[]: #convo ids for each assistant - print('todo') - #TODO this gives only the first message, with its id, then we have to call /conversation/{m_id} - - return [True, res] - -def get_conversations(msg_id, output_mode = ''): - headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} - response = requests.get(f"{endpoints['conversation']['get']}", headers=headers) - - -def query(model_id, query, top_p, temperature, stream=False): - """Call the model with some query""" - headers = {"Content-Type": "application/json"} #, "Authorization": f"Bearer {api_key}"} - url = urlbase+endpoints['genContent']+str(api_key) - #attachments = [] #file attachments - - payload = { #todo options for other parameters - "contents": [ - { - "parts": [ - { - "text": query - } - ] - } - ] + return [True, data] + + models = [] + for model in data.get('models', []): + # Filter to only generation-capable models + if 'generateContent' in model.get('supportedGenerationMethods', []): + models.append({ + 'name': model.get('name', ''), + 'display_name': model.get('displayName', ''), + 'description': model.get('description', ''), + 'supported_methods': model.get('supportedGenerationMethods', []), + 'input_token_limit': model.get('inputTokenLimit', 0), + 'output_token_limit': model.get('outputTokenLimit', 0), + }) + return [True, models] + + except requests.exceptions.RequestException as e: + log(f"Error listing models: {e}") + return [False, str(e)] + +def count_tokens(model_id, text): + """Count tokens for given text""" + if not model_id: + model_id = "gemini-2.0-flash-exp" + + headers = {"Content-Type": "application/json"} + endpoint = endpoints['model']['count_tokens'].format(model=model_id) + url = f"{urlbase}{endpoint}?key={api_key}" + + payload = { + "contents": [{ + "parts": [{"text": text}] + }] } - # payload = { - # "model": model_id, - # "top_p": top_p, - # "temperature": temperature, - # "stream": stream, - # "messages": [{"role": "user", "content": query}], #TODO include previous conversation here - # # components/schemas/ChatCompletionsMessage-Input - # # content: string - # # attachments: - # # role: user, assistant, or system - # # name: - # # tool_call_id - # # tool_calls - - #} - - if stream: - response = requests.post(url, json=payload, headers=headers, stream=True) - for line in response.iter_lines(): - if line: - line = line.decode('utf-8') - if line.startswith("data: ") and not line.startswith("data: [DONE]"): - data = line[len("data: "):] - try: - content = json.loads(data)["choices"][0]["delta"]["content"] - if content: - log(content) - sys.stdout.write(content) - sys.stdout.flush() - except (json.JSONDecodeError, KeyError, IndexError): - pass - print() - else: + + try: response = requests.post(url, json=payload, headers=headers) - if response.status_code == 200: - response_json = response.json() - #TODO handle message to append - response = response_json['candidates'][0]['content']['parts'][0]['text'] - log(response) - return response - #return response_json["choices"][0]["message"]["content"] + response.raise_for_status() + data = response.json() + return [True, data.get('totalTokens', 0)] + except requests.exceptions.RequestException as e: + log(f"Error counting tokens: {e}") + return [False, str(e)] + +def query(model_id, query_text, top_p=0.9, temperature=0.7, stream=False, max_tokens=8192): + """Send a query to the Gemini model with conversation context""" + global current_conversation, system_instruction + + if not model_id: + model_id = "gemini-2.0-flash-exp" + + # Add user message to conversation + current_conversation.append({ + "role": "user", + "parts": [{"text": query_text}] + }) + + headers = {"Content-Type": "application/json"} + endpoint = endpoints['model']['generate'].format(model=model_id) + url = f"{urlbase}{endpoint}?key={api_key}" + + payload = { + "contents": current_conversation, + "generationConfig": { + "temperature": temperature, + "topP": top_p, + "maxOutputTokens": max_tokens, + } + } + + # Add system instruction if set + if system_instruction: + payload["systemInstruction"] = { + "parts": [{"text": system_instruction}] + } + + log(f"Sending query to {model_id}: {query_text}") + + try: + if stream: + return _stream_response(url, payload, headers) + else: + return _non_stream_response(url, payload, headers) + + except requests.exceptions.RequestException as e: + log(f"Error querying model: {e}") + print(f"Error: {e}") + return None + +def _non_stream_response(url, payload, headers): + """Handle non-streaming response""" + global current_conversation + + response = requests.post(url, json=payload, headers=headers) + response.raise_for_status() + + data = response.json() + + # Check for prompt feedback (safety issues) + if 'promptFeedback' in data: + feedback = data['promptFeedback'] + if 'blockReason' in feedback: + error_msg = f"Prompt blocked: {feedback['blockReason']}" + log(error_msg) + return error_msg + + if 'candidates' in data and len(data['candidates']) > 0: + candidate = data['candidates'][0] + + # Check finish reason + finish_reason = candidate.get('finishReason', '') + if finish_reason and finish_reason not in ['STOP', '']: + log(f"Warning: Finish reason: {finish_reason}") + + if 'content' in candidate: + assistant_message = candidate['content'] + current_conversation.append(assistant_message) + + # Extract text from parts + text_parts = [] + for part in assistant_message.get('parts', []): + if 'text' in part: + text_parts.append(part['text']) + + response_text = ''.join(text_parts) + log(f"Response: {response_text}") + + # Log token usage if available + if 'usageMetadata' in data: + metadata = data['usageMetadata'] + log(f"Token usage - Prompt: {metadata.get('promptTokenCount', 0)}, " + f"Response: {metadata.get('candidatesTokenCount', 0)}, " + f"Total: {metadata.get('totalTokenCount', 0)}") + + return response_text + + log(f"Unexpected response structure: {data}") + return None + +def _stream_response(url, payload, headers): + """Handle streaming response""" + global current_conversation + + # Gemini streaming uses SSE with alt=sse parameter + stream_url = url + "&alt=sse" + + response = requests.post(stream_url, json=payload, headers=headers, stream=True) + response.raise_for_status() + + accumulated_text = [] + + for line in response.iter_lines(): + if line: + line = line.decode('utf-8') + if line.startswith("data: "): + data_str = line[6:] + try: + data = json.loads(data_str) + if 'candidates' in data: + for candidate in data['candidates']: + if 'content' in candidate: + for part in candidate['content'].get('parts', []): + if 'text' in part: + text = part['text'] + accumulated_text.append(text) + sys.stdout.write(text) + sys.stdout.flush() + except json.JSONDecodeError: + pass + + print() # New line after streaming + + # Add assistant response to conversation + full_response = ''.join(accumulated_text) + if full_response: + current_conversation.append({ + "role": "model", + "parts": [{"text": full_response}] + }) + log(f"Streamed response: {full_response}") + + return full_response + +def save_conversation(filepath=None): + """Save current conversation to file""" + global system_instruction + + if filepath is None: + filepath = f"{logdir}/conversation_{tnow()}.json" + + # Ensure directory exists + os.makedirs(os.path.dirname(filepath), exist_ok=True) + + conversation_data = { + "system_instruction": system_instruction, + "messages": current_conversation, + "saved_at": tnow() + } + + with open(filepath, 'w') as f: + json.dump(conversation_data, f, indent=2) + + log(f"Conversation saved to {filepath}") + return filepath + +def load_conversation(filepath): + """Load conversation from file""" + global current_conversation, system_instruction + + try: + with open(filepath, 'r') as f: + data = json.load(f) + + # Handle both old and new formats + if isinstance(data, list): + # Old format: just messages + current_conversation = data + system_instruction = None else: - print(f"Error: {response.status_code}") - print(response.text) - log(response.text) - return None + # New format: with metadata + current_conversation = data.get('messages', []) + system_instruction = data.get('system_instruction') + + log(f"Conversation loaded from {filepath}") + return True + except Exception as e: + log(f"Error loading conversation: {e}") + return False + +def list_saved_conversations(directory=None): + """List all saved conversation files""" + if directory is None: + directory = logdir + + try: + conversations = [] + for filename in os.listdir(directory): + if filename.startswith('conversation_') and filename.endswith('.json'): + filepath = os.path.join(directory, filename) + with open(filepath, 'r') as f: + data = json.load(f) + + conversations.append({ + 'filename': filename, + 'filepath': filepath, + 'saved_at': data.get('saved_at', 'unknown'), + 'message_count': len(data.get('messages', data if isinstance(data, list) else [])), + }) + + # Sort by saved_at + conversations.sort(key=lambda x: x['saved_at'], reverse=True) + return [True, conversations] + except Exception as e: + log(f"Error listing conversations: {e}") + return [False, str(e)] + +def clear_conversation(): + """Clear current conversation context""" + global current_conversation, system_instruction + current_conversation = [] + system_instruction = None + log("Conversation cleared") @@ -3,11 +3,12 @@ import argparse import os import sys +import json import ai_api_module #this is the CLI program used to interact with the GAI python module. -default_model=None +default_model = "gemini-2.0-flash-exp" def main(): """Main function to parse arguments and execute commands""" @@ -15,69 +16,95 @@ def main(): parser.add_argument( "--apikey", default=os.getenv("GEMINI_API_KEY"), help="API key for the service" ) - parser.add_argument("--modelid", default=default_model, help=f"ID of the model to use (default={default_model}") + parser.add_argument("--modelid", default=default_model, help=f"ID of the model to use (default={default_model})") parser.add_argument("--list-models", action="store_true", help="list the available models") - parser.add_argument("--assistantid", default="default_chatbot", help="ID of the assistant") - parser.add_argument("--list-assistants", "--list-a", "--list-assistant", action='store_true', help='list all the available assistants') - parser.add_argument('--list-conversations', '--list-conversation', '--list-convo', '--list-c', action='store_true', help='list previous conversations') + + # Conversation management + parser.add_argument("--new", "-n", action="store_true", help="Start a new conversation") + parser.add_argument("--save-conversation", metavar="FILE", help="Save conversation to file") + parser.add_argument("--load-conversation", metavar="FILE", help="Load conversation from file") + parser.add_argument("--list-conversations", action="store_true", help="List saved conversations") + parser.add_argument("--clear", action="store_true", help="Clear conversation history") + parser.add_argument("--system-instruction", metavar="TEXT", help="Set system instruction for the conversation") + + # Query parameters parser.add_argument("--query", "-q", default="", help="Query to be sent to the assistant") parser.add_argument("--topp", type=float, default=0.9, help="Top P value for the model") - parser.add_argument( - "--temperature", type=float, default=0.7, help="Temperature value for the model" - ) + parser.add_argument("--temperature", type=float, default=0.7, help="Temperature value for the model") + parser.add_argument("--max-tokens", type=int, default=8192, help="Maximum output tokens") parser.add_argument("--stream", action="store_true", help="Enable streaming response") parser.add_argument("--no-stream", action="store_true", help="Disable streaming response") - parser.add_argument("--output_mode", "-o", nargs=1, action="store", default='text', help="just output the json") + parser.add_argument("--output_mode", "-o", default='text', help="Output mode (text or json)") + parser.add_argument("--count-tokens", action="store_true", help="Count tokens in query without sending") - #parser.add_argument("--file-health", action="store_true", help="") - #parser.add_argument("--file-list", action="store_true", help="List all uploaded files") - #parser.add_argument("--file-upload", help="Upload a file to the service") - #parser.add_argument("--file-status", help="Check status of a file by ID") - #parser.add_argument("--file-purpose", default="assistants", help="Purpose of the file") - #parser.add_argument("--attach-files", nargs="*", help="File IDs to attach to the query") - - parser.add_argument("--new", "-n", help="force new session and new chat", action="store_true", default=False) #TODO: solidify a good approach here - parser.add_argument("extra_query", nargs="*") args = parser.parse_args() - # Try to get API key if not provided + # Get API key if args.apikey: ai_api_module.api_key = args.apikey else: ai_api_module.get_api_key() if not ai_api_module.api_key: - print("Error: API key not provided") + print("Error: API key not provided. Set GEMINI_API_KEY environment variable.") sys.exit(1) - if args.list_models: - res = ai_api_module.list_models(args.output_mode) - if res[0]: - for m in res[1]: - print(str(m)) - #print(f" ModelID: {m['id']}\n Description: {m['description']}\n Tokens: {m['max_completion_tokens']}\n Modalities: {str(m['modalities'])}") - print() + # Handle system instruction + if args.system_instruction: + ai_api_module.set_system_instruction(args.system_instruction) + print(f"System instruction set") + + # Handle conversation management + if args.new: + ai_api_module.clear_conversation() + print("Started new conversation") + + if args.load_conversation: + if ai_api_module.load_conversation(args.load_conversation): + print(f"Loaded conversation from {args.load_conversation}") else: - print('error: ') - print(str(res[1].text)) - return - + print(f"Failed to load conversation from {args.load_conversation}") + sys.exit(1) + if args.list_conversations: - res = ai_api_module.list_conversations(args.output_mode) - print(res) - return - - if args.list_assistants: - res = ai_api_module.list_assistants(args.output_mode) - if res[0]: - for a in res[1]: - print(f" ID: {a['id']}\n Name: {a['display_name']}") - print() + success, result = ai_api_module.list_saved_conversations() + if success: + if not result: + print("No saved conversations found") + else: + print("Saved conversations:") + for conv in result: + print(f" {conv['filename']}") + print(f" Saved: {conv['saved_at']}") + print(f" Messages: {conv['message_count']}") + print() else: - print('error: ' + str(res[1])) + print(f"Error: {result}") + return + + if args.clear: + ai_api_module.clear_conversation() + print("Conversation history cleared") + return + + # List models + if args.list_models: + success, result = ai_api_module.list_models(args.output_mode) + if success: + for model in result: + if args.output_mode == 'json': + print(json.dumps(model, indent=2)) + else: + print(f"Model: {model['display_name']}") + print(f" Name: {model['name']}") + print(f" Description: {model.get('description', 'N/A')}") + print(f" Input tokens: {model.get('input_token_limit', 'N/A')}") + print(f" Output tokens: {model.get('output_token_limit', 'N/A')}") + print() + else: + print(f'Error: {result}') return - # Process query if args.query or args.extra_query: @@ -87,17 +114,35 @@ def main(): query += " " + " ".join(args.extra_query) else: query = " ".join(args.extra_query) - + + # Count tokens if requested + if args.count_tokens: + success, count = ai_api_module.count_tokens(args.modelid, query) + if success: + print(f"Token count: {count}") + else: + print(f"Error counting tokens: {count}") + return + + # Determine streaming + use_stream = args.stream and not args.no_stream + response = ai_api_module.query( args.modelid, query, args.topp, args.temperature, - not args.no_stream and args.stream, + use_stream, + args.max_tokens, ) - if response and not args.stream: + if response and not use_stream: print(response) + + # Auto-save conversation if requested + if args.save_conversation: + filepath = ai_api_module.save_conversation(args.save_conversation) + print(f"\nConversation saved to {filepath}") else: parser.print_help() |
