initial commit

This commit is contained in:
2026-04-07 17:41:25 +02:00
commit 1ed9bdfa55
45 changed files with 4712 additions and 0 deletions

View File

@@ -0,0 +1,503 @@
// Package auth handles ProtonVPN authentication using the SRP protocol.
package auth
import (
"bufio"
"bytes"
"crypto/tls"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"strings"
"time"
"protonvpn-wg-confgen/internal/api"
"protonvpn-wg-confgen/internal/config"
"protonvpn-wg-confgen/internal/constants"
"protonvpn-wg-confgen/pkg/timeutil"
"github.com/ProtonMail/go-srp"
"golang.org/x/term"
)
// Client handles ProtonVPN authentication
type Client struct {
config *config.Config
httpClient *http.Client
sessionStore *SessionStore
}
// NewClient creates a new authentication client
func NewClient(cfg *config.Config) *Client {
return &Client{
config: cfg,
sessionStore: NewSessionStore(),
httpClient: &http.Client{
Timeout: 30 * time.Second,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: false,
MinVersion: tls.VersionTLS12,
},
},
},
}
}
// handleSessionRefresh attempts to refresh a session and save it if successful
func (c *Client) handleSessionRefresh(savedSession *api.Session, reason string) (*api.Session, error) {
fmt.Println(reason)
refreshedSession, err := RefreshSession(c.httpClient, c.config.APIURL, savedSession)
if err != nil {
fmt.Printf("Token refresh failed: %v\n", err)
fmt.Println("Re-authenticating with password...")
fmt.Println("(Your trusted device status for MFA will be preserved)")
_ = c.sessionStore.Delete()
return nil, err
}
fmt.Println("Session refreshed successfully!")
// Check if refresh token was rotated
if savedSession.RefreshToken != refreshedSession.RefreshToken {
fmt.Println("Refresh token was rotated")
}
// Save the refreshed session
if !c.config.NoSession {
sessionDuration, _ := timeutil.ParseSessionDuration(c.config.SessionDuration)
if err := c.sessionStore.Save(refreshedSession, c.config.Username, sessionDuration); err != nil {
fmt.Printf("Warning: Failed to save refreshed session: %v\n", err)
}
}
return refreshedSession, nil
}
// tryExistingSession attempts to use an existing saved session
func (c *Client) tryExistingSession() (*api.Session, error) {
savedSession, timeUntilExpiry, err := c.sessionStore.Load(c.config.Username)
if err != nil {
fmt.Printf("Warning: Failed to load saved session: %v\n", err)
return nil, err
}
if savedSession == nil {
return nil, nil
}
// Determine what to do with the saved session
switch {
case c.config.ForceRefresh:
reason := fmt.Sprintf("Forcing session refresh (current session expires in %s)", timeutil.HumanizeDuration(timeUntilExpiry))
return c.handleSessionRefresh(savedSession, reason)
case timeUntilExpiry < time.Duration(constants.SessionRefreshDays)*24*time.Hour && timeUntilExpiry > 0:
reason := fmt.Sprintf("Session expires soon (in %s), attempting refresh...", timeutil.HumanizeDuration(timeUntilExpiry))
return c.handleSessionRefresh(savedSession, reason)
case VerifySession(c.httpClient, c.config.APIURL, savedSession):
fmt.Printf("Using saved session (expires in %s)\n", timeutil.HumanizeDuration(timeUntilExpiry))
return savedSession, nil
default:
fmt.Println("Saved session invalid, re-authenticating...")
_ = c.sessionStore.Delete()
return nil, nil
}
}
// Authenticate performs the full authentication flow
func (c *Client) Authenticate() (*api.Session, error) {
if err := c.ensureUsername(); err != nil {
return nil, err
}
// Try existing session unless clearing or disabled
if session := c.handleExistingSession(); session != nil {
return session, nil
}
if err := c.ensurePassword(); err != nil {
return nil, err
}
// Perform fresh authentication
session, err := c.performFreshAuth()
if err != nil {
return nil, err
}
// Handle session scope upgrade if needed
if err := c.upgradeSessionIfNeeded(session); err != nil {
return nil, err
}
c.saveSessionIfEnabled(session)
return session, nil
}
// handleExistingSession handles session clearing or reuse
func (c *Client) handleExistingSession() *api.Session {
if c.config.ClearSession {
fmt.Println("Clearing saved session...")
_ = c.sessionStore.Delete()
return nil
}
if c.config.NoSession {
return nil
}
session, err := c.tryExistingSession()
if err == nil && session != nil {
return session
}
return nil
}
// performFreshAuth performs SRP authentication and returns a new session
func (c *Client) performFreshAuth() (*api.Session, error) {
authInfo, err := c.getAuthInfo()
if err != nil {
return nil, fmt.Errorf("failed to get auth info: %w", err)
}
clientProofs, err := c.generateSRPProofs(authInfo)
if err != nil {
return nil, err
}
authReq := c.buildAuthRequest(authInfo, clientProofs)
// Handle 2FA if needed
if authInfo.TwoFA.Enabled == constants.EnabledTrue && authInfo.TwoFA.TOTP == constants.EnabledTrue {
code, err := c.get2FACode()
if err != nil {
return nil, err
}
authReq["TwoFactorCode"] = code
}
session, err := c.sendAuthRequest(authReq)
if err != nil {
return nil, err
}
// Verify server proof
if session.ServerProof != base64.StdEncoding.EncodeToString(clientProofs.ExpectedServerProof) {
return nil, fmt.Errorf("server proof verification failed")
}
return session, nil
}
// generateSRPProofs generates SRP client proofs for authentication
func (c *Client) generateSRPProofs(authInfo *api.AuthInfoResponse) (*srp.Proofs, error) {
auth, err := srp.NewAuth(
authInfo.Version,
c.config.Username,
[]byte(c.config.Password),
authInfo.Salt,
authInfo.Modulus,
authInfo.ServerEphemeral,
)
if err != nil {
return nil, fmt.Errorf("failed to create SRP auth: %w", err)
}
proofs, err := auth.GenerateProofs(2048)
if err != nil {
return nil, fmt.Errorf("failed to generate SRP proofs: %w", err)
}
return proofs, nil
}
// buildAuthRequest builds the authentication request payload
func (c *Client) buildAuthRequest(authInfo *api.AuthInfoResponse, proofs *srp.Proofs) map[string]interface{} {
return map[string]interface{}{
"Username": c.config.Username,
"ClientEphemeral": base64.StdEncoding.EncodeToString(proofs.ClientEphemeral),
"ClientProof": base64.StdEncoding.EncodeToString(proofs.ClientProof),
"SRPSession": authInfo.SRPSession,
"PersistentCookies": 0,
}
}
// upgradeSessionIfNeeded upgrades session with 2FA if VPN scope is missing
func (c *Client) upgradeSessionIfNeeded(session *api.Session) error {
hasVPNScope, hasTwoFactorScope := c.checkSessionScopes(session)
if hasVPNScope || !hasTwoFactorScope {
return nil
}
fmt.Println("Session lacks VPN scope - 2FA verification required to upgrade session...")
code, err := c.get2FACode()
if err != nil {
return fmt.Errorf("failed to get 2FA code: %w", err)
}
updatedScopes, err := c.submit2FA(session, code)
if err != nil {
return fmt.Errorf("2FA verification failed: %w", err)
}
session.Scopes = updatedScopes
fmt.Println("2FA verified - session upgraded with VPN scope")
return nil
}
// checkSessionScopes checks if session has VPN and twofactor scopes
func (c *Client) checkSessionScopes(session *api.Session) (hasVPN, hasTwoFactor bool) {
for _, scope := range session.Scopes {
switch scope {
case "vpn":
hasVPN = true
case "twofactor":
hasTwoFactor = true
}
}
return
}
// saveSessionIfEnabled saves the session if persistence is enabled
func (c *Client) saveSessionIfEnabled(session *api.Session) {
if c.config.NoSession {
return
}
sessionDuration, err := timeutil.ParseSessionDuration(c.config.SessionDuration)
if err != nil {
fmt.Printf("Warning: Invalid session duration, using default: %v\n", err)
sessionDuration = 0
}
if err := c.sessionStore.Save(session, c.config.Username, sessionDuration); err != nil {
fmt.Printf("Warning: Failed to save session: %v\n", err)
}
}
func (c *Client) ensureUsername() error {
if c.config.Username == "" {
fmt.Print("Username (without @protonmail.com): ")
reader := bufio.NewReader(os.Stdin)
username, err := reader.ReadString('\n')
if err != nil {
return fmt.Errorf("error reading username: %w", err)
}
c.config.Username = strings.TrimSpace(username)
if c.config.Username == "" {
return fmt.Errorf("username cannot be empty")
}
}
return nil
}
func (c *Client) ensurePassword() error {
if c.config.Password == "" {
fmt.Print("Password: ")
passwordBytes, err := term.ReadPassword(int(os.Stdin.Fd()))
fmt.Println()
if err != nil {
return fmt.Errorf("error reading password: %w", err)
}
c.config.Password = string(passwordBytes)
}
return nil
}
func (c *Client) get2FACode() (string, error) {
fmt.Print("2FA Code: ")
reader := bufio.NewReader(os.Stdin)
code, err := reader.ReadString('\n')
if err != nil {
return "", fmt.Errorf("error reading 2FA code: %w", err)
}
code = strings.TrimSpace(code)
// Validate that code is numeric (TOTP codes are 6 digits)
if code == "" {
return "", fmt.Errorf("2FA code cannot be empty")
}
for _, c := range code {
if c < '0' || c > '9' {
return "", fmt.Errorf("2FA code must be numeric (TOTP only).\n" +
"FIDO2/WebAuthn security keys are not supported.\n" +
"Please ensure you have TOTP (authenticator app) configured as your 2FA method")
}
}
return code, nil
}
func (c *Client) getAuthInfo() (*api.AuthInfoResponse, error) {
reqBody := map[string]interface{}{
"Username": c.config.Username,
"Intent": "Proton",
}
body, err := json.Marshal(reqBody)
if err != nil {
return nil, err
}
req, err := http.NewRequest(http.MethodPost, c.config.APIURL+"/core/v4/auth/info", bytes.NewBuffer(body))
if err != nil {
return nil, err
}
c.setHeaders(req)
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("HTTP error %d: %s", resp.StatusCode, string(respBody))
}
var authInfo api.AuthInfoResponse
if err := json.Unmarshal(respBody, &authInfo); err != nil {
return nil, fmt.Errorf("failed to parse auth info: %w", err)
}
if authInfo.Code != CodeSuccess {
return nil, fmt.Errorf("failed to get auth info, code: %d", authInfo.Code)
}
// Validate required fields
if authInfo.Modulus == "" {
return nil, fmt.Errorf("received empty modulus from auth info")
}
if authInfo.ServerEphemeral == "" {
return nil, fmt.Errorf("received empty server ephemeral from auth info")
}
return &authInfo, nil
}
func (c *Client) sendAuthRequest(authReq map[string]interface{}) (*api.Session, error) {
body, err := json.Marshal(authReq)
if err != nil {
return nil, err
}
req, err := http.NewRequest(http.MethodPost, c.config.APIURL+"/core/v4/auth", bytes.NewBuffer(body))
if err != nil {
return nil, err
}
c.setHeaders(req)
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("authentication HTTP error %d: %s", resp.StatusCode, string(respBody))
}
var session api.Session
if err := json.Unmarshal(respBody, &session); err != nil {
return nil, err
}
// Handle mailbox password request (2-password mode)
// Code 10013 means the account uses legacy 2-password mode which requires a separate mailbox password
// VPN doesn't need mailbox decryption, but the auth flow requires completing it
if session.Code == CodeMailboxPasswordError {
return nil, fmt.Errorf("your account uses legacy 2-password mode which is not supported.\n" +
"Please switch to single-password mode:\n" +
" 1. Go to account.proton.me\n" +
" 2. Settings → All settings → Account and password → Passwords\n" +
" 3. Switch to 'One-password mode'\n" +
"This is recommended by Proton for most users and is required for this tool")
}
if session.Code != CodeSuccess {
return nil, NewError(session.Code)
}
return &session, nil
}
func (c *Client) setHeaders(req *http.Request) {
req.Header.Set("Content-Type", "application/json")
req.Header.Set("x-pm-appversion", constants.AppVersion)
req.Header.Set("User-Agent", constants.UserAgent)
}
// submit2FA submits a 2FA code to upgrade the session with additional scopes (like VPN)
func (c *Client) submit2FA(session *api.Session, code string) ([]string, error) {
reqBody := map[string]interface{}{
"TwoFactorCode": code,
}
body, err := json.Marshal(reqBody)
if err != nil {
return nil, err
}
req, err := http.NewRequest(http.MethodPost, c.config.APIURL+"/core/v4/auth/2fa", bytes.NewBuffer(body))
if err != nil {
return nil, err
}
// Need to include auth headers for 2FA upgrade
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", session.AccessToken))
req.Header.Set("x-pm-uid", session.UID)
req.Header.Set("x-pm-appversion", constants.AppVersion)
req.Header.Set("User-Agent", constants.UserAgent)
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("2FA HTTP error %d: %s", resp.StatusCode, string(respBody))
}
// Parse response to get updated scopes
var twoFAResp struct {
Code int `json:"Code"`
Scopes []string `json:"Scopes"`
Error string `json:"Error,omitempty"`
}
if err := json.Unmarshal(respBody, &twoFAResp); err != nil {
return nil, fmt.Errorf("failed to parse 2FA response: %w", err)
}
if twoFAResp.Code != CodeSuccess {
if twoFAResp.Error != "" {
return nil, fmt.Errorf("2FA failed (code %d): %s", twoFAResp.Code, twoFAResp.Error)
}
return nil, NewError(twoFAResp.Code)
}
return twoFAResp.Scopes, nil
}

