Files
gpt-plus-gpt/internal/task/manager.go
2026-03-15 20:48:19 +08:00

98 lines
2.1 KiB
Go

package task
import (
"fmt"
"log"
"sync"
"time"
"gpt-plus/internal/db"
"gorm.io/gorm"
)
// TaskManager controls the lifecycle of task execution.
// Only one task may run at a time (single-task serial constraint).
type TaskManager struct {
mu sync.Mutex
current *TaskRunner
gormDB *gorm.DB
}
func NewTaskManager(d *gorm.DB) *TaskManager {
return &TaskManager{gormDB: d}
}
// Init marks any leftover running/stopping tasks as interrupted on startup.
func (m *TaskManager) Init() {
m.gormDB.Model(&db.Task{}).
Where("status IN ?", []string{StatusRunning, StatusStopping}).
Updates(map[string]interface{}{
"status": StatusInterrupted,
"stopped_at": time.Now(),
})
log.Println("[task-manager] init: marked leftover tasks as interrupted")
}
func (m *TaskManager) Start(taskID string) error {
m.mu.Lock()
defer m.mu.Unlock()
if m.current != nil && m.current.IsRunning() {
return fmt.Errorf("已有任务正在运行 (ID: %s)", m.current.taskID)
}
var t db.Task
if err := m.gormDB.First(&t, "id = ?", taskID).Error; err != nil {
return fmt.Errorf("任务不存在: %w", err)
}
if t.Status != StatusPending && t.Status != StatusStopped && t.Status != StatusInterrupted {
return fmt.Errorf("任务状态不允许启动: %s", t.Status)
}
runner, err := NewTaskRunner(taskID, m.gormDB)
if err != nil {
return fmt.Errorf("创建任务运行器失败: %w", err)
}
m.current = runner
go func() {
runner.Run()
m.mu.Lock()
if m.current == runner {
m.current = nil
}
m.mu.Unlock()
}()
return nil
}
func (m *TaskManager) Stop(taskID string) error {
m.mu.Lock()
defer m.mu.Unlock()
if m.current == nil || m.current.taskID != taskID {
return fmt.Errorf("该任务未在运行")
}
m.current.GracefulStop()
return nil
}
func (m *TaskManager) ForceStop(taskID string) error {
m.mu.Lock()
defer m.mu.Unlock()
if m.current == nil || m.current.taskID != taskID {
return fmt.Errorf("该任务未在运行")
}
m.current.ForceStop()
return nil
}
func (m *TaskManager) IsRunning() bool {
m.mu.Lock()
defer m.mu.Unlock()
return m.current != nil && m.current.IsRunning()
}