package security

import (
	"crypto/rand"
	"encoding/base64"
	"fmt"
	"net/http"
	"path/filepath"
	"strings"
	"sync"
	"time"
)

// RateLimiter 请求限制器
type RateLimiter struct {
	clients map[string]*ClientInfo
	mutex   sync.RWMutex
}

// ClientInfo 客户端信息
type ClientInfo struct {
	requests int
	lastTime time.Time
}

// NewRateLimiter 创建新的请求限制器
func NewRateLimiter() *RateLimiter {
	limiter := &RateLimiter{
		clients: make(map[string]*ClientInfo),
	}

	return limiter
}

// cleanupExpired 清理过期的客户端记录（懒清理）
func (rl *RateLimiter) cleanupExpired() {
	now := time.Now()
	for ip, info := range rl.clients {
		if now.Sub(info.lastTime) > 5*time.Minute {
			delete(rl.clients, ip)
		}
	}
}

// AllowRequest 检查是否允许请求
func (rl *RateLimiter) AllowRequest(ip string, maxRequests int, window time.Duration) bool {
	rl.mutex.Lock()
	defer rl.mutex.Unlock()

	// 每100个请求清理一次过期记录
	if len(rl.clients) > 100 {
		rl.cleanupExpired()
	}

	now := time.Now()
	client, exists := rl.clients[ip]

	if !exists {
		rl.clients[ip] = &ClientInfo{
			requests: 1,
			lastTime: now,
		}
		return true
	}

	// 如果距离上次请求超过时间窗口，重置计数
	if now.Sub(client.lastTime) > window {
		client.requests = 1
		client.lastTime = now
		return true
	}

	// 检查是否超过请求限制
	if client.requests >= maxRequests {
		return false
	}

	client.requests++
	client.lastTime = now
	return true
}

// RateLimitMiddleware 请求限制中间件
func RateLimitMiddleware(maxRequests int, window time.Duration) func(http.Handler) http.Handler {
	limiter := NewRateLimiter()

	return func(next http.Handler) http.Handler {
		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
			ip := getClientIP(r)

			if !limiter.AllowRequest(ip, maxRequests, window) {
				http.Error(w, "请求过于频繁，请稍后再试", http.StatusTooManyRequests)
				return
			}

			next.ServeHTTP(w, r)
		})
	}
}

// getClientIP 获取客户端真实IP
func getClientIP(r *http.Request) string {
	// 检查X-Forwarded-For头
	if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
		// 取第一个IP（可能包含多个IP）
		if idx := strings.Index(xff, ","); idx != -1 {
			return strings.TrimSpace(xff[:idx])
		}
		return strings.TrimSpace(xff)
	}

	// 检查X-Real-IP头
	if xri := r.Header.Get("X-Real-IP"); xri != "" {
		return strings.TrimSpace(xri)
	}

	// 使用RemoteAddr
	if idx := strings.LastIndex(r.RemoteAddr, ":"); idx != -1 {
		return r.RemoteAddr[:idx]
	}

	return r.RemoteAddr
}

// generateNonce 生成安全的随机nonce值
func generateNonce() string {
	b := make([]byte, 16)
	rand.Read(b)
	return base64.URLEncoding.EncodeToString(b)
}

// SecurityHeaders 添加HTTP安全头中间件
func SecurityHeaders(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		// 生成nonce
		nonce := generateNonce()

		// 将nonce存储在响应头中，以便模板使用
		w.Header().Set("X-Content-Security-Nonce", nonce)

		// 防止MIME类型嗅探
		w.Header().Set("X-Content-Type-Options", "nosniff")

		// 防止点击劫持
		w.Header().Set("X-Frame-Options", "DENY")

		// XSS保护
		w.Header().Set("X-XSS-Protection", "1; mode=block")

		// 安全的内容安全策略（使用nonce替代unsafe-inline）
		csp := fmt.Sprintf(
			"default-src 'self'; "+
				"img-src 'self' data: https:; "+
				"script-src 'self' 'nonce-%s'; "+
				"style-src 'self' 'unsafe-inline'; "+
				"connect-src 'self'; "+
				"font-src 'self'; "+
				"object-src 'none'; "+
				"base-uri 'self'; "+
				"form-action 'self'; "+
				"upgrade-insecure-requests", nonce)

		w.Header().Set("Content-Security-Policy", csp)

		// 强制HTTPS（如果可用）
		w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")

		// 引用者策略
		w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")

		// 权限策略
		w.Header().Set("Permissions-Policy",
			"geolocation=(), "+
				"microphone=(), "+
				"camera=(), "+
				"payment=(), "+
				"usb=(), "+
				"magnetometer=(), "+
				"gyroscope=(), "+
				"accelerometer=()")

		next.ServeHTTP(w, r)
	})
}

// ValidatePath 验证路径安全性，防止路径遍历攻击
func ValidatePath(basePath, requestPath string) bool {
	// 规范化路径
	cleanBasePath := filepath.Clean(basePath)
	cleanReqPath := filepath.Clean(requestPath)

	// 确保BasePath以/开头
	if !strings.HasPrefix(cleanBasePath, "/") {
		cleanBasePath = "/" + cleanBasePath
	}

	// 检查请求路径是否在BasePath范围内
	if !strings.HasPrefix(cleanReqPath, cleanBasePath) {
		return false
	}

	// 检查是否包含危险的路径组件
	dangerousComponents := []string{
		"..", "~", "$HOME", "%HOME",
		"etc", "bin", "usr", "var", "tmp",
	}

	pathParts := strings.Split(cleanReqPath, "/")
	for _, part := range pathParts {
		// 跳过空字符串（根路径会产生空字符串）
		if part == "" {
			continue
		}
		for _, dangerous := range dangerousComponents {
			if part == dangerous {
				return false
			}
		}
	}

	return true
}

// IsAllowedPath 检查是否为允许的路径
func IsAllowedPath(basePath, requestPath string) (string, bool) {
	// 规范化BasePath
	cleanBasePath := filepath.Clean(basePath)
	if !strings.HasPrefix(cleanBasePath, "/") {
		cleanBasePath = "/" + cleanBasePath
	}

	// 规范化请求路径
	cleanReqPath := filepath.Clean(requestPath)
	if !strings.HasPrefix(cleanReqPath, "/") {
		cleanReqPath = "/" + cleanReqPath
	}

	// 检查路径安全性
	if !ValidatePath(cleanBasePath, cleanReqPath) {
		return "", false
	}

	// 处理根路径
	if cleanReqPath == "/" {
		if cleanBasePath == "/" {
			return "/", true // 相册页面
		}
		return cleanBasePath, true // 重定向到BasePath
	}

	// 提取相对路径
	if strings.HasPrefix(cleanReqPath, cleanBasePath) {
		relativePath := strings.TrimPrefix(cleanReqPath, cleanBasePath)
		if relativePath == "" {
			return "/", true // 相册页面
		}
		// 确保相对路径以/开头
		if !strings.HasPrefix(relativePath, "/") {
			relativePath = "/" + relativePath
		}
		if relativePath == "/list.json" {
			return "/list.json", true // JSON接口
		}
		if relativePath == "/health" {
			return "/health", true // 健康检查接口
		}
	}

	return "", false
}