#include "bouncer.h"
 
+#include <usual/pgutil.h>
+
 static const char *hdr2hex(const struct MBuf *data, char *buf, unsigned buflen)
 {
        const uint8_t *bin = data->data + data->read_pos;
        return false;
 }
 
-bool set_pool(PgSocket *client, const char *dbname, const char *username)
+/* mask to get offset into valid_crypt_salt[] */
+#define SALT_MASK  0x3F
+
+static const char valid_crypt_salt[] =
+"./0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
+
+static bool send_client_authreq(PgSocket *client)
+{
+       uint8_t saltlen = 0;
+       int res;
+       int auth = cf_auth_type;
+       uint8_t randbuf[2];
+
+       if (auth == AUTH_CRYPT) {
+               saltlen = 2;
+               get_random_bytes(randbuf, saltlen);
+               client->tmp_login_salt[0] = valid_crypt_salt[randbuf[0] & SALT_MASK];
+               client->tmp_login_salt[1] = valid_crypt_salt[randbuf[1] & SALT_MASK];
+               client->tmp_login_salt[2] = 0;
+       } else if (cf_auth_type == AUTH_MD5) {
+               saltlen = 4;
+               get_random_bytes((void*)client->tmp_login_salt, saltlen);
+       } else if (auth == AUTH_ANY)
+               auth = AUTH_TRUST;
+
+       SEND_generic(res, client, 'R', "ib", auth, client->tmp_login_salt, saltlen);
+       return res;
+}
+
+static void start_auth_request(PgSocket *client, const char *username)
+{
+       int res;
+       char quoted_username[64], query[128];
+
+       client->auth_user = client->db->auth_user;
+       /* have to fetch user info from db */
+       client->pool = get_pool(client->db, client->db->auth_user);
+       if (!find_server(client)) {
+               client->wait_for_user_conn = true;
+               return;
+       }
+       slog_noise(client, "Doing auth_conn query");
+       client->wait_for_user_conn = false;
+       client->wait_for_user = true;
+       if (!sbuf_pause(&client->sbuf)) {
+               release_server(client->link);
+               disconnect_client(client, true, "pause failed");
+               return;
+       }
+       client->link->ready = 0;
+
+       pg_quote_literal(quoted_username, username, sizeof(quoted_username));
+       snprintf(query, sizeof(query), "SELECT usename, passwd FROM pg_shadow WHERE usename=%s", quoted_username);
+       SEND_generic(res, client->link, 'Q', "s", query);
+       if (!res)
+               disconnect_server(client->link, false, "unable to send login query");
+}
+
+static bool finish_set_pool(PgSocket *client, bool takeover)
 {
-       PgDatabase *db;
-       PgUser *user;
+       PgUser *user = client->auth_user;
+       /* pool user may be forced */
+       if (client->db->forced_user) {
+               user = client->db->forced_user;
+       }
+       client->pool = get_pool(client->db, user);
+       if (!client->pool) {
+               disconnect_client(client, true, "no memory for pool");
+               return false;
+       }
+
+       if (cf_log_connections)
+               slog_info(client, "login attempt: db=%s user=%s", client->db->name, client->auth_user->name);
+
+       if (!check_fast_fail(client))
+               return false;
+
+       if (takeover)
+               return true;
+
+       if (client->pool->db->admin) {
+               if (!admin_post_login(client))
+                       return false;
+       }
+
+       if (cf_auth_type <= AUTH_TRUST || client->own_user) {
+               if (!finish_client_login(client))
+                       return false;
+       } else {
+               if (!send_client_authreq(client)) {
+                       disconnect_client(client, false, "failed to send auth req");
+                       return false;
+               }
+       }
+       return true;
+}
 
