710 lines
22 KiB
Go
710 lines
22 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"log"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
|
|
"gpt-plus/pkg/httpclient"
|
|
"gpt-plus/pkg/provider/email"
|
|
)
|
|
|
|
// LoginResult holds the tokens obtained from a successful OAuth login.
|
|
type LoginResult struct {
|
|
AccessToken string
|
|
RefreshToken string
|
|
IDToken string
|
|
ChatGPTAccountID string
|
|
ChatGPTUserID string
|
|
}
|
|
|
|
// Login performs the full OAuth login flow: authorize -> password verify -> token exchange.
|
|
func Login(
|
|
ctx context.Context,
|
|
client *httpclient.Client,
|
|
emailAddr, password, deviceID string,
|
|
sentinel *SentinelGenerator,
|
|
mailboxID string,
|
|
emailProvider email.EmailProvider,
|
|
) (*LoginResult, error) {
|
|
// Keep registration cookies — the verified email session helps skip OTP during login.
|
|
// Only ensure oai-did is set.
|
|
cookieURL, _ := url.Parse(oauthIssuer)
|
|
client.GetCookieJar().SetCookies(cookieURL, []*http.Cookie{
|
|
{Name: "oai-did", Value: deviceID},
|
|
})
|
|
|
|
// Generate PKCE and state
|
|
codeVerifier, codeChallenge := generatePKCE()
|
|
state := generateState()
|
|
|
|
// ===== Step 1: GET /oauth/authorize (no screen_hint) =====
|
|
authorizeParams := url.Values{
|
|
"response_type": {"code"},
|
|
"client_id": {oauthClientID},
|
|
"redirect_uri": {oauthRedirectURI},
|
|
"scope": {oauthScope},
|
|
"code_challenge": {codeChallenge},
|
|
"code_challenge_method": {"S256"},
|
|
"state": {state},
|
|
"id_token_add_organizations": {"true"},
|
|
"codex_cli_simplified_flow": {"true"},
|
|
}
|
|
authorizeURL := oauthIssuer + "/oauth/authorize?" + authorizeParams.Encode()
|
|
|
|
navH := navigateHeaders()
|
|
resp, err := client.Get(authorizeURL, navH)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("step1 authorize: %w", err)
|
|
}
|
|
httpclient.ReadBody(resp)
|
|
|
|
// ===== Step 2: POST authorize/continue with email =====
|
|
sentinelToken, err := sentinel.GenerateToken(ctx, client, "authorize_continue")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("step2 sentinel: %w", err)
|
|
}
|
|
|
|
headers := commonHeaders(deviceID)
|
|
headers["referer"] = oauthIssuer + "/log-in"
|
|
headers["openai-sentinel-token"] = sentinelToken
|
|
|
|
continueBody := map[string]interface{}{
|
|
"username": map[string]string{"kind": "email", "value": emailAddr},
|
|
}
|
|
resp, err = client.PostJSON(oauthIssuer+"/api/accounts/authorize/continue", continueBody, headers)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("step2 authorize/continue: %w", err)
|
|
}
|
|
body, _ := httpclient.ReadBody(resp)
|
|
if resp.StatusCode != 200 {
|
|
return nil, fmt.Errorf("step2 failed (%d): %s", resp.StatusCode, string(body))
|
|
}
|
|
|
|
// ===== Step 3: POST password/verify =====
|
|
sentinelToken, err = sentinel.GenerateToken(ctx, client, "password_verify")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("step3 sentinel: %w", err)
|
|
}
|
|
|
|
headers = commonHeaders(deviceID)
|
|
headers["referer"] = oauthIssuer + "/log-in/password"
|
|
headers["openai-sentinel-token"] = sentinelToken
|
|
|
|
pwBody := map[string]string{"password": password}
|
|
resp, err = client.PostJSON(oauthIssuer+"/api/accounts/password/verify", pwBody, headers)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("step3 password/verify: %w", err)
|
|
}
|
|
body, _ = httpclient.ReadBody(resp)
|
|
if resp.StatusCode != 200 {
|
|
return nil, fmt.Errorf("step3 password verify failed (%d): %s", resp.StatusCode, string(body))
|
|
}
|
|
|
|
var verifyResp struct {
|
|
ContinueURL string `json:"continue_url"`
|
|
Page struct {
|
|
Type string `json:"type"`
|
|
} `json:"page"`
|
|
}
|
|
json.Unmarshal(body, &verifyResp)
|
|
continueURL := verifyResp.ContinueURL
|
|
pageType := verifyResp.Page.Type
|
|
log.Printf("[login] step3 password/verify: status=%d, page=%s, continue=%s", resp.StatusCode, pageType, continueURL)
|
|
|
|
// ===== Step 3.5: Email OTP verification (if triggered for new accounts) =====
|
|
if pageType == "email_otp_verification" || pageType == "email_otp_send" ||
|
|
strings.Contains(continueURL, "email-verification") || strings.Contains(continueURL, "email-otp") {
|
|
if mailboxID == "" || emailProvider == nil {
|
|
return nil, errors.New("email verification required but no mailbox/provider available")
|
|
}
|
|
|
|
// Trigger the OTP send explicitly (like registration does) and poll for delivery.
|
|
log.Printf("[login] step3.5 OTP required — triggering email-otp/send...")
|
|
|
|
// Snapshot existing emails BEFORE triggering the send, to skip the registration OTP.
|
|
sendHeaders := commonHeaders(deviceID)
|
|
sendHeaders["referer"] = oauthIssuer + "/email-verification"
|
|
|
|
// GET /api/accounts/email-otp/send to trigger the OTP email
|
|
_, err = client.Get(oauthIssuer+"/api/accounts/email-otp/send", sendHeaders)
|
|
if err != nil {
|
|
log.Printf("[login] step3.5 email-otp/send failed (non-fatal): %v", err)
|
|
}
|
|
|
|
otpCode, err := emailProvider.WaitForVerificationCode(ctx, mailboxID, 120*time.Second, time.Now())
|
|
if err != nil {
|
|
return nil, fmt.Errorf("step3.5 wait for otp: %w", err)
|
|
}
|
|
if otpCode == "" {
|
|
return nil, errors.New("step3.5 wait for otp: empty code returned")
|
|
}
|
|
log.Printf("[login] step3.5 got OTP code: %s", otpCode)
|
|
|
|
headers = commonHeaders(deviceID)
|
|
headers["referer"] = oauthIssuer + "/email-verification"
|
|
|
|
otpBody := map[string]string{"code": otpCode}
|
|
resp, err = client.PostJSON(oauthIssuer+"/api/accounts/email-otp/validate", otpBody, headers)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("step3.5 validate otp: %w", err)
|
|
}
|
|
body, _ = httpclient.ReadBody(resp)
|
|
if resp.StatusCode != 200 {
|
|
return nil, fmt.Errorf("step3.5 otp validate failed (%d): %s", resp.StatusCode, string(body))
|
|
}
|
|
|
|
json.Unmarshal(body, &verifyResp)
|
|
continueURL = verifyResp.ContinueURL
|
|
pageType = verifyResp.Page.Type
|
|
log.Printf("[login] step3.5 otp validate: page=%s, continue=%s", pageType, continueURL)
|
|
|
|
// If about-you step needed, submit name/birthdate
|
|
if strings.Contains(continueURL, "about-you") {
|
|
firstName, lastName := generateRandomName()
|
|
birthdate := generateRandomBirthday()
|
|
|
|
headers = commonHeaders(deviceID)
|
|
headers["referer"] = oauthIssuer + "/about-you"
|
|
|
|
createBody := map[string]string{
|
|
"name": firstName + " " + lastName,
|
|
"birthdate": birthdate,
|
|
}
|
|
resp, err = client.PostJSON(oauthIssuer+"/api/accounts/create_account", createBody, headers)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("step3.5 create_account: %w", err)
|
|
}
|
|
body, _ = httpclient.ReadBody(resp)
|
|
if resp.StatusCode == 200 {
|
|
var createResp struct {
|
|
ContinueURL string `json:"continue_url"`
|
|
}
|
|
json.Unmarshal(body, &createResp)
|
|
continueURL = createResp.ContinueURL
|
|
} else if resp.StatusCode == 400 && strings.Contains(string(body), "already_exists") {
|
|
continueURL = oauthIssuer + "/sign-in-with-chatgpt/codex/consent"
|
|
}
|
|
}
|
|
}
|
|
|
|
// Handle consent page type
|
|
if strings.Contains(pageType, "consent") || continueURL == "" {
|
|
continueURL = oauthIssuer + "/sign-in-with-chatgpt/codex/consent"
|
|
}
|
|
|
|
log.Printf("[login] step4 entering consent flow: page=%s, continue=%s", pageType, continueURL)
|
|
|
|
// ===== Step 4: Follow consent/workspace/organization redirects to extract code =====
|
|
authCode, err := followConsentRedirects(ctx, client, deviceID, continueURL, "")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("step4 consent flow: %w", err)
|
|
}
|
|
|
|
// ===== Step 5: Exchange code for tokens =====
|
|
tokens, err := exchangeCodeForTokens(ctx, client, authCode, codeVerifier)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("step5 token exchange: %w", err)
|
|
}
|
|
|
|
return tokens, nil
|
|
}
|
|
|
|
// decodeAuthSessionCookie parses the oai-client-auth-session cookie.
|
|
// Format: base64(json).timestamp.signature (Flask/itsdangerous style).
|
|
func decodeAuthSessionCookie(client *httpclient.Client) (map[string]interface{}, error) {
|
|
cookieURL, _ := url.Parse(oauthIssuer)
|
|
for _, c := range client.GetCookieJar().Cookies(cookieURL) {
|
|
if c.Name == "oai-client-auth-session" {
|
|
val := c.Value
|
|
firstPart := val
|
|
if idx := strings.Index(val, "."); idx > 0 {
|
|
firstPart = val[:idx]
|
|
}
|
|
// Add base64 padding
|
|
if pad := 4 - len(firstPart)%4; pad != 4 {
|
|
firstPart += strings.Repeat("=", pad)
|
|
}
|
|
raw, err := base64.URLEncoding.DecodeString(firstPart)
|
|
if err != nil {
|
|
// Try standard base64
|
|
raw, err = base64.StdEncoding.DecodeString(firstPart)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("decode auth session base64: %w", err)
|
|
}
|
|
}
|
|
var data map[string]interface{}
|
|
if err := json.Unmarshal(raw, &data); err != nil {
|
|
return nil, fmt.Errorf("parse auth session json: %w", err)
|
|
}
|
|
return data, nil
|
|
}
|
|
}
|
|
return nil, errors.New("oai-client-auth-session cookie not found")
|
|
}
|
|
|
|
// followConsentRedirects navigates through consent, workspace/select, organization/select
|
|
// until extracting the authorization code from the callback URL.
|
|
// targetWorkspaceID: if non-empty, select this specific workspace instead of the first one.
|
|
func followConsentRedirects(ctx context.Context, client *httpclient.Client, deviceID, continueURL, targetWorkspaceID string) (string, error) {
|
|
// Normalize URL
|
|
if strings.HasPrefix(continueURL, "/") {
|
|
continueURL = oauthIssuer + continueURL
|
|
}
|
|
|
|
navH := navigateHeaders()
|
|
var resp *http.Response
|
|
|
|
if targetWorkspaceID == "" {
|
|
// Default: auto-follow redirects. If we get a code directly, use it.
|
|
var err error
|
|
resp, err = client.Get(continueURL, navH)
|
|
if err != nil {
|
|
return "", fmt.Errorf("get consent: %w", err)
|
|
}
|
|
httpclient.ReadBody(resp)
|
|
|
|
if resp.Request != nil {
|
|
if code := extractCodeFromURL(resp.Request.URL.String()); code != "" {
|
|
return code, nil
|
|
}
|
|
}
|
|
log.Printf("[login] consent page: status=%d, final_url=%s", resp.StatusCode, resp.Request.URL.String())
|
|
} else {
|
|
// Specific workspace: use DoNoRedirect to prevent auto-selecting the default workspace.
|
|
currentURL := continueURL
|
|
for i := 0; i < 15; i++ {
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, currentURL, nil)
|
|
if err != nil {
|
|
return "", fmt.Errorf("build consent redirect request: %w", err)
|
|
}
|
|
for k, v := range navH {
|
|
req.Header.Set(k, v)
|
|
}
|
|
|
|
resp, err := client.DoNoRedirect(req)
|
|
if err != nil {
|
|
if code := extractCodeFromURL(currentURL); code != "" {
|
|
log.Printf("[login] consent redirect got default code, ignoring for workspace %s", targetWorkspaceID)
|
|
break
|
|
}
|
|
return "", fmt.Errorf("consent redirect: %w", err)
|
|
}
|
|
httpclient.ReadBody(resp)
|
|
|
|
if resp.StatusCode >= 300 && resp.StatusCode < 400 {
|
|
loc := resp.Header.Get("Location")
|
|
if loc == "" {
|
|
break
|
|
}
|
|
if !strings.HasPrefix(loc, "http") {
|
|
loc = oauthIssuer + loc
|
|
}
|
|
if code := extractCodeFromURL(loc); code != "" {
|
|
log.Printf("[login] consent redirect has default code, ignoring for workspace %s", targetWorkspaceID)
|
|
break
|
|
}
|
|
log.Printf("[login] consent redirect %d: %d → %s", i+1, resp.StatusCode, loc)
|
|
currentURL = loc
|
|
continue
|
|
}
|
|
log.Printf("[login] consent page reached: status=%d, url=%s", resp.StatusCode, currentURL)
|
|
break
|
|
}
|
|
}
|
|
|
|
// Decode oai-client-auth-session cookie to extract workspace data
|
|
sessionData, err := decodeAuthSessionCookie(client)
|
|
if err != nil {
|
|
log.Printf("[login] warning: %v", err)
|
|
cookieURL, _ := url.Parse(oauthIssuer)
|
|
for _, c := range client.GetCookieJar().Cookies(cookieURL) {
|
|
log.Printf("[login] cookie: %s (len=%d)", c.Name, len(c.Value))
|
|
}
|
|
return "", fmt.Errorf("consent flow: %w", err)
|
|
}
|
|
|
|
log.Printf("[login] decoded auth session, keys: %v", getMapKeys(sessionData))
|
|
|
|
// Extract workspace_id from session data
|
|
var workspaceID string
|
|
if workspaces, ok := sessionData["workspaces"].([]interface{}); ok && len(workspaces) > 0 {
|
|
// If a specific workspace is requested, find it by ID
|
|
if targetWorkspaceID != "" {
|
|
for _, w := range workspaces {
|
|
if ws, ok := w.(map[string]interface{}); ok {
|
|
if id, ok := ws["id"].(string); ok && id == targetWorkspaceID {
|
|
workspaceID = id
|
|
kind, _ := ws["kind"].(string)
|
|
log.Printf("[login] matched target workspace_id: %s (kind: %s)", workspaceID, kind)
|
|
break
|
|
}
|
|
}
|
|
}
|
|
if workspaceID == "" {
|
|
log.Printf("[login] target workspace %s not found in session, falling back to first", targetWorkspaceID)
|
|
}
|
|
}
|
|
// Fallback to first workspace if target not found or not specified
|
|
if workspaceID == "" {
|
|
if ws, ok := workspaces[0].(map[string]interface{}); ok {
|
|
if id, ok := ws["id"].(string); ok {
|
|
workspaceID = id
|
|
kind, _ := ws["kind"].(string)
|
|
log.Printf("[login] workspace_id: %s (kind: %s)", workspaceID, kind)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if workspaceID == "" {
|
|
// For new accounts that have no workspaces yet, try direct authorize flow
|
|
// The consent page itself might have set up the session — try following the
|
|
// authorize URL again which should now redirect with a code
|
|
log.Printf("[login] no workspaces in session, trying re-authorize...")
|
|
code, err := followRedirectsForCode(ctx, client, continueURL)
|
|
if err == nil && code != "" {
|
|
return code, nil
|
|
}
|
|
return "", errors.New("no workspace_id found in auth session cookie and re-authorize failed")
|
|
}
|
|
|
|
// POST workspace/select (no redirect following)
|
|
headers := commonHeaders(deviceID)
|
|
headers["referer"] = continueURL
|
|
|
|
wsBody := map[string]string{"workspace_id": workspaceID}
|
|
wsJSON, _ := json.Marshal(wsBody)
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, oauthIssuer+"/api/accounts/workspace/select", strings.NewReader(string(wsJSON)))
|
|
if err != nil {
|
|
return "", fmt.Errorf("build workspace/select request: %w", err)
|
|
}
|
|
for k, v := range headers {
|
|
req.Header.Set(k, v)
|
|
}
|
|
resp, err = client.DoNoRedirect(req)
|
|
if err != nil {
|
|
return "", fmt.Errorf("workspace/select: %w", err)
|
|
}
|
|
wsData, _ := httpclient.ReadBody(resp)
|
|
log.Printf("[login] workspace/select: status=%d", resp.StatusCode)
|
|
|
|
// Check for redirect with code
|
|
if resp.StatusCode >= 300 && resp.StatusCode < 400 {
|
|
loc := resp.Header.Get("Location")
|
|
if code := extractCodeFromURL(loc); code != "" {
|
|
return code, nil
|
|
}
|
|
// Follow redirect chain
|
|
if loc != "" {
|
|
if !strings.HasPrefix(loc, "http") {
|
|
loc = oauthIssuer + loc
|
|
}
|
|
code, err := followRedirectsForCode(ctx, client, loc)
|
|
if err == nil && code != "" {
|
|
return code, nil
|
|
}
|
|
}
|
|
}
|
|
|
|
// Parse workspace/select response for org data
|
|
if resp.StatusCode == 200 {
|
|
var wsResp struct {
|
|
Data struct {
|
|
Orgs []struct {
|
|
ID string `json:"id"`
|
|
Projects []struct {
|
|
ID string `json:"id"`
|
|
} `json:"projects"`
|
|
} `json:"orgs"`
|
|
} `json:"data"`
|
|
ContinueURL string `json:"continue_url"`
|
|
}
|
|
json.Unmarshal(wsData, &wsResp)
|
|
|
|
// Extract org_id and project_id
|
|
var orgID, projectID string
|
|
if len(wsResp.Data.Orgs) > 0 {
|
|
orgID = wsResp.Data.Orgs[0].ID
|
|
if len(wsResp.Data.Orgs[0].Projects) > 0 {
|
|
projectID = wsResp.Data.Orgs[0].Projects[0].ID
|
|
}
|
|
}
|
|
|
|
if orgID != "" {
|
|
log.Printf("[login] org_id: %s, project_id: %s", orgID, projectID)
|
|
|
|
// POST organization/select
|
|
orgBody := map[string]string{"org_id": orgID}
|
|
if projectID != "" {
|
|
orgBody["project_id"] = projectID
|
|
}
|
|
orgJSON, _ := json.Marshal(orgBody)
|
|
orgReq, err := http.NewRequestWithContext(ctx, http.MethodPost, oauthIssuer+"/api/accounts/organization/select", strings.NewReader(string(orgJSON)))
|
|
if err != nil {
|
|
return "", fmt.Errorf("build organization/select request: %w", err)
|
|
}
|
|
for k, v := range headers {
|
|
orgReq.Header.Set(k, v)
|
|
}
|
|
resp, err = client.DoNoRedirect(orgReq)
|
|
if err != nil {
|
|
return "", fmt.Errorf("organization/select: %w", err)
|
|
}
|
|
orgData, _ := httpclient.ReadBody(resp)
|
|
log.Printf("[login] organization/select: status=%d", resp.StatusCode)
|
|
|
|
// Check for redirect with code
|
|
if resp.StatusCode >= 300 && resp.StatusCode < 400 {
|
|
loc := resp.Header.Get("Location")
|
|
if code := extractCodeFromURL(loc); code != "" {
|
|
return code, nil
|
|
}
|
|
if loc != "" {
|
|
if !strings.HasPrefix(loc, "http") {
|
|
loc = oauthIssuer + loc
|
|
}
|
|
code, err := followRedirectsForCode(ctx, client, loc)
|
|
if err == nil && code != "" {
|
|
return code, nil
|
|
}
|
|
}
|
|
}
|
|
|
|
// Parse continue_url from response
|
|
if resp.StatusCode == 200 {
|
|
var orgResp struct {
|
|
ContinueURL string `json:"continue_url"`
|
|
}
|
|
json.Unmarshal(orgData, &orgResp)
|
|
if orgResp.ContinueURL != "" {
|
|
continueURL = orgResp.ContinueURL
|
|
}
|
|
}
|
|
} else if wsResp.ContinueURL != "" {
|
|
continueURL = wsResp.ContinueURL
|
|
}
|
|
}
|
|
|
|
// Follow the final redirect chain to extract the authorization code
|
|
if strings.HasPrefix(continueURL, "/") {
|
|
continueURL = oauthIssuer + continueURL
|
|
}
|
|
|
|
code, err := followRedirectsForCode(ctx, client, continueURL)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if code == "" {
|
|
return "", errors.New("could not extract authorization code from redirect chain")
|
|
}
|
|
return code, nil
|
|
}
|
|
|
|
// getMapKeys returns the keys of a map for logging.
|
|
func getMapKeys(m map[string]interface{}) []string {
|
|
keys := make([]string, 0, len(m))
|
|
for k := range m {
|
|
keys = append(keys, k)
|
|
}
|
|
return keys
|
|
}
|
|
|
|
// followRedirectsForCode follows HTTP redirects manually to capture the code from the callback URL.
|
|
func followRedirectsForCode(ctx context.Context, client *httpclient.Client, startURL string) (string, error) {
|
|
currentURL := startURL
|
|
navH := navigateHeaders()
|
|
|
|
for i := 0; i < 20; i++ { // max 20 redirects
|
|
// Before making the request, check if the current URL itself contains the code
|
|
// (e.g. localhost callback that we can't actually connect to)
|
|
if code := extractCodeFromURL(currentURL); code != "" {
|
|
log.Printf("[login] extracted code from URL before request: %s", currentURL)
|
|
return code, nil
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, currentURL, nil)
|
|
if err != nil {
|
|
return "", fmt.Errorf("build redirect request: %w", err)
|
|
}
|
|
for k, v := range navH {
|
|
req.Header.Set(k, v)
|
|
}
|
|
|
|
resp, err := client.DoNoRedirect(req)
|
|
if err != nil {
|
|
// Connection error to localhost callback — extract code from the URL we were trying to reach
|
|
if code := extractCodeFromURL(currentURL); code != "" {
|
|
log.Printf("[login] extracted code from failed redirect URL: %s", currentURL)
|
|
return code, nil
|
|
}
|
|
return "", fmt.Errorf("redirect request to %s: %w", currentURL, err)
|
|
}
|
|
httpclient.ReadBody(resp)
|
|
|
|
if resp.StatusCode >= 300 && resp.StatusCode < 400 {
|
|
loc := resp.Header.Get("Location")
|
|
if loc == "" {
|
|
return "", errors.New("redirect with no Location header")
|
|
}
|
|
if !strings.HasPrefix(loc, "http") {
|
|
loc = oauthIssuer + loc
|
|
}
|
|
|
|
log.Printf("[login] redirect %d: %d → %s", i+1, resp.StatusCode, loc)
|
|
|
|
if code := extractCodeFromURL(loc); code != "" {
|
|
return code, nil
|
|
}
|
|
currentURL = loc
|
|
continue
|
|
}
|
|
log.Printf("[login] redirect chain ended: status=%d, url=%s", resp.StatusCode, currentURL)
|
|
|
|
// Check the final URL for code
|
|
if resp.Request != nil {
|
|
if code := extractCodeFromURL(resp.Request.URL.String()); code != "" {
|
|
return code, nil
|
|
}
|
|
}
|
|
break
|
|
}
|
|
return "", nil
|
|
}
|
|
|
|
// exchangeCodeForTokens exchanges an authorization code for OAuth tokens.
|
|
func exchangeCodeForTokens(ctx context.Context, client *httpclient.Client, code, codeVerifier string) (*LoginResult, error) {
|
|
tokenURL := oauthIssuer + "/oauth/token"
|
|
|
|
values := url.Values{
|
|
"grant_type": {"authorization_code"},
|
|
"code": {code},
|
|
"redirect_uri": {oauthRedirectURI},
|
|
"client_id": {oauthClientID},
|
|
"code_verifier": {codeVerifier},
|
|
}
|
|
|
|
headers := map[string]string{
|
|
"Content-Type": "application/x-www-form-urlencoded",
|
|
}
|
|
|
|
var tokenResp struct {
|
|
AccessToken string `json:"access_token"`
|
|
RefreshToken string `json:"refresh_token"`
|
|
IDToken string `json:"id_token"`
|
|
TokenType string `json:"token_type"`
|
|
ExpiresIn int `json:"expires_in"`
|
|
}
|
|
|
|
var lastErr error
|
|
for attempt := 0; attempt < 2; attempt++ {
|
|
resp, err := client.PostForm(tokenURL, values, headers)
|
|
if err != nil {
|
|
lastErr = fmt.Errorf("token exchange: %w", err)
|
|
time.Sleep(1 * time.Second)
|
|
continue
|
|
}
|
|
body, _ := httpclient.ReadBody(resp)
|
|
if resp.StatusCode != 200 {
|
|
lastErr = fmt.Errorf("token exchange failed (%d): %s", resp.StatusCode, string(body))
|
|
time.Sleep(1 * time.Second)
|
|
continue
|
|
}
|
|
if err := json.Unmarshal(body, &tokenResp); err != nil {
|
|
return nil, fmt.Errorf("parse token response: %w", err)
|
|
}
|
|
lastErr = nil
|
|
break
|
|
}
|
|
if lastErr != nil {
|
|
return nil, lastErr
|
|
}
|
|
|
|
// Extract account/user IDs from JWT
|
|
accountID, userID := extractIDsFromJWT(tokenResp.AccessToken)
|
|
|
|
return &LoginResult{
|
|
AccessToken: tokenResp.AccessToken,
|
|
RefreshToken: tokenResp.RefreshToken,
|
|
IDToken: tokenResp.IDToken,
|
|
ChatGPTAccountID: accountID,
|
|
ChatGPTUserID: userID,
|
|
}, nil
|
|
}
|
|
|
|
// extractCodeFromURL extracts the "code" query parameter from a URL.
|
|
func extractCodeFromURL(rawURL string) string {
|
|
if !strings.Contains(rawURL, "code=") {
|
|
return ""
|
|
}
|
|
parsed, err := url.Parse(rawURL)
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
return parsed.Query().Get("code")
|
|
}
|
|
|
|
// extractIDsFromJWT decodes the JWT payload and extracts chatgpt account/user IDs
|
|
// from the "https://api.openai.com/auth" claim.
|
|
func extractIDsFromJWT(token string) (accountID, userID string) {
|
|
parts := strings.SplitN(token, ".", 3)
|
|
if len(parts) < 2 {
|
|
return
|
|
}
|
|
|
|
payload := parts[1]
|
|
// Add padding if needed
|
|
switch len(payload) % 4 {
|
|
case 2:
|
|
payload += "=="
|
|
case 3:
|
|
payload += "="
|
|
}
|
|
|
|
decoded, err := base64.URLEncoding.DecodeString(payload)
|
|
if err != nil {
|
|
// Try RawURLEncoding
|
|
decoded, err = base64.RawURLEncoding.DecodeString(parts[1])
|
|
if err != nil {
|
|
return
|
|
}
|
|
}
|
|
|
|
var claims map[string]interface{}
|
|
if err := json.Unmarshal(decoded, &claims); err != nil {
|
|
return
|
|
}
|
|
|
|
authClaim, ok := claims["https://api.openai.com/auth"]
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
authMap, ok := authClaim.(map[string]interface{})
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
if uid, ok := authMap["user_id"].(string); ok {
|
|
userID = uid
|
|
}
|
|
|
|
// Account ID is nested in organizations or accounts
|
|
if orgs, ok := authMap["organizations"].([]interface{}); ok && len(orgs) > 0 {
|
|
if org, ok := orgs[0].(map[string]interface{}); ok {
|
|
if id, ok := org["id"].(string); ok {
|
|
accountID = id
|
|
}
|
|
}
|
|
}
|
|
|
|
return
|
|
}
|