initial commit

This commit is contained in:
2026-04-07 17:41:25 +02:00
commit 1ed9bdfa55
45 changed files with 4712 additions and 0 deletions

View File

@@ -0,0 +1,148 @@
// Package vpn manages VPN certificate generation and server interactions.
package vpn
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
"protonvpn-wg-confgen/internal/api"
"protonvpn-wg-confgen/internal/config"
"protonvpn-wg-confgen/internal/constants"
"protonvpn-wg-confgen/pkg/timeutil"
"github.com/ProtonVPN/go-vpn-lib/ed25519"
)
// Client handles VPN operations
type Client struct {
config *config.Config
session *api.Session
httpClient *http.Client
}
// NewClient creates a new VPN client
func NewClient(cfg *config.Config, session *api.Session) *Client {
return &Client{
config: cfg,
session: session,
httpClient: &http.Client{Timeout: 10 * time.Second},
}
}
// GetCertificate generates a VPN certificate
func (c *Client) GetCertificate(keyPair *ed25519.KeyPair) (*api.VPNInfo, error) {
publicKeyPEM, err := keyPair.PublicKeyPKIXPem()
if err != nil {
return nil, fmt.Errorf("failed to get public key PEM: %w", err)
}
// Use provided device name or generate one
deviceName := c.config.DeviceName
if deviceName == "" {
deviceName = fmt.Sprintf("WireGuard-%s-%d", c.config.Username, time.Now().Unix())
}
// Parse duration
durationStr, err := timeutil.ParseToMinutes(c.config.Duration)
if err != nil {
return nil, fmt.Errorf("failed to parse duration: %w", err)
}
// Build certificate request matching official ProtonVPN API format
// Feature keys from: python-proton-vpn-api-core/proton/vpn/session/fetcher.py
certReq := map[string]interface{}{
"ClientPublicKey": publicKeyPEM,
"ClientPublicKeyMode": "EC",
"Mode": "persistent", // Create persistent configuration
"DeviceName": deviceName,
"Duration": durationStr,
"Features": map[string]interface{}{
"NetShieldLevel": 0, // NetShield disabled
"RandomNAT": false, // Moderate NAT disabled
"PortForwarding": false, // Port forwarding disabled
"SplitTCP": c.config.EnableAccelerator, // VPN Accelerator (called SplitTCP in API)
},
}
certJSON, err := json.Marshal(certReq)
if err != nil {
return nil, err
}
req, err := http.NewRequest(http.MethodPost, c.config.APIURL+"/vpn/v1/certificate", bytes.NewBuffer(certJSON))
if err != nil {
return nil, err
}
c.setHeaders(req)
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
var vpnInfo api.VPNInfo
if err := json.Unmarshal(body, &vpnInfo); err != nil {
return nil, err
}
if !constants.IsSuccessCode(vpnInfo.Code) {
// Include the actual API error message if available
if vpnInfo.Error != "" {
return nil, fmt.Errorf("VPN certificate error (code %d): %s", vpnInfo.Code, vpnInfo.Error)
}
return nil, fmt.Errorf("failed to get VPN certificate, code: %d", vpnInfo.Code)
}
return &vpnInfo, nil
}
// GetServers fetches the list of VPN servers
func (c *Client) GetServers() ([]api.LogicalServer, error) {
req, err := http.NewRequest(http.MethodGet, c.config.APIURL+"/vpn/v1/logicals", http.NoBody)
if err != nil {
return nil, err
}
c.setHeaders(req)
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
var response api.LogicalsResponse
if err := json.Unmarshal(body, &response); err != nil {
return nil, err
}
if !constants.IsSuccessCode(response.Code) {
return nil, fmt.Errorf("API returned error code: %d", response.Code)
}
return response.LogicalServers, nil
}
func (c *Client) setHeaders(req *http.Request) {
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.session.AccessToken))
req.Header.Set("x-pm-uid", c.session.UID)
req.Header.Set("x-pm-appversion", constants.AppVersion)
req.Header.Set("User-Agent", constants.UserAgent)
}

View File

