package main

import (
	"bytes"
	"database/sql"
	"encoding/json"
	"fmt"
	"io"
	"log"
	"net/http"
	"sort"
	"strings"
	"time"
)

func runExporter() {
	cfg := loadExporterConfig()
	log.Printf("Exporter starting -> VM %s, interval %ds", cfg.VMURL, cfg.Interval)

	db, err := NewDB(cfg.DBDsn)
	if err != nil {
		log.Fatalf("db: %v", err)
	}
	if err := db.ensureExportStateTable(); err != nil {
		log.Fatalf("export state: %v", err)
	}

	tick := time.NewTicker(time.Duration(cfg.Interval) * time.Second)
	defer tick.Stop()

	for {
		if err := exportOnce(db, cfg); err != nil {
			log.Printf("export: %v", err)
		}
		if err := dropOldTables(db); err != nil {
			log.Printf("drop old: %v", err)
		}
		<-tick.C
	}
}

func (d *DB) ensureExportStateTable() error {
	_, err := d.conn.Exec(`CREATE TABLE IF NOT EXISTS vm_export_state (
		table_name VARCHAR(64) NOT NULL PRIMARY KEY,
		last_id BIGINT NOT NULL DEFAULT 0,
		updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
	) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4`)
	return err
}

func (d *DB) getExportCursor(table string) (int64, error) {
	var id int64
	err := d.conn.QueryRow("SELECT last_id FROM vm_export_state WHERE table_name = ?", table).Scan(&id)
	if err == sql.ErrNoRows {
		return 0, nil
	}
	return id, err
}

func (d *DB) setExportCursor(table string, id int64) error {
	_, err := d.conn.Exec(
		"INSERT INTO vm_export_state (table_name, last_id) VALUES (?, ?) ON DUPLICATE KEY UPDATE last_id = VALUES(last_id)",
		table, id,
	)
	return err
}

func (d *DB) serverLabels() (map[int]map[string]string, error) {
	rows, err := d.conn.Query("SELECT id, url, name, anon_name FROM servers")
	if err != nil {
		return nil, err
	}
	defer rows.Close()
	out := map[int]map[string]string{}
	for rows.Next() {
		var id int
		var url, name, anon string
		if err := rows.Scan(&id, &url, &name, &anon); err != nil {
			return nil, err
		}
		out[id] = map[string]string{"server_url": url, "server_name": name, "server_anon": anon}
	}
	return out, rows.Err()
}

func (d *DB) agentLabels() (map[int]map[string]string, error) {
	rows, err := d.conn.Query("SELECT id, name, anon_name FROM agents")
	if err != nil {
		return nil, err
	}
	defer rows.Close()
	out := map[int]map[string]string{}
	for rows.Next() {
		var id int
		var name, anon string
		if err := rows.Scan(&id, &name, &anon); err != nil {
			return nil, err
		}
		out[id] = map[string]string{"agent_name": name, "agent_anon": anon}
	}
	return out, rows.Err()
}

type vmLine struct {
	Metric     map[string]string `json:"metric"`
	Values     []float64         `json:"values"`
	Timestamps []int64           `json:"timestamps"`
}

