package main

import (
	"bytes"
	"crypto/tls"
	"encoding/json"
	"fmt"
	"io"
	"log"
	"net"
	"net/http"
	"net/http/httptrace"
	"strings"
	"time"
)

func runAgent() {
	cfg := loadAgentConfig()
	log.Printf("Agent starting -> %s, interval %ds", cfg.ServerURL, cfg.CheckInterval)

	secureClient := &http.Client{Timeout: 10 * time.Second}
	insecureClient := &http.Client{
		Timeout: 10 * time.Second,
		Transport: &http.Transport{
			TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
		},
	}
	apiClient := &http.Client{
		Timeout: 15 * time.Second,
		CheckRedirect: func(req *http.Request, via []*http.Request) error {
			if len(via) >= 5 {
				return fmt.Errorf("too many redirects")
			}
			// preserve original method on redirect (Go changes POST→GET on 301/302)
			req.Method = via[0].Method
			return nil
		},
	}

	var servers []serverDTO
	var lastFetch time.Time
	var cachedIP string

	for {
		if len(servers) == 0 || time.Since(lastFetch) >= 10*time.Minute {
			fetched, err := fetchServers(apiClient, cfg)
			if err != nil {
				log.Printf("fetchServers: %v", err)
				if len(servers) == 0 {
					log.Printf("No servers, retry in 1 min")
					time.Sleep(time.Minute)
					continue
				}
				log.Printf("Keeping old server list (%d servers)", len(servers))
			} else {
				servers = fetched
				lastFetch = time.Now()
				log.Printf("Server list refreshed: %d servers", len(servers))
				if ip := getMyIP(); ip != "" {
					cachedIP = ip
				}
			}
		}

		results := checkAll(secureClient, insecureClient, servers)
		log.Printf("Checked %d servers, sending results", len(results))
		if err := sendResults(apiClient, cfg, results, cachedIP); err != nil {
			log.Printf("sendResults: %v", err)
		}

		time.Sleep(time.Duration(cfg.CheckInterval) * time.Second)
	}
}

type serverDTO struct {
	ID           int    `json:"id"`
	URL          string `json:"url"`
	Insecure     bool   `json:"insecure"`
	ExpectedCode int    `json:"expected_code"`
	ExpectedBody string `json:"expected_body"`
}

func fetchServers(client *http.Client, cfg AgentConfig) ([]serverDTO, error) {
	req, err := http.NewRequest(http.MethodGet, cfg.ServerURL+"/api/v1/servers", nil)
	if err != nil {
		return nil, err
	}
	req.Header.Set("X-Agent-Token", cfg.Token)
	resp, err := client.Do(req)
	if err != nil {
		return nil, err
	}
	defer resp.Body.Close()
	if resp.StatusCode != http.StatusOK {
		return nil, fmt.Errorf("server returned %d", resp.StatusCode)
	}
	var out []serverDTO
	return out, json.NewDecoder(resp.Body).Decode(&out)
}

type agentResult struct {
	ServerID  int    `json:"server_id"`
	Success   bool   `json:"success"`
	HTTPCode  int    `json:"http_code"`
	ConnectMS int    `json:"connect_ms"`
	TTFBMS    int    `json:"ttfb_ms"`
	TotalMS   int    `json:"total_ms"`
	CheckedAt string `json:"checked_at"`
}

func checkAll(secure, insecure *http.Client, servers []serverDTO) []agentResult {
	results := make([]agentResult, 0, len(servers))
	for _, sv := range servers {
		client := secure
		if sv.Insecure {
			client = insecure
		}
		results = append(results, checkOne(client, sv))
	}
	return results
}

func checkOne(client *http.Client, sv serverDTO) agentResult {
	expectedCode := sv.ExpectedCode
	if expectedCode == 0 {
		expectedCode = 200
	}
	r := agentResult{
		ServerID:  sv.ID,
		CheckedAt: time.Now().UTC().Format(time.RFC3339),
	}

	var connectStart, connectDone, firstByte time.Time
	start := time.Now()

	trace := &httptrace.ClientTrace{
		ConnectStart:         func(_, _ string) { connectStart = time.Now() },
		ConnectDone:          func(_, _ string, _ error) { connectDone = time.Now() },
		GotFirstResponseByte: func() { firstByte = time.Now() },
	}

	req, err := http.NewRequest(http.MethodGet, sv.URL, nil)
	if err != nil {
		log.Printf("[check] %s: %v", sv.URL, err)
		return r
	}
	req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))

	resp, err := client.Do(req)
	if err != nil {
		log.Printf("[check] %s: %v", sv.URL, err)
		r.TotalMS = int(time.Since(start).Milliseconds())
		return r
	}
	defer resp.Body.Close()

	body, _ := io.ReadAll(resp.Body)
	r.TotalMS = int(time.Since(start).Milliseconds())
	r.HTTPCode = resp.StatusCode

	if !connectStart.IsZero() && !connectDone.IsZero() {
		r.ConnectMS = int(connectDone.Sub(connectStart).Milliseconds())
	}
	if !firstByte.IsZero() {
		r.TTFBMS = int(firstByte.Sub(start).Milliseconds())
	}

	r.Success = resp.StatusCode == expectedCode
	if r.Success && sv.ExpectedBody != "" {
		r.Success = bytes.Contains(body, []byte(sv.ExpectedBody))
	}
	if !r.Success {
		r.ConnectMS = 0
		r.TTFBMS = 0
		r.TotalMS = 0
	}
	log.Printf("[check] %s -> %d success=%v total=%dms", sv.URL, resp.StatusCode, r.Success, r.TotalMS)
	return r
}

func sendResults(client *http.Client, cfg AgentConfig, results []agentResult, ip string) error {
	payload := map[string]interface{}{
		"results":  results,
		"agent_ip": ip,
	}
	b, err := json.Marshal(payload)
	if err != nil {
		return err
	}
	req, err := http.NewRequest(http.MethodPost, cfg.ServerURL+"/api/v1/results", bytes.NewReader(b))
	if err != nil {
		return err
	}
	req.Header.Set("X-Agent-Token", cfg.Token)
	req.Header.Set("Content-Type", "application/json")
	resp, err := client.Do(req)
	if err != nil {
		return err
	}
	defer resp.Body.Close()
	if resp.StatusCode != http.StatusNoContent {
		body, _ := io.ReadAll(resp.Body)
		return fmt.Errorf("server returned %d: %s", resp.StatusCode, body)
	}
	return nil
}

func getMyIP() string {
	client := &http.Client{Timeout: 5 * time.Second}
	resp, err := client.Get("https://api.ipify.org")
	if err != nil {
		log.Printf("getMyIP: %v", err)
		return ""
	}
	defer resp.Body.Close()
	b, _ := io.ReadAll(resp.Body)
	ip := strings.TrimSpace(string(b))
	if isPrivateIP(ip) {
		log.Printf("getMyIP: got private IP %s, ignoring", ip)
		return ""
	}
	return ip
}

func isPrivateIP(s string) bool {
	ip := net.ParseIP(s)
	if ip == nil {
		return true
	}
	private := []string{"10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "127.0.0.0/8", "::1/128", "fc00::/7"}
	for _, cidr := range private {
		_, block, _ := net.ParseCIDR(cidr)
		if block.Contains(ip) {
			return true
		}
	}
	return false
}