@@ -0,0 +1,165 @@
package vpn
import (
"errors"
"fmt"
"sort"
"strings"
"protonvpn-wg-confgen/internal/api"
"protonvpn-wg-confgen/internal/config"
"protonvpn-wg-confgen/internal/constants"
)
// ServerSelector handles server selection logic
type ServerSelector struct {
config *config.Config
}
// NewServerSelector creates a new server selector
func NewServerSelector(cfg *config.Config) *ServerSelector {
return &ServerSelector{config: cfg}
}
// SelectBest selects the best server based on configuration
func (s *ServerSelector) SelectBest(servers []api.LogicalServer) (*api.LogicalServer, error) {
filtered := s.filterServers(servers)
if s.config.Debug {
s.printDebugServerList(filtered)
}
if len(filtered) == 0 {
return nil, s.buildNoServersError()
}
// Sort servers: first by score (descending), then by load (ascending)
sort.Slice(filtered, func(i, j int) bool {
// If scores are different, higher score wins
if filtered[i].Score != filtered[j].Score {
return filtered[i].Score > filtered[j].Score
}
// If scores are equal, lower load wins
return filtered[i].Load < filtered[j].Load
})
return &filtered[0], nil
}
func (s *ServerSelector) filterServers(servers []api.LogicalServer) []api.LogicalServer {
var filtered []api.LogicalServer
for i := range servers {
if s.isServerEligible(&servers[i]) {
filtered = append(filtered, servers[i])
}
}
return filtered
}
func (s *ServerSelector) isServerEligible(server *api.LogicalServer) bool {
// Skip offline servers
if server.Status != constants.StatusOnline {
return false
}
// Filter by tier based on -free-only flag
if s.config.FreeOnly {
// When free-only is enabled, only accept Free tier servers
if server.Tier != api.TierFree {
return false
}
} else {
// Otherwise, filter out free tier servers
if server.Tier == api.TierFree {
return false
}
}
// Filter by P2P support if requested (but not when using Secure Core or Free tier)
if s.config.P2PServersOnly && !s.config.SecureCoreOnly && !s.config.FreeOnly && server.Features&api.FeatureP2P == 0 {
return false
}
// Filter by Secure Core if requested
if s.config.SecureCoreOnly && server.Features&api.FeatureSecureCore == 0 {
return false
}
// Filter by country
if !s.isCountryMatch(server) {
return false
}
// Skip servers with no physical servers
if len(server.Servers) == 0 {
return false
}
return true
}
func (s *ServerSelector) isCountryMatch(server *api.LogicalServer) bool {
for _, country := range s.config.Countries {
if server.ExitCountry == country {
return true
}
}
return false
}
func (s *ServerSelector) buildNoServersError() error {
errMsg := fmt.Sprintf("No suitable servers found for countries: %v", s.config.Countries)
if s.config.SecureCoreOnly {
errMsg += " with Secure Core"
} else if s.config.P2PServersOnly {
errMsg += " with P2P support"
}
return errors.New(errMsg)
}
// GetBestPhysicalServer returns the best physical server from a logical server
func GetBestPhysicalServer(server *api.LogicalServer) *api.PhysicalServer {
if len(server.Servers) == 0 {
return nil
}
// Find the first online physical server
for i := range server.Servers {
if server.Servers[i].Status == constants.StatusOnline {
return &server.Servers[i]
}
}
// If no online servers, return the first one
return &server.Servers[0]
}
// printDebugServerList prints a debug list of filtered servers
func (s *ServerSelector) printDebugServerList(servers []api.LogicalServer) {
fmt.Printf("\nDEBUG: Found %d servers after filtering:\n", len(servers))
fmt.Println("==================================================================================")
fmt.Printf("%-15s | %-18s | %-12s | Load | Score | Features\n", "Server", "City", "Tier")
fmt.Println("----------------------------------------------------------------------------------")
for i := range servers {
features := api.GetFeatureNames(servers[i].Features)
featureStr := "-"
if len(features) > 0 {
featureStr = strings.Join(features, ", ")
}
fmt.Printf("%-15s | %-18s | %-12s | %3d%% | %.2f | %s\n",
servers[i].Name,
servers[i].City,
api.GetTierName(servers[i].Tier),
servers[i].Load,
servers[i].Score,
featureStr)
}
fmt.Println("==================================================================================")
}