202 lines
4.9 KiB
Go
202 lines
4.9 KiB
Go
package handler
|
|
|
|
import (
|
|
"encoding/json"
|
|
"net/http"
|
|
"time"
|
|
|
|
"gpt-plus/internal/db"
|
|
"gpt-plus/internal/task"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
var taskManager *task.TaskManager
|
|
|
|
func SetTaskManager(tm *task.TaskManager) {
|
|
taskManager = tm
|
|
}
|
|
|
|
func RegisterTaskRoutes(api *gin.RouterGroup) {
|
|
api.POST("/tasks", CreateTask)
|
|
api.GET("/tasks", ListTasks)
|
|
api.GET("/tasks/:id", GetTask)
|
|
api.GET("/tasks/:id/logs", GetTaskLogs)
|
|
api.POST("/tasks/:id/start", StartTask)
|
|
api.POST("/tasks/:id/stop", StopTask)
|
|
api.POST("/tasks/:id/force-stop", ForceStopTask)
|
|
api.DELETE("/tasks/:id", DeleteTask)
|
|
}
|
|
|
|
type createTaskRequest struct {
|
|
Type string `json:"type" binding:"required"`
|
|
Count int `json:"count" binding:"required,min=1"`
|
|
}
|
|
|
|
func CreateTask(c *gin.Context) {
|
|
var req createTaskRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
if req.Type != "plus" && req.Type != "team" && req.Type != "both" {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "类型必须为 plus, team 或 both"})
|
|
return
|
|
}
|
|
|
|
// Snapshot current config
|
|
var configs []db.SystemConfig
|
|
db.GetDB().Find(&configs)
|
|
cfgJSON, _ := json.Marshal(configs)
|
|
|
|
t := &db.Task{
|
|
ID: uuid.New().String(),
|
|
Type: req.Type,
|
|
TotalCount: req.Count,
|
|
Status: "pending",
|
|
Config: string(cfgJSON),
|
|
}
|
|
|
|
if err := db.GetDB().Create(t).Error; err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
c.JSON(http.StatusCreated, t)
|
|
}
|
|
|
|
func ListTasks(c *gin.Context) {
|
|
d := db.GetDB()
|
|
query := d.Model(&db.Task{}).Order("created_at DESC")
|
|
|
|
if status := c.Query("status"); status != "" {
|
|
statuses := splitComma(status)
|
|
query = query.Where("status IN ?", statuses)
|
|
}
|
|
|
|
p := db.PaginationParams{Page: intQuery(c, "page", 1), Size: intQuery(c, "size", 20)}
|
|
var tasks []db.Task
|
|
result, err := db.Paginate(query, p, &tasks)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
c.JSON(http.StatusOK, result)
|
|
}
|
|
|
|
func GetTask(c *gin.Context) {
|
|
var t db.Task
|
|
if err := db.GetDB().First(&t, "id = ?", c.Param("id")).Error; err != nil {
|
|
c.JSON(http.StatusNotFound, gin.H{"error": "任务不存在"})
|
|
return
|
|
}
|
|
c.JSON(http.StatusOK, t)
|
|
}
|
|
|
|
func GetTaskLogs(c *gin.Context) {
|
|
taskID := c.Param("id")
|
|
sinceID := intQuery(c, "since_id", 0)
|
|
limit := intQuery(c, "limit", 50)
|
|
if limit > 200 {
|
|
limit = 200
|
|
}
|
|
|
|
var logs []db.TaskLog
|
|
query := db.GetDB().Where("task_id = ?", taskID)
|
|
if sinceID > 0 {
|
|
query = query.Where("id > ?", sinceID)
|
|
}
|
|
query.Order("id ASC").Limit(limit).Find(&logs)
|
|
c.JSON(http.StatusOK, gin.H{"items": logs})
|
|
}
|
|
|
|
func StartTask(c *gin.Context) {
|
|
if taskManager == nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "任务管理器未初始化"})
|
|
return
|
|
}
|
|
if err := taskManager.Start(c.Param("id")); err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
c.JSON(http.StatusOK, gin.H{"message": "任务已启动"})
|
|
}
|
|
|
|
func StopTask(c *gin.Context) {
|
|
if taskManager == nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "任务管理器未初始化"})
|
|
return
|
|
}
|
|
if err := taskManager.Stop(c.Param("id")); err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
c.JSON(http.StatusOK, gin.H{"message": "正在停止..."})
|
|
}
|
|
|
|
func ForceStopTask(c *gin.Context) {
|
|
if taskManager == nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "任务管理器未初始化"})
|
|
return
|
|
}
|
|
if err := taskManager.ForceStop(c.Param("id")); err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
c.JSON(http.StatusOK, gin.H{"message": "已强制取消"})
|
|
}
|
|
|
|
func DeleteTask(c *gin.Context) {
|
|
taskID := c.Param("id")
|
|
var t db.Task
|
|
if err := db.GetDB().First(&t, "id = ?", taskID).Error; err != nil {
|
|
c.JSON(http.StatusNotFound, gin.H{"error": "任务不存在"})
|
|
return
|
|
}
|
|
if t.Status == "running" || t.Status == "stopping" {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "无法删除运行中的任务"})
|
|
return
|
|
}
|
|
db.GetDB().Where("task_id = ?", taskID).Delete(&db.TaskLog{})
|
|
db.GetDB().Delete(&t)
|
|
c.JSON(http.StatusOK, gin.H{"message": "已删除"})
|
|
}
|
|
|
|
func splitComma(s string) []string {
|
|
var result []string
|
|
for _, v := range splitStr(s, ",") {
|
|
v = trimStr(v)
|
|
if v != "" {
|
|
result = append(result, v)
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
func splitStr(s, sep string) []string {
|
|
result := []string{}
|
|
start := 0
|
|
for i := 0; i <= len(s)-len(sep); i++ {
|
|
if s[i:i+len(sep)] == sep {
|
|
result = append(result, s[start:i])
|
|
start = i + len(sep)
|
|
}
|
|
}
|
|
result = append(result, s[start:])
|
|
return result
|
|
}
|
|
|
|
func trimStr(s string) string {
|
|
for len(s) > 0 && (s[0] == ' ' || s[0] == '\t') {
|
|
s = s[1:]
|
|
}
|
|
for len(s) > 0 && (s[len(s)-1] == ' ' || s[len(s)-1] == '\t') {
|
|
s = s[:len(s)-1]
|
|
}
|
|
return s
|
|
}
|
|
|
|
// Ensure time import is used
|
|
var _ = time.Now
|