package main

import (
	"database/sql"
	"fmt"
	"strings"
	"time"

	_ "github.com/go-sql-driver/mysql"
)

type DB struct {
	conn *sql.DB
}

type Server struct {
	ID           int
	URL          string
	Name         string
	AnonName     string
	Active       bool
	Insecure     bool
	ExpectedCode int
	ExpectedBody string
	CreatedAt    time.Time
}

type Agent struct {
	ID        int
	Name      string
	AnonName  string
	Token     string
	LastIP    string
	LastSeen  *time.Time
	CreatedAt time.Time
}

type CheckResult struct {
	AgentID   int
	ServerID  int
	CheckedAt time.Time
	Success   bool
	HTTPCode  int
	ConnectMS int
	TTFBMS    int
	TotalMS   int
}

type ServerStat struct {
	Total     int
	Successes int
	AvgTTFBMS int
}

type ChartPoint struct {
	T       time.Time
	Success bool
	TotalMS int
	TTFBMS  int
	AgentID int
}

func ensureParseTime(dsn string) string {
	if strings.Contains(dsn, "parseTime=") {
		return dsn
	}
	if strings.Contains(dsn, "?") {
		return dsn + "&parseTime=true"
	}
	return dsn + "?parseTime=true"
}

func NewDB(dsn string) (*DB, error) {
	conn, err := sql.Open("mysql", ensureParseTime(dsn))
	if err != nil {
		return nil, err
	}
	conn.SetMaxOpenConns(10)
	conn.SetMaxIdleConns(5)
	if err := conn.Ping(); err != nil {
		return nil, fmt.Errorf("ping: %w", err)
	}
	d := &DB{conn: conn}
	if err := d.initSchema(); err != nil {
		return nil, fmt.Errorf("initSchema: %w", err)
	}
	return d, nil
}

func (d *DB) initSchema() error {
	stmts := []string{
		`CREATE TABLE IF NOT EXISTS servers (
			id INT AUTO_INCREMENT PRIMARY KEY,
			url VARCHAR(512) NOT NULL UNIQUE,
			name VARCHAR(255) NOT NULL DEFAULT '',
			anon_name VARCHAR(255) NOT NULL DEFAULT '',
			active TINYINT(1) NOT NULL DEFAULT 1,
			insecure TINYINT(1) NOT NULL DEFAULT 0,
			expected_code INT NOT NULL DEFAULT 200,
			expected_body VARCHAR(512) NOT NULL DEFAULT '',
			created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
		) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4`,
		`CREATE TABLE IF NOT EXISTS agents (
			id INT AUTO_INCREMENT PRIMARY KEY,
			name VARCHAR(255) NOT NULL,
			anon_name VARCHAR(255) NOT NULL DEFAULT '',
			token VARCHAR(128) NOT NULL UNIQUE,
			last_ip VARCHAR(45),
			last_seen DATETIME NULL,
			created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
		) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4`,
	}
	for _, s := range stmts {
		if _, err := d.conn.Exec(s); err != nil {
			return err
		}
	}
	if err := d.migrateAddColumn("servers", "anon_name", "VARCHAR(255) NOT NULL DEFAULT '' AFTER name"); err != nil {
		return err
	}
	if err := d.migrateAddColumn("agents", "anon_name", "VARCHAR(255) NOT NULL DEFAULT '' AFTER name"); err != nil {
		return err
	}
	if err := d.migrateAddColumn("servers", "insecure", "TINYINT(1) NOT NULL DEFAULT 0 AFTER active"); err != nil {
		return err
	}
	if err := d.migrateAddColumn("servers", "expected_code", "INT NOT NULL DEFAULT 200 AFTER insecure"); err != nil {
		return err
	}
	if err := d.migrateAddColumn("servers", "expected_body", "VARCHAR(512) NOT NULL DEFAULT '' AFTER expected_code"); err != nil {
		return err
	}
	return d.EnsureResultsTable(time.Now())
}

