﻿/* ****************************************************************************
 *
 * Copyright (c) Jeff Hardy 2010-2012.
 *
 * This source code is subject to terms and conditions of the Apache License, Version 2.0. A 
 * copy of the license can be found in the License.html file at the root of this distribution. If 
 * you cannot locate the  Apache License, Version 2.0, please send an email to 
 * dlr@microsoft.com. By using this source code in any fashion, you are agreeing to be bound 
 * by the terms of the Apache License, Version 2.0.
 *
 * You must not remove this notice, or any other, from this software.
 *
 *
 * ***************************************************************************/

using System;
using System.Collections;
using System.Diagnostics;
using Community.CsharpSqlite;
using IronPython.Runtime;
using IronPython.Runtime.Operations;
using IronPython.Runtime.Types;

using sqlite3_stmt = Community.CsharpSqlite.Sqlite3.Vdbe;

namespace IronPython.SQLite
{
    enum StatementType
    {
        Unknown,
        Select,
        Insert,
        Update,
        Delete,
        Replace,
        Other
    }

    [DebuggerDisplay("{sql}")]
    internal class Statement
    {
        private readonly Guid uniqueid;
        private Sqlite3.sqlite3 db;
        internal sqlite3_stmt st;

        private object current = null, nextRow = null;
        bool started = false;
        private string sql;
        private bool bound = false;
        internal bool in_use = false;

        public string Tail { get; private set; }

        public Statement(PythonSQLite.Connection connection, string operation)
        {
            this.uniqueid = Guid.NewGuid();

            this.db = connection.db;
            this.sql = operation;

            this.st = null;
            string tail = null;
            if(Sqlite3.sqlite3_prepare(this.db, this.sql, -1, ref this.st, ref tail) != Sqlite3.SQLITE_OK /*TODO: || too much sql */)
            {
                Sqlite3.sqlite3_finalize(st);
                this.st = null;
                throw PythonSQLite.GetSqliteError(this.db, null);
            }

            this.Tail = tail;
        }

        private Statement(Sqlite3.sqlite3 db, sqlite3_stmt stmt, string operation, string tail)
        {
            this.uniqueid = Guid.NewGuid();

            this.db = db;
            this.sql = operation;

            this.st = stmt;
            this.Tail = tail;
        }

        ~Statement()
        {
            if(this.st != null)
            {
                Sqlite3.sqlite3_finalize(this.st);
            }

            this.st = null;
        }

        private StatementType _type = StatementType.Unknown;
        public StatementType StatementType
        {
            get
            {
                if(this._type != StatementType.Unknown)
                    return _type;

                string s = this.sql.TrimStart();

                if(s.StartsWith("select", StringComparison.InvariantCultureIgnoreCase))
                    this._type = StatementType.Select;
                else if(s.StartsWith("insert", StringComparison.InvariantCultureIgnoreCase))
                    this._type = StatementType.Insert;
                else if(s.StartsWith("update", StringComparison.InvariantCultureIgnoreCase))
                    this._type = StatementType.Update;
                else if(s.StartsWith("delete", StringComparison.InvariantCultureIgnoreCase))
                    this._type = StatementType.Delete;
                else if(s.StartsWith("replace", StringComparison.InvariantCultureIgnoreCase))
                    this._type = StatementType.Replace;
                else
                    this._type = StatementType.Other;

                return this._type;
            }
        }

        public void BindParameters(CodeContext context, object parameters)
        {
            if(bound)
                this.ClearParameters();

            int num_params_needed = Sqlite3.sqlite3_bind_parameter_count(this.st);

            if(parameters == null)
            {
                if(num_params_needed > 0)
                    throw PythonSQLite.MakeProgrammingError("parameters are required but not specified.");
                else
                    return;
            }

            if(parameters is IDictionary)
                BindParameters(context, (IDictionary)parameters, num_params_needed);
            else if(parameters is IList)
                BindParameters(context, (IList)parameters, num_params_needed);
            else
                throw PythonSQLite.MakeProgrammingError("unknown parameter type");

            bound = true;
        }

        private void BindParameters(CodeContext context, IDictionary args, int num_params_needed)
        {
            for(int i = 1; i <= num_params_needed; ++i)
            {
                string binding_name = Sqlite3.sqlite3_bind_parameter_name(this.st, i);
                if(string.IsNullOrEmpty(binding_name))
                    throw PythonSQLite.MakeProgrammingError("Binding {0} has no name, but you supplied a dictionary (which has only names).".Format(i));

                // remove the leading colon
                binding_name = binding_name.Substring(1);

                if(args.Contains(binding_name))
                    BindParameter(context, i, maybeAdapt(context, args[binding_name]));
                else
                    throw PythonSQLite.MakeProgrammingError("You did not supply a value for binding {0}.".Format(i));
            }
        }

        private void BindParameters(CodeContext context, IList args, int num_params_needed)
        {
            if(num_params_needed != args.Count)
                throw PythonSQLite.MakeProgrammingError("Incorrect number of bindings supplied.");

            for(int i = 0; i < args.Count; ++i)
            {
                BindParameter(context, i + 1, maybeAdapt(context, args[i]));
            }
        }