func exportOnce(d *DB, cfg ExporterConfig) error {
	tables, err := d.resultsTables()
	if err != nil {
		return err
	}
	sort.Strings(tables)
	servers, err := d.serverLabels()
	if err != nil {
		return err
	}
	agents, err := d.agentLabels()
	if err != nil {
		return err
	}

	for _, tbl := range tables {
		for {
			cursor, err := d.getExportCursor(tbl)
			if err != nil {
				return fmt.Errorf("%s cursor: %w", tbl, err)
			}
			rows, err := d.conn.Query(fmt.Sprintf(
				"SELECT id, agent_id, server_id, checked_at, success, http_code, connect_ms, ttfb_ms, total_ms FROM %s WHERE id > ? ORDER BY id LIMIT ?",
				tbl,
			), cursor, cfg.Batch)
			if err != nil {
				return fmt.Errorf("%s select: %w", tbl, err)
			}
			var buf bytes.Buffer
			enc := json.NewEncoder(&buf)
			var maxID int64
			var count int
			for rows.Next() {
				var id int64
				var agentID, serverID, httpCode, connectMS, ttfbMS, totalMS int
				var checkedAt time.Time
				var success bool
				if err := rows.Scan(&id, &agentID, &serverID, &checkedAt, &success, &httpCode, &connectMS, &ttfbMS, &totalMS); err != nil {
					rows.Close()
					return err
				}
				ts := checkedAt.UnixMilli()
				base := map[string]string{
					"server_id": fmt.Sprintf("%d", serverID),
					"agent_id":  fmt.Sprintf("%d", agentID),
				}
				for k, v := range servers[serverID] {
					base[k] = v
				}
				for k, v := range agents[agentID] {
					base[k] = v
				}
				succ := 0.0
				if success {
					succ = 1.0
				}
				for metric, val := range map[string]float64{
					"status_check_success":    succ,
					"status_check_http_code":  float64(httpCode),
					"status_check_connect_ms": float64(connectMS),
					"status_check_ttfb_ms":    float64(ttfbMS),
					"status_check_total_ms":   float64(totalMS),
				} {
					labels := copyLabels(base)
					labels["__name__"] = metric
					if err := enc.Encode(vmLine{Metric: labels, Values: []float64{val}, Timestamps: []int64{ts}}); err != nil {
						rows.Close()
						return err
					}
				}
				if id > maxID {
					maxID = id
				}
				count++
			}
			rows.Close()
			if count == 0 {
				break
			}
			if err := postVM(cfg.VMURL, &buf); err != nil {
				return fmt.Errorf("%s post: %w", tbl, err)
			}
			if err := d.setExportCursor(tbl, maxID); err != nil {
				return fmt.Errorf("%s cursor update: %w", tbl, err)
			}
			log.Printf("exported %d rows from %s (cursor=%d)", count, tbl, maxID)
			if count < cfg.Batch {
				break
			}
		}
	}
	return nil
}

func copyLabels(m map[string]string) map[string]string {
	out := make(map[string]string, len(m)+1)
	for k, v := range m {
		out[k] = v
	}
	return out
}

func postVM(vmURL string, body io.Reader) error {
	url := vmURL + "/api/v1/import"
	req, err := http.NewRequest("POST", url, body)
	if err != nil {
		return err
	}
	req.Header.Set("Content-Type", "application/stream+json")
	client := &http.Client{Timeout: 60 * time.Second}
	resp, err := client.Do(req)
	if err != nil {
		return err
	}
	defer resp.Body.Close()
	if resp.StatusCode >= 300 {
		b, _ := io.ReadAll(resp.Body)
		return fmt.Errorf("vm status %d: %s", resp.StatusCode, strings.TrimSpace(string(b)))
	}
	_, _ = io.Copy(io.Discard, resp.Body)
	return nil
}

// dropOldTables drops results_YYYY_MM tables older than the current month,
// but only once the current month is at least 48 hours old (so any late-arriving
// data for the previous month has been flushed and exported).
func dropOldTables(d *DB) error {
	now := time.Now()
	startOfMonth := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location())
	if now.Sub(startOfMonth) < 48*time.Hour {
		return nil
	}
	tables, err := d.resultsTables()
	if err != nil {
		return err
	}
	current := ResultsTableName(now)
	for _, tbl := range tables {
		if tbl >= current {
			continue
		}
		cursor, err := d.getExportCursor(tbl)
		if err != nil {
			return err
		}
		var maxID sql.NullInt64
		if err := d.conn.QueryRow(fmt.Sprintf("SELECT MAX(id) FROM %s", tbl)).Scan(&maxID); err != nil {
			return err
		}
		if maxID.Valid && cursor < maxID.Int64 {
			log.Printf("skip drop %s: export incomplete (cursor=%d, max=%d)", tbl, cursor, maxID.Int64)
			continue
		}
		if _, err := d.conn.Exec(fmt.Sprintf("DROP TABLE %s", tbl)); err != nil {
			return fmt.Errorf("drop %s: %w", tbl, err)
		}
		if _, err := d.conn.Exec("DELETE FROM vm_export_state WHERE table_name = ?", tbl); err != nil {
			return err
		}
		log.Printf("dropped old table %s", tbl)
	}
	return nil
}
