from flask import Flask, request, jsonify, send_from_directory, send_file
import torch
import re
from transformers import AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import os
import sys
import time
from typing import List, Optional
app = Flask(__name__, template_folder='.', static_folder='.')
@app.route('/')
def serve_static(filename):
return send_from_directory('.', filename)
# Hugging Face token handling
TOKEN = os.environ.get("HUGGINGFACE_HUB_TOKEN")
if TOKEN:
print("[INFO] Using Hugging Face token from ENV.")
else:
TOKEN = "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"
print("[WARNING] No HUGGINGFACE_HUB_TOKEN found in ENV; using fallback token.")
# Model name and local cache
model_name = 'meta-llama/Llama-3.2-3B-Instruct'
cache_dir = "./model_cache"
try:
print(f"Loading tokenizer for {model_name} from {cache_dir}...")
tokenizer = AutoTokenizer.from_pretrained(
model_name,
cache_dir=cache_dir,
use_fast=True,
token=TOKEN
)
tokenizer.pad_token = tokenizer.eos_token
print(f"Loading model {model_name} from {cache_dir}...")
model = AutoModelForCausalLM.from_pretrained(
model_name,
cache_dir=cache_dir,
device_map="cpu",
# Remove torch_dtype to use default float32
low_cpu_mem_usage=True,
token=TOKEN
)
except Exception as e:
print(f"Model or tokenizer not found locally. Downloading... {e}")
print(f"Downloading tokenizer for {model_name}...")
tokenizer = AutoTokenizer.from_pretrained(
model_name,
cache_dir=cache_dir,
use_fast=True,
token=TOKEN
)
tokenizer.pad_token = tokenizer.eos_token
print(f"Downloading model {model_name}...")
model = AutoModelForCausalLM.from_pretrained(
model_name,
cache_dir=cache_dir,
device_map="cpu",
# Remove torch_dtype to use default float32
low_cpu_mem_usage=True,
token=TOKEN
)
model.to(torch.device("cpu"))
print("Model loaded successfully!")
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
text_chunks = []
faiss_index = None
def chunk_text(text: str, max_length: int = 100) -> List[str]:
chunks = []
current_chunk = ""
for line in text.splitlines():
line = line.strip()
if not line:
continue
if len(current_chunk.split()) + len(line.split()) <= max_length:
current_chunk += " " + line
else:
if current_chunk:
chunks.append(current_chunk.strip())
current_chunk = line
if current_chunk:
chunks.append(current_chunk.strip())
return chunks
def setup_rag(data_path: Optional[str] = None):
global text_chunks, faiss_index
text_chunks = []
faiss_index = None
if not data_path or not os.path.isfile(data_path):
print(f"[RAG] No valid data file at {data_path}", file=sys.stderr)
return
with open(data_path, "r", encoding="utf-8") as f:
full_text = f.read()
if not full_text.strip():
print(f"[RAG] File {data_path} is empty or whitespace only", file=sys.stderr)
return
text_chunks = chunk_text(full_text, max_length=100)
print(f"[RAG] Loaded {len(text_chunks)} chunks from {data_path}", file=sys.stderr)
all_embs = []
for c in text_chunks:
emb = embedding_model.encode([c], convert_to_tensor=False, show_progress_bar=False)
all_embs.append(emb[0])
embs_np = np.array(all_embs, dtype=np.float32)
if embs_np.shape[0] == 0:
print("[RAG] No embeddings generated", file=sys.stderr)
return
dim = embs_np.shape[1]
faiss_index = faiss.IndexFlatL2(dim)
faiss_index.add(embs_np)
print(f"[RAG] Index built with {len(text_chunks)} chunks from {data_path} (dim={dim})", file=sys.stderr)
# Pre-load RAG data from QuintonData.txt
setup_rag("QuintonData.txt")
def question_mentions_quinton_or_scale(prompt: str) -> bool:
plower = prompt.lower()
keywords = ["quinton", "calmus", "scale ai", "scaleai"]
return any(kw in plower for kw in keywords)
def rag_generate(
prompt: str,
k: int = 3,
max_new_tokens: int = 240, # Increased from 200 to 240 (20% increase)
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 40
) -> str:
device = torch.device("cpu")
start_time = time.time()
print(f"DEBUG: Starting rag_generate with prompt: {prompt}", file=sys.stderr)
use_rag = question_mentions_quinton_or_scale(prompt) and faiss_index is not None and len(text_chunks) > 0
print(f"DEBUG: use_rag={use_rag}, faiss_index={faiss_index}, text_chunks length={len(text_chunks)}", file=sys.stderr)
if use_rag:
print(f"DEBUG: Encoding prompt for RAG...", file=sys.stderr)
emb = embedding_model.encode([prompt], convert_to_tensor=False, show_progress_bar=False)
print(f"DEBUG: Searching FAISS index...", file=sys.stderr)
distances, indices = faiss_index.search(np.array(emb, dtype=np.float32), k)
top_chunks = [text_chunks[i] for i in indices[0] if i < len(text_chunks)]
context_text = "\n".join(top_chunks)
print(f"DEBUG: RAG context: {context_text}", file=sys.stderr)
final_prompt = (
f"Using the following context:\n{context_text}\n\n"
f"Answer the question: {prompt}\n\n"
"Provide a concise response in 1-2 sentences, summarizing or rephrasing the key points relevant to the question without copying the context verbatim:\n"
)
else:
final_prompt = (
f"Answer the question: {prompt}\n\n"
"Provide a concise response in 1-2 sentences:\n"
)
print(f"DEBUG: Final prompt:\n{final_prompt}", file=sys.stderr)
try:
inputs = tokenizer(final_prompt, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
print(f"DEBUG: Tokenized inputs: {inputs}", file=sys.stderr)
with torch.no_grad():
outputs = model.generate(
inputs["input_ids"],
attention_mask=inputs.get("attention_mask"),
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
top_k=top_k,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
print(f"DEBUG: Model outputs: {outputs}", file=sys.stderr)
raw_out = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
print(f"DEBUG: Raw output from model: {raw_out}", file=sys.stderr)
if use_rag:
marker = "Provide a concise response in 1-2 sentences, summarizing or rephrasing the key points relevant to the question without copying the context verbatim:"
else:
marker = "Provide a concise response in 1-2 sentences:"
answer_start = raw_out.find(marker) + len(marker) if marker in raw_out else 0
answer_text = raw_out[answer_start:].strip()
lines = answer_text.splitlines()
out_lines = []
for ln in lines:
lstrip = ln.strip()
if not lstrip or lstrip.startswith("-"):
continue
out_lines.append(lstrip.replace("- ", ""))
final_text = " ".join(out_lines).strip()
final_text = final_text.replace(prompt, "").strip()
final_text = re.sub(r'[^\w\s\.\-\'"]', '', final_text) # Sanitize output
sentences = [s.strip() for s in final_text.split(".") if s.strip()]
if len(sentences) > 2:
final_text = ". ".join(sentences[:2]) + "."
elif len(sentences) == 1 and not final_text.endswith("."):
final_text += "."
elif len(sentences) == 0:
final_text = "No relevant answer could be generated from the context."
print(f"DEBUG: Final sanitized response: {final_text}", file=sys.stderr)
print(f"Prompt: {prompt}", file=sys.stdout)
print(f"Response: {final_text}", file=sys.stdout)
print(f"Time taken: {time.time() - start_time:.2f} seconds\n", file=sys.stdout)
return final_text
except Exception as e:
error_msg = f"Error in rag_generate: {str(e)}"
print(f"DEBUG: {error_msg}", file=sys.stderr)
raise Exception(error_msg)
@app.route('/')
def index():
try:
with open('index.html', 'r', encoding='utf-8') as f:
return f.read()
except FileNotFoundError:
return "index.html not found.", 404
except Exception as e:
return f"Error reading index.html: {e}", 500
@app.route('/login', methods=['POST'])
def login():
if request.is_json:
data = request.get_json()
password = data.get('password', '')
else:
password = request.form.get('password', '')
correct_password = "ScaleAI2025"
if password == correct_password:
return jsonify({'success': True})
return jsonify({'success': False})
@app.route('/chat', methods=['POST'])
def chat():
try:
data = request.get_json()
prompt = data.get('prompt', None)
print(f"DEBUG: Received prompt: {prompt}", file=sys.stderr)
if not prompt:
return jsonify({'error': 'No prompt provided'}), 400
response = rag_generate(prompt)
print(f"DEBUG: Sending response: {response}", file=sys.stderr)
return jsonify({'response': response})
except Exception as e:
error_msg = f"Failed to generate response: {str(e)}"
print(f"ERROR: {error_msg}", file=sys.stderr)
return jsonify({'error': error_msg}), 500
@app.route('/update_rag', methods=['POST'])
def update_rag():
data = request.get_json()
data_path = data.get('data_path', "QuintonData.txt")
setup_rag(data_path)
return jsonify({'status': f"RAG updated with {data_path}"})
@app.route('/tokenize', methods=['POST'])
def tokenize():
data = request.get_json()
text = data.get('text', '')
if not text:
return jsonify({'error': 'No text provided'}), 400
tokens = tokenizer.encode(text, add_special_tokens=False)
token_strs = [tokenizer.decode([t], skip_special_tokens=True) for t in tokens]
out = ""
idx = 0
for ts in token_strs:
fi = text.find(ts, idx)
if fi == -1:
continue
out += text[idx:fi]
out += f"|{ts}|"
idx = fi + len(ts)
out += text[idx:]
return jsonify({
'original_text': text,
'tokenized_text': out,
'tokens': [str(t) for t in tokens]
})
@app.route('/download/')
def download_file(filename):
try:
return send_file(filename, as_attachment=True)
except FileNotFoundError:
return "File not found", 404
except Exception as e:
print(f"[ERROR] {e}", file=sys.stderr)
return f"Error: {e}", 500
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000, debug=True, use_reloader=False)