        private void BindParameter(CodeContext context, int index, object arg)
        {
            int rc;
            if(arg == null)
                rc = Sqlite3.sqlite3_bind_null(st, index);
            else if(arg is int)
                rc = Sqlite3.sqlite3_bind_int(st, index, (int)arg);
            else if(arg is bool)
                rc = Sqlite3.sqlite3_bind_int(st, index, (bool)arg ? 1 : 0);
            else if(arg is long)
                rc = Sqlite3.sqlite3_bind_int64(st, index, (long)arg);
            else if(arg is Microsoft.Scripting.Math.BigInteger)
                rc = Sqlite3.sqlite3_bind_int64(st, index, ((Microsoft.Scripting.Math.BigInteger)arg).ToInt64());
            else if (arg is System.Numerics.BigInteger)
                rc = Sqlite3.sqlite3_bind_int64(st, index, (long)((System.Numerics.BigInteger)arg));
            else if(arg is float)
                rc = Sqlite3.sqlite3_bind_double(st, index, (float)arg);
            else if(arg is double)
                rc = Sqlite3.sqlite3_bind_double(st, index, (double)arg);
            else if(arg is string)
                rc = Sqlite3.sqlite3_bind_text(st, index, (string)arg, -1, Sqlite3.SQLITE_TRANSIENT);
            else if(arg is byte[])
                rc = Sqlite3.sqlite3_bind_blob(this.st, index, (byte[])arg, -1, Sqlite3.SQLITE_TRANSIENT);
            else if(arg is PythonBuffer)
            {
                //TODO: see if there is a better way to do this
                PythonBuffer buffer = (PythonBuffer)arg;
                string s = buffer.__getslice__(0, null).ToString();
                byte[] bytes = PythonSQLite.Latin1.GetBytes(s);

                rc = Sqlite3.sqlite3_bind_blob(this.st, index, bytes, -1, Sqlite3.SQLITE_TRANSIENT);
            }
            else
                throw PythonSQLite.MakeInterfaceError("Unable to bind parameter {0} - unsupported type {1}".Format(index, arg.GetType()));

            if(rc != Sqlite3.SQLITE_OK)
                throw PythonSQLite.MakeInterfaceError("Unable to bind parameter {0}: {1}".Format(index, Sqlite3.sqlite3_errmsg(db)));
        }

        private object maybeAdapt(CodeContext context, object value)
        {
            return needsAdaptation(context, value) ? adaptValue(context, value) : value;
        }

        private bool needsAdaptation(CodeContext context, object value)
        {
            // TODO The check for primitive types could probably be cached like pysqlite does
            if(value == null ||
                value is int ||
                value is bool ||
                value is long ||
                value is Microsoft.Scripting.Math.BigInteger ||
                value is float ||
                value is double ||
                value is string ||
                value is byte[] ||
                value is PythonBuffer)
            {
                object proto = DynamicHelpers.GetPythonTypeFromType(typeof(PythonSQLite.PrepareProtocol));
                object type = DynamicHelpers.GetPythonType(value);

                object key = new PythonTuple(new[] { type, proto });

                return PythonSQLite.adapters.ContainsKey(key);
            }
            else
            {
                return true;
            }
        }

        private object adaptValue(CodeContext context, object value)
        {
            object proto = DynamicHelpers.GetPythonTypeFromType(typeof(PythonSQLite.PrepareProtocol));
            object type = DynamicHelpers.GetPythonType(value);

            object key = new PythonTuple(new[] { type, proto });

            object adapter;
            if(PythonSQLite.adapters.TryGetValue(key, out adapter))
            {
                object adapted = PythonCalls.Call(context, adapter, value);
                return adapted;
            }

            // TODO: Use proto? Any value whatsoever?

            object conform;
            if(context.LanguageContext.Operations.TryGetMember(value, "__conform__", out conform))
            {
                object adapted = PythonCalls.Call(context, conform, proto);
                if(adapted != null)
                {
                    return adapted;
                }
            }

            return value;
        }

        public int RawStep()
        {
            return Util.Step(st);
        }

        public int SqliteFinalize()
        {
            int rc = Sqlite3.SQLITE_OK;

            if(this.st != null)
            {
                rc = Sqlite3.sqlite3_finalize(this.st);
                this.st = null;
            }

            this.in_use = false;

            return rc;
        }

        public int Reset()
        {
            int rc = Sqlite3.SQLITE_OK;

            if(this.in_use && this.st != null)
            {
                rc = Sqlite3.sqlite3_reset(this.st);

                if(rc == Sqlite3.SQLITE_OK)
                    this.in_use = false;
            }

            return rc;
        }

        private void ClearParameters()
        {
            if(Sqlite3.sqlite3_clear_bindings(this.st) != Sqlite3.SQLITE_OK)
                throw PythonSQLite.GetSqliteError(this.db, null);
        }

        internal void MarkDirty()
        {
            this.in_use = true;
        }

        internal int Recompile(CodeContext context, object parameters)
        {
            sqlite3_stmt new_st = null;
            string tail = null;

            int rc = Sqlite3.sqlite3_prepare(this.db, this.sql, -1, ref new_st, ref tail);
            if(rc == Sqlite3.SQLITE_OK)
            {
                Statement new_stmt = new Statement(this.st.db, new_st, this.sql, tail);
                new_stmt.BindParameters(context, parameters);

                Sqlite3.sqlite3_finalize(this.st);
                this.st = new_st;
            }

            return rc;
        }
    }
}