// migrateAddColumn adds a column to a table only if it doesn't already exist,
// checking information_schema to avoid any locking on already-migrated tables.
func (d *DB) migrateAddColumn(table, column, definition string) error {
	var count int
	err := d.conn.QueryRow(
		`SELECT COUNT(*) FROM information_schema.COLUMNS
		 WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = ? AND COLUMN_NAME = ?`,
		table, column,
	).Scan(&count)
	if err != nil {
		return err
	}
	if count > 0 {
		return nil // already exists
	}
	_, err = d.conn.Exec(fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s %s", table, column, definition))
	return err
}

func ResultsTableName(t time.Time) string {
	return fmt.Sprintf("results_%04d_%02d", t.Year(), t.Month())
}

func (d *DB) EnsureResultsTable(t time.Time) error {
	name := ResultsTableName(t)
	q := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s (
		id BIGINT AUTO_INCREMENT PRIMARY KEY,
		agent_id INT NOT NULL,
		server_id INT NOT NULL,
		checked_at DATETIME NOT NULL,
		success TINYINT(1) NOT NULL,
		http_code INT NOT NULL DEFAULT 0,
		connect_ms INT NOT NULL DEFAULT 0,
		ttfb_ms INT NOT NULL DEFAULT 0,
		total_ms INT NOT NULL DEFAULT 0,
		INDEX idx_server_time (server_id, checked_at),
		INDEX idx_agent_time (agent_id, checked_at)
	) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4`, name)
	_, err := d.conn.Exec(q)
	return err
}

func (d *DB) resultsTables() ([]string, error) {
	rows, err := d.conn.Query("SHOW TABLES LIKE 'results_%'")
	if err != nil {
		return nil, err
	}
	defer rows.Close()
	var tables []string
	for rows.Next() {
		var t string
		if err := rows.Scan(&t); err != nil {
			return nil, err
		}
		tables = append(tables, t)
	}
	return tables, rows.Err()
}

func (d *DB) ListServers(activeOnly bool) ([]Server, error) {
	q := "SELECT id, url, name, anon_name, active, insecure, expected_code, expected_body, created_at FROM servers"
	if activeOnly {
		q += " WHERE active = 1"
	}
	q += " ORDER BY id"
	rows, err := d.conn.Query(q)
	if err != nil {
		return nil, err
	}
	defer rows.Close()
	var out []Server
	for rows.Next() {
		var s Server
		if err := rows.Scan(&s.ID, &s.URL, &s.Name, &s.AnonName, &s.Active, &s.Insecure, &s.ExpectedCode, &s.ExpectedBody, &s.CreatedAt); err != nil {
			return nil, err
		}
		out = append(out, s)
	}
	return out, rows.Err()
}

func (d *DB) AddServer(url, name, anonName string, insecure bool, expectedCode int, expectedBody string) error {
	if expectedCode == 0 {
		expectedCode = 200
	}
	_, err := d.conn.Exec(
		"INSERT INTO servers (url, name, anon_name, insecure, expected_code, expected_body) VALUES (?, ?, ?, ?, ?, ?)",
		url, name, anonName, insecure, expectedCode, expectedBody,
	)
	return err
}

func (d *DB) SetServerActive(id int, active bool) error {
	v := 0
	if active {
		v = 1
	}
	_, err := d.conn.Exec("UPDATE servers SET active = ? WHERE id = ?", v, id)
	return err
}

func (d *DB) DeleteServer(id int) error {
	_, err := d.conn.Exec("DELETE FROM servers WHERE id = ?", id)
	return err
}

func (d *DB) ListAgents() ([]Agent, error) {
	rows, err := d.conn.Query(
		"SELECT id, name, anon_name, token, COALESCE(last_ip,''), last_seen, created_at FROM agents ORDER BY id",
	)
	if err != nil {
		return nil, err
	}
	defer rows.Close()
	var out []Agent
	for rows.Next() {
		var a Agent
		if err := rows.Scan(&a.ID, &a.Name, &a.AnonName, &a.Token, &a.LastIP, &a.LastSeen, &a.CreatedAt); err != nil {
			return nil, err
		}
		out = append(out, a)
	}
	return out, rows.Err()
}

func (d *DB) AddAgent(name, anonName, token string) error {
	_, err := d.conn.Exec("INSERT INTO agents (name, anon_name, token) VALUES (?, ?, ?)", name, anonName, token)
	return err
}

func (d *DB) DeleteAgent(id int) error {
	_, err := d.conn.Exec("DELETE FROM agents WHERE id = ?", id)
	return err
}

func (d *DB) AgentByToken(token string) (*Agent, error) {
	var a Agent
	err := d.conn.QueryRow(
		"SELECT id, name, anon_name, token, COALESCE(last_ip,''), last_seen, created_at FROM agents WHERE token = ?", token,
	).Scan(&a.ID, &a.Name, &a.AnonName, &a.Token, &a.LastIP, &a.LastSeen, &a.CreatedAt)
	if err == sql.ErrNoRows {
		return nil, nil
	}
	if err != nil {
		return nil, err
	}
	return &a, nil
}

func (d *DB) UpdateAgentSeen(id int, ip string) error {
	_, err := d.conn.Exec("UPDATE agents SET last_ip = ?, last_seen = NOW() WHERE id = ?", ip, id)
	return err
}

func (d *DB) SaveResults(results []CheckResult) error {
	if len(results) == 0 {
		return nil
	}
	now := time.Now()
	if err := d.EnsureResultsTable(now); err != nil {
		return err
	}
	table := ResultsTableName(now)
	tx, err := d.conn.Begin()
	if err != nil {
		return err
	}
	defer tx.Rollback()
	stmt, err := tx.Prepare(fmt.Sprintf(
		"INSERT INTO %s (agent_id, server_id, checked_at, success, http_code, connect_ms, ttfb_ms, total_ms) VALUES (?,?,?,?,?,?,?,?)",
		table,
	))
	if err != nil {
		return err
	}
	defer stmt.Close()
	for _, r := range results {
		t := r.CheckedAt
		if t.IsZero() {
			t = now
		}
		if _, err := stmt.Exec(r.AgentID, r.ServerID, t, r.Success, r.HTTPCode, r.ConnectMS, r.TTFBMS, r.TotalMS); err != nil {
			return err
		}
	}
	return tx.Commit()
}

func (d *DB) ServerStats(since time.Time) (map[int]ServerStat, error) {
	tables, err := d.resultsTables()
	if err != nil {
		return nil, err
	}
	out := map[int]ServerStat{}
	for _, tbl := range tables {
		q := fmt.Sprintf(
			"SELECT server_id, COUNT(*), SUM(success), AVG(NULLIF(ttfb_ms,0)) FROM %s WHERE checked_at >= ? GROUP BY server_id",
			tbl,
		)
		rows, err := d.conn.Query(q, since)
		if err != nil {
			continue
		}
		for rows.Next() {
			var sid, total int
			var succ sql.NullInt64
			var avgTTFB sql.NullFloat64
			if err := rows.Scan(&sid, &total, &succ, &avgTTFB); err != nil {
				rows.Close()
				return nil, err
			}
			cur := out[sid]
			cur.Total += total
			cur.Successes += int(succ.Int64)
			// weighted average across tables
			if avgTTFB.Valid && total > 0 {
				cur.AvgTTFBMS = int((float64(cur.AvgTTFBMS*cur.Total) + avgTTFB.Float64*float64(total)) / float64(cur.Total+total))
			}
			out[sid] = cur
		}
		rows.Close()
	}
	return out, nil
}

// ServerChartData returns chart points for a server. agentID=0 means all agents.
func (d *DB) ServerChartData(serverID int, agentID int, since time.Time) ([]ChartPoint, error) {
	tables, err := d.resultsTables()
	if err != nil {
		return nil, err
	}
	var points []ChartPoint
	for _, tbl := range tables {
		var q string
		var args []interface{}
		if agentID > 0 {
			q = fmt.Sprintf(
				"SELECT checked_at, success, total_ms, ttfb_ms, agent_id FROM %s WHERE server_id = ? AND agent_id = ? AND checked_at >= ? ORDER BY checked_at",
				tbl,
			)
			args = []interface{}{serverID, agentID, since}
		} else {
			q = fmt.Sprintf(
				"SELECT checked_at, success, total_ms, ttfb_ms, agent_id FROM %s WHERE server_id = ? AND checked_at >= ? ORDER BY checked_at",
				tbl,
			)
			args = []interface{}{serverID, since}
		}
		rows, err := d.conn.Query(q, args...)
		if err != nil {
			continue
		}
		for rows.Next() {
			var p ChartPoint
			if err := rows.Scan(&p.T, &p.Success, &p.TotalMS, &p.TTFBMS, &p.AgentID); err != nil {
				rows.Close()
				return nil, err
			}
			points = append(points, p)
		}
		rows.Close()
	}
	return points, nil
}

func (d *DB) GetLatestServerResult(serverID int) (*CheckResult, error) {
	tables, err := d.resultsTables()
	if err != nil {
		return nil, err
	}

	var latest *CheckResult
	latestTime := time.Time{}

	for _, tbl := range tables {
		query := fmt.Sprintf(`
			SELECT agent_id, checked_at, success, http_code, connect_ms, ttfb_ms, total_ms
			FROM %s 
			WHERE server_id = ? 
			ORDER BY checked_at DESC 
			LIMIT 1`, tbl)

		var cr CheckResult
		err := d.conn.QueryRow(query, serverID).Scan(
			&cr.AgentID, &cr.CheckedAt, &cr.Success, &cr.HTTPCode,
			&cr.ConnectMS, &cr.TTFBMS, &cr.TotalMS,
		)
		if err == nil {
			if latest == nil || cr.CheckedAt.After(latestTime) {
				latest = &cr
				latestTime = cr.CheckedAt
			}
		}
	}

	return latest, nil
}
