package handler import ( "encoding/json" "net/http" "time" "gpt-plus/internal/db" "gpt-plus/internal/task" "github.com/gin-gonic/gin" "github.com/google/uuid" ) var taskManager *task.TaskManager func SetTaskManager(tm *task.TaskManager) { taskManager = tm } func RegisterTaskRoutes(api *gin.RouterGroup) { api.POST("/tasks", CreateTask) api.GET("/tasks", ListTasks) api.GET("/tasks/:id", GetTask) api.GET("/tasks/:id/logs", GetTaskLogs) api.POST("/tasks/:id/start", StartTask) api.POST("/tasks/:id/stop", StopTask) api.POST("/tasks/:id/force-stop", ForceStopTask) api.DELETE("/tasks/:id", DeleteTask) } type createTaskRequest struct { Type string `json:"type" binding:"required"` Count int `json:"count" binding:"required,min=1"` } func CreateTask(c *gin.Context) { var req createTaskRequest if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } if req.Type != "plus" && req.Type != "team" && req.Type != "both" { c.JSON(http.StatusBadRequest, gin.H{"error": "类型必须为 plus, team 或 both"}) return } // Snapshot current config var configs []db.SystemConfig db.GetDB().Find(&configs) cfgJSON, _ := json.Marshal(configs) t := &db.Task{ ID: uuid.New().String(), Type: req.Type, TotalCount: req.Count, Status: "pending", Config: string(cfgJSON), } if err := db.GetDB().Create(t).Error; err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } c.JSON(http.StatusCreated, t) } func ListTasks(c *gin.Context) { d := db.GetDB() query := d.Model(&db.Task{}).Order("created_at DESC") if status := c.Query("status"); status != "" { statuses := splitComma(status) query = query.Where("status IN ?", statuses) } p := db.PaginationParams{Page: intQuery(c, "page", 1), Size: intQuery(c, "size", 20)} var tasks []db.Task result, err := db.Paginate(query, p, &tasks) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, result) } func GetTask(c *gin.Context) { var t db.Task if err := db.GetDB().First(&t, "id = ?", c.Param("id")).Error; err != nil { c.JSON(http.StatusNotFound, gin.H{"error": "任务不存在"}) return } c.JSON(http.StatusOK, t) } func GetTaskLogs(c *gin.Context) { taskID := c.Param("id") sinceID := intQuery(c, "since_id", 0) limit := intQuery(c, "limit", 50) if limit > 200 { limit = 200 } var logs []db.TaskLog query := db.GetDB().Where("task_id = ?", taskID) if sinceID > 0 { query = query.Where("id > ?", sinceID) } query.Order("id ASC").Limit(limit).Find(&logs) c.JSON(http.StatusOK, gin.H{"items": logs}) } func StartTask(c *gin.Context) { if taskManager == nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "任务管理器未初始化"}) return } if err := taskManager.Start(c.Param("id")); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, gin.H{"message": "任务已启动"}) } func StopTask(c *gin.Context) { if taskManager == nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "任务管理器未初始化"}) return } if err := taskManager.Stop(c.Param("id")); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, gin.H{"message": "正在停止..."}) } func ForceStopTask(c *gin.Context) { if taskManager == nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "任务管理器未初始化"}) return } if err := taskManager.ForceStop(c.Param("id")); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, gin.H{"message": "已强制取消"}) } func DeleteTask(c *gin.Context) { taskID := c.Param("id") var t db.Task if err := db.GetDB().First(&t, "id = ?", taskID).Error; err != nil { c.JSON(http.StatusNotFound, gin.H{"error": "任务不存在"}) return } if t.Status == "running" || t.Status == "stopping" { c.JSON(http.StatusBadRequest, gin.H{"error": "无法删除运行中的任务"}) return } db.GetDB().Where("task_id = ?", taskID).Delete(&db.TaskLog{}) db.GetDB().Delete(&t) c.JSON(http.StatusOK, gin.H{"message": "已删除"}) } func splitComma(s string) []string { var result []string for _, v := range splitStr(s, ",") { v = trimStr(v) if v != "" { result = append(result, v) } } return result } func splitStr(s, sep string) []string { result := []string{} start := 0 for i := 0; i <= len(s)-len(sep); i++ { if s[i:i+len(sep)] == sep { result = append(result, s[start:i]) start = i + len(sep) } } result = append(result, s[start:]) return result } func trimStr(s string) string { for len(s) > 0 && (s[0] == ' ' || s[0] == '\t') { s = s[1:] } for len(s) > 0 && (s[len(s)-1] == ' ' || s[len(s)-1] == '\t') { s = s[:len(s)-1] } return s } // Ensure time import is used var _ = time.Now