230 lines
6.1 KiB
Go
230 lines
6.1 KiB
Go
package task
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"time"
|
|
|
|
"gpt-plus/pkg/auth"
|
|
"gpt-plus/pkg/chatgpt"
|
|
)
|
|
|
|
const (
|
|
finalMembershipPolls = 5
|
|
finalMembershipPollDelay = 2 * time.Second
|
|
)
|
|
|
|
type finalMembershipState struct {
|
|
Personal *chatgpt.AccountInfo
|
|
Workspace *chatgpt.AccountInfo
|
|
}
|
|
|
|
func (s *finalMembershipState) plusActive() bool {
|
|
return plusMembershipActive(s.Personal)
|
|
}
|
|
|
|
func (s *finalMembershipState) teamActive() bool {
|
|
return teamMembershipActive(s.Workspace)
|
|
}
|
|
|
|
func (s *finalMembershipState) satisfied(taskType string) bool {
|
|
switch taskType {
|
|
case TaskTypePlus:
|
|
return s.plusActive()
|
|
case TaskTypeTeam:
|
|
return s.teamActive()
|
|
case TaskTypeBoth:
|
|
return s.plusActive() && s.teamActive()
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func (s *finalMembershipState) resultPlanForTask(taskType string) string {
|
|
if s != nil && s.satisfied(taskType) {
|
|
return taskType
|
|
}
|
|
return s.actualPlan()
|
|
}
|
|
|
|
func (s *finalMembershipState) actualPlan() string {
|
|
switch {
|
|
case s == nil:
|
|
return "unknown"
|
|
case s.plusActive() && s.teamActive():
|
|
return TaskTypeBoth
|
|
case s.teamActive():
|
|
return "team"
|
|
case s.plusActive():
|
|
return "plus"
|
|
case s.Personal != nil && s.Personal.PlanType != "":
|
|
return s.Personal.PlanType
|
|
case s.Workspace != nil && s.Workspace.PlanType != "":
|
|
return s.Workspace.PlanType
|
|
default:
|
|
return "unknown"
|
|
}
|
|
}
|
|
|
|
func (s *finalMembershipState) describe() string {
|
|
return fmt.Sprintf("personal=%s workspace=%s", describeMembership(s.Personal), describeMembership(s.Workspace))
|
|
}
|
|
|
|
func plusMembershipActive(info *chatgpt.AccountInfo) bool {
|
|
return info != nil &&
|
|
info.Structure == "personal" &&
|
|
info.PlanType == "plus" &&
|
|
info.HasActiveSubscription &&
|
|
info.SubscriptionID != ""
|
|
}
|
|
|
|
func teamMembershipActive(info *chatgpt.AccountInfo) bool {
|
|
return info != nil &&
|
|
info.Structure == "workspace" &&
|
|
info.PlanType == "team"
|
|
}
|
|
|
|
func describeMembership(info *chatgpt.AccountInfo) string {
|
|
if info == nil {
|
|
return "none"
|
|
}
|
|
return fmt.Sprintf("%s/%s(active=%v sub=%t id=%s)",
|
|
info.Structure, info.PlanType, info.HasActiveSubscription, info.SubscriptionID != "", info.AccountID)
|
|
}
|
|
|
|
func selectPersonalMembership(accounts []*chatgpt.AccountInfo) *chatgpt.AccountInfo {
|
|
for _, acct := range accounts {
|
|
if acct.Structure == "personal" {
|
|
return acct
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func selectWorkspaceMembership(accounts []*chatgpt.AccountInfo, preferredID string) *chatgpt.AccountInfo {
|
|
if preferredID != "" {
|
|
for _, acct := range accounts {
|
|
if acct.Structure == "workspace" && acct.AccountID == preferredID {
|
|
return acct
|
|
}
|
|
}
|
|
}
|
|
for _, acct := range accounts {
|
|
if acct.Structure == "workspace" && acct.PlanType == "team" {
|
|
return acct
|
|
}
|
|
}
|
|
for _, acct := range accounts {
|
|
if acct.Structure == "workspace" {
|
|
return acct
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (r *TaskRunner) verifyTaskMembership(
|
|
ctx context.Context,
|
|
taskType string,
|
|
session *chatgpt.Session,
|
|
teamAccountID string,
|
|
statusFn chatgpt.StatusFunc,
|
|
) (*finalMembershipState, error) {
|
|
var lastState *finalMembershipState
|
|
var lastErr error
|
|
|
|
for attempt := 1; attempt <= finalMembershipPolls; attempt++ {
|
|
if err := refreshVerificationTokens(ctx, session); err != nil && statusFn != nil {
|
|
statusFn(" -> Final token refresh %d/%d failed: %v", attempt, finalMembershipPolls, err)
|
|
}
|
|
|
|
accounts, err := chatgpt.CheckAccountFull(session.Client, session.AccessToken, session.DeviceID)
|
|
if err != nil {
|
|
lastErr = fmt.Errorf("accounts/check failed: %w", err)
|
|
} else {
|
|
lastState = &finalMembershipState{
|
|
Personal: selectPersonalMembership(accounts),
|
|
Workspace: selectWorkspaceMembership(accounts, teamAccountID),
|
|
}
|
|
}
|
|
|
|
if (taskType == TaskTypeTeam || taskType == TaskTypeBoth) &&
|
|
(lastState == nil || !lastState.teamActive()) &&
|
|
teamAccountID != "" {
|
|
workspaceToken, wsErr := chatgpt.GetWorkspaceAccessToken(session.Client, teamAccountID)
|
|
if wsErr == nil {
|
|
if wsAccounts, wsCheckErr := chatgpt.CheckAccountFull(session.Client, workspaceToken, session.DeviceID); wsCheckErr == nil {
|
|
if lastState == nil {
|
|
lastState = &finalMembershipState{}
|
|
}
|
|
lastState.Workspace = selectWorkspaceMembership(wsAccounts, teamAccountID)
|
|
} else if lastErr == nil {
|
|
lastErr = fmt.Errorf("workspace accounts/check failed: %w", wsCheckErr)
|
|
}
|
|
} else if lastErr == nil {
|
|
lastErr = fmt.Errorf("workspace token refresh failed: %w", wsErr)
|
|
}
|
|
}
|
|
|
|
if lastState != nil && statusFn != nil {
|
|
statusFn(" -> Final membership %d/%d: %s", attempt, finalMembershipPolls, lastState.describe())
|
|
}
|
|
if lastState != nil && lastState.satisfied(taskType) {
|
|
return lastState, nil
|
|
}
|
|
if lastState != nil {
|
|
lastErr = fmt.Errorf("membership mismatch for %s: %s", taskType, lastState.describe())
|
|
}
|
|
|
|
if attempt < finalMembershipPolls {
|
|
if err := sleepWithContext(ctx, finalMembershipPollDelay); err != nil {
|
|
return lastState, err
|
|
}
|
|
}
|
|
}
|
|
|
|
if lastState != nil {
|
|
return lastState, fmt.Errorf("final membership mismatch for %s: %s", taskType, lastState.describe())
|
|
}
|
|
if lastErr != nil {
|
|
return nil, fmt.Errorf("final membership check failed for %s: %w", taskType, lastErr)
|
|
}
|
|
return nil, fmt.Errorf("final membership check failed for %s", taskType)
|
|
}
|
|
|
|
func refreshVerificationTokens(ctx context.Context, session *chatgpt.Session) error {
|
|
if err := session.RefreshSession(); err == nil {
|
|
return nil
|
|
} else {
|
|
tokens, tokenErr := auth.ObtainCodexTokens(ctx, session.Client, session.DeviceID, "")
|
|
if tokenErr != nil {
|
|
return fmt.Errorf("session refresh failed: %v; codex refresh failed: %w", err, tokenErr)
|
|
}
|
|
if tokens.AccessToken == "" {
|
|
return fmt.Errorf("codex refresh returned empty access token")
|
|
}
|
|
session.AccessToken = tokens.AccessToken
|
|
if tokens.RefreshToken != "" {
|
|
session.RefreshToken = tokens.RefreshToken
|
|
}
|
|
if tokens.IDToken != "" {
|
|
session.IDToken = tokens.IDToken
|
|
}
|
|
if tokens.ChatGPTAccountID != "" {
|
|
session.AccountID = tokens.ChatGPTAccountID
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func sleepWithContext(ctx context.Context, delay time.Duration) error {
|
|
timer := time.NewTimer(delay)
|
|
defer timer.Stop()
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case <-timer.C:
|
|
return nil
|
|
}
|
|
}
|