import os
import uuid
import subprocess
import re
import numpy as np
import mysql.connector
from flask import Flask, request, jsonify, render_template
from flask_cors import CORS
from resemblyzer import VoiceEncoder, preprocess_wav
from numpy.linalg import norm
import whisper

# ----- CONFIG -----
DB_CONFIG = {
    "host": "localhost",
    "user": "root",
    "password": "",
    "database": "voice_auth",
}

UPLOAD_FOLDER = "uploads"
os.makedirs(UPLOAD_FOLDER, exist_ok=True)

SIMILARITY_THRESHOLD = 0.75

# ----- APP -----
app = Flask(__name__, static_folder="static", template_folder="templates")
CORS(app)

encoder = VoiceEncoder()
model_whisper = whisper.load_model("base")  # Choix modèle (base, small, medium...)

# ----- DB connection -----
def get_db_connection():
    return mysql.connector.connect(**DB_CONFIG)

def ensure_tables():
    conn = get_db_connection()
    cur = conn.cursor()
    cur.execute("""
    CREATE TABLE IF NOT EXISTS users (
        id VARCHAR(36) PRIMARY KEY,
        created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
        first_name VARCHAR(100) DEFAULT NULL,
        last_name VARCHAR(100) DEFAULT NULL,
        voice_embedding LONGBLOB NOT NULL
    )
    """)
    conn.commit()
    cur.close()
    conn.close()

ensure_tables()

# ----- UTILS -----
def convert_to_wav_pcm(input_path, output_path):
    # Convertir en WAV PCM mono 16kHz
    cmd = [
        "ffmpeg", "-y", "-i", input_path,
        "-ar", "16000", "-ac", "1", "-c:a", "pcm_s16le",
        output_path
    ]
    subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True)

def cosine_similarity(a, b):
    return float(np.dot(a, b) / (norm(a) * norm(b)))

def extract_names_from_text(text):
    # Cherche "je m'appelle prénom nom" dans la transcription
    match = re.search(r"je m'appelle ([a-zA-Z]+) ([a-zA-Z]+)", text, re.I)
    if match:
        return match.group(1).capitalize(), match.group(2).capitalize()
    return None, None

def compute_embedding(audio_path):
    wav = preprocess_wav(audio_path)
    emb = encoder.embed_utterance(wav)
    return emb

# ----- ROUTES -----
@app.route("/")
def index():
    return render_template("index.html")

@app.route("/enroll", methods=["POST"])
def enroll():
    if "audio" not in request.files:
        return jsonify({"success": False, "message": "Aucun fichier audio reçu"}), 400
    file = request.files["audio"]

    tmp_name = str(uuid.uuid4())
    raw_path = os.path.join(UPLOAD_FOLDER, tmp_name + ".webm")
    wav_path = os.path.join(UPLOAD_FOLDER, tmp_name + "_conv.wav")

    file.save(raw_path)

    try:
        convert_to_wav_pcm(raw_path, wav_path)
        # Transcription avec Whisper
        result = model_whisper.transcribe(wav_path, language="fr")
        texte = result["text"].strip()
        print("Texte transcrit :", texte)

        first_name, last_name = extract_names_from_text(texte)
        if not first_name or not last_name:
            return jsonify({
                "success": False,
                "message": "Nom et prénom non détectés dans la voix. Merci de dire clairement 'Je m'appelle Prénom Nom'."
            }), 400

        emb = compute_embedding(wav_path)
        emb_bytes = emb.tobytes()

        user_id = str(uuid.uuid4())

        conn = get_db_connection()
        cur = conn.cursor()
        cur.execute(
            "INSERT INTO users (id, first_name, last_name, voice_embedding) VALUES (%s, %s, %s, %s)",
            (user_id, first_name, last_name, emb_bytes)
        )
        conn.commit()
        cur.close()
        conn.close()

        return jsonify({
            "success": True,
            "message": f"Enrôlement réussi pour {first_name} {last_name}",
            "user_id": user_id
        })

    except subprocess.CalledProcessError as e:
        return jsonify({"success": False, "message": "Erreur ffmpeg : " + str(e)}), 500
    except Exception as e:
        return jsonify({"success": False, "message": "Erreur interne : " + str(e)}), 500
    finally:
        for p in (raw_path, wav_path):
            try:
                if os.path.exists(p):
                    os.remove(p)
            except:
                pass

@app.route("/auth", methods=["POST"])
def auth():
    if "audio" not in request.files:
        return jsonify({"success": False, "message": "Aucun fichier audio reçu"}), 400

    file = request.files["audio"]
    tmp_name = str(uuid.uuid4())
    raw_path = os.path.join(UPLOAD_FOLDER, tmp_name + ".webm")
    wav_path = os.path.join(UPLOAD_FOLDER, tmp_name + "_conv.wav")
    file.save(raw_path)

    try:
        convert_to_wav_pcm(raw_path, wav_path)
        emb = compute_embedding(wav_path)

        conn = get_db_connection()
        cur = conn.cursor()
        cur.execute("SELECT id, voice_embedding, first_name, last_name FROM users")
        rows = cur.fetchall()
        cur.close()
        conn.close()

        best = {"id": None, "score": -1, "first_name": None, "last_name": None}
        for row in rows:
            user_id = row[0]
            emb_bytes = row[1]
            stored = np.frombuffer(emb_bytes, dtype=np.float32)
            score = cosine_similarity(emb, stored)
            if score > best["score"]:
                best = {
                    "id": user_id,
                    "score": score,
                    "first_name": row[2],
                    "last_name": row[3]
                }

        if best["score"] >= SIMILARITY_THRESHOLD:
            return jsonify({
                "success": True,
                "user_id": best["id"],
                "score": best["score"],
                "first_name": best["first_name"],
                "last_name": best["last_name"],
                "message": "Authentification réussie."
            })
        else:
            return jsonify({
                "success": False,
                "score": best["score"],
                "message": "Voix non reconnue."
            }), 401

    except subprocess.CalledProcessError as e:
        return jsonify({"success": False, "message": "Erreur ffmpeg : " + str(e)}), 500
    except Exception as e:
        return jsonify({"success": False, "message": "Erreur interne : " + str(e)}), 500
    finally:
        for p in (raw_path, wav_path):
            try:
                if os.path.exists(p):
                    os.remove(p)
            except:
                pass

# ----- LANCEMENT -----
if __name__ == "__main__":
    app.run(host="0.0.0.0", port=5000, debug=True)
