419 lines
11 KiB
Go
419 lines
11 KiB
Go
package service
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"log"
|
||
"net/http"
|
||
"strings"
|
||
"time"
|
||
|
||
"gpt-plus/config"
|
||
"gpt-plus/internal/db"
|
||
"gpt-plus/pkg/auth"
|
||
"gpt-plus/pkg/chatgpt"
|
||
"gpt-plus/pkg/httpclient"
|
||
"gpt-plus/pkg/proxy"
|
||
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
func getProxyClient(d *gorm.DB) (*httpclient.Client, error) {
|
||
var b2Enabled db.SystemConfig
|
||
if d.Where("key = ?", "proxy.b2proxy.enabled").First(&b2Enabled).Error == nil && b2Enabled.Value == "true" {
|
||
var apiBase, zone, proto db.SystemConfig
|
||
d.Where("key = ?", "proxy.b2proxy.api_base").First(&apiBase)
|
||
d.Where("key = ?", "proxy.b2proxy.zone").First(&zone)
|
||
d.Where("key = ?", "proxy.b2proxy.proto").First(&proto)
|
||
|
||
var country db.SystemConfig
|
||
d.Where("key = ?", "card.default_country").First(&country)
|
||
countryCode := country.Value
|
||
if countryCode == "" {
|
||
countryCode = "US"
|
||
}
|
||
|
||
sessTime := 5
|
||
var sessTimeCfg db.SystemConfig
|
||
if d.Where("key = ?", "proxy.b2proxy.sess_time").First(&sessTimeCfg).Error == nil {
|
||
fmt.Sscanf(sessTimeCfg.Value, "%d", &sessTime)
|
||
}
|
||
|
||
b2Cfg := config.B2ProxyConfig{
|
||
Enabled: true, APIBase: apiBase.Value,
|
||
Zone: zone.Value, Proto: proto.Value,
|
||
PType: 1, SessTime: sessTime,
|
||
}
|
||
|
||
proxyURL, err := proxy.FetchB2Proxy(b2Cfg, countryCode)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("fetch B2Proxy: %w", err)
|
||
}
|
||
return httpclient.NewClient(proxyURL)
|
||
}
|
||
|
||
var proxyCfg db.SystemConfig
|
||
if d.Where("key = ?", "proxy.url").First(&proxyCfg).Error == nil && proxyCfg.Value != "" {
|
||
return httpclient.NewClient(proxyCfg.Value)
|
||
}
|
||
|
||
return httpclient.NewClient("")
|
||
}
|
||
|
||
// parseErrorCode extracts the "code" field from a JSON error response body.
|
||
func parseErrorCode(body []byte) string {
|
||
var errResp struct {
|
||
Detail struct {
|
||
Code string `json:"code"`
|
||
} `json:"detail"`
|
||
}
|
||
if json.Unmarshal(body, &errResp) == nil && errResp.Detail.Code != "" {
|
||
return errResp.Detail.Code
|
||
}
|
||
// Fallback: try top-level code
|
||
var simple struct {
|
||
Code string `json:"code"`
|
||
}
|
||
if json.Unmarshal(body, &simple) == nil && simple.Code != "" {
|
||
return simple.Code
|
||
}
|
||
return ""
|
||
}
|
||
|
||
// refreshAccountToken attempts to refresh the access token via Codex OAuth.
|
||
// Returns true if refresh succeeded and DB was updated.
|
||
func refreshAccountToken(d *gorm.DB, acct *db.Account) bool {
|
||
client, err := getProxyClient(d)
|
||
if err != nil {
|
||
log.Printf("[token-refresh] %s: proxy failed: %v", acct.Email, err)
|
||
return false
|
||
}
|
||
|
||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||
defer cancel()
|
||
|
||
// team_member 不传 workspace_id,用默认 personal workspace 刷新
|
||
workspaceID := ""
|
||
if acct.Plan == "team_owner" {
|
||
workspaceID = acct.TeamWorkspaceID
|
||
}
|
||
tokens, err := auth.ObtainCodexTokens(ctx, client, acct.DeviceID, workspaceID)
|
||
if err != nil {
|
||
log.Printf("[token-refresh] %s: ObtainCodexTokens failed: %v", acct.Email, err)
|
||
return false
|
||
}
|
||
|
||
acct.AccessToken = tokens.AccessToken
|
||
if tokens.RefreshToken != "" {
|
||
acct.RefreshToken = tokens.RefreshToken
|
||
}
|
||
if tokens.IDToken != "" {
|
||
acct.IDToken = tokens.IDToken
|
||
}
|
||
if tokens.ChatGPTAccountID != "" {
|
||
acct.AccountID = tokens.ChatGPTAccountID
|
||
}
|
||
d.Save(acct)
|
||
log.Printf("[token-refresh] %s: token refreshed successfully", acct.Email)
|
||
return true
|
||
}
|
||
|
||
type AccountCheckResult struct {
|
||
ID uint `json:"id"`
|
||
Email string `json:"email"`
|
||
Status string `json:"status"`
|
||
Plan string `json:"plan"`
|
||
Message string `json:"message"`
|
||
}
|
||
|
||
func CheckAccountStatuses(d *gorm.DB, ids []uint) []AccountCheckResult {
|
||
var results []AccountCheckResult
|
||
|
||
for _, id := range ids {
|
||
var acct db.Account
|
||
if d.First(&acct, id).Error != nil {
|
||
results = append(results, AccountCheckResult{ID: id, Status: "error", Message: "账号不存在"})
|
||
continue
|
||
}
|
||
|
||
r := checkSingleAccount(d, &acct)
|
||
results = append(results, r)
|
||
}
|
||
|
||
return results
|
||
}
|
||
|
||
func checkSingleAccount(d *gorm.DB, acct *db.Account) AccountCheckResult {
|
||
r := AccountCheckResult{ID: acct.ID, Email: acct.Email}
|
||
now := time.Now()
|
||
acct.StatusCheckedAt = &now
|
||
|
||
client, err := getProxyClient(d)
|
||
if err != nil {
|
||
r.Status = "error"
|
||
r.Message = "代理连接失败: " + err.Error()
|
||
d.Save(acct)
|
||
return r
|
||
}
|
||
|
||
accounts, err := chatgpt.CheckAccountFull(client, acct.AccessToken, acct.DeviceID)
|
||
if err != nil {
|
||
errMsg := err.Error()
|
||
|
||
// Try to detect specific error codes from response body
|
||
if strings.Contains(errMsg, "401") {
|
||
// Could be token_invalidated or account_deactivated — try refresh
|
||
log.Printf("[account-check] %s: got 401, attempting token refresh...", acct.Email)
|
||
if refreshAccountToken(d, acct) {
|
||
// Retry with new token
|
||
accounts2, err2 := chatgpt.CheckAccountFull(client, acct.AccessToken, acct.DeviceID)
|
||
if err2 == nil {
|
||
return buildCheckSuccess(d, acct, accounts2)
|
||
}
|
||
log.Printf("[account-check] %s: retry after refresh still failed: %v", acct.Email, err2)
|
||
}
|
||
acct.Status = "banned"
|
||
r.Status = "banned"
|
||
r.Message = "令牌无效且刷新失败,可能已封禁"
|
||
} else if strings.Contains(errMsg, "403") {
|
||
acct.Status = "banned"
|
||
r.Status = "banned"
|
||
r.Message = "账号已封禁 (403)"
|
||
} else {
|
||
acct.Status = "unknown"
|
||
r.Status = "unknown"
|
||
r.Message = errMsg
|
||
}
|
||
d.Save(acct)
|
||
return r
|
||
}
|
||
|
||
return buildCheckSuccess(d, acct, accounts)
|
||
}
|
||
|
||
func buildCheckSuccess(d *gorm.DB, acct *db.Account, accounts []*chatgpt.AccountInfo) AccountCheckResult {
|
||
r := AccountCheckResult{ID: acct.ID, Email: acct.Email}
|
||
|
||
var planParts []string
|
||
for _, info := range accounts {
|
||
planParts = append(planParts, fmt.Sprintf("%s(%s)", info.PlanType, info.Structure))
|
||
}
|
||
|
||
target := selectMembershipAccount(acct, accounts)
|
||
if target == nil {
|
||
acct.Status = "unknown"
|
||
r.Status = "unknown"
|
||
r.Message = "membership check returned no matching account"
|
||
d.Save(acct)
|
||
return r
|
||
}
|
||
|
||
resolvedStatus := normalizeMembershipStatus(target.PlanType)
|
||
if resolvedStatus == "" {
|
||
resolvedStatus = "unknown"
|
||
}
|
||
|
||
acct.Status = resolvedStatus
|
||
r.Status = resolvedStatus
|
||
r.Plan = resolvedStatus
|
||
r.Message = fmt.Sprintf("membership=%s (%d accounts: %s)", resolvedStatus, len(accounts), strings.Join(planParts, ", "))
|
||
|
||
if target.AccountID != "" && target.AccountID != acct.AccountID {
|
||
acct.AccountID = target.AccountID
|
||
}
|
||
if target.Structure == "workspace" && target.AccountID != "" {
|
||
acct.TeamWorkspaceID = target.AccountID
|
||
}
|
||
d.Save(acct)
|
||
log.Printf("[account-check] %s: %s", acct.Email, r.Message)
|
||
return r
|
||
}
|
||
|
||
func normalizeMembershipStatus(planType string) string {
|
||
switch planType {
|
||
case "free", "plus", "team":
|
||
return planType
|
||
default:
|
||
return ""
|
||
}
|
||
}
|
||
|
||
func selectMembershipAccount(acct *db.Account, accounts []*chatgpt.AccountInfo) *chatgpt.AccountInfo {
|
||
if len(accounts) == 0 {
|
||
return nil
|
||
}
|
||
|
||
if acct.Plan == "team_owner" || acct.Plan == "team_member" {
|
||
if acct.TeamWorkspaceID != "" {
|
||
for _, info := range accounts {
|
||
if info.AccountID == acct.TeamWorkspaceID {
|
||
return info
|
||
}
|
||
}
|
||
}
|
||
for _, info := range accounts {
|
||
if info.Structure == "workspace" && info.PlanType == "team" {
|
||
return info
|
||
}
|
||
}
|
||
for _, info := range accounts {
|
||
if info.Structure == "workspace" {
|
||
return info
|
||
}
|
||
}
|
||
}
|
||
|
||
if acct.AccountID != "" {
|
||
for _, info := range accounts {
|
||
if info.AccountID == acct.AccountID && info.Structure == "personal" {
|
||
return info
|
||
}
|
||
}
|
||
}
|
||
|
||
for _, info := range accounts {
|
||
if info.Structure == "personal" {
|
||
return info
|
||
}
|
||
}
|
||
|
||
if acct.AccountID != "" {
|
||
for _, info := range accounts {
|
||
if info.AccountID == acct.AccountID {
|
||
return info
|
||
}
|
||
}
|
||
}
|
||
|
||
return accounts[0]
|
||
}
|
||
|
||
// --- Model Test ---
|
||
|
||
type ModelTestResult struct {
|
||
ID uint `json:"id"`
|
||
Email string `json:"email"`
|
||
Model string `json:"model"`
|
||
Success bool `json:"success"`
|
||
Message string `json:"message"`
|
||
Output string `json:"output,omitempty"`
|
||
}
|
||
|
||
func TestModelAvailability(d *gorm.DB, accountID uint, modelID string) *ModelTestResult {
|
||
var acct db.Account
|
||
if d.First(&acct, accountID).Error != nil {
|
||
return &ModelTestResult{ID: accountID, Model: modelID, Message: "账号不存在"}
|
||
}
|
||
|
||
if modelID == "" {
|
||
modelID = "gpt-4o"
|
||
}
|
||
|
||
result := doModelTest(d, &acct, modelID)
|
||
|
||
// If token_invalidated, auto-refresh and retry once
|
||
if !result.Success && strings.Contains(result.Message, "令牌过期") {
|
||
log.Printf("[model-test] %s: token expired, refreshing...", acct.Email)
|
||
if refreshAccountToken(d, &acct) {
|
||
result = doModelTest(d, &acct, modelID)
|
||
if result.Success {
|
||
result.Message += " (令牌已自动刷新)"
|
||
}
|
||
}
|
||
}
|
||
|
||
return result
|
||
}
|
||
|
||
func doModelTest(d *gorm.DB, acct *db.Account, modelID string) *ModelTestResult {
|
||
client, err := getProxyClient(d)
|
||
if err != nil {
|
||
return &ModelTestResult{ID: acct.ID, Email: acct.Email, Model: modelID, Message: "代理连接失败: " + err.Error()}
|
||
}
|
||
|
||
result := &ModelTestResult{ID: acct.ID, Email: acct.Email, Model: modelID}
|
||
|
||
apiURL := "https://chatgpt.com/backend-api/codex/responses"
|
||
payload := map[string]interface{}{
|
||
"model": modelID,
|
||
"input": []map[string]interface{}{
|
||
{
|
||
"role": "user",
|
||
"content": []map[string]interface{}{
|
||
{"type": "input_text", "text": "hi"},
|
||
},
|
||
},
|
||
},
|
||
"stream": false,
|
||
"store": false,
|
||
}
|
||
|
||
headers := map[string]string{
|
||
"Authorization": "Bearer " + acct.AccessToken,
|
||
"Content-Type": "application/json",
|
||
"Accept": "*/*",
|
||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/145.0.0.0 Safari/537.36",
|
||
"Origin": "https://chatgpt.com",
|
||
"Referer": "https://chatgpt.com/",
|
||
"oai-language": "en-US",
|
||
}
|
||
if acct.AccountID != "" {
|
||
headers["chatgpt-account-id"] = acct.AccountID
|
||
}
|
||
if acct.DeviceID != "" {
|
||
headers["oai-device-id"] = acct.DeviceID
|
||
}
|
||
|
||
resp, err := client.PostJSON(apiURL, payload, headers)
|
||
if err != nil {
|
||
result.Message = fmt.Sprintf("请求失败: %v", err)
|
||
return result
|
||
}
|
||
|
||
body, _ := io.ReadAll(resp.Body)
|
||
resp.Body.Close()
|
||
|
||
switch resp.StatusCode {
|
||
case http.StatusOK:
|
||
result.Success = true
|
||
result.Message = fmt.Sprintf("模型 %s 可用", modelID)
|
||
var respData map[string]interface{}
|
||
if json.Unmarshal(body, &respData) == nil {
|
||
if output, ok := respData["output_text"]; ok {
|
||
result.Output = fmt.Sprintf("%v", output)
|
||
}
|
||
}
|
||
case http.StatusUnauthorized:
|
||
code := parseErrorCode(body)
|
||
switch code {
|
||
case "token_invalidated":
|
||
result.Message = "令牌过期 (token_invalidated)"
|
||
case "account_deactivated":
|
||
result.Message = "账号已封禁 (account_deactivated)"
|
||
acct.Status = "banned"
|
||
d.Save(acct)
|
||
default:
|
||
result.Message = fmt.Sprintf("认证失败 (401, code=%s)", code)
|
||
}
|
||
case http.StatusForbidden:
|
||
result.Message = "账号被封禁 (403)"
|
||
acct.Status = "banned"
|
||
d.Save(acct)
|
||
case http.StatusNotFound:
|
||
result.Message = fmt.Sprintf("模型 %s 不存在 (404)", modelID)
|
||
case http.StatusTooManyRequests:
|
||
result.Message = "请求限流 (429),请稍后再试"
|
||
default:
|
||
errMsg := string(body)
|
||
if len(errMsg) > 300 {
|
||
errMsg = errMsg[:300]
|
||
}
|
||
result.Message = fmt.Sprintf("HTTP %d: %s", resp.StatusCode, errMsg)
|
||
}
|
||
|
||
log.Printf("[model-test] %s model=%s status=%d success=%v msg=%s", acct.Email, modelID, resp.StatusCode, result.Success, result.Message)
|
||
return result
|
||
}
|