View File

@@ -0,0 +1,82 @@
package auth
import (
"errors"
"fmt"
"protonvpn-wg-confgen/internal/constants"
)
// Error codes from ProtonVPN API
// Official source: github.com/ProtonMail/protoncore_android/.../ResponseCodes.kt
// See API_REFERENCE.md for full documentation.
const (
CodeSuccess = constants.APICodeSuccess
CodeWrongPassword = 8002 // PASSWORD_WRONG: Incorrect password
CodeWrongPasswordFormat = 8004 // Password format is incorrect (observed)
CodeCaptchaRequired = 9001 // HUMAN_VERIFICATION_REQUIRED: CAPTCHA needed
Code2FARequiredForVPN = 9100 // VPN-specific: certificate endpoint requires 2FA session (not in official docs)
CodeAccountDeleted = 10002 // ACCOUNT_DELETED: Account has been deleted
CodeAccountDisabled = 10003 // ACCOUNT_DISABLED: Account has been disabled
CodeMailboxPasswordError = 10013 // Legacy 2-password mode / invalid refresh token (context-dependent)
)
// Error represents an authentication error with ProtonVPN-specific error code
type Error struct {
Code int
Message string
}
// Error implements the error interface
func (e Error) Error() string {
return e.Message
}
// NewError creates a new authentication error from an API response code
func NewError(code int) error {
message := getErrorMessage(code)
return Error{
Code: code,
Message: message,
}
}
// getErrorMessage returns a human-readable error message for a given error code
func getErrorMessage(code int) string {
switch code {
case CodeWrongPassword:
return "incorrect username or password"
case CodeWrongPasswordFormat:
return "password format is incorrect"
case CodeCaptchaRequired:
return "CAPTCHA verification required"
case Code2FARequiredForVPN:
return "2FA required for VPN operations - your session was authenticated without 2FA (device trust). Use -clear-session to force re-authentication with 2FA"
case CodeAccountDeleted:
return "account has been deleted"
case CodeAccountDisabled:
return "account has been disabled"
case CodeMailboxPasswordError:
return "account uses legacy 2-password mode - please switch to single-password mode at account.proton.me"
default:
return fmt.Sprintf("authentication failed with code: %d", code)
}
}
// IsAccountError checks if the error is an account status error (deleted or disabled)
func IsAccountError(err error) bool {
var authErr Error
if !errors.As(err, &authErr) {
return false
}
return authErr.Code == CodeAccountDeleted || authErr.Code == CodeAccountDisabled
}
// IsCaptchaError checks if the error requires CAPTCHA verification
func IsCaptchaError(err error) bool {
var authErr Error
if !errors.As(err, &authErr) {
return false
}
return authErr.Code == CodeCaptchaRequired
}

View File

@@ -0,0 +1,210 @@
package auth
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"time"
"protonvpn-wg-confgen/internal/api"
"protonvpn-wg-confgen/internal/constants"
)
// SessionStore handles persistent session storage
type SessionStore struct {
filePath string
}
// NewSessionStore creates a new session store
func NewSessionStore() *SessionStore {
homeDir, err := os.UserHomeDir()
if err != nil {
// Fallback to current directory
homeDir = "."
}
return &SessionStore{
filePath: filepath.Join(homeDir, constants.SessionFileName),
}
}
// SavedSession represents a session with metadata
type SavedSession struct {
Session *api.Session `json:"session"`
Username string `json:"username"`
SavedAt time.Time `json:"saved_at"`
ExpiresAt time.Time `json:"expires_at"`
}
// Save stores the session to disk
func (s *SessionStore) Save(session *api.Session, username string, duration time.Duration) error {
savedSession := &SavedSession{
Session: session,
Username: username,
SavedAt: time.Now(),
}
// Calculate expiration based on API response
apiExpiration := time.Now().Add(time.Duration(session.ExpiresIn) * time.Second)
if duration == 0 {
// Use the API's expiration
savedSession.ExpiresAt = apiExpiration
} else {
// Use the user-specified duration, but cap it at API expiration
userExpiration := time.Now().Add(duration)
if userExpiration.After(apiExpiration) {
savedSession.ExpiresAt = apiExpiration
} else {
savedSession.ExpiresAt = userExpiration
}
}
data, err := json.MarshalIndent(savedSession, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal session: %w", err)
}
err = os.WriteFile(s.filePath, data, constants.SessionFileMode)
if err != nil {
return fmt.Errorf("failed to write session file: %w", err)
}
return nil
}
// Load retrieves a saved session from disk
func (s *SessionStore) Load(username string) (*api.Session, time.Duration, error) {
data, err := os.ReadFile(s.filePath)
if err != nil {
if os.IsNotExist(err) {
return nil, 0, nil // No saved session
}
return nil, 0, fmt.Errorf("failed to read session file: %w", err)
}
var savedSession SavedSession
err = json.Unmarshal(data, &savedSession)
if err != nil {
return nil, 0, fmt.Errorf("failed to unmarshal session: %w", err)
}
// Check if session is for the same user
if savedSession.Username != username {
return nil, 0, nil
}
// Check if session has expired
now := time.Now()
if now.After(savedSession.ExpiresAt) {
// Delete expired session
_ = s.Delete()
return nil, 0, nil
}
// Calculate time until expiration
timeUntilExpiry := savedSession.ExpiresAt.Sub(now)
return savedSession.Session, timeUntilExpiry, nil
}
// Delete removes the saved session
func (s *SessionStore) Delete() error {
err := os.Remove(s.filePath)
if err != nil && !os.IsNotExist(err) {
return fmt.Errorf("failed to delete session file: %w", err)
}
return nil
}
// GetPath returns the session file path
func (s *SessionStore) GetPath() string {
return s.filePath
}
// RefreshSession attempts to refresh the session using the refresh token.
// It returns a new session with updated tokens if successful.
func RefreshSession(httpClient *http.Client, apiURL string, oldSession *api.Session) (*api.Session, error) {
// Based on proton-python-client/proton/api.py refresh() method
reqBody := map[string]interface{}{
"ResponseType": "token",
"GrantType": "refresh_token",
"RefreshToken": oldSession.RefreshToken,
"RedirectURI": "http://protonmail.ch",
}
body, err := json.Marshal(reqBody)
if err != nil {
return nil, err
}
req, err := http.NewRequest(http.MethodPost, apiURL+"/auth/refresh", bytes.NewBuffer(body))
if err != nil {
return nil, err
}
// Set standard headers
req.Header.Set("Content-Type", "application/json")
req.Header.Set("x-pm-appversion", constants.AppVersion)
req.Header.Set("User-Agent", constants.UserAgent)
// Include auth headers for refresh
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oldSession.AccessToken))
req.Header.Set("x-pm-uid", oldSession.UID)
resp, err := httpClient.Do(req)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if resp.StatusCode == http.StatusOK {
var session api.Session
if err := json.Unmarshal(respBody, &session); err != nil {
return nil, err
}
if constants.IsSuccessCode(session.Code) {
return &session, nil
}
}
// If refresh fails, return error to trigger re-authentication
return nil, fmt.Errorf("refresh failed (status %d): %s", resp.StatusCode, string(respBody))
}
// VerifySession checks if a session is still valid by making a test API request.
func VerifySession(httpClient *http.Client, apiURL string, session *api.Session) bool {
// Make a simple request to verify the session
req, err := http.NewRequest(http.MethodGet, apiURL+"/vpn/v1/logicals", http.NoBody)
if err != nil {
return false
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", session.AccessToken))
req.Header.Set("x-pm-uid", session.UID)
req.Header.Set("x-pm-appversion", constants.AppVersion)
req.Header.Set("User-Agent", constants.UserAgent)
resp, err := httpClient.Do(req)
if err != nil {
return false
}
defer func() { _ = resp.Body.Close() }()
// If we get a 401, the session is invalid
if resp.StatusCode == http.StatusUnauthorized {
return false
}
// Any 2xx response means the session is valid
return resp.StatusCode >= 200 && resp.StatusCode < 300
}