initial commit
This commit is contained in:
503
protonvpn-wg-confgen/internal/auth/auth.go
Normal file
503
protonvpn-wg-confgen/internal/auth/auth.go
Normal file
@@ -0,0 +1,503 @@
|
||||
// Package auth handles ProtonVPN authentication using the SRP protocol.
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"protonvpn-wg-confgen/internal/api"
|
||||
"protonvpn-wg-confgen/internal/config"
|
||||
"protonvpn-wg-confgen/internal/constants"
|
||||
"protonvpn-wg-confgen/pkg/timeutil"
|
||||
|
||||
"github.com/ProtonMail/go-srp"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
// Client handles ProtonVPN authentication
|
||||
type Client struct {
|
||||
config *config.Config
|
||||
httpClient *http.Client
|
||||
sessionStore *SessionStore
|
||||
}
|
||||
|
||||
// NewClient creates a new authentication client
|
||||
func NewClient(cfg *config.Config) *Client {
|
||||
return &Client{
|
||||
config: cfg,
|
||||
sessionStore: NewSessionStore(),
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: false,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// handleSessionRefresh attempts to refresh a session and save it if successful
|
||||
func (c *Client) handleSessionRefresh(savedSession *api.Session, reason string) (*api.Session, error) {
|
||||
fmt.Println(reason)
|
||||
refreshedSession, err := RefreshSession(c.httpClient, c.config.APIURL, savedSession)
|
||||
if err != nil {
|
||||
fmt.Printf("Token refresh failed: %v\n", err)
|
||||
fmt.Println("Re-authenticating with password...")
|
||||
fmt.Println("(Your trusted device status for MFA will be preserved)")
|
||||
_ = c.sessionStore.Delete()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fmt.Println("Session refreshed successfully!")
|
||||
// Check if refresh token was rotated
|
||||
if savedSession.RefreshToken != refreshedSession.RefreshToken {
|
||||
fmt.Println("Refresh token was rotated")
|
||||
}
|
||||
|
||||
// Save the refreshed session
|
||||
if !c.config.NoSession {
|
||||
sessionDuration, _ := timeutil.ParseSessionDuration(c.config.SessionDuration)
|
||||
if err := c.sessionStore.Save(refreshedSession, c.config.Username, sessionDuration); err != nil {
|
||||
fmt.Printf("Warning: Failed to save refreshed session: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
return refreshedSession, nil
|
||||
}
|
||||
|
||||
// tryExistingSession attempts to use an existing saved session
|
||||
func (c *Client) tryExistingSession() (*api.Session, error) {
|
||||
savedSession, timeUntilExpiry, err := c.sessionStore.Load(c.config.Username)
|
||||
if err != nil {
|
||||
fmt.Printf("Warning: Failed to load saved session: %v\n", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if savedSession == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Determine what to do with the saved session
|
||||
switch {
|
||||
case c.config.ForceRefresh:
|
||||
reason := fmt.Sprintf("Forcing session refresh (current session expires in %s)", timeutil.HumanizeDuration(timeUntilExpiry))
|
||||
return c.handleSessionRefresh(savedSession, reason)
|
||||
|
||||
case timeUntilExpiry < time.Duration(constants.SessionRefreshDays)*24*time.Hour && timeUntilExpiry > 0:
|
||||
reason := fmt.Sprintf("Session expires soon (in %s), attempting refresh...", timeutil.HumanizeDuration(timeUntilExpiry))
|
||||
return c.handleSessionRefresh(savedSession, reason)
|
||||
|
||||
case VerifySession(c.httpClient, c.config.APIURL, savedSession):
|
||||
fmt.Printf("Using saved session (expires in %s)\n", timeutil.HumanizeDuration(timeUntilExpiry))
|
||||
return savedSession, nil
|
||||
|
||||
default:
|
||||
fmt.Println("Saved session invalid, re-authenticating...")
|
||||
_ = c.sessionStore.Delete()
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Authenticate performs the full authentication flow
|
||||
func (c *Client) Authenticate() (*api.Session, error) {
|
||||
if err := c.ensureUsername(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Try existing session unless clearing or disabled
|
||||
if session := c.handleExistingSession(); session != nil {
|
||||
return session, nil
|
||||
}
|
||||
|
||||
if err := c.ensurePassword(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Perform fresh authentication
|
||||
session, err := c.performFreshAuth()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Handle session scope upgrade if needed
|
||||
if err := c.upgradeSessionIfNeeded(session); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.saveSessionIfEnabled(session)
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// handleExistingSession handles session clearing or reuse
|
||||
func (c *Client) handleExistingSession() *api.Session {
|
||||
if c.config.ClearSession {
|
||||
fmt.Println("Clearing saved session...")
|
||||
_ = c.sessionStore.Delete()
|
||||
return nil
|
||||
}
|
||||
|
||||
if c.config.NoSession {
|
||||
return nil
|
||||
}
|
||||
|
||||
session, err := c.tryExistingSession()
|
||||
if err == nil && session != nil {
|
||||
return session
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// performFreshAuth performs SRP authentication and returns a new session
|
||||
func (c *Client) performFreshAuth() (*api.Session, error) {
|
||||
authInfo, err := c.getAuthInfo()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get auth info: %w", err)
|
||||
}
|
||||
|
||||
clientProofs, err := c.generateSRPProofs(authInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
authReq := c.buildAuthRequest(authInfo, clientProofs)
|
||||
|
||||
// Handle 2FA if needed
|
||||
if authInfo.TwoFA.Enabled == constants.EnabledTrue && authInfo.TwoFA.TOTP == constants.EnabledTrue {
|
||||
code, err := c.get2FACode()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
authReq["TwoFactorCode"] = code
|
||||
}
|
||||
|
||||
session, err := c.sendAuthRequest(authReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Verify server proof
|
||||
if session.ServerProof != base64.StdEncoding.EncodeToString(clientProofs.ExpectedServerProof) {
|
||||
return nil, fmt.Errorf("server proof verification failed")
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// generateSRPProofs generates SRP client proofs for authentication
|
||||
func (c *Client) generateSRPProofs(authInfo *api.AuthInfoResponse) (*srp.Proofs, error) {
|
||||
auth, err := srp.NewAuth(
|
||||
authInfo.Version,
|
||||
c.config.Username,
|
||||
[]byte(c.config.Password),
|
||||
authInfo.Salt,
|
||||
authInfo.Modulus,
|
||||
authInfo.ServerEphemeral,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create SRP auth: %w", err)
|
||||
}
|
||||
|
||||
proofs, err := auth.GenerateProofs(2048)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate SRP proofs: %w", err)
|
||||
}
|
||||
return proofs, nil
|
||||
}
|
||||
|
||||
// buildAuthRequest builds the authentication request payload
|
||||
func (c *Client) buildAuthRequest(authInfo *api.AuthInfoResponse, proofs *srp.Proofs) map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"Username": c.config.Username,
|
||||
"ClientEphemeral": base64.StdEncoding.EncodeToString(proofs.ClientEphemeral),
|
||||
"ClientProof": base64.StdEncoding.EncodeToString(proofs.ClientProof),
|
||||
"SRPSession": authInfo.SRPSession,
|
||||
"PersistentCookies": 0,
|
||||
}
|
||||
}
|
||||
|
||||
// upgradeSessionIfNeeded upgrades session with 2FA if VPN scope is missing
|
||||
func (c *Client) upgradeSessionIfNeeded(session *api.Session) error {
|
||||
hasVPNScope, hasTwoFactorScope := c.checkSessionScopes(session)
|
||||
|
||||
if hasVPNScope || !hasTwoFactorScope {
|
||||
return nil
|
||||
}
|
||||
|
||||
fmt.Println("Session lacks VPN scope - 2FA verification required to upgrade session...")
|
||||
code, err := c.get2FACode()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get 2FA code: %w", err)
|
||||
}
|
||||
|
||||
updatedScopes, err := c.submit2FA(session, code)
|
||||
if err != nil {
|
||||
return fmt.Errorf("2FA verification failed: %w", err)
|
||||
}
|
||||
session.Scopes = updatedScopes
|
||||
fmt.Println("2FA verified - session upgraded with VPN scope")
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkSessionScopes checks if session has VPN and twofactor scopes
|
||||
func (c *Client) checkSessionScopes(session *api.Session) (hasVPN, hasTwoFactor bool) {
|
||||
for _, scope := range session.Scopes {
|
||||
switch scope {
|
||||
case "vpn":
|
||||
hasVPN = true
|
||||
case "twofactor":
|
||||
hasTwoFactor = true
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// saveSessionIfEnabled saves the session if persistence is enabled
|
||||
func (c *Client) saveSessionIfEnabled(session *api.Session) {
|
||||
if c.config.NoSession {
|
||||
return
|
||||
}
|
||||
|
||||
sessionDuration, err := timeutil.ParseSessionDuration(c.config.SessionDuration)
|
||||
if err != nil {
|
||||
fmt.Printf("Warning: Invalid session duration, using default: %v\n", err)
|
||||
sessionDuration = 0
|
||||
}
|
||||
|
||||
if err := c.sessionStore.Save(session, c.config.Username, sessionDuration); err != nil {
|
||||
fmt.Printf("Warning: Failed to save session: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) ensureUsername() error {
|
||||
if c.config.Username == "" {
|
||||
fmt.Print("Username (without @protonmail.com): ")
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
username, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return fmt.Errorf("error reading username: %w", err)
|
||||
}
|
||||
c.config.Username = strings.TrimSpace(username)
|
||||
if c.config.Username == "" {
|
||||
return fmt.Errorf("username cannot be empty")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) ensurePassword() error {
|
||||
if c.config.Password == "" {
|
||||
fmt.Print("Password: ")
|
||||
passwordBytes, err := term.ReadPassword(int(os.Stdin.Fd()))
|
||||
fmt.Println()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error reading password: %w", err)
|
||||
}
|
||||
c.config.Password = string(passwordBytes)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) get2FACode() (string, error) {
|
||||
fmt.Print("2FA Code: ")
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
code, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error reading 2FA code: %w", err)
|
||||
}
|
||||
code = strings.TrimSpace(code)
|
||||
|
||||
// Validate that code is numeric (TOTP codes are 6 digits)
|
||||
if code == "" {
|
||||
return "", fmt.Errorf("2FA code cannot be empty")
|
||||
}
|
||||
|
||||
for _, c := range code {
|
||||
if c < '0' || c > '9' {
|
||||
return "", fmt.Errorf("2FA code must be numeric (TOTP only).\n" +
|
||||
"FIDO2/WebAuthn security keys are not supported.\n" +
|
||||
"Please ensure you have TOTP (authenticator app) configured as your 2FA method")
|
||||
}
|
||||
}
|
||||
|
||||
return code, nil
|
||||
}
|
||||
|
||||
func (c *Client) getAuthInfo() (*api.AuthInfoResponse, error) {
|
||||
reqBody := map[string]interface{}{
|
||||
"Username": c.config.Username,
|
||||
"Intent": "Proton",
|
||||
}
|
||||
|
||||
body, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, c.config.APIURL+"/core/v4/auth/info", bytes.NewBuffer(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.setHeaders(req)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("HTTP error %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var authInfo api.AuthInfoResponse
|
||||
if err := json.Unmarshal(respBody, &authInfo); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse auth info: %w", err)
|
||||
}
|
||||
|
||||
if authInfo.Code != CodeSuccess {
|
||||
return nil, fmt.Errorf("failed to get auth info, code: %d", authInfo.Code)
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
if authInfo.Modulus == "" {
|
||||
return nil, fmt.Errorf("received empty modulus from auth info")
|
||||
}
|
||||
if authInfo.ServerEphemeral == "" {
|
||||
return nil, fmt.Errorf("received empty server ephemeral from auth info")
|
||||
}
|
||||
|
||||
return &authInfo, nil
|
||||
}
|
||||
|
||||
func (c *Client) sendAuthRequest(authReq map[string]interface{}) (*api.Session, error) {
|
||||
body, err := json.Marshal(authReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, c.config.APIURL+"/core/v4/auth", bytes.NewBuffer(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.setHeaders(req)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("authentication HTTP error %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var session api.Session
|
||||
if err := json.Unmarshal(respBody, &session); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Handle mailbox password request (2-password mode)
|
||||
// Code 10013 means the account uses legacy 2-password mode which requires a separate mailbox password
|
||||
// VPN doesn't need mailbox decryption, but the auth flow requires completing it
|
||||
if session.Code == CodeMailboxPasswordError {
|
||||
return nil, fmt.Errorf("your account uses legacy 2-password mode which is not supported.\n" +
|
||||
"Please switch to single-password mode:\n" +
|
||||
" 1. Go to account.proton.me\n" +
|
||||
" 2. Settings → All settings → Account and password → Passwords\n" +
|
||||
" 3. Switch to 'One-password mode'\n" +
|
||||
"This is recommended by Proton for most users and is required for this tool")
|
||||
}
|
||||
|
||||
if session.Code != CodeSuccess {
|
||||
return nil, NewError(session.Code)
|
||||
}
|
||||
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
func (c *Client) setHeaders(req *http.Request) {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("x-pm-appversion", constants.AppVersion)
|
||||
req.Header.Set("User-Agent", constants.UserAgent)
|
||||
}
|
||||
|
||||
// submit2FA submits a 2FA code to upgrade the session with additional scopes (like VPN)
|
||||
func (c *Client) submit2FA(session *api.Session, code string) ([]string, error) {
|
||||
reqBody := map[string]interface{}{
|
||||
"TwoFactorCode": code,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, c.config.APIURL+"/core/v4/auth/2fa", bytes.NewBuffer(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Need to include auth headers for 2FA upgrade
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", session.AccessToken))
|
||||
req.Header.Set("x-pm-uid", session.UID)
|
||||
req.Header.Set("x-pm-appversion", constants.AppVersion)
|
||||
req.Header.Set("User-Agent", constants.UserAgent)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("2FA HTTP error %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
// Parse response to get updated scopes
|
||||
var twoFAResp struct {
|
||||
Code int `json:"Code"`
|
||||
Scopes []string `json:"Scopes"`
|
||||
Error string `json:"Error,omitempty"`
|
||||
}
|
||||
if err := json.Unmarshal(respBody, &twoFAResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse 2FA response: %w", err)
|
||||
}
|
||||
|
||||
if twoFAResp.Code != CodeSuccess {
|
||||
if twoFAResp.Error != "" {
|
||||
return nil, fmt.Errorf("2FA failed (code %d): %s", twoFAResp.Code, twoFAResp.Error)
|
||||
}
|
||||
return nil, NewError(twoFAResp.Code)
|
||||
}
|
||||
|
||||
return twoFAResp.Scopes, nil
|
||||
}
|
||||
Reference in New Issue
Block a user