+bool set_pool(PgSocket *client, const char *dbname, const char *username, bool takeover)
+{
        /* find database */
-       db = find_database(dbname);
-       if (!db) {
-               db = register_auto_database(dbname);
-               if (!db) {
+       client->db = find_database(dbname);
+       if (!client->db) {
+               client->db = register_auto_database(dbname);
+               if (!client->db) {
                        disconnect_client(client, true, "No such database: %s", dbname);
+                       if (cf_log_connections)
+                               slog_info(client, "login failed: db=%s user=%s", dbname, username);
                        return false;
                }
                else {
        }
 
        /* are new connections allowed? */
-       if (db->db_disabled) {
+       if (client->db->db_disabled) {
                disconnect_client(client, true, "database does not allow connections: %s", dbname);
                return false;
        }
 
+       if (client->db->admin) {
+               if (admin_pre_login(client, username))
+                       return finish_set_pool(client, takeover);
+       }
+
        /* find user */
        if (cf_auth_type == AUTH_ANY) {
                /* ignore requested user */
-               user = NULL;
-
-               if (db->forced_user == NULL) {
+               if (client->db->forced_user == NULL) {
                        slog_error(client, "auth_type=any requires forced user");
                        disconnect_client(client, true, "bouncer config error");
                        return false;
                }
-               client->auth_user = db->forced_user;
+               client->auth_user = client->db->forced_user;
        } else {
                /* the user clients wants to log in as */
-               user = find_user(username);
-               if (!user) {
+               client->auth_user = find_user(username);
+               if (!client->auth_user && client->db->auth_user) {
+                       if (takeover) {
+                               client->auth_user = add_db_user(client->db, username, "");
+                               return finish_set_pool(client, takeover);
+                       }
+                       start_auth_request(client, username);
+                       return false;
+               }
+               if (!client->auth_user) {
                        disconnect_client(client, true, "No such user: %s", username);
+                       if (cf_log_connections)
+                               slog_info(client, "login failed: db=%s user=%s", dbname, username);
                        return false;
                }
-               client->auth_user = user;
        }
+       return finish_set_pool(client, takeover);
+}
 
-       /* pool user may be forced */
-       if (db->forced_user)
-               user = db->forced_user;
-       client->pool = get_pool(db, user);
-       if (!client->pool) {
-               disconnect_client(client, true, "no memory for pool");
+bool handle_auth_response(PgSocket *client, PktHdr *pkt) {
+       uint16_t columns;
+       uint32_t length;
+       const char *username, *password;
+       PgUser user;
+
+       switch(pkt->type) {
+       case 'T':       /* RowDescription */
+               if (!mbuf_get_uint16be(&pkt->data, &columns)) {
+                       disconnect_server(client->link, false, "bad packet");
+                       return false;
+               }
+               if (columns != 2u) {
+                       disconnect_server(client->link, false, "expected 1 column from login query, not %hu", columns);
+                       return false;
+               }
+               break;
+       case 'D':       /* DataRow */
+               memset(&user, 0, sizeof(user));
+               if (!mbuf_get_uint16be(&pkt->data, &columns)) {
+                       disconnect_server(client->link, false, "bad packet");
+                       return false;
+               }
+               if (columns != 2u) {
+                       disconnect_server(client->link, false, "expected 1 column from login query, not %hu", columns);
+                       return false;
+               }
+               if (!mbuf_get_uint32be(&pkt->data, &length)) {
+                       disconnect_server(client->link, false, "bad packet");
+                       return false;
+               }
+               if (!mbuf_get_chars(&pkt->data, length, &username)) {
+                       disconnect_server(client->link, false, "bad packet");
+                       return false;
+               }
+               if (sizeof(user.name) - 1 < length)
+                       length = sizeof(user.name) - 1;
+               memcpy(user.name, username, length);
+               if (!mbuf_get_uint32be(&pkt->data, &length)) {
+                       disconnect_server(client->link, false, "bad packet");
+                       return false;
+               }
+               if (length == (uint32_t)-1) {
+                       // NULL - set an md5 password with an impossible value,
+                       // so that nothing will ever match
+                       password = "md5";
+                       length = 3;
+               } else {
+                       if (!mbuf_get_chars(&pkt->data, length, &password)) {
+                               disconnect_server(client->link, false, "bad packet");
+                               return false;
+                       }
+               }
+               if (sizeof(user.passwd)  - 1 < length)
+                       length = sizeof(user.passwd) - 1;
+               memcpy(user.passwd, password, length);
+
+               client->auth_user = add_db_user(client->db, user.name, user.passwd);
+               if (!client->auth_user) {
+                       disconnect_server(client->link, false, "unable to allocate new user for auth");
+                       return false;
+               }
+               break;
+       case 'C':       /* CommandComplete */
+               break;
+       case 'Z':       /* ReadyForQuery */
+               sbuf_prepare_skip(&client->link->sbuf, pkt->len);
+               if (!client->auth_user) {
+                       if (cf_log_connections)
+                               slog_info(client, "login failed: db=%s", client->db->name);
+                       disconnect_client(client, true, "No such user");
+               } else {
+                       slog_noise(client, "auth query complete");
+                       sbuf_continue(&client->sbuf);
+               }
+               return true;
+       default:
+               disconnect_server(client->link, false, "unexpected response from login query");
                return false;
        }
-
-       return check_fast_fail(client);
+       sbuf_prepare_skip(&client->link->sbuf, pkt->len);
+       return true;
 }
 
 static bool decide_startup_pool(PgSocket *client, PktHdr *pkt)
                }
        }
 
-       /* find pool and log about it */
-       if (set_pool(client, dbname, username)) {
-               if (cf_log_connections)
-                       slog_info(client, "login attempt: db=%s user=%s", dbname, username);
-               return true;
-       } else {
-               if (cf_log_connections)
-                       slog_info(client, "login failed: db=%s user=%s", dbname, username);
-               return false;
-       }
-}
-
-/* mask to get offset into valid_crypt_salt[] */
-#define SALT_MASK  0x3F
-
-static const char valid_crypt_salt[] =
-"./0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
-
-static bool send_client_authreq(PgSocket *client)
-{
-       uint8_t saltlen = 0;
-       int res;
-       int auth = cf_auth_type;
-       uint8_t randbuf[2];
-
-       if (auth == AUTH_CRYPT) {
-               saltlen = 2;
-               get_random_bytes(randbuf, saltlen);
-               client->tmp_login_salt[0] = valid_crypt_salt[randbuf[0] & SALT_MASK];
-               client->tmp_login_salt[1] = valid_crypt_salt[randbuf[1] & SALT_MASK];
-               client->tmp_login_salt[2] = 0;
-       } else if (cf_auth_type == AUTH_MD5) {
-               saltlen = 4;
-               get_random_bytes((void*)client->tmp_login_salt, saltlen);
-       } else if (auth == AUTH_ANY)
-               auth = AUTH_TRUST;
-
-       SEND_generic(res, client, 'R', "ib", auth, client->tmp_login_salt, saltlen);
-       return res;
+       /* find pool */
+       return set_pool(client, dbname, username, false);
 }
 
 /* decide on packets of client in login phase */
                disconnect_client(client, true, "Old V2 protocol not supported");
                return false;
        case PKT_STARTUP:
-               if (client->pool) {
+               if (client->pool && !client->wait_for_user_conn && !client->wait_for_user) {
                        disconnect_client(client, true, "client re-sent startup pkt");
                        return false;
                }
 
-               if (!decide_startup_pool(client, pkt))
-                       return false;
-
-               if (client->pool->db->admin) {
-                       if (!admin_pre_login(client))
+               if (client->wait_for_user) {
+                       client->wait_for_user = false;
+                       if (!finish_set_pool(client, false))
                                return false;
+               } else if (!decide_startup_pool(client, pkt)) {
+                       return false;
                }
 
-               if (cf_auth_type <= AUTH_TRUST || client->own_user) {
-                       if (!finish_client_login(client))
-                               return false;
-               } else {
-                       if (!send_client_authreq(client)) {
-                               disconnect_client(client, false, "failed to send auth req");
-                               return false;
-                       }
-               }
                break;
        case 'p':               /* PasswordMessage */
                /* haven't requested it */
 
        return strcmp(name, user->name);
 }
 
+/* destroy PgUser, for usage with btree */
+static void user_node_release(struct AANode *node, void *arg)
+{
+       PgUser *user = container_of(node, PgUser, tree_node);
+       slab_free(user_cache, user);
+}
+
 /* initialization before config loading */
 void init_objects(void)
 {
                statlist_remove(&justfree_client_list, &client->head);
                break;
        case CL_LOGIN:
+               if (newstate == CL_WAITING)
+                       newstate = CL_WAITING_LOGIN;
                statlist_remove(&login_client_list, &client->head);
                break;
+       case CL_WAITING_LOGIN:
+               if (newstate == CL_ACTIVE)
+                       newstate = CL_LOGIN;
        case CL_WAITING:
                statlist_remove(&pool->waiting_client_list, &client->head);
                break;
                statlist_append(&login_client_list, &client->head);
                break;
        case CL_WAITING:
+       case CL_WAITING_LOGIN:
                statlist_append(&pool->waiting_client_list, &client->head);
                break;
        case CL_ACTIVE:
                        slab_free(db_cache, db);
                        return NULL;
                }
+               aatree_init(&db->user_tree, user_node_cmp, user_node_release);
                put_in_order(&db->head, &database_list, cmp_database);
        }
 
        return user;
 }
 
+/* add or update db users */
+PgUser *add_db_user(PgDatabase *db, const char *name, const char *passwd)
+{
+       PgUser *user = NULL;
+       struct AANode *node;
+
+       node = aatree_search(&db->user_tree, (uintptr_t)name);
+       user = node ? container_of(node, PgUser, tree_node) : NULL;
+
+       if (user == NULL) {
+               user = slab_alloc(user_cache);
+               if (!user)
+                       return NULL;
+
+               list_init(&user->head);
+               list_init(&user->pool_list);
+               safe_strcpy(user->name, name, sizeof(user->name));
+
+               aatree_insert(&db->user_tree, (uintptr_t)user->name, &user->tree_node);
+               user->pool_mode = POOL_INHERIT;
+       }
+       safe_strcpy(user->passwd, passwd, sizeof(user->passwd));
+       return user;
+}
+
 /* create separate user object for storing server user info */
 PgUser *force_user(PgDatabase *db, const char *name, const char *passwd)
 {
 /* deactivate socket and put into wait queue */
 static void pause_client(PgSocket *client)
 {
-       Assert(client->state == CL_ACTIVE);
+       Assert(client->state == CL_ACTIVE || client->state == CL_LOGIN);
 
        slog_debug(client, "pause_client");
        change_client_state(client, CL_WAITING);
 /* wake client from wait */
 void activate_client(PgSocket *client)
 {
-       Assert(client->state == CL_WAITING);
+       Assert(client->state == CL_WAITING || client->state == CL_WAITING_LOGIN);
 
        slog_debug(client, "activate_client");
        change_client_state(client, CL_ACTIVE);
        bool res;
        bool varchange = false;
 
-       Assert(client->state == CL_ACTIVE);
+       Assert(client->state == CL_ACTIVE || client->state == CL_LOGIN);
 
        if (client->link)
                return true;
                }
        case CL_LOGIN:
        case CL_WAITING:
+       case CL_WAITING_LOGIN:
        case CL_CANCEL:
                break;
        default:
                return false;
        client->suspended = 1;
 
-       if (!set_pool(client, dbname, username))
+       if (!set_pool(client, dbname, username, true))
                return false;
 
        change_client_state(client, CL_ACTIVE);