Initial sanitized code sync
This commit is contained in:
418
internal/service/account_svc.go
Normal file
418
internal/service/account_svc.go
Normal file
@@ -0,0 +1,418 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gpt-plus/config"
|
||||
"gpt-plus/internal/db"
|
||||
"gpt-plus/pkg/auth"
|
||||
"gpt-plus/pkg/chatgpt"
|
||||
"gpt-plus/pkg/httpclient"
|
||||
"gpt-plus/pkg/proxy"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func getProxyClient(d *gorm.DB) (*httpclient.Client, error) {
|
||||
var b2Enabled db.SystemConfig
|
||||
if d.Where("key = ?", "proxy.b2proxy.enabled").First(&b2Enabled).Error == nil && b2Enabled.Value == "true" {
|
||||
var apiBase, zone, proto db.SystemConfig
|
||||
d.Where("key = ?", "proxy.b2proxy.api_base").First(&apiBase)
|
||||
d.Where("key = ?", "proxy.b2proxy.zone").First(&zone)
|
||||
d.Where("key = ?", "proxy.b2proxy.proto").First(&proto)
|
||||
|
||||
var country db.SystemConfig
|
||||
d.Where("key = ?", "card.default_country").First(&country)
|
||||
countryCode := country.Value
|
||||
if countryCode == "" {
|
||||
countryCode = "US"
|
||||
}
|
||||
|
||||
sessTime := 5
|
||||
var sessTimeCfg db.SystemConfig
|
||||
if d.Where("key = ?", "proxy.b2proxy.sess_time").First(&sessTimeCfg).Error == nil {
|
||||
fmt.Sscanf(sessTimeCfg.Value, "%d", &sessTime)
|
||||
}
|
||||
|
||||
b2Cfg := config.B2ProxyConfig{
|
||||
Enabled: true, APIBase: apiBase.Value,
|
||||
Zone: zone.Value, Proto: proto.Value,
|
||||
PType: 1, SessTime: sessTime,
|
||||
}
|
||||
|
||||
proxyURL, err := proxy.FetchB2Proxy(b2Cfg, countryCode)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetch B2Proxy: %w", err)
|
||||
}
|
||||
return httpclient.NewClient(proxyURL)
|
||||
}
|
||||
|
||||
var proxyCfg db.SystemConfig
|
||||
if d.Where("key = ?", "proxy.url").First(&proxyCfg).Error == nil && proxyCfg.Value != "" {
|
||||
return httpclient.NewClient(proxyCfg.Value)
|
||||
}
|
||||
|
||||
return httpclient.NewClient("")
|
||||
}
|
||||
|
||||
// parseErrorCode extracts the "code" field from a JSON error response body.
|
||||
func parseErrorCode(body []byte) string {
|
||||
var errResp struct {
|
||||
Detail struct {
|
||||
Code string `json:"code"`
|
||||
} `json:"detail"`
|
||||
}
|
||||
if json.Unmarshal(body, &errResp) == nil && errResp.Detail.Code != "" {
|
||||
return errResp.Detail.Code
|
||||
}
|
||||
// Fallback: try top-level code
|
||||
var simple struct {
|
||||
Code string `json:"code"`
|
||||
}
|
||||
if json.Unmarshal(body, &simple) == nil && simple.Code != "" {
|
||||
return simple.Code
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// refreshAccountToken attempts to refresh the access token via Codex OAuth.
|
||||
// Returns true if refresh succeeded and DB was updated.
|
||||
func refreshAccountToken(d *gorm.DB, acct *db.Account) bool {
|
||||
client, err := getProxyClient(d)
|
||||
if err != nil {
|
||||
log.Printf("[token-refresh] %s: proxy failed: %v", acct.Email, err)
|
||||
return false
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// team_member 不传 workspace_id,用默认 personal workspace 刷新
|
||||
workspaceID := ""
|
||||
if acct.Plan == "team_owner" {
|
||||
workspaceID = acct.TeamWorkspaceID
|
||||
}
|
||||
tokens, err := auth.ObtainCodexTokens(ctx, client, acct.DeviceID, workspaceID)
|
||||
if err != nil {
|
||||
log.Printf("[token-refresh] %s: ObtainCodexTokens failed: %v", acct.Email, err)
|
||||
return false
|
||||
}
|
||||
|
||||
acct.AccessToken = tokens.AccessToken
|
||||
if tokens.RefreshToken != "" {
|
||||
acct.RefreshToken = tokens.RefreshToken
|
||||
}
|
||||
if tokens.IDToken != "" {
|
||||
acct.IDToken = tokens.IDToken
|
||||
}
|
||||
if tokens.ChatGPTAccountID != "" {
|
||||
acct.AccountID = tokens.ChatGPTAccountID
|
||||
}
|
||||
d.Save(acct)
|
||||
log.Printf("[token-refresh] %s: token refreshed successfully", acct.Email)
|
||||
return true
|
||||
}
|
||||
|
||||
type AccountCheckResult struct {
|
||||
ID uint `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Status string `json:"status"`
|
||||
Plan string `json:"plan"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
func CheckAccountStatuses(d *gorm.DB, ids []uint) []AccountCheckResult {
|
||||
var results []AccountCheckResult
|
||||
|
||||
for _, id := range ids {
|
||||
var acct db.Account
|
||||
if d.First(&acct, id).Error != nil {
|
||||
results = append(results, AccountCheckResult{ID: id, Status: "error", Message: "账号不存在"})
|
||||
continue
|
||||
}
|
||||
|
||||
r := checkSingleAccount(d, &acct)
|
||||
results = append(results, r)
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
func checkSingleAccount(d *gorm.DB, acct *db.Account) AccountCheckResult {
|
||||
r := AccountCheckResult{ID: acct.ID, Email: acct.Email}
|
||||
now := time.Now()
|
||||
acct.StatusCheckedAt = &now
|
||||
|
||||
client, err := getProxyClient(d)
|
||||
if err != nil {
|
||||
r.Status = "error"
|
||||
r.Message = "代理连接失败: " + err.Error()
|
||||
d.Save(acct)
|
||||
return r
|
||||
}
|
||||
|
||||
accounts, err := chatgpt.CheckAccountFull(client, acct.AccessToken, acct.DeviceID)
|
||||
if err != nil {
|
||||
errMsg := err.Error()
|
||||
|
||||
// Try to detect specific error codes from response body
|
||||
if strings.Contains(errMsg, "401") {
|
||||
// Could be token_invalidated or account_deactivated — try refresh
|
||||
log.Printf("[account-check] %s: got 401, attempting token refresh...", acct.Email)
|
||||
if refreshAccountToken(d, acct) {
|
||||
// Retry with new token
|
||||
accounts2, err2 := chatgpt.CheckAccountFull(client, acct.AccessToken, acct.DeviceID)
|
||||
if err2 == nil {
|
||||
return buildCheckSuccess(d, acct, accounts2)
|
||||
}
|
||||
log.Printf("[account-check] %s: retry after refresh still failed: %v", acct.Email, err2)
|
||||
}
|
||||
acct.Status = "banned"
|
||||
r.Status = "banned"
|
||||
r.Message = "令牌无效且刷新失败,可能已封禁"
|
||||
} else if strings.Contains(errMsg, "403") {
|
||||
acct.Status = "banned"
|
||||
r.Status = "banned"
|
||||
r.Message = "账号已封禁 (403)"
|
||||
} else {
|
||||
acct.Status = "unknown"
|
||||
r.Status = "unknown"
|
||||
r.Message = errMsg
|
||||
}
|
||||
d.Save(acct)
|
||||
return r
|
||||
}
|
||||
|
||||
return buildCheckSuccess(d, acct, accounts)
|
||||
}
|
||||
|
||||
func buildCheckSuccess(d *gorm.DB, acct *db.Account, accounts []*chatgpt.AccountInfo) AccountCheckResult {
|
||||
r := AccountCheckResult{ID: acct.ID, Email: acct.Email}
|
||||
|
||||
var planParts []string
|
||||
for _, info := range accounts {
|
||||
planParts = append(planParts, fmt.Sprintf("%s(%s)", info.PlanType, info.Structure))
|
||||
}
|
||||
|
||||
target := selectMembershipAccount(acct, accounts)
|
||||
if target == nil {
|
||||
acct.Status = "unknown"
|
||||
r.Status = "unknown"
|
||||
r.Message = "membership check returned no matching account"
|
||||
d.Save(acct)
|
||||
return r
|
||||
}
|
||||
|
||||
resolvedStatus := normalizeMembershipStatus(target.PlanType)
|
||||
if resolvedStatus == "" {
|
||||
resolvedStatus = "unknown"
|
||||
}
|
||||
|
||||
acct.Status = resolvedStatus
|
||||
r.Status = resolvedStatus
|
||||
r.Plan = resolvedStatus
|
||||
r.Message = fmt.Sprintf("membership=%s (%d accounts: %s)", resolvedStatus, len(accounts), strings.Join(planParts, ", "))
|
||||
|
||||
if target.AccountID != "" && target.AccountID != acct.AccountID {
|
||||
acct.AccountID = target.AccountID
|
||||
}
|
||||
if target.Structure == "workspace" && target.AccountID != "" {
|
||||
acct.TeamWorkspaceID = target.AccountID
|
||||
}
|
||||
d.Save(acct)
|
||||
log.Printf("[account-check] %s: %s", acct.Email, r.Message)
|
||||
return r
|
||||
}
|
||||
|
||||
func normalizeMembershipStatus(planType string) string {
|
||||
switch planType {
|
||||
case "free", "plus", "team":
|
||||
return planType
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func selectMembershipAccount(acct *db.Account, accounts []*chatgpt.AccountInfo) *chatgpt.AccountInfo {
|
||||
if len(accounts) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if acct.Plan == "team_owner" || acct.Plan == "team_member" {
|
||||
if acct.TeamWorkspaceID != "" {
|
||||
for _, info := range accounts {
|
||||
if info.AccountID == acct.TeamWorkspaceID {
|
||||
return info
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, info := range accounts {
|
||||
if info.Structure == "workspace" && info.PlanType == "team" {
|
||||
return info
|
||||
}
|
||||
}
|
||||
for _, info := range accounts {
|
||||
if info.Structure == "workspace" {
|
||||
return info
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if acct.AccountID != "" {
|
||||
for _, info := range accounts {
|
||||
if info.AccountID == acct.AccountID && info.Structure == "personal" {
|
||||
return info
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, info := range accounts {
|
||||
if info.Structure == "personal" {
|
||||
return info
|
||||
}
|
||||
}
|
||||
|
||||
if acct.AccountID != "" {
|
||||
for _, info := range accounts {
|
||||
if info.AccountID == acct.AccountID {
|
||||
return info
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return accounts[0]
|
||||
}
|
||||
|
||||
// --- Model Test ---
|
||||
|
||||
type ModelTestResult struct {
|
||||
ID uint `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Model string `json:"model"`
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"`
|
||||
Output string `json:"output,omitempty"`
|
||||
}
|
||||
|
||||
func TestModelAvailability(d *gorm.DB, accountID uint, modelID string) *ModelTestResult {
|
||||
var acct db.Account
|
||||
if d.First(&acct, accountID).Error != nil {
|
||||
return &ModelTestResult{ID: accountID, Model: modelID, Message: "账号不存在"}
|
||||
}
|
||||
|
||||
if modelID == "" {
|
||||
modelID = "gpt-4o"
|
||||
}
|
||||
|
||||
result := doModelTest(d, &acct, modelID)
|
||||
|
||||
// If token_invalidated, auto-refresh and retry once
|
||||
if !result.Success && strings.Contains(result.Message, "令牌过期") {
|
||||
log.Printf("[model-test] %s: token expired, refreshing...", acct.Email)
|
||||
if refreshAccountToken(d, &acct) {
|
||||
result = doModelTest(d, &acct, modelID)
|
||||
if result.Success {
|
||||
result.Message += " (令牌已自动刷新)"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func doModelTest(d *gorm.DB, acct *db.Account, modelID string) *ModelTestResult {
|
||||
client, err := getProxyClient(d)
|
||||
if err != nil {
|
||||
return &ModelTestResult{ID: acct.ID, Email: acct.Email, Model: modelID, Message: "代理连接失败: " + err.Error()}
|
||||
}
|
||||
|
||||
result := &ModelTestResult{ID: acct.ID, Email: acct.Email, Model: modelID}
|
||||
|
||||
apiURL := "https://chatgpt.com/backend-api/codex/responses"
|
||||
payload := map[string]interface{}{
|
||||
"model": modelID,
|
||||
"input": []map[string]interface{}{
|
||||
{
|
||||
"role": "user",
|
||||
"content": []map[string]interface{}{
|
||||
{"type": "input_text", "text": "hi"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"stream": false,
|
||||
"store": false,
|
||||
}
|
||||
|
||||
headers := map[string]string{
|
||||
"Authorization": "Bearer " + acct.AccessToken,
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "*/*",
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/145.0.0.0 Safari/537.36",
|
||||
"Origin": "https://chatgpt.com",
|
||||
"Referer": "https://chatgpt.com/",
|
||||
"oai-language": "en-US",
|
||||
}
|
||||
if acct.AccountID != "" {
|
||||
headers["chatgpt-account-id"] = acct.AccountID
|
||||
}
|
||||
if acct.DeviceID != "" {
|
||||
headers["oai-device-id"] = acct.DeviceID
|
||||
}
|
||||
|
||||
resp, err := client.PostJSON(apiURL, payload, headers)
|
||||
if err != nil {
|
||||
result.Message = fmt.Sprintf("请求失败: %v", err)
|
||||
return result
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
|
||||
switch resp.StatusCode {
|
||||
case http.StatusOK:
|
||||
result.Success = true
|
||||
result.Message = fmt.Sprintf("模型 %s 可用", modelID)
|
||||
var respData map[string]interface{}
|
||||
if json.Unmarshal(body, &respData) == nil {
|
||||
if output, ok := respData["output_text"]; ok {
|
||||
result.Output = fmt.Sprintf("%v", output)
|
||||
}
|
||||
}
|
||||
case http.StatusUnauthorized:
|
||||
code := parseErrorCode(body)
|
||||
switch code {
|
||||
case "token_invalidated":
|
||||
result.Message = "令牌过期 (token_invalidated)"
|
||||
case "account_deactivated":
|
||||
result.Message = "账号已封禁 (account_deactivated)"
|
||||
acct.Status = "banned"
|
||||
d.Save(acct)
|
||||
default:
|
||||
result.Message = fmt.Sprintf("认证失败 (401, code=%s)", code)
|
||||
}
|
||||
case http.StatusForbidden:
|
||||
result.Message = "账号被封禁 (403)"
|
||||
acct.Status = "banned"
|
||||
d.Save(acct)
|
||||
case http.StatusNotFound:
|
||||
result.Message = fmt.Sprintf("模型 %s 不存在 (404)", modelID)
|
||||
case http.StatusTooManyRequests:
|
||||
result.Message = "请求限流 (429),请稍后再试"
|
||||
default:
|
||||
errMsg := string(body)
|
||||
if len(errMsg) > 300 {
|
||||
errMsg = errMsg[:300]
|
||||
}
|
||||
result.Message = fmt.Sprintf("HTTP %d: %s", resp.StatusCode, errMsg)
|
||||
}
|
||||
|
||||
log.Printf("[model-test] %s model=%s status=%d success=%v msg=%s", acct.Email, modelID, resp.StatusCode, result.Success, result.Message)
|
||||
return result
|
||||
}
|
||||
125
internal/service/cpa_svc.go
Normal file
125
internal/service/cpa_svc.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
|
||||
"gpt-plus/internal/db"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// getConfigValue reads a config value from the database, auto-decrypting password fields.
|
||||
func getConfigValue(d *gorm.DB, key string) (string, error) {
|
||||
var cfg db.SystemConfig
|
||||
if err := d.Where("key = ?", key).First(&cfg).Error; err != nil {
|
||||
return "", err
|
||||
}
|
||||
if cfg.Type == "password" && cfg.Value != "" {
|
||||
decrypted, err := db.Decrypt(cfg.Value)
|
||||
if err != nil {
|
||||
// Fallback to raw value (may be plaintext from seed)
|
||||
return cfg.Value, nil
|
||||
}
|
||||
return decrypted, nil
|
||||
}
|
||||
return cfg.Value, nil
|
||||
}
|
||||
|
||||
// TransferResult holds the result for a single account transfer.
|
||||
type TransferResult struct {
|
||||
ID uint `json:"id"`
|
||||
Email string `json:"email"`
|
||||
OK bool `json:"ok"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// TransferAccountToCPA builds an auth file for the account and uploads it to CPA.
|
||||
func TransferAccountToCPA(d *gorm.DB, accountID uint) TransferResult {
|
||||
// Get account first (for email in result)
|
||||
var acct db.Account
|
||||
if err := d.First(&acct, accountID).Error; err != nil {
|
||||
return TransferResult{ID: accountID, Error: "账号不存在"}
|
||||
}
|
||||
|
||||
result := TransferResult{ID: acct.ID, Email: acct.Email}
|
||||
|
||||
// Get CPA config
|
||||
baseURL, err := getConfigValue(d, "cpa.base_url")
|
||||
if err != nil || baseURL == "" {
|
||||
result.Error = "CPA 地址未配置"
|
||||
return result
|
||||
}
|
||||
managementKey, err := getConfigValue(d, "cpa.management_key")
|
||||
if err != nil || managementKey == "" {
|
||||
result.Error = "CPA Management Key 未配置"
|
||||
return result
|
||||
}
|
||||
|
||||
// Build auth file
|
||||
auth := buildAuthFile(&acct)
|
||||
jsonData, err := json.MarshalIndent(auth, "", " ")
|
||||
if err != nil {
|
||||
result.Error = fmt.Sprintf("序列化失败: %v", err)
|
||||
return result
|
||||
}
|
||||
|
||||
// If team_owner, also transfer sub-accounts
|
||||
if acct.Plan == "team_owner" {
|
||||
var subs []db.Account
|
||||
d.Where("parent_id = ?", acct.ID).Find(&subs)
|
||||
for _, sub := range subs {
|
||||
subAuth := buildAuthFile(&sub)
|
||||
subData, _ := json.MarshalIndent(subAuth, "", " ")
|
||||
if err := uploadAuthFile(baseURL, managementKey, sub.Email+".auth.json", subData); err != nil {
|
||||
result.Error = fmt.Sprintf("子号 %s 上传失败: %v", sub.Email, err)
|
||||
return result
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Upload main account
|
||||
if err := uploadAuthFile(baseURL, managementKey, acct.Email+".auth.json", jsonData); err != nil {
|
||||
result.Error = fmt.Sprintf("上传失败: %v", err)
|
||||
return result
|
||||
}
|
||||
|
||||
result.OK = true
|
||||
return result
|
||||
}
|
||||
|
||||
// uploadAuthFile uploads a single auth JSON file to CPA via multipart form.
|
||||
func uploadAuthFile(baseURL, managementKey, filename string, data []byte) error {
|
||||
var body bytes.Buffer
|
||||
writer := multipart.NewWriter(&body)
|
||||
part, err := writer.CreateFormFile("file", filename)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建表单失败: %w", err)
|
||||
}
|
||||
part.Write(data)
|
||||
writer.Close()
|
||||
|
||||
req, err := http.NewRequest("POST", baseURL+"/v0/management/auth-files", &body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
req.Header.Set("Authorization", "Bearer "+managementKey)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("请求失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("CPA 返回 %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
121
internal/service/export_svc.go
Normal file
121
internal/service/export_svc.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"gpt-plus/internal/db"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type authFile struct {
|
||||
OpenAIAPIKey string `json:"OPENAI_API_KEY"`
|
||||
AccessToken string `json:"access_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
LastRefresh string `json:"last_refresh"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
Tokens authTokens `json:"tokens"`
|
||||
}
|
||||
|
||||
type authTokens struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
AccountID string `json:"account_id"`
|
||||
IDToken string `json:"id_token"`
|
||||
LastRefresh string `json:"last_refresh"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
}
|
||||
|
||||
func buildAuthFile(acct *db.Account) *authFile {
|
||||
now := time.Now().UTC().Format("2006-01-02T15:04:05Z")
|
||||
token := acct.AccessToken
|
||||
if acct.WorkspaceToken != "" {
|
||||
token = acct.WorkspaceToken
|
||||
}
|
||||
accountID := acct.AccountID
|
||||
if acct.TeamWorkspaceID != "" {
|
||||
accountID = acct.TeamWorkspaceID
|
||||
}
|
||||
return &authFile{
|
||||
AccessToken: token,
|
||||
IDToken: acct.IDToken,
|
||||
LastRefresh: now,
|
||||
RefreshToken: acct.RefreshToken,
|
||||
Tokens: authTokens{
|
||||
AccessToken: token,
|
||||
AccountID: accountID,
|
||||
IDToken: acct.IDToken,
|
||||
LastRefresh: now,
|
||||
RefreshToken: acct.RefreshToken,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ExportAccounts exports accounts as auth.json files.
|
||||
// If multiple files, returns a zip archive.
|
||||
func ExportAccounts(d *gorm.DB, ids []uint, note string) ([]byte, string, error) {
|
||||
var accounts []db.Account
|
||||
d.Where("id IN ?", ids).Find(&accounts)
|
||||
|
||||
if len(accounts) == 0 {
|
||||
return nil, "", fmt.Errorf("未找到账号")
|
||||
}
|
||||
|
||||
// Collect all files to export
|
||||
type exportFile struct {
|
||||
Name string
|
||||
Data []byte
|
||||
}
|
||||
var files []exportFile
|
||||
|
||||
for _, acct := range accounts {
|
||||
auth := buildAuthFile(&acct)
|
||||
data, _ := json.MarshalIndent(auth, "", " ")
|
||||
files = append(files, exportFile{
|
||||
Name: acct.Email + ".auth.json",
|
||||
Data: data,
|
||||
})
|
||||
|
||||
// If team_owner, also export sub-accounts
|
||||
if acct.Plan == "team_owner" {
|
||||
var subs []db.Account
|
||||
d.Where("parent_id = ?", acct.ID).Find(&subs)
|
||||
for _, sub := range subs {
|
||||
subAuth := buildAuthFile(&sub)
|
||||
subData, _ := json.MarshalIndent(subAuth, "", " ")
|
||||
files = append(files, exportFile{
|
||||
Name: sub.Email + ".auth.json",
|
||||
Data: subData,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Single file — return directly
|
||||
if len(files) == 1 {
|
||||
return files[0].Data, files[0].Name, nil
|
||||
}
|
||||
|
||||
// Multiple files — zip archive
|
||||
var buf bytes.Buffer
|
||||
zw := zip.NewWriter(&buf)
|
||||
for _, f := range files {
|
||||
w, err := zw.Create(f.Name)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
w.Write(f.Data)
|
||||
}
|
||||
zw.Close()
|
||||
|
||||
ts := time.Now().Format("20060102_150405")
|
||||
filename := fmt.Sprintf("export_%s_%s.zip", note, ts)
|
||||
if note == "" {
|
||||
filename = fmt.Sprintf("export_%s.zip", ts)
|
||||
}
|
||||
|
||||
return buf.Bytes(), filename, nil
|
||||
}
|
||||
168
internal/service/export_test.go
Normal file
168
internal/service/export_test.go
Normal file
@@ -0,0 +1,168 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"gpt-plus/internal/db"
|
||||
|
||||
"github.com/glebarez/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
func setupExportTestDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
d, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("open db: %v", err)
|
||||
}
|
||||
d.AutoMigrate(&db.Account{})
|
||||
return d
|
||||
}
|
||||
|
||||
func TestExportSinglePlusAccount(t *testing.T) {
|
||||
d := setupExportTestDB(t)
|
||||
d.Create(&db.Account{
|
||||
Email: "plus@test.com", Plan: "plus", Status: "active",
|
||||
AccessToken: "at-123", RefreshToken: "rt-456", IDToken: "id-789",
|
||||
AccountID: "acc-001",
|
||||
})
|
||||
|
||||
data, filename, err := ExportAccounts(d, []uint{1}, "")
|
||||
if err != nil {
|
||||
t.Fatalf("export: %v", err)
|
||||
}
|
||||
if !strings.HasSuffix(filename, ".auth.json") {
|
||||
t.Fatalf("filename = %q, want *.auth.json", filename)
|
||||
}
|
||||
|
||||
var auth authFile
|
||||
if err := json.Unmarshal(data, &auth); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if auth.AccessToken != "at-123" {
|
||||
t.Fatalf("access_token = %q", auth.AccessToken)
|
||||
}
|
||||
if auth.Tokens.AccountID != "acc-001" {
|
||||
t.Fatalf("tokens.account_id = %q", auth.Tokens.AccountID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExportTeamWithSubAccounts(t *testing.T) {
|
||||
d := setupExportTestDB(t)
|
||||
owner := db.Account{
|
||||
Email: "owner@test.com", Plan: "team_owner", Status: "active",
|
||||
AccessToken: "owner-at", RefreshToken: "owner-rt",
|
||||
AccountID: "team-acc", TeamWorkspaceID: "ws-123", WorkspaceToken: "ws-tok",
|
||||
}
|
||||
d.Create(&owner)
|
||||
d.Create(&db.Account{
|
||||
Email: "member@test.com", Plan: "team_member", Status: "active",
|
||||
ParentID: &owner.ID, AccessToken: "mem-at", RefreshToken: "mem-rt",
|
||||
})
|
||||
|
||||
data, filename, err := ExportAccounts(d, []uint{owner.ID}, "test")
|
||||
if err != nil {
|
||||
t.Fatalf("export: %v", err)
|
||||
}
|
||||
if !strings.HasSuffix(filename, ".zip") {
|
||||
t.Fatalf("filename = %q, want *.zip", filename)
|
||||
}
|
||||
|
||||
r, err := zip.NewReader(bytes.NewReader(data), int64(len(data)))
|
||||
if err != nil {
|
||||
t.Fatalf("open zip: %v", err)
|
||||
}
|
||||
if len(r.File) != 2 {
|
||||
t.Fatalf("zip has %d files, want 2", len(r.File))
|
||||
}
|
||||
|
||||
names := make(map[string]bool)
|
||||
for _, f := range r.File {
|
||||
names[f.Name] = true
|
||||
}
|
||||
if !names["owner@test.com.auth.json"] {
|
||||
t.Fatal("missing owner auth file")
|
||||
}
|
||||
if !names["member@test.com.auth.json"] {
|
||||
t.Fatal("missing member auth file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExportTeamOwnerUsesWorkspaceToken(t *testing.T) {
|
||||
d := setupExportTestDB(t)
|
||||
d.Create(&db.Account{
|
||||
Email: "team@test.com", Plan: "team_owner", Status: "active",
|
||||
AccessToken: "normal-at", WorkspaceToken: "ws-at",
|
||||
TeamWorkspaceID: "team-ws-id", AccountID: "personal-acc",
|
||||
})
|
||||
|
||||
data, _, err := ExportAccounts(d, []uint{1}, "")
|
||||
if err != nil {
|
||||
t.Fatalf("export: %v", err)
|
||||
}
|
||||
|
||||
var auth authFile
|
||||
json.Unmarshal(data, &auth)
|
||||
if auth.AccessToken != "ws-at" {
|
||||
t.Fatalf("should use workspace token, got %q", auth.AccessToken)
|
||||
}
|
||||
if auth.Tokens.AccountID != "team-ws-id" {
|
||||
t.Fatalf("should use team workspace ID, got %q", auth.Tokens.AccountID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExportMultiplePlusAccounts(t *testing.T) {
|
||||
d := setupExportTestDB(t)
|
||||
d.Create(&db.Account{Email: "a@test.com", Plan: "plus", AccessToken: "at-a"})
|
||||
d.Create(&db.Account{Email: "b@test.com", Plan: "plus", AccessToken: "at-b"})
|
||||
|
||||
data, filename, err := ExportAccounts(d, []uint{1, 2}, "batch")
|
||||
if err != nil {
|
||||
t.Fatalf("export: %v", err)
|
||||
}
|
||||
if !strings.HasSuffix(filename, ".zip") {
|
||||
t.Fatalf("filename = %q, want *.zip for multiple", filename)
|
||||
}
|
||||
if !strings.Contains(filename, "batch") {
|
||||
t.Fatalf("filename should contain note, got %q", filename)
|
||||
}
|
||||
|
||||
r, _ := zip.NewReader(bytes.NewReader(data), int64(len(data)))
|
||||
if len(r.File) != 2 {
|
||||
t.Fatalf("zip has %d files, want 2", len(r.File))
|
||||
}
|
||||
}
|
||||
|
||||
func TestExportNotFound(t *testing.T) {
|
||||
d := setupExportTestDB(t)
|
||||
|
||||
_, _, err := ExportAccounts(d, []uint{999}, "")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for nonexistent account")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAuthFileFallbacks(t *testing.T) {
|
||||
acct := &db.Account{
|
||||
Email: "basic@test.com", AccessToken: "at-1", RefreshToken: "rt-1",
|
||||
IDToken: "id-1", AccountID: "acc-1",
|
||||
}
|
||||
auth := buildAuthFile(acct)
|
||||
|
||||
if auth.AccessToken != "at-1" {
|
||||
t.Fatalf("access_token = %q", auth.AccessToken)
|
||||
}
|
||||
if auth.Tokens.AccountID != "acc-1" {
|
||||
t.Fatalf("tokens.account_id = %q", auth.Tokens.AccountID)
|
||||
}
|
||||
if auth.LastRefresh == "" {
|
||||
t.Fatal("last_refresh should be set")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user