module db;

import std.algorithm;
import std.array;
import std.typecons;
import std.conv;

import ddbc;
import slf4d;
import handy_httpd.components.optional;

private DataSource dataSource;

static this() {
    import std.process : environment;
    string username = environment.get("TEACHER_TOOLS_DB_USERNAME", "teacher-tools-dev");
    string password = environment.get("TEACHER_TOOLS_DB_PASSWORD", "testpass");
    string dbUrl = environment.get("TEACHER_TOOLS_DB_URL", "postgresql://localhost:5432/teacher-tools-dev");
    string connectionStr = dbUrl ~ "?user=" ~ username ~ ",password=" ~ password;

    dataSource = createDataSource(connectionStr);
}

Connection getDb() {
    return dataSource.getConnection();
}

T[] findAll(T, Args...)(
    Connection conn,
    string query,
    T function(DataSetReader) parser,
    Args args
) {
    PreparedStatement ps = conn.prepareStatement(query);
    scope(exit) ps.close();
    bindAllArgs(ps, args);
    ResultSet rs = ps.executeQuery();
    scope(exit) rs.close();
    Appender!(T[]) app;
    foreach (row; rs) {
        app ~= parser(row);
    }
    return app[];
}

Optional!T findOne(T, Args...)(
    Connection conn,
    string query,
    T function(DataSetReader) parser,
    Args args
) {
    PreparedStatement ps = conn.prepareStatement(query);
    scope(exit) ps.close();
    bindAllArgs(ps, args);
    ResultSet rs = ps.executeQuery();
    scope(exit) rs.close();
    if (rs.next()) {
        return Optional!T.of(parser(rs));
    }
    return Optional!T.empty;
}

ulong count(Args...)(Connection conn, string query, Args args) {
    return findOne(conn, query, r => r.getUlong(1), args).orElse(0);
}

bool recordExists(Args...)(Connection conn, string query, Args args) {
    PreparedStatement ps = conn.prepareStatement(query);
    scope(exit) ps.close();
    bindAllArgs(ps, args);
    ResultSet rs = ps.executeQuery();
    scope(exit) rs.close();
    return rs.next();
}

ulong insertOne(Args...)(Connection conn, string query, Args args) {
    PreparedStatement ps = conn.prepareStatement(query);
    scope(exit) ps.close();
    bindAllArgs(ps, args);
    import std.variant;
    Variant insertedId;
    int affectedRows = ps.executeUpdate(insertedId);
    if (affectedRows != 1) {
        throw new Exception("Failed to insert exactly 1 row.");
    }
    return insertedId.coerce!ulong;
}

int update(Args...)(Connection conn, string query, Args args) {
    PreparedStatement ps = conn.prepareStatement(query);
    scope(exit) ps.close();
    bindAllArgs(ps, args);
    return ps.executeUpdate();
}

void bindAllArgs(Args...)(PreparedStatement ps, Args args) {
    int idx;
    static foreach (i, arg; args) {
        idx = i + 1;
        static if (is(typeof(arg) == string)) ps.setString(idx, arg);
        else static if (is(typeof(arg) == const(string))) ps.setString(idx, arg);
        else static if (is(typeof(arg) == bool)) ps.setBoolean(idx, arg);
        else static if (is(typeof(arg) == ulong)) ps.setUlong(idx, arg);
        else static if (is(typeof(arg) == const(ulong))) ps.setUlong(idx, arg);
        else static if (is(typeof(arg) == ushort)) ps.setUshort(idx, arg);
        else static if (is(typeof(arg) == const(ushort))) ps.setUshort(idx, arg);
        else static if (is(typeof(arg) == int)) ps.setInt(idx, arg);
        else static if (is(typeof(arg) == const(int))) ps.setInt(idx, arg);
        else static if (is(typeof(arg) == uint)) ps.setUint(idx, arg);
        else static if (is(typeof(arg) == const(uint))) ps.setUint(idx, arg);
        else static assert(false, "Unsupported argument type: " ~ (typeof(arg).stringof));
    }
}

private string toSnakeCase(string camelCase) {
    import std.uni;
    if (camelCase.length == 0) return camelCase;
    auto app = appender!string;
    app ~= toLower(camelCase[0]);
    for (int i = 1; i < camelCase.length; i++) {
        if (isUpper(camelCase[i])) {
            app ~= '_';
            app ~= toLower(camelCase[i]);
        } else {
            app ~= camelCase[i];
        }
    }
    return app[];
}

unittest {
    assert(toSnakeCase("testValue") == "test_value");
}

private string[] getColumnNames(T)() {
    import std.string : toLower;
    alias members = __traits(allMembers, T);
    string[members.length] columnNames;
    static foreach (i; 0 .. members.length) {
        static if (__traits(getAttributes, __traits(getMember, T, members[i])).length > 0) {
            columnNames[i] = toLower(__traits(getAttributes, __traits(getMember, T, members[i]))[0].name);
        } else {
            columnNames[i] = toLower(toSnakeCase(members[i]));
        }
    }
    return columnNames.dup;
}

private string getArgsStr(T)() {
    import std.traits : Fields;
    alias types = Fields!T;
    string argsStr = "";
    static foreach (i, type; types) {
        argsStr ~= "row.peek!(" ~ type.stringof ~ ")(" ~ i.to!string ~ ")";
        static if (i + 1 < types.length) {
            argsStr ~= ", ";
        }
    }
    return argsStr;
}

// T parseRow(T)(Row row) {
//     mixin("T t = T(" ~ getArgsStr!T ~ ");");
//     return t;
// }