initial commit
This commit is contained in:
165
protonvpn-wg-confgen/internal/api/types.go
Normal file
165
protonvpn-wg-confgen/internal/api/types.go
Normal file
@@ -0,0 +1,165 @@
|
||||
// Package api defines the data structures for ProtonVPN API responses.
|
||||
package api
|
||||
|
||||
// AuthInfoResponse represents the response from the auth info endpoint
|
||||
type AuthInfoResponse struct {
|
||||
Code int `json:"Code"`
|
||||
Version int `json:"Version"`
|
||||
Modulus string `json:"Modulus"`
|
||||
ServerEphemeral string `json:"ServerEphemeral"`
|
||||
Salt string `json:"Salt"`
|
||||
SRPSession string `json:"SRPSession"`
|
||||
TwoFA struct {
|
||||
Enabled int `json:"Enabled"`
|
||||
TOTP int `json:"TOTP"`
|
||||
} `json:"2FA"`
|
||||
}
|
||||
|
||||
// AuthRequest represents the authentication request payload
|
||||
type AuthRequest struct {
|
||||
Username string `json:"Username"`
|
||||
ClientEphemeral string `json:"ClientEphemeral"`
|
||||
ClientProof string `json:"ClientProof"`
|
||||
SRPSession string `json:"SRPSession"`
|
||||
TwoFactorCode string `json:"TwoFactorCode,omitempty"`
|
||||
}
|
||||
|
||||
// Session represents a ProtonVPN session
|
||||
type Session struct {
|
||||
Code int `json:"Code"`
|
||||
AccessToken string `json:"AccessToken"`
|
||||
RefreshToken string `json:"RefreshToken"`
|
||||
TokenType string `json:"TokenType"`
|
||||
Scopes []string `json:"Scopes"`
|
||||
UID string `json:"UID"`
|
||||
UserID string `json:"UserID"`
|
||||
EventID string `json:"EventID"`
|
||||
ServerProof string `json:"ServerProof"`
|
||||
PasswordMode int `json:"PasswordMode"`
|
||||
ExpiresIn int `json:"ExpiresIn"` // Session expiration in seconds
|
||||
TwoFA struct {
|
||||
Enabled int `json:"Enabled"`
|
||||
TOTP int `json:"TOTP"`
|
||||
} `json:"2FA"`
|
||||
}
|
||||
|
||||
// VPNInfo represents VPN certificate information
|
||||
type VPNInfo struct {
|
||||
Code int `json:"Code"`
|
||||
Error string `json:"Error,omitempty"`
|
||||
SerialNumber string `json:"SerialNumber"`
|
||||
ClientKeyFingerprint string `json:"ClientKeyFingerprint"`
|
||||
ClientKey string `json:"ClientKey"`
|
||||
Certificate string `json:"Certificate"`
|
||||
ExpirationTime int64 `json:"ExpirationTime"`
|
||||
RefreshTime int64 `json:"RefreshTime"`
|
||||
Mode string `json:"Mode"`
|
||||
DeviceName string `json:"DeviceName"`
|
||||
ServerPublicKeyMode string `json:"ServerPublicKeyMode"`
|
||||
ServerPublicKey string `json:"ServerPublicKey"`
|
||||
Features struct {
|
||||
Bouncing bool `json:"bouncing"`
|
||||
ModerateNAT bool `json:"moderate-nat"`
|
||||
NetshieldLevel int `json:"netshield-level"`
|
||||
PortForwarding bool `json:"port-forwarding"`
|
||||
VPNAccelerator bool `json:"vpn-accelerator"`
|
||||
} `json:"Features"`
|
||||
}
|
||||
|
||||
// LogicalServer represents a ProtonVPN logical server
|
||||
type LogicalServer struct {
|
||||
ID string `json:"ID"`
|
||||
Name string `json:"Name"`
|
||||
EntryCountry string `json:"EntryCountry"`
|
||||
ExitCountry string `json:"ExitCountry"`
|
||||
Domain string `json:"Domain"`
|
||||
Tier int `json:"Tier"`
|
||||
Features int `json:"Features"`
|
||||
Region string `json:"Region"`
|
||||
City string `json:"City"`
|
||||
Score float64 `json:"Score"`
|
||||
Load int `json:"Load"`
|
||||
Status int `json:"Status"`
|
||||
Servers []PhysicalServer `json:"Servers"`
|
||||
HostCountry string `json:"HostCountry"`
|
||||
Location struct {
|
||||
Lat float64 `json:"Lat"`
|
||||
Long float64 `json:"Long"`
|
||||
} `json:"Location"`
|
||||
}
|
||||
|
||||
// PhysicalServer represents a physical VPN server
|
||||
type PhysicalServer struct {
|
||||
ID string `json:"ID"`
|
||||
EntryIP string `json:"EntryIP"`
|
||||
ExitIP string `json:"ExitIP"`
|
||||
Domain string `json:"Domain"`
|
||||
Status int `json:"Status"`
|
||||
Label string `json:"Label"`
|
||||
X25519PublicKey string `json:"X25519PublicKey"`
|
||||
Generation int `json:"Generation"`
|
||||
ServicesDownReason string `json:"ServicesDownReason"`
|
||||
}
|
||||
|
||||
// LogicalsResponse represents the response from the logicals endpoint
|
||||
type LogicalsResponse struct {
|
||||
Code int `json:"Code"`
|
||||
LogicalServers []LogicalServer `json:"LogicalServers"`
|
||||
}
|
||||
|
||||
// Server feature constants
|
||||
const (
|
||||
FeatureSecureCore = 1
|
||||
FeatureTor = 2
|
||||
FeatureP2P = 4
|
||||
FeatureStreaming = 8
|
||||
FeatureIPv6 = 16
|
||||
)
|
||||
|
||||
// Server tier constants
|
||||
const (
|
||||
TierFree = 0
|
||||
TierPlus = 2
|
||||
TierPM = 3
|
||||
)
|
||||
|
||||
// Password mode constants
|
||||
const (
|
||||
PasswordModeSingle = 1
|
||||
PasswordModeTwo = 2
|
||||
)
|
||||
|
||||
// GetTierName returns a human-readable name for the server tier
|
||||
func GetTierName(tier int) string {
|
||||
switch tier {
|
||||
case TierFree:
|
||||
return "Free"
|
||||
case TierPlus:
|
||||
return "Plus"
|
||||
case TierPM:
|
||||
return "ProtonMail"
|
||||
default:
|
||||
return "Unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// GetFeatureNames returns a list of enabled features for a server
|
||||
func GetFeatureNames(features int) []string {
|
||||
var result []string
|
||||
if features&FeatureSecureCore != 0 {
|
||||
result = append(result, "SecureCore")
|
||||
}
|
||||
if features&FeatureTor != 0 {
|
||||
result = append(result, "Tor")
|
||||
}
|
||||
if features&FeatureP2P != 0 {
|
||||
result = append(result, "P2P")
|
||||
}
|
||||
if features&FeatureStreaming != 0 {
|
||||
result = append(result, "Streaming")
|
||||
}
|
||||
if features&FeatureIPv6 != 0 {
|
||||
result = append(result, "IPv6")
|
||||
}
|
||||
return result
|
||||
}
|
||||
503
protonvpn-wg-confgen/internal/auth/auth.go
Normal file
503
protonvpn-wg-confgen/internal/auth/auth.go
Normal file
@@ -0,0 +1,503 @@
|
||||
// Package auth handles ProtonVPN authentication using the SRP protocol.
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"protonvpn-wg-confgen/internal/api"
|
||||
"protonvpn-wg-confgen/internal/config"
|
||||
"protonvpn-wg-confgen/internal/constants"
|
||||
"protonvpn-wg-confgen/pkg/timeutil"
|
||||
|
||||
"github.com/ProtonMail/go-srp"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
// Client handles ProtonVPN authentication
|
||||
type Client struct {
|
||||
config *config.Config
|
||||
httpClient *http.Client
|
||||
sessionStore *SessionStore
|
||||
}
|
||||
|
||||
// NewClient creates a new authentication client
|
||||
func NewClient(cfg *config.Config) *Client {
|
||||
return &Client{
|
||||
config: cfg,
|
||||
sessionStore: NewSessionStore(),
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: false,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// handleSessionRefresh attempts to refresh a session and save it if successful
|
||||
func (c *Client) handleSessionRefresh(savedSession *api.Session, reason string) (*api.Session, error) {
|
||||
fmt.Println(reason)
|
||||
refreshedSession, err := RefreshSession(c.httpClient, c.config.APIURL, savedSession)
|
||||
if err != nil {
|
||||
fmt.Printf("Token refresh failed: %v\n", err)
|
||||
fmt.Println("Re-authenticating with password...")
|
||||
fmt.Println("(Your trusted device status for MFA will be preserved)")
|
||||
_ = c.sessionStore.Delete()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fmt.Println("Session refreshed successfully!")
|
||||
// Check if refresh token was rotated
|
||||
if savedSession.RefreshToken != refreshedSession.RefreshToken {
|
||||
fmt.Println("Refresh token was rotated")
|
||||
}
|
||||
|
||||
// Save the refreshed session
|
||||
if !c.config.NoSession {
|
||||
sessionDuration, _ := timeutil.ParseSessionDuration(c.config.SessionDuration)
|
||||
if err := c.sessionStore.Save(refreshedSession, c.config.Username, sessionDuration); err != nil {
|
||||
fmt.Printf("Warning: Failed to save refreshed session: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
return refreshedSession, nil
|
||||
}
|
||||
|
||||
// tryExistingSession attempts to use an existing saved session
|
||||
func (c *Client) tryExistingSession() (*api.Session, error) {
|
||||
savedSession, timeUntilExpiry, err := c.sessionStore.Load(c.config.Username)
|
||||
if err != nil {
|
||||
fmt.Printf("Warning: Failed to load saved session: %v\n", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if savedSession == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Determine what to do with the saved session
|
||||
switch {
|
||||
case c.config.ForceRefresh:
|
||||
reason := fmt.Sprintf("Forcing session refresh (current session expires in %s)", timeutil.HumanizeDuration(timeUntilExpiry))
|
||||
return c.handleSessionRefresh(savedSession, reason)
|
||||
|
||||
case timeUntilExpiry < time.Duration(constants.SessionRefreshDays)*24*time.Hour && timeUntilExpiry > 0:
|
||||
reason := fmt.Sprintf("Session expires soon (in %s), attempting refresh...", timeutil.HumanizeDuration(timeUntilExpiry))
|
||||
return c.handleSessionRefresh(savedSession, reason)
|
||||
|
||||
case VerifySession(c.httpClient, c.config.APIURL, savedSession):
|
||||
fmt.Printf("Using saved session (expires in %s)\n", timeutil.HumanizeDuration(timeUntilExpiry))
|
||||
return savedSession, nil
|
||||
|
||||
default:
|
||||
fmt.Println("Saved session invalid, re-authenticating...")
|
||||
_ = c.sessionStore.Delete()
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Authenticate performs the full authentication flow
|
||||
func (c *Client) Authenticate() (*api.Session, error) {
|
||||
if err := c.ensureUsername(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Try existing session unless clearing or disabled
|
||||
if session := c.handleExistingSession(); session != nil {
|
||||
return session, nil
|
||||
}
|
||||
|
||||
if err := c.ensurePassword(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Perform fresh authentication
|
||||
session, err := c.performFreshAuth()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Handle session scope upgrade if needed
|
||||
if err := c.upgradeSessionIfNeeded(session); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.saveSessionIfEnabled(session)
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// handleExistingSession handles session clearing or reuse
|
||||
func (c *Client) handleExistingSession() *api.Session {
|
||||
if c.config.ClearSession {
|
||||
fmt.Println("Clearing saved session...")
|
||||
_ = c.sessionStore.Delete()
|
||||
return nil
|
||||
}
|
||||
|
||||
if c.config.NoSession {
|
||||
return nil
|
||||
}
|
||||
|
||||
session, err := c.tryExistingSession()
|
||||
if err == nil && session != nil {
|
||||
return session
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// performFreshAuth performs SRP authentication and returns a new session
|
||||
func (c *Client) performFreshAuth() (*api.Session, error) {
|
||||
authInfo, err := c.getAuthInfo()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get auth info: %w", err)
|
||||
}
|
||||
|
||||
clientProofs, err := c.generateSRPProofs(authInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
authReq := c.buildAuthRequest(authInfo, clientProofs)
|
||||
|
||||
// Handle 2FA if needed
|
||||
if authInfo.TwoFA.Enabled == constants.EnabledTrue && authInfo.TwoFA.TOTP == constants.EnabledTrue {
|
||||
code, err := c.get2FACode()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
authReq["TwoFactorCode"] = code
|
||||
}
|
||||
|
||||
session, err := c.sendAuthRequest(authReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Verify server proof
|
||||
if session.ServerProof != base64.StdEncoding.EncodeToString(clientProofs.ExpectedServerProof) {
|
||||
return nil, fmt.Errorf("server proof verification failed")
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// generateSRPProofs generates SRP client proofs for authentication
|
||||
func (c *Client) generateSRPProofs(authInfo *api.AuthInfoResponse) (*srp.Proofs, error) {
|
||||
auth, err := srp.NewAuth(
|
||||
authInfo.Version,
|
||||
c.config.Username,
|
||||
[]byte(c.config.Password),
|
||||
authInfo.Salt,
|
||||
authInfo.Modulus,
|
||||
authInfo.ServerEphemeral,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create SRP auth: %w", err)
|
||||
}
|
||||
|
||||
proofs, err := auth.GenerateProofs(2048)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate SRP proofs: %w", err)
|
||||
}
|
||||
return proofs, nil
|
||||
}
|
||||
|
||||
// buildAuthRequest builds the authentication request payload
|
||||
func (c *Client) buildAuthRequest(authInfo *api.AuthInfoResponse, proofs *srp.Proofs) map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"Username": c.config.Username,
|
||||
"ClientEphemeral": base64.StdEncoding.EncodeToString(proofs.ClientEphemeral),
|
||||
"ClientProof": base64.StdEncoding.EncodeToString(proofs.ClientProof),
|
||||
"SRPSession": authInfo.SRPSession,
|
||||
"PersistentCookies": 0,
|
||||
}
|
||||
}
|
||||
|
||||
// upgradeSessionIfNeeded upgrades session with 2FA if VPN scope is missing
|
||||
func (c *Client) upgradeSessionIfNeeded(session *api.Session) error {
|
||||
hasVPNScope, hasTwoFactorScope := c.checkSessionScopes(session)
|
||||
|
||||
if hasVPNScope || !hasTwoFactorScope {
|
||||
return nil
|
||||
}
|
||||
|
||||
fmt.Println("Session lacks VPN scope - 2FA verification required to upgrade session...")
|
||||
code, err := c.get2FACode()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get 2FA code: %w", err)
|
||||
}
|
||||
|
||||
updatedScopes, err := c.submit2FA(session, code)
|
||||
if err != nil {
|
||||
return fmt.Errorf("2FA verification failed: %w", err)
|
||||
}
|
||||
session.Scopes = updatedScopes
|
||||
fmt.Println("2FA verified - session upgraded with VPN scope")
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkSessionScopes checks if session has VPN and twofactor scopes
|
||||
func (c *Client) checkSessionScopes(session *api.Session) (hasVPN, hasTwoFactor bool) {
|
||||
for _, scope := range session.Scopes {
|
||||
switch scope {
|
||||
case "vpn":
|
||||
hasVPN = true
|
||||
case "twofactor":
|
||||
hasTwoFactor = true
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// saveSessionIfEnabled saves the session if persistence is enabled
|
||||
func (c *Client) saveSessionIfEnabled(session *api.Session) {
|
||||
if c.config.NoSession {
|
||||
return
|
||||
}
|
||||
|
||||
sessionDuration, err := timeutil.ParseSessionDuration(c.config.SessionDuration)
|
||||
if err != nil {
|
||||
fmt.Printf("Warning: Invalid session duration, using default: %v\n", err)
|
||||
sessionDuration = 0
|
||||
}
|
||||
|
||||
if err := c.sessionStore.Save(session, c.config.Username, sessionDuration); err != nil {
|
||||
fmt.Printf("Warning: Failed to save session: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) ensureUsername() error {
|
||||
if c.config.Username == "" {
|
||||
fmt.Print("Username (without @protonmail.com): ")
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
username, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return fmt.Errorf("error reading username: %w", err)
|
||||
}
|
||||
c.config.Username = strings.TrimSpace(username)
|
||||
if c.config.Username == "" {
|
||||
return fmt.Errorf("username cannot be empty")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) ensurePassword() error {
|
||||
if c.config.Password == "" {
|
||||
fmt.Print("Password: ")
|
||||
passwordBytes, err := term.ReadPassword(int(os.Stdin.Fd()))
|
||||
fmt.Println()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error reading password: %w", err)
|
||||
}
|
||||
c.config.Password = string(passwordBytes)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) get2FACode() (string, error) {
|
||||
fmt.Print("2FA Code: ")
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
code, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error reading 2FA code: %w", err)
|
||||
}
|
||||
code = strings.TrimSpace(code)
|
||||
|
||||
// Validate that code is numeric (TOTP codes are 6 digits)
|
||||
if code == "" {
|
||||
return "", fmt.Errorf("2FA code cannot be empty")
|
||||
}
|
||||
|
||||
for _, c := range code {
|
||||
if c < '0' || c > '9' {
|
||||
return "", fmt.Errorf("2FA code must be numeric (TOTP only).\n" +
|
||||
"FIDO2/WebAuthn security keys are not supported.\n" +
|
||||
"Please ensure you have TOTP (authenticator app) configured as your 2FA method")
|
||||
}
|
||||
}
|
||||
|
||||
return code, nil
|
||||
}
|
||||
|
||||
func (c *Client) getAuthInfo() (*api.AuthInfoResponse, error) {
|
||||
reqBody := map[string]interface{}{
|
||||
"Username": c.config.Username,
|
||||
"Intent": "Proton",
|
||||
}
|
||||
|
||||
body, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, c.config.APIURL+"/core/v4/auth/info", bytes.NewBuffer(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.setHeaders(req)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("HTTP error %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var authInfo api.AuthInfoResponse
|
||||
if err := json.Unmarshal(respBody, &authInfo); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse auth info: %w", err)
|
||||
}
|
||||
|
||||
if authInfo.Code != CodeSuccess {
|
||||
return nil, fmt.Errorf("failed to get auth info, code: %d", authInfo.Code)
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
if authInfo.Modulus == "" {
|
||||
return nil, fmt.Errorf("received empty modulus from auth info")
|
||||
}
|
||||
if authInfo.ServerEphemeral == "" {
|
||||
return nil, fmt.Errorf("received empty server ephemeral from auth info")
|
||||
}
|
||||
|
||||
return &authInfo, nil
|
||||
}
|
||||
|
||||
func (c *Client) sendAuthRequest(authReq map[string]interface{}) (*api.Session, error) {
|
||||
body, err := json.Marshal(authReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, c.config.APIURL+"/core/v4/auth", bytes.NewBuffer(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.setHeaders(req)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("authentication HTTP error %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var session api.Session
|
||||
if err := json.Unmarshal(respBody, &session); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Handle mailbox password request (2-password mode)
|
||||
// Code 10013 means the account uses legacy 2-password mode which requires a separate mailbox password
|
||||
// VPN doesn't need mailbox decryption, but the auth flow requires completing it
|
||||
if session.Code == CodeMailboxPasswordError {
|
||||
return nil, fmt.Errorf("your account uses legacy 2-password mode which is not supported.\n" +
|
||||
"Please switch to single-password mode:\n" +
|
||||
" 1. Go to account.proton.me\n" +
|
||||
" 2. Settings → All settings → Account and password → Passwords\n" +
|
||||
" 3. Switch to 'One-password mode'\n" +
|
||||
"This is recommended by Proton for most users and is required for this tool")
|
||||
}
|
||||
|
||||
if session.Code != CodeSuccess {
|
||||
return nil, NewError(session.Code)
|
||||
}
|
||||
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
func (c *Client) setHeaders(req *http.Request) {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("x-pm-appversion", constants.AppVersion)
|
||||
req.Header.Set("User-Agent", constants.UserAgent)
|
||||
}
|
||||
|
||||
// submit2FA submits a 2FA code to upgrade the session with additional scopes (like VPN)
|
||||
func (c *Client) submit2FA(session *api.Session, code string) ([]string, error) {
|
||||
reqBody := map[string]interface{}{
|
||||
"TwoFactorCode": code,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, c.config.APIURL+"/core/v4/auth/2fa", bytes.NewBuffer(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Need to include auth headers for 2FA upgrade
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", session.AccessToken))
|
||||
req.Header.Set("x-pm-uid", session.UID)
|
||||
req.Header.Set("x-pm-appversion", constants.AppVersion)
|
||||
req.Header.Set("User-Agent", constants.UserAgent)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("2FA HTTP error %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
// Parse response to get updated scopes
|
||||
var twoFAResp struct {
|
||||
Code int `json:"Code"`
|
||||
Scopes []string `json:"Scopes"`
|
||||
Error string `json:"Error,omitempty"`
|
||||
}
|
||||
if err := json.Unmarshal(respBody, &twoFAResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse 2FA response: %w", err)
|
||||
}
|
||||
|
||||
if twoFAResp.Code != CodeSuccess {
|
||||
if twoFAResp.Error != "" {
|
||||
return nil, fmt.Errorf("2FA failed (code %d): %s", twoFAResp.Code, twoFAResp.Error)
|
||||
}
|
||||
return nil, NewError(twoFAResp.Code)
|
||||
}
|
||||
|
||||
return twoFAResp.Scopes, nil
|
||||
}
|
||||
82
protonvpn-wg-confgen/internal/auth/errors.go
Normal file
82
protonvpn-wg-confgen/internal/auth/errors.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"protonvpn-wg-confgen/internal/constants"
|
||||
)
|
||||
|
||||
// Error codes from ProtonVPN API
|
||||
// Official source: github.com/ProtonMail/protoncore_android/.../ResponseCodes.kt
|
||||
// See API_REFERENCE.md for full documentation.
|
||||
const (
|
||||
CodeSuccess = constants.APICodeSuccess
|
||||
CodeWrongPassword = 8002 // PASSWORD_WRONG: Incorrect password
|
||||
CodeWrongPasswordFormat = 8004 // Password format is incorrect (observed)
|
||||
CodeCaptchaRequired = 9001 // HUMAN_VERIFICATION_REQUIRED: CAPTCHA needed
|
||||
Code2FARequiredForVPN = 9100 // VPN-specific: certificate endpoint requires 2FA session (not in official docs)
|
||||
CodeAccountDeleted = 10002 // ACCOUNT_DELETED: Account has been deleted
|
||||
CodeAccountDisabled = 10003 // ACCOUNT_DISABLED: Account has been disabled
|
||||
CodeMailboxPasswordError = 10013 // Legacy 2-password mode / invalid refresh token (context-dependent)
|
||||
)
|
||||
|
||||
// Error represents an authentication error with ProtonVPN-specific error code
|
||||
type Error struct {
|
||||
Code int
|
||||
Message string
|
||||
}
|
||||
|
||||
// Error implements the error interface
|
||||
func (e Error) Error() string {
|
||||
return e.Message
|
||||
}
|
||||
|
||||
// NewError creates a new authentication error from an API response code
|
||||
func NewError(code int) error {
|
||||
message := getErrorMessage(code)
|
||||
return Error{
|
||||
Code: code,
|
||||
Message: message,
|
||||
}
|
||||
}
|
||||
|
||||
// getErrorMessage returns a human-readable error message for a given error code
|
||||
func getErrorMessage(code int) string {
|
||||
switch code {
|
||||
case CodeWrongPassword:
|
||||
return "incorrect username or password"
|
||||
case CodeWrongPasswordFormat:
|
||||
return "password format is incorrect"
|
||||
case CodeCaptchaRequired:
|
||||
return "CAPTCHA verification required"
|
||||
case Code2FARequiredForVPN:
|
||||
return "2FA required for VPN operations - your session was authenticated without 2FA (device trust). Use -clear-session to force re-authentication with 2FA"
|
||||
case CodeAccountDeleted:
|
||||
return "account has been deleted"
|
||||
case CodeAccountDisabled:
|
||||
return "account has been disabled"
|
||||
case CodeMailboxPasswordError:
|
||||
return "account uses legacy 2-password mode - please switch to single-password mode at account.proton.me"
|
||||
default:
|
||||
return fmt.Sprintf("authentication failed with code: %d", code)
|
||||
}
|
||||
}
|
||||
|
||||
// IsAccountError checks if the error is an account status error (deleted or disabled)
|
||||
func IsAccountError(err error) bool {
|
||||
var authErr Error
|
||||
if !errors.As(err, &authErr) {
|
||||
return false
|
||||
}
|
||||
return authErr.Code == CodeAccountDeleted || authErr.Code == CodeAccountDisabled
|
||||
}
|
||||
|
||||
// IsCaptchaError checks if the error requires CAPTCHA verification
|
||||
func IsCaptchaError(err error) bool {
|
||||
var authErr Error
|
||||
if !errors.As(err, &authErr) {
|
||||
return false
|
||||
}
|
||||
return authErr.Code == CodeCaptchaRequired
|
||||
}
|
||||
210
protonvpn-wg-confgen/internal/auth/session.go
Normal file
210
protonvpn-wg-confgen/internal/auth/session.go
Normal file
@@ -0,0 +1,210 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"protonvpn-wg-confgen/internal/api"
|
||||
"protonvpn-wg-confgen/internal/constants"
|
||||
)
|
||||
|
||||
// SessionStore handles persistent session storage
|
||||
type SessionStore struct {
|
||||
filePath string
|
||||
}
|
||||
|
||||
// NewSessionStore creates a new session store
|
||||
func NewSessionStore() *SessionStore {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
// Fallback to current directory
|
||||
homeDir = "."
|
||||
}
|
||||
|
||||
return &SessionStore{
|
||||
filePath: filepath.Join(homeDir, constants.SessionFileName),
|
||||
}
|
||||
}
|
||||
|
||||
// SavedSession represents a session with metadata
|
||||
type SavedSession struct {
|
||||
Session *api.Session `json:"session"`
|
||||
Username string `json:"username"`
|
||||
SavedAt time.Time `json:"saved_at"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
}
|
||||
|
||||
// Save stores the session to disk
|
||||
func (s *SessionStore) Save(session *api.Session, username string, duration time.Duration) error {
|
||||
savedSession := &SavedSession{
|
||||
Session: session,
|
||||
Username: username,
|
||||
SavedAt: time.Now(),
|
||||
}
|
||||
|
||||
// Calculate expiration based on API response
|
||||
apiExpiration := time.Now().Add(time.Duration(session.ExpiresIn) * time.Second)
|
||||
|
||||
if duration == 0 {
|
||||
// Use the API's expiration
|
||||
savedSession.ExpiresAt = apiExpiration
|
||||
} else {
|
||||
// Use the user-specified duration, but cap it at API expiration
|
||||
userExpiration := time.Now().Add(duration)
|
||||
if userExpiration.After(apiExpiration) {
|
||||
savedSession.ExpiresAt = apiExpiration
|
||||
} else {
|
||||
savedSession.ExpiresAt = userExpiration
|
||||
}
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(savedSession, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal session: %w", err)
|
||||
}
|
||||
|
||||
err = os.WriteFile(s.filePath, data, constants.SessionFileMode)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write session file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Load retrieves a saved session from disk
|
||||
func (s *SessionStore) Load(username string) (*api.Session, time.Duration, error) {
|
||||
data, err := os.ReadFile(s.filePath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, 0, nil // No saved session
|
||||
}
|
||||
return nil, 0, fmt.Errorf("failed to read session file: %w", err)
|
||||
}
|
||||
|
||||
var savedSession SavedSession
|
||||
err = json.Unmarshal(data, &savedSession)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to unmarshal session: %w", err)
|
||||
}
|
||||
|
||||
// Check if session is for the same user
|
||||
if savedSession.Username != username {
|
||||
return nil, 0, nil
|
||||
}
|
||||
|
||||
// Check if session has expired
|
||||
now := time.Now()
|
||||
if now.After(savedSession.ExpiresAt) {
|
||||
// Delete expired session
|
||||
_ = s.Delete()
|
||||
return nil, 0, nil
|
||||
}
|
||||
|
||||
// Calculate time until expiration
|
||||
timeUntilExpiry := savedSession.ExpiresAt.Sub(now)
|
||||
|
||||
return savedSession.Session, timeUntilExpiry, nil
|
||||
}
|
||||
|
||||
// Delete removes the saved session
|
||||
func (s *SessionStore) Delete() error {
|
||||
err := os.Remove(s.filePath)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to delete session file: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPath returns the session file path
|
||||
func (s *SessionStore) GetPath() string {
|
||||
return s.filePath
|
||||
}
|
||||
|
||||
// RefreshSession attempts to refresh the session using the refresh token.
|
||||
// It returns a new session with updated tokens if successful.
|
||||
func RefreshSession(httpClient *http.Client, apiURL string, oldSession *api.Session) (*api.Session, error) {
|
||||
// Based on proton-python-client/proton/api.py refresh() method
|
||||
reqBody := map[string]interface{}{
|
||||
"ResponseType": "token",
|
||||
"GrantType": "refresh_token",
|
||||
"RefreshToken": oldSession.RefreshToken,
|
||||
"RedirectURI": "http://protonmail.ch",
|
||||
}
|
||||
|
||||
body, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, apiURL+"/auth/refresh", bytes.NewBuffer(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set standard headers
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("x-pm-appversion", constants.AppVersion)
|
||||
req.Header.Set("User-Agent", constants.UserAgent)
|
||||
|
||||
// Include auth headers for refresh
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oldSession.AccessToken))
|
||||
req.Header.Set("x-pm-uid", oldSession.UID)
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
var session api.Session
|
||||
if err := json.Unmarshal(respBody, &session); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if constants.IsSuccessCode(session.Code) {
|
||||
return &session, nil
|
||||
}
|
||||
}
|
||||
|
||||
// If refresh fails, return error to trigger re-authentication
|
||||
return nil, fmt.Errorf("refresh failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
// VerifySession checks if a session is still valid by making a test API request.
|
||||
func VerifySession(httpClient *http.Client, apiURL string, session *api.Session) bool {
|
||||
// Make a simple request to verify the session
|
||||
req, err := http.NewRequest(http.MethodGet, apiURL+"/vpn/v1/logicals", http.NoBody)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", session.AccessToken))
|
||||
req.Header.Set("x-pm-uid", session.UID)
|
||||
req.Header.Set("x-pm-appversion", constants.AppVersion)
|
||||
req.Header.Set("User-Agent", constants.UserAgent)
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
// If we get a 401, the session is invalid
|
||||
if resp.StatusCode == http.StatusUnauthorized {
|
||||
return false
|
||||
}
|
||||
|
||||
// Any 2xx response means the session is valid
|
||||
return resp.StatusCode >= 200 && resp.StatusCode < 300
|
||||
}
|
||||
120
protonvpn-wg-confgen/internal/config/flags.go
Normal file
120
protonvpn-wg-confgen/internal/config/flags.go
Normal file
@@ -0,0 +1,120 @@
|
||||
// Package config handles command-line argument parsing and configuration management.
|
||||
package config
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"protonvpn-wg-confgen/internal/constants"
|
||||
"protonvpn-wg-confgen/pkg/validation"
|
||||
)
|
||||
|
||||
// Parse parses command-line flags and returns a Config
|
||||
func Parse() (*Config, error) {
|
||||
cfg := &Config{}
|
||||
|
||||
var countriesFlag string
|
||||
var dnsServersFlag string
|
||||
var allowedIPsFlag string
|
||||
|
||||
// Set default DNS and allowed IPs based on IPv6 support
|
||||
defaultDNS := constants.DefaultDNSIPv4
|
||||
defaultAllowedIPs := constants.DefaultAllowedIPsIPv4
|
||||
|
||||
// Authentication flags
|
||||
flag.StringVar(&cfg.Username, "username", "", "ProtonVPN username")
|
||||
flag.StringVar(&cfg.Password, "password", "", "ProtonVPN password (will prompt if not provided)")
|
||||
|
||||
// Server selection flags
|
||||
flag.StringVar(&countriesFlag, "countries", "", "Comma-separated list of country codes (e.g., US,NL,CH)")
|
||||
flag.BoolVar(&cfg.P2PServersOnly, "p2p-only", constants.DefaultP2POnly, "Use only P2P-enabled servers")
|
||||
flag.BoolVar(&cfg.SecureCoreOnly, "secure-core", false, "Use only Secure Core servers (multi-hop through privacy-friendly countries)")
|
||||
flag.BoolVar(&cfg.FreeOnly, "free-only", false, "Use only Free tier servers (tier 0)")
|
||||
|
||||
// Output configuration
|
||||
flag.StringVar(&cfg.OutputFile, "output", "protonvpn.conf", "Output WireGuard configuration file")
|
||||
flag.StringVar(&cfg.DeviceName, "device-name", "", "Device name for WireGuard config (auto-generated if empty)")
|
||||
|
||||
// Network configuration
|
||||
flag.BoolVar(&cfg.EnableIPv6, "ipv6", false, "Enable IPv6 support")
|
||||
flag.StringVar(&dnsServersFlag, "dns", "", "Comma-separated list of DNS servers (defaults based on IPv6 setting)")
|
||||
flag.StringVar(&allowedIPsFlag, "allowed-ips", "", "Comma-separated list of allowed IPs (defaults based on IPv6 setting)")
|
||||
flag.BoolVar(&cfg.EnableAccelerator, "accelerator", true, "Enable VPN accelerator")
|
||||
|
||||
// Certificate configuration
|
||||
flag.StringVar(&cfg.Duration, "duration", constants.DefaultCertDuration, "Certificate duration (e.g., 30m, 24h, 7d, 1h30m). Max: 365d")
|
||||
|
||||
// Session management
|
||||
flag.BoolVar(&cfg.ClearSession, "clear-session", false, "Clear saved session and force re-authentication")
|
||||
flag.BoolVar(&cfg.NoSession, "no-session", false, "Don't save or use session persistence")
|
||||
flag.BoolVar(&cfg.ForceRefresh, "force-refresh", false, "Force session refresh even if not expired")
|
||||
flag.StringVar(&cfg.SessionDuration, "session-duration", "0", "Session cache duration (e.g., 12h, 24h, 7d). 0 = no expiration")
|
||||
|
||||
// Advanced configuration
|
||||
flag.StringVar(&cfg.APIURL, "api-url", constants.DefaultAPIURL, "ProtonVPN API URL")
|
||||
flag.BoolVar(&cfg.Debug, "debug", false, "Enable debug output")
|
||||
|
||||
flag.Parse()
|
||||
|
||||
// Validate required flags
|
||||
if countriesFlag == "" {
|
||||
return nil, fmt.Errorf("countries flag is required")
|
||||
}
|
||||
|
||||
// Parse and validate country codes
|
||||
cfg.Countries = parseCountries(countriesFlag)
|
||||
for _, country := range cfg.Countries {
|
||||
if !validation.IsValidCountryCode(country) {
|
||||
return nil, fmt.Errorf("invalid country code: %s", country)
|
||||
}
|
||||
}
|
||||
|
||||
// Set defaults based on IPv6 setting
|
||||
if cfg.EnableIPv6 {
|
||||
defaultDNS = fmt.Sprintf("%s,%s", constants.DefaultDNSIPv4, constants.DefaultDNSIPv6)
|
||||
defaultAllowedIPs = fmt.Sprintf("%s,%s", constants.DefaultAllowedIPsIPv4, constants.DefaultAllowedIPsIPv6)
|
||||
}
|
||||
|
||||
// Use defaults if flags are empty
|
||||
if dnsServersFlag == "" {
|
||||
dnsServersFlag = defaultDNS
|
||||
}
|
||||
if allowedIPsFlag == "" {
|
||||
allowedIPsFlag = defaultAllowedIPs
|
||||
}
|
||||
|
||||
// Parse lists (with space trimming)
|
||||
cfg.DNSServers = parseCommaSeparatedList(dnsServersFlag)
|
||||
cfg.AllowedIPs = parseCommaSeparatedList(allowedIPsFlag)
|
||||
|
||||
// Clean up username
|
||||
cfg.Username = validation.CleanUsername(cfg.Username)
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// parseCommaSeparatedList parses a comma-separated string into a trimmed slice
|
||||
func parseCommaSeparatedList(input string) []string {
|
||||
parts := strings.Split(input, ",")
|
||||
var result []string
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if part != "" {
|
||||
result = append(result, part)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// parseCountries parses and normalizes country codes
|
||||
func parseCountries(countriesFlag string) []string {
|
||||
return parseCommaSeparatedList(strings.ToUpper(countriesFlag))
|
||||
}
|
||||
|
||||
// PrintUsage prints usage information
|
||||
func PrintUsage() {
|
||||
fmt.Fprintf(os.Stderr, "Usage: %s -username <username> -countries <country-codes> [options]\n\n", os.Args[0])
|
||||
flag.PrintDefaults()
|
||||
}
|
||||
48
protonvpn-wg-confgen/internal/config/types.go
Normal file
48
protonvpn-wg-confgen/internal/config/types.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package config
|
||||
|
||||
import "fmt"
|
||||
|
||||
// Config holds all configuration options
|
||||
type Config struct {
|
||||
// Authentication
|
||||
Username string
|
||||
Password string
|
||||
|
||||
// Server selection
|
||||
Countries []string
|
||||
P2PServersOnly bool
|
||||
SecureCoreOnly bool
|
||||
FreeOnly bool
|
||||
|
||||
// Output configuration
|
||||
OutputFile string
|
||||
ClientPrivateKey string
|
||||
DeviceName string
|
||||
|
||||
// Network configuration
|
||||
DNSServers []string
|
||||
AllowedIPs []string
|
||||
EnableAccelerator bool
|
||||
EnableIPv6 bool
|
||||
|
||||
// Certificate configuration
|
||||
Duration string
|
||||
|
||||
// Session management
|
||||
ClearSession bool
|
||||
NoSession bool
|
||||
ForceRefresh bool
|
||||
SessionDuration string
|
||||
|
||||
// Advanced configuration
|
||||
APIURL string
|
||||
Debug bool
|
||||
}
|
||||
|
||||
// ValidateCredentials checks if we have the required credentials
|
||||
func (c *Config) ValidateCredentials() error {
|
||||
if c.Username == "" {
|
||||
return fmt.Errorf("username is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
37
protonvpn-wg-confgen/internal/constants/api.go
Normal file
37
protonvpn-wg-confgen/internal/constants/api.go
Normal file
@@ -0,0 +1,37 @@
|
||||
// Package constants defines constants used throughout the application.
|
||||
package constants
|
||||
|
||||
// API endpoints
|
||||
const (
|
||||
DefaultAPIURL = "https://vpn-api.proton.me"
|
||||
AuthInfoPath = "/core/v4/auth/info"
|
||||
AuthPath = "/core/v4/auth"
|
||||
RefreshPath = "/auth/refresh"
|
||||
CertificatePath = "/vpn/v1/certificate"
|
||||
LogicalsPath = "/vpn/v1/logicals"
|
||||
)
|
||||
|
||||
// API version headers - can be overridden at build time via ldflags:
|
||||
// go build -ldflags "-X .../internal/constants.AppVersion=linux-vpn@X.Y.Z"
|
||||
var (
|
||||
AppVersion = "linux-vpn@4.13.1"
|
||||
UserAgent = "ProtonVPN/4.13.1 (Linux; Ubuntu)"
|
||||
)
|
||||
|
||||
// API response codes
|
||||
// Reference: proton-python-client/proton/api.py checks for codes 1000 and 1001
|
||||
const (
|
||||
APICodeSuccess = 1000
|
||||
APICodeMultiStatus = 1001 // Also indicates success in some contexts
|
||||
)
|
||||
|
||||
// IsSuccessCode checks if an API response code indicates success
|
||||
func IsSuccessCode(code int) bool {
|
||||
return code == APICodeSuccess || code == APICodeMultiStatus
|
||||
}
|
||||
|
||||
// Server/feature status values
|
||||
const (
|
||||
StatusOnline = 1
|
||||
EnabledTrue = 1
|
||||
)
|
||||
14
protonvpn-wg-confgen/internal/constants/defaults.go
Normal file
14
protonvpn-wg-confgen/internal/constants/defaults.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package constants
|
||||
|
||||
// Certificate defaults
|
||||
const (
|
||||
DefaultCertDuration = "365d"
|
||||
MaxCertDuration = 365 // days
|
||||
CertMode = "persistent"
|
||||
PublicKeyMode = "EC"
|
||||
)
|
||||
|
||||
// Server selection defaults
|
||||
const (
|
||||
DefaultP2POnly = true
|
||||
)
|
||||
9
protonvpn-wg-confgen/internal/constants/session.go
Normal file
9
protonvpn-wg-confgen/internal/constants/session.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package constants
|
||||
|
||||
// Session defaults
|
||||
const (
|
||||
SessionFileName = ".protonvpn-session.json"
|
||||
SessionFileMode = 0o600 // Read/write for owner only
|
||||
SessionRefreshDays = 7 // Refresh when less than 7 days remain
|
||||
SessionExpirySeconds = 2592000 // 30 days in seconds (from API)
|
||||
)
|
||||
17
protonvpn-wg-confgen/internal/constants/wireguard.go
Normal file
17
protonvpn-wg-confgen/internal/constants/wireguard.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package constants
|
||||
|
||||
// WireGuard defaults
|
||||
const (
|
||||
WireGuardPort = 51820
|
||||
DefaultMTU = 1420
|
||||
|
||||
// IPv4 configuration
|
||||
WireGuardIPv4 = "10.2.0.2/32"
|
||||
DefaultDNSIPv4 = "10.2.0.1"
|
||||
DefaultAllowedIPsIPv4 = "0.0.0.0/0"
|
||||
|
||||
// IPv6 configuration
|
||||
WireGuardIPv6 = "2a07:b944::2:2/128"
|
||||
DefaultDNSIPv6 = "2a07:b944::2:1"
|
||||
DefaultAllowedIPsIPv6 = "::/0"
|
||||
)
|
||||
148
protonvpn-wg-confgen/internal/vpn/client.go
Normal file
148
protonvpn-wg-confgen/internal/vpn/client.go
Normal file
@@ -0,0 +1,148 @@
|
||||
// Package vpn manages VPN certificate generation and server interactions.
|
||||
package vpn
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"protonvpn-wg-confgen/internal/api"
|
||||
"protonvpn-wg-confgen/internal/config"
|
||||
"protonvpn-wg-confgen/internal/constants"
|
||||
"protonvpn-wg-confgen/pkg/timeutil"
|
||||
|
||||
"github.com/ProtonVPN/go-vpn-lib/ed25519"
|
||||
)
|
||||
|
||||
// Client handles VPN operations
|
||||
type Client struct {
|
||||
config *config.Config
|
||||
session *api.Session
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewClient creates a new VPN client
|
||||
func NewClient(cfg *config.Config, session *api.Session) *Client {
|
||||
return &Client{
|
||||
config: cfg,
|
||||
session: session,
|
||||
httpClient: &http.Client{Timeout: 10 * time.Second},
|
||||
}
|
||||
}
|
||||
|
||||
// GetCertificate generates a VPN certificate
|
||||
func (c *Client) GetCertificate(keyPair *ed25519.KeyPair) (*api.VPNInfo, error) {
|
||||
publicKeyPEM, err := keyPair.PublicKeyPKIXPem()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get public key PEM: %w", err)
|
||||
}
|
||||
|
||||
// Use provided device name or generate one
|
||||
deviceName := c.config.DeviceName
|
||||
if deviceName == "" {
|
||||
deviceName = fmt.Sprintf("WireGuard-%s-%d", c.config.Username, time.Now().Unix())
|
||||
}
|
||||
|
||||
// Parse duration
|
||||
durationStr, err := timeutil.ParseToMinutes(c.config.Duration)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse duration: %w", err)
|
||||
}
|
||||
|
||||
// Build certificate request matching official ProtonVPN API format
|
||||
// Feature keys from: python-proton-vpn-api-core/proton/vpn/session/fetcher.py
|
||||
certReq := map[string]interface{}{
|
||||
"ClientPublicKey": publicKeyPEM,
|
||||
"ClientPublicKeyMode": "EC",
|
||||
"Mode": "persistent", // Create persistent configuration
|
||||
"DeviceName": deviceName,
|
||||
"Duration": durationStr,
|
||||
"Features": map[string]interface{}{
|
||||
"NetShieldLevel": 0, // NetShield disabled
|
||||
"RandomNAT": false, // Moderate NAT disabled
|
||||
"PortForwarding": false, // Port forwarding disabled
|
||||
"SplitTCP": c.config.EnableAccelerator, // VPN Accelerator (called SplitTCP in API)
|
||||
},
|
||||
}
|
||||
|
||||
certJSON, err := json.Marshal(certReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, c.config.APIURL+"/vpn/v1/certificate", bytes.NewBuffer(certJSON))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.setHeaders(req)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var vpnInfo api.VPNInfo
|
||||
if err := json.Unmarshal(body, &vpnInfo); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !constants.IsSuccessCode(vpnInfo.Code) {
|
||||
// Include the actual API error message if available
|
||||
if vpnInfo.Error != "" {
|
||||
return nil, fmt.Errorf("VPN certificate error (code %d): %s", vpnInfo.Code, vpnInfo.Error)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get VPN certificate, code: %d", vpnInfo.Code)
|
||||
}
|
||||
|
||||
return &vpnInfo, nil
|
||||
}
|
||||
|
||||
// GetServers fetches the list of VPN servers
|
||||
func (c *Client) GetServers() ([]api.LogicalServer, error) {
|
||||
req, err := http.NewRequest(http.MethodGet, c.config.APIURL+"/vpn/v1/logicals", http.NoBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.setHeaders(req)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var response api.LogicalsResponse
|
||||
if err := json.Unmarshal(body, &response); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !constants.IsSuccessCode(response.Code) {
|
||||
return nil, fmt.Errorf("API returned error code: %d", response.Code)
|
||||
}
|
||||
|
||||
return response.LogicalServers, nil
|
||||
}
|
||||
|
||||
func (c *Client) setHeaders(req *http.Request) {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.session.AccessToken))
|
||||
req.Header.Set("x-pm-uid", c.session.UID)
|
||||
req.Header.Set("x-pm-appversion", constants.AppVersion)
|
||||
req.Header.Set("User-Agent", constants.UserAgent)
|
||||
}
|
||||
165
protonvpn-wg-confgen/internal/vpn/servers.go
Normal file
165
protonvpn-wg-confgen/internal/vpn/servers.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package vpn
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"protonvpn-wg-confgen/internal/api"
|
||||
"protonvpn-wg-confgen/internal/config"
|
||||
"protonvpn-wg-confgen/internal/constants"
|
||||
)
|
||||
|
||||
// ServerSelector handles server selection logic
|
||||
type ServerSelector struct {
|
||||
config *config.Config
|
||||
}
|
||||
|
||||
// NewServerSelector creates a new server selector
|
||||
func NewServerSelector(cfg *config.Config) *ServerSelector {
|
||||
return &ServerSelector{config: cfg}
|
||||
}
|
||||
|
||||
// SelectBest selects the best server based on configuration
|
||||
func (s *ServerSelector) SelectBest(servers []api.LogicalServer) (*api.LogicalServer, error) {
|
||||
filtered := s.filterServers(servers)
|
||||
|
||||
if s.config.Debug {
|
||||
s.printDebugServerList(filtered)
|
||||
}
|
||||
|
||||
if len(filtered) == 0 {
|
||||
return nil, s.buildNoServersError()
|
||||
}
|
||||
|
||||
// Sort servers: first by score (descending), then by load (ascending)
|
||||
sort.Slice(filtered, func(i, j int) bool {
|
||||
// If scores are different, higher score wins
|
||||
if filtered[i].Score != filtered[j].Score {
|
||||
return filtered[i].Score > filtered[j].Score
|
||||
}
|
||||
// If scores are equal, lower load wins
|
||||
return filtered[i].Load < filtered[j].Load
|
||||
})
|
||||
|
||||
return &filtered[0], nil
|
||||
}
|
||||
|
||||
func (s *ServerSelector) filterServers(servers []api.LogicalServer) []api.LogicalServer {
|
||||
var filtered []api.LogicalServer
|
||||
|
||||
for i := range servers {
|
||||
if s.isServerEligible(&servers[i]) {
|
||||
filtered = append(filtered, servers[i])
|
||||
}
|
||||
}
|
||||
|
||||
return filtered
|
||||
}
|
||||
|
||||
func (s *ServerSelector) isServerEligible(server *api.LogicalServer) bool {
|
||||
// Skip offline servers
|
||||
if server.Status != constants.StatusOnline {
|
||||
return false
|
||||
}
|
||||
|
||||
// Filter by tier based on -free-only flag
|
||||
if s.config.FreeOnly {
|
||||
// When free-only is enabled, only accept Free tier servers
|
||||
if server.Tier != api.TierFree {
|
||||
return false
|
||||
}
|
||||
} else {
|
||||
// Otherwise, filter out free tier servers
|
||||
if server.Tier == api.TierFree {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Filter by P2P support if requested (but not when using Secure Core or Free tier)
|
||||
if s.config.P2PServersOnly && !s.config.SecureCoreOnly && !s.config.FreeOnly && server.Features&api.FeatureP2P == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Filter by Secure Core if requested
|
||||
if s.config.SecureCoreOnly && server.Features&api.FeatureSecureCore == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Filter by country
|
||||
if !s.isCountryMatch(server) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip servers with no physical servers
|
||||
if len(server.Servers) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *ServerSelector) isCountryMatch(server *api.LogicalServer) bool {
|
||||
for _, country := range s.config.Countries {
|
||||
if server.ExitCountry == country {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *ServerSelector) buildNoServersError() error {
|
||||
errMsg := fmt.Sprintf("No suitable servers found for countries: %v", s.config.Countries)
|
||||
|
||||
if s.config.SecureCoreOnly {
|
||||
errMsg += " with Secure Core"
|
||||
} else if s.config.P2PServersOnly {
|
||||
errMsg += " with P2P support"
|
||||
}
|
||||
|
||||
return errors.New(errMsg)
|
||||
}
|
||||
|
||||
// GetBestPhysicalServer returns the best physical server from a logical server
|
||||
func GetBestPhysicalServer(server *api.LogicalServer) *api.PhysicalServer {
|
||||
if len(server.Servers) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Find the first online physical server
|
||||
for i := range server.Servers {
|
||||
if server.Servers[i].Status == constants.StatusOnline {
|
||||
return &server.Servers[i]
|
||||
}
|
||||
}
|
||||
|
||||
// If no online servers, return the first one
|
||||
return &server.Servers[0]
|
||||
}
|
||||
|
||||
// printDebugServerList prints a debug list of filtered servers
|
||||
func (s *ServerSelector) printDebugServerList(servers []api.LogicalServer) {
|
||||
fmt.Printf("\nDEBUG: Found %d servers after filtering:\n", len(servers))
|
||||
fmt.Println("==================================================================================")
|
||||
fmt.Printf("%-15s | %-18s | %-12s | Load | Score | Features\n", "Server", "City", "Tier")
|
||||
fmt.Println("----------------------------------------------------------------------------------")
|
||||
|
||||
for i := range servers {
|
||||
features := api.GetFeatureNames(servers[i].Features)
|
||||
featureStr := "-"
|
||||
if len(features) > 0 {
|
||||
featureStr = strings.Join(features, ", ")
|
||||
}
|
||||
|
||||
fmt.Printf("%-15s | %-18s | %-12s | %3d%% | %.2f | %s\n",
|
||||
servers[i].Name,
|
||||
servers[i].City,
|
||||
api.GetTierName(servers[i].Tier),
|
||||
servers[i].Load,
|
||||
servers[i].Score,
|
||||
featureStr)
|
||||
}
|
||||
|
||||
fmt.Println("==================================================================================")
|
||||
}
|
||||
Reference in New Issue
Block a user