Files
gpt-plus-gpt/pkg/auth/oauth.go
2026-03-15 20:48:19 +08:00

710 lines
22 KiB
Go

package auth
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"log"
"net/http"
"net/url"
"strings"
"time"
"gpt-plus/pkg/httpclient"
"gpt-plus/pkg/provider/email"
)
// LoginResult holds the tokens obtained from a successful OAuth login.
type LoginResult struct {
AccessToken string
RefreshToken string
IDToken string
ChatGPTAccountID string
ChatGPTUserID string
}
// Login performs the full OAuth login flow: authorize -> password verify -> token exchange.
func Login(
ctx context.Context,
client *httpclient.Client,
emailAddr, password, deviceID string,
sentinel *SentinelGenerator,
mailboxID string,
emailProvider email.EmailProvider,
) (*LoginResult, error) {
// Keep registration cookies — the verified email session helps skip OTP during login.
// Only ensure oai-did is set.
cookieURL, _ := url.Parse(oauthIssuer)
client.GetCookieJar().SetCookies(cookieURL, []*http.Cookie{
{Name: "oai-did", Value: deviceID},
})
// Generate PKCE and state
codeVerifier, codeChallenge := generatePKCE()
state := generateState()
// ===== Step 1: GET /oauth/authorize (no screen_hint) =====
authorizeParams := url.Values{
"response_type": {"code"},
"client_id": {oauthClientID},
"redirect_uri": {oauthRedirectURI},
"scope": {oauthScope},
"code_challenge": {codeChallenge},
"code_challenge_method": {"S256"},
"state": {state},
"id_token_add_organizations": {"true"},
"codex_cli_simplified_flow": {"true"},
}
authorizeURL := oauthIssuer + "/oauth/authorize?" + authorizeParams.Encode()
navH := navigateHeaders()
resp, err := client.Get(authorizeURL, navH)
if err != nil {
return nil, fmt.Errorf("step1 authorize: %w", err)
}
httpclient.ReadBody(resp)
// ===== Step 2: POST authorize/continue with email =====
sentinelToken, err := sentinel.GenerateToken(ctx, client, "authorize_continue")
if err != nil {
return nil, fmt.Errorf("step2 sentinel: %w", err)
}
headers := commonHeaders(deviceID)
headers["referer"] = oauthIssuer + "/log-in"
headers["openai-sentinel-token"] = sentinelToken
continueBody := map[string]interface{}{
"username": map[string]string{"kind": "email", "value": emailAddr},
}
resp, err = client.PostJSON(oauthIssuer+"/api/accounts/authorize/continue", continueBody, headers)
if err != nil {
return nil, fmt.Errorf("step2 authorize/continue: %w", err)
}
body, _ := httpclient.ReadBody(resp)
if resp.StatusCode != 200 {
return nil, fmt.Errorf("step2 failed (%d): %s", resp.StatusCode, string(body))
}
// ===== Step 3: POST password/verify =====
sentinelToken, err = sentinel.GenerateToken(ctx, client, "password_verify")
if err != nil {
return nil, fmt.Errorf("step3 sentinel: %w", err)
}
headers = commonHeaders(deviceID)
headers["referer"] = oauthIssuer + "/log-in/password"
headers["openai-sentinel-token"] = sentinelToken
pwBody := map[string]string{"password": password}
resp, err = client.PostJSON(oauthIssuer+"/api/accounts/password/verify", pwBody, headers)
if err != nil {
return nil, fmt.Errorf("step3 password/verify: %w", err)
}
body, _ = httpclient.ReadBody(resp)
if resp.StatusCode != 200 {
return nil, fmt.Errorf("step3 password verify failed (%d): %s", resp.StatusCode, string(body))
}
var verifyResp struct {
ContinueURL string `json:"continue_url"`
Page struct {
Type string `json:"type"`
} `json:"page"`
}
json.Unmarshal(body, &verifyResp)
continueURL := verifyResp.ContinueURL
pageType := verifyResp.Page.Type
log.Printf("[login] step3 password/verify: status=%d, page=%s, continue=%s", resp.StatusCode, pageType, continueURL)
// ===== Step 3.5: Email OTP verification (if triggered for new accounts) =====
if pageType == "email_otp_verification" || pageType == "email_otp_send" ||
strings.Contains(continueURL, "email-verification") || strings.Contains(continueURL, "email-otp") {
if mailboxID == "" || emailProvider == nil {
return nil, errors.New("email verification required but no mailbox/provider available")
}
// Trigger the OTP send explicitly (like registration does) and poll for delivery.
log.Printf("[login] step3.5 OTP required — triggering email-otp/send...")
// Snapshot existing emails BEFORE triggering the send, to skip the registration OTP.
sendHeaders := commonHeaders(deviceID)
sendHeaders["referer"] = oauthIssuer + "/email-verification"
// GET /api/accounts/email-otp/send to trigger the OTP email
_, err = client.Get(oauthIssuer+"/api/accounts/email-otp/send", sendHeaders)
if err != nil {
log.Printf("[login] step3.5 email-otp/send failed (non-fatal): %v", err)
}
otpCode, err := emailProvider.WaitForVerificationCode(ctx, mailboxID, 120*time.Second, time.Now())
if err != nil {
return nil, fmt.Errorf("step3.5 wait for otp: %w", err)
}
if otpCode == "" {
return nil, errors.New("step3.5 wait for otp: empty code returned")
}
log.Printf("[login] step3.5 got OTP code: %s", otpCode)
headers = commonHeaders(deviceID)
headers["referer"] = oauthIssuer + "/email-verification"
otpBody := map[string]string{"code": otpCode}
resp, err = client.PostJSON(oauthIssuer+"/api/accounts/email-otp/validate", otpBody, headers)
if err != nil {
return nil, fmt.Errorf("step3.5 validate otp: %w", err)
}
body, _ = httpclient.ReadBody(resp)
if resp.StatusCode != 200 {
return nil, fmt.Errorf("step3.5 otp validate failed (%d): %s", resp.StatusCode, string(body))
}
json.Unmarshal(body, &verifyResp)
continueURL = verifyResp.ContinueURL
pageType = verifyResp.Page.Type
log.Printf("[login] step3.5 otp validate: page=%s, continue=%s", pageType, continueURL)
// If about-you step needed, submit name/birthdate
if strings.Contains(continueURL, "about-you") {
firstName, lastName := generateRandomName()
birthdate := generateRandomBirthday()
headers = commonHeaders(deviceID)
headers["referer"] = oauthIssuer + "/about-you"
createBody := map[string]string{
"name": firstName + " " + lastName,
"birthdate": birthdate,
}
resp, err = client.PostJSON(oauthIssuer+"/api/accounts/create_account", createBody, headers)
if err != nil {
return nil, fmt.Errorf("step3.5 create_account: %w", err)
}
body, _ = httpclient.ReadBody(resp)
if resp.StatusCode == 200 {
var createResp struct {
ContinueURL string `json:"continue_url"`
}
json.Unmarshal(body, &createResp)
continueURL = createResp.ContinueURL
} else if resp.StatusCode == 400 && strings.Contains(string(body), "already_exists") {
continueURL = oauthIssuer + "/sign-in-with-chatgpt/codex/consent"
}
}
}
// Handle consent page type
if strings.Contains(pageType, "consent") || continueURL == "" {
continueURL = oauthIssuer + "/sign-in-with-chatgpt/codex/consent"
}
log.Printf("[login] step4 entering consent flow: page=%s, continue=%s", pageType, continueURL)
// ===== Step 4: Follow consent/workspace/organization redirects to extract code =====
authCode, err := followConsentRedirects(ctx, client, deviceID, continueURL, "")
if err != nil {
return nil, fmt.Errorf("step4 consent flow: %w", err)
}
// ===== Step 5: Exchange code for tokens =====
tokens, err := exchangeCodeForTokens(ctx, client, authCode, codeVerifier)
if err != nil {
return nil, fmt.Errorf("step5 token exchange: %w", err)
}
return tokens, nil
}
// decodeAuthSessionCookie parses the oai-client-auth-session cookie.
// Format: base64(json).timestamp.signature (Flask/itsdangerous style).
func decodeAuthSessionCookie(client *httpclient.Client) (map[string]interface{}, error) {
cookieURL, _ := url.Parse(oauthIssuer)
for _, c := range client.GetCookieJar().Cookies(cookieURL) {
if c.Name == "oai-client-auth-session" {
val := c.Value
firstPart := val
if idx := strings.Index(val, "."); idx > 0 {
firstPart = val[:idx]
}
// Add base64 padding
if pad := 4 - len(firstPart)%4; pad != 4 {
firstPart += strings.Repeat("=", pad)
}
raw, err := base64.URLEncoding.DecodeString(firstPart)
if err != nil {
// Try standard base64
raw, err = base64.StdEncoding.DecodeString(firstPart)
if err != nil {
return nil, fmt.Errorf("decode auth session base64: %w", err)
}
}
var data map[string]interface{}
if err := json.Unmarshal(raw, &data); err != nil {
return nil, fmt.Errorf("parse auth session json: %w", err)
}
return data, nil
}
}
return nil, errors.New("oai-client-auth-session cookie not found")
}
// followConsentRedirects navigates through consent, workspace/select, organization/select
// until extracting the authorization code from the callback URL.
// targetWorkspaceID: if non-empty, select this specific workspace instead of the first one.
func followConsentRedirects(ctx context.Context, client *httpclient.Client, deviceID, continueURL, targetWorkspaceID string) (string, error) {
// Normalize URL
if strings.HasPrefix(continueURL, "/") {
continueURL = oauthIssuer + continueURL
}
navH := navigateHeaders()
var resp *http.Response
if targetWorkspaceID == "" {
// Default: auto-follow redirects. If we get a code directly, use it.
var err error
resp, err = client.Get(continueURL, navH)
if err != nil {
return "", fmt.Errorf("get consent: %w", err)
}
httpclient.ReadBody(resp)
if resp.Request != nil {
if code := extractCodeFromURL(resp.Request.URL.String()); code != "" {
return code, nil
}
}
log.Printf("[login] consent page: status=%d, final_url=%s", resp.StatusCode, resp.Request.URL.String())
} else {
// Specific workspace: use DoNoRedirect to prevent auto-selecting the default workspace.
currentURL := continueURL
for i := 0; i < 15; i++ {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, currentURL, nil)
if err != nil {
return "", fmt.Errorf("build consent redirect request: %w", err)
}
for k, v := range navH {
req.Header.Set(k, v)
}
resp, err := client.DoNoRedirect(req)
if err != nil {
if code := extractCodeFromURL(currentURL); code != "" {
log.Printf("[login] consent redirect got default code, ignoring for workspace %s", targetWorkspaceID)
break
}
return "", fmt.Errorf("consent redirect: %w", err)
}
httpclient.ReadBody(resp)
if resp.StatusCode >= 300 && resp.StatusCode < 400 {
loc := resp.Header.Get("Location")
if loc == "" {
break
}
if !strings.HasPrefix(loc, "http") {
loc = oauthIssuer + loc
}
if code := extractCodeFromURL(loc); code != "" {
log.Printf("[login] consent redirect has default code, ignoring for workspace %s", targetWorkspaceID)
break
}
log.Printf("[login] consent redirect %d: %d → %s", i+1, resp.StatusCode, loc)
currentURL = loc
continue
}
log.Printf("[login] consent page reached: status=%d, url=%s", resp.StatusCode, currentURL)
break
}
}
// Decode oai-client-auth-session cookie to extract workspace data
sessionData, err := decodeAuthSessionCookie(client)
if err != nil {
log.Printf("[login] warning: %v", err)
cookieURL, _ := url.Parse(oauthIssuer)
for _, c := range client.GetCookieJar().Cookies(cookieURL) {
log.Printf("[login] cookie: %s (len=%d)", c.Name, len(c.Value))
}
return "", fmt.Errorf("consent flow: %w", err)
}
log.Printf("[login] decoded auth session, keys: %v", getMapKeys(sessionData))
// Extract workspace_id from session data
var workspaceID string
if workspaces, ok := sessionData["workspaces"].([]interface{}); ok && len(workspaces) > 0 {
// If a specific workspace is requested, find it by ID
if targetWorkspaceID != "" {
for _, w := range workspaces {
if ws, ok := w.(map[string]interface{}); ok {
if id, ok := ws["id"].(string); ok && id == targetWorkspaceID {
workspaceID = id
kind, _ := ws["kind"].(string)
log.Printf("[login] matched target workspace_id: %s (kind: %s)", workspaceID, kind)
break
}
}
}
if workspaceID == "" {
log.Printf("[login] target workspace %s not found in session, falling back to first", targetWorkspaceID)
}
}
// Fallback to first workspace if target not found or not specified
if workspaceID == "" {
if ws, ok := workspaces[0].(map[string]interface{}); ok {
if id, ok := ws["id"].(string); ok {
workspaceID = id
kind, _ := ws["kind"].(string)
log.Printf("[login] workspace_id: %s (kind: %s)", workspaceID, kind)
}
}
}
}
if workspaceID == "" {
// For new accounts that have no workspaces yet, try direct authorize flow
// The consent page itself might have set up the session — try following the
// authorize URL again which should now redirect with a code
log.Printf("[login] no workspaces in session, trying re-authorize...")
code, err := followRedirectsForCode(ctx, client, continueURL)
if err == nil && code != "" {
return code, nil
}
return "", errors.New("no workspace_id found in auth session cookie and re-authorize failed")
}
// POST workspace/select (no redirect following)
headers := commonHeaders(deviceID)
headers["referer"] = continueURL
wsBody := map[string]string{"workspace_id": workspaceID}
wsJSON, _ := json.Marshal(wsBody)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, oauthIssuer+"/api/accounts/workspace/select", strings.NewReader(string(wsJSON)))
if err != nil {
return "", fmt.Errorf("build workspace/select request: %w", err)
}
for k, v := range headers {
req.Header.Set(k, v)
}
resp, err = client.DoNoRedirect(req)
if err != nil {
return "", fmt.Errorf("workspace/select: %w", err)
}
wsData, _ := httpclient.ReadBody(resp)
log.Printf("[login] workspace/select: status=%d", resp.StatusCode)
// Check for redirect with code
if resp.StatusCode >= 300 && resp.StatusCode < 400 {
loc := resp.Header.Get("Location")
if code := extractCodeFromURL(loc); code != "" {
return code, nil
}
// Follow redirect chain
if loc != "" {
if !strings.HasPrefix(loc, "http") {
loc = oauthIssuer + loc
}
code, err := followRedirectsForCode(ctx, client, loc)
if err == nil && code != "" {
return code, nil
}
}
}
// Parse workspace/select response for org data
if resp.StatusCode == 200 {
var wsResp struct {
Data struct {
Orgs []struct {
ID string `json:"id"`
Projects []struct {
ID string `json:"id"`
} `json:"projects"`
} `json:"orgs"`
} `json:"data"`
ContinueURL string `json:"continue_url"`
}
json.Unmarshal(wsData, &wsResp)
// Extract org_id and project_id
var orgID, projectID string
if len(wsResp.Data.Orgs) > 0 {
orgID = wsResp.Data.Orgs[0].ID
if len(wsResp.Data.Orgs[0].Projects) > 0 {
projectID = wsResp.Data.Orgs[0].Projects[0].ID
}
}
if orgID != "" {
log.Printf("[login] org_id: %s, project_id: %s", orgID, projectID)
// POST organization/select
orgBody := map[string]string{"org_id": orgID}
if projectID != "" {
orgBody["project_id"] = projectID
}
orgJSON, _ := json.Marshal(orgBody)
orgReq, err := http.NewRequestWithContext(ctx, http.MethodPost, oauthIssuer+"/api/accounts/organization/select", strings.NewReader(string(orgJSON)))
if err != nil {
return "", fmt.Errorf("build organization/select request: %w", err)
}
for k, v := range headers {
orgReq.Header.Set(k, v)
}
resp, err = client.DoNoRedirect(orgReq)
if err != nil {
return "", fmt.Errorf("organization/select: %w", err)
}
orgData, _ := httpclient.ReadBody(resp)
log.Printf("[login] organization/select: status=%d", resp.StatusCode)
// Check for redirect with code
if resp.StatusCode >= 300 && resp.StatusCode < 400 {
loc := resp.Header.Get("Location")
if code := extractCodeFromURL(loc); code != "" {
return code, nil
}
if loc != "" {
if !strings.HasPrefix(loc, "http") {
loc = oauthIssuer + loc
}
code, err := followRedirectsForCode(ctx, client, loc)
if err == nil && code != "" {
return code, nil
}
}
}
// Parse continue_url from response
if resp.StatusCode == 200 {
var orgResp struct {
ContinueURL string `json:"continue_url"`
}
json.Unmarshal(orgData, &orgResp)
if orgResp.ContinueURL != "" {
continueURL = orgResp.ContinueURL
}
}
} else if wsResp.ContinueURL != "" {
continueURL = wsResp.ContinueURL
}
}
// Follow the final redirect chain to extract the authorization code
if strings.HasPrefix(continueURL, "/") {
continueURL = oauthIssuer + continueURL
}
code, err := followRedirectsForCode(ctx, client, continueURL)
if err != nil {
return "", err
}
if code == "" {
return "", errors.New("could not extract authorization code from redirect chain")
}
return code, nil
}
// getMapKeys returns the keys of a map for logging.
func getMapKeys(m map[string]interface{}) []string {
keys := make([]string, 0, len(m))
for k := range m {
keys = append(keys, k)
}
return keys
}
// followRedirectsForCode follows HTTP redirects manually to capture the code from the callback URL.
func followRedirectsForCode(ctx context.Context, client *httpclient.Client, startURL string) (string, error) {
currentURL := startURL
navH := navigateHeaders()
for i := 0; i < 20; i++ { // max 20 redirects
// Before making the request, check if the current URL itself contains the code
// (e.g. localhost callback that we can't actually connect to)
if code := extractCodeFromURL(currentURL); code != "" {
log.Printf("[login] extracted code from URL before request: %s", currentURL)
return code, nil
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, currentURL, nil)
if err != nil {
return "", fmt.Errorf("build redirect request: %w", err)
}
for k, v := range navH {
req.Header.Set(k, v)
}
resp, err := client.DoNoRedirect(req)
if err != nil {
// Connection error to localhost callback — extract code from the URL we were trying to reach
if code := extractCodeFromURL(currentURL); code != "" {
log.Printf("[login] extracted code from failed redirect URL: %s", currentURL)
return code, nil
}
return "", fmt.Errorf("redirect request to %s: %w", currentURL, err)
}
httpclient.ReadBody(resp)
if resp.StatusCode >= 300 && resp.StatusCode < 400 {
loc := resp.Header.Get("Location")
if loc == "" {
return "", errors.New("redirect with no Location header")
}
if !strings.HasPrefix(loc, "http") {
loc = oauthIssuer + loc
}
log.Printf("[login] redirect %d: %d → %s", i+1, resp.StatusCode, loc)
if code := extractCodeFromURL(loc); code != "" {
return code, nil
}
currentURL = loc
continue
}
log.Printf("[login] redirect chain ended: status=%d, url=%s", resp.StatusCode, currentURL)
// Check the final URL for code
if resp.Request != nil {
if code := extractCodeFromURL(resp.Request.URL.String()); code != "" {
return code, nil
}
}
break
}
return "", nil
}
// exchangeCodeForTokens exchanges an authorization code for OAuth tokens.
func exchangeCodeForTokens(ctx context.Context, client *httpclient.Client, code, codeVerifier string) (*LoginResult, error) {
tokenURL := oauthIssuer + "/oauth/token"
values := url.Values{
"grant_type": {"authorization_code"},
"code": {code},
"redirect_uri": {oauthRedirectURI},
"client_id": {oauthClientID},
"code_verifier": {codeVerifier},
}
headers := map[string]string{
"Content-Type": "application/x-www-form-urlencoded",
}
var tokenResp struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
IDToken string `json:"id_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
}
var lastErr error
for attempt := 0; attempt < 2; attempt++ {
resp, err := client.PostForm(tokenURL, values, headers)
if err != nil {
lastErr = fmt.Errorf("token exchange: %w", err)
time.Sleep(1 * time.Second)
continue
}
body, _ := httpclient.ReadBody(resp)
if resp.StatusCode != 200 {
lastErr = fmt.Errorf("token exchange failed (%d): %s", resp.StatusCode, string(body))
time.Sleep(1 * time.Second)
continue
}
if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, fmt.Errorf("parse token response: %w", err)
}
lastErr = nil
break
}
if lastErr != nil {
return nil, lastErr
}
// Extract account/user IDs from JWT
accountID, userID := extractIDsFromJWT(tokenResp.AccessToken)
return &LoginResult{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
IDToken: tokenResp.IDToken,
ChatGPTAccountID: accountID,
ChatGPTUserID: userID,
}, nil
}
// extractCodeFromURL extracts the "code" query parameter from a URL.
func extractCodeFromURL(rawURL string) string {
if !strings.Contains(rawURL, "code=") {
return ""
}
parsed, err := url.Parse(rawURL)
if err != nil {
return ""
}
return parsed.Query().Get("code")
}
// extractIDsFromJWT decodes the JWT payload and extracts chatgpt account/user IDs
// from the "https://api.openai.com/auth" claim.
func extractIDsFromJWT(token string) (accountID, userID string) {
parts := strings.SplitN(token, ".", 3)
if len(parts) < 2 {
return
}
payload := parts[1]
// Add padding if needed
switch len(payload) % 4 {
case 2:
payload += "=="
case 3:
payload += "="
}
decoded, err := base64.URLEncoding.DecodeString(payload)
if err != nil {
// Try RawURLEncoding
decoded, err = base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return
}
}
var claims map[string]interface{}
if err := json.Unmarshal(decoded, &claims); err != nil {
return
}
authClaim, ok := claims["https://api.openai.com/auth"]
if !ok {
return
}
authMap, ok := authClaim.(map[string]interface{})
if !ok {
return
}
if uid, ok := authMap["user_id"].(string); ok {
userID = uid
}
// Account ID is nested in organizations or accounts
if orgs, ok := authMap["organizations"].([]interface{}); ok && len(orgs) > 0 {
if org, ok := orgs[0].(map[string]interface{}); ok {
if id, ok := org["id"].(string); ok {
accountID = id
}
}
}
return
}