package main import ( "context" "crypto/tls" "crypto/x509" "encoding/pem" "flag" "fmt" "net" "os" "slices" "syscall" "time" "go.uber.org/zap" ) const ( MAX_CERT_POOL_SIZE = 1024 HANDSHAKE_TICK = 10 * time.Millisecond HANDSHAKE_TIMEOUT = 5 * time.Second ) type config struct { caFile string certFile string keyFile string peerName string isClient bool outFile string serverAddress string serverPort int exportLabel string exportLen int } func parseArgs() (c config, err error) { caFile := flag.String("ca", "", "Path to the file containing all of the Certification Authorities that should be trusted") certFile := flag.String("crt", "", "Path to the file containing the certificate to present to the TLS peer") keyFile := flag.String("key", "", "Path to the file containing the key to use with the TLS peer") peerName := flag.String("peer", "", "Name of the peer that must be in the certificate returned") isClient := flag.Bool("client", false, "Whether we should act as a TLS client. If not specified, act as a server") outFile := flag.String("out", "", "Path to the output file containing the exported secret") exportLabel := flag.String("label", "", "Label associated with the exported secret") exportLen := flag.Int("length", 32, "Length of the exported secret, in bytes") serverAddress := flag.String("addr", "", "IP address to which the TLS server must bind. By default, it binds to all local addresses") serverPort := flag.Int("port", 443, "Port number of the server") flag.Parse() if *caFile == "" { err = fmt.Errorf("ca argument is mandatory") return } if *certFile == "" { err = fmt.Errorf("crt argument is mandatory") return } if *keyFile == "" { err = fmt.Errorf("key argument is mandatory") return } if *peerName == "" { err = fmt.Errorf("peer argument is mandatory") return } if *outFile == "" { err = fmt.Errorf("out argument is mandatory") return } if *exportLabel == "" { err = fmt.Errorf("label argument is mandatory") } c = config{ caFile: *caFile, certFile: *certFile, keyFile: *keyFile, peerName: *peerName, isClient: *isClient, outFile: *outFile, exportLabel: *exportLabel, exportLen: *exportLen, serverAddress: *serverAddress, serverPort: *serverPort, } return } func getCAPool(caFile string) (cp *x509.CertPool, retErr error) { cp = x509.NewCertPool() f, err := os.Open(caFile) if err != nil { retErr = fmt.Errorf("failed to open CA file: %w", err) return } defer f.Close() fileInfo, err := f.Stat() if err != nil { retErr = fmt.Errorf("failed to stat() CA file: %w", err) return } fileMap, err := syscall.Mmap(int(f.Fd()), 0, int(fileInfo.Size()), syscall.PROT_READ, syscall.MAP_PRIVATE) if err != nil { retErr = fmt.Errorf("failed to mmap file content: %w", err) return } defer syscall.Munmap(fileMap) certPoolSize := 0 fm := fileMap[:] for len(fm) > 0 { block, rem := pem.Decode(fm) if block == nil { max := 1024 if len(fm) < max { max = len(fm) } retErr = fmt.Errorf("invalid PEM encoding at %q", string(fm[:max])) return } fm = rem cert, err := x509.ParseCertificate(block.Bytes) if err != nil { retErr = fmt.Errorf("failed to parse certificate: %w", err) return } cp.AddCert(cert) certPoolSize += 1 if certPoolSize > MAX_CERT_POOL_SIZE { retErr = fmt.Errorf("cert pool max size exceeded") return } } return } func verifyClientCertificate(verifiedChains [][]*x509.Certificate, peerName string) error { for _, certChain := range verifiedChains { if len(certChain) == 0 { return fmt.Errorf("empty cert chain: should never happen?") } cert := *certChain[0] if !slices.Contains(cert.DNSNames, peerName) { continue } if cert.KeyUsage&x509.KeyUsageDigitalSignature == 0 { continue } if !slices.Contains(cert.ExtKeyUsage, x509.ExtKeyUsageClientAuth) { continue } return nil } return fmt.Errorf("failed to found a proper certificate matching the expected values") } func getTLSConf(generalConf config) (tc tls.Config, retErr error) { cert, err := tls.LoadX509KeyPair(generalConf.certFile, generalConf.keyFile) if err != nil { retErr = fmt.Errorf("failed to load certificates or private key from files: %w", err) return } caPool, err := getCAPool(generalConf.caFile) if err != nil { retErr = fmt.Errorf("failed to load CA certificates: %w", err) return } tc.MinVersion = tls.VersionTLS13 tc.Certificates = []tls.Certificate{cert} if generalConf.isClient { // TLS Client tc.RootCAs = caPool tc.ServerName = generalConf.peerName } else { // TLS Server tc.ClientCAs = caPool tc.ClientAuth = tls.RequireAndVerifyClientCert tc.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { return verifyClientCertificate(verifiedChains, generalConf.peerName) } } return } func handshakeCompleted(ctx context.Context, conn *tls.Conn) <-chan struct{} { c := make(chan struct{}, 0) go func() { tick := time.NewTicker(HANDSHAKE_TICK) for { cs := conn.ConnectionState() if cs.HandshakeComplete { c <- struct{}{} return } tick.Reset(HANDSHAKE_TICK) select { case <-ctx.Done(): return case <-tick.C: } } }() return c } func getClientSecret(generalConf config, tlsConfig *tls.Config) ([]byte, error) { conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", generalConf.serverAddress, generalConf.serverPort), tlsConfig) if err != nil { return nil, fmt.Errorf("failed to connect to the TLS server: %w", err) } if err := conn.Handshake(); err != nil { return nil, fmt.Errorf("failed to perform handshake: %w", err) } ctx, cancelFunc := context.WithTimeout(context.Background(), HANDSHAKE_TIMEOUT) defer cancelFunc() select { case <-handshakeCompleted(ctx, conn): cancelFunc() case <-ctx.Done(): return nil, fmt.Errorf("failed to perform handshake during the alloted time") } defer conn.Close() cs := conn.ConnectionState() keyMat, err := cs.ExportKeyingMaterial(generalConf.exportLabel, nil, generalConf.exportLen) if err != nil { return nil, fmt.Errorf("failed to export key material from TLS connection: %w", err) } return keyMat, nil } func getServerSecret(generalConf config, tlsConfig *tls.Config) ([]byte, error) { lstn, err := net.Listen("tcp", fmt.Sprintf("%s:%d", generalConf.serverAddress, generalConf.serverPort)) if err != nil { return nil, fmt.Errorf("failed to create the new server listener: %w", err) } conn, err := lstn.Accept() if err != nil { return nil, fmt.Errorf("failed to accept connection from the client: %w", err) } defer conn.Close() tlsConn := tls.Server(conn, tlsConfig) if err := tlsConn.Handshake(); err != nil { return nil, fmt.Errorf("failed to perform handshake: %w", err) } defer tlsConn.Close() ctx, cancelFunc := context.WithTimeout(context.Background(), HANDSHAKE_TIMEOUT) defer cancelFunc() select { case <-handshakeCompleted(ctx, tlsConn): cancelFunc() case <-ctx.Done(): return nil, fmt.Errorf("failed to perform handshake during the alloted time") } cs := tlsConn.ConnectionState() keyMat, _ := cs.ExportKeyingMaterial(generalConf.exportLabel, nil, generalConf.exportLen) return keyMat, nil } func main() { rawLogger, err := zap.NewProduction() if err != nil { os.Exit(1) } defer rawLogger.Sync() logger := rawLogger.Sugar() generalConf, err := parseArgs() if err != nil { logger.Fatalf("failed to parse args: %s", err.Error()) } tlsConfig, err := getTLSConf(generalConf) if err != nil { logger.Fatalf("failed to initialize TLS config: %s", err.Error()) } var secret []byte if generalConf.isClient { secret, err = getClientSecret(generalConf, &tlsConfig) if err != nil { logger.Fatalf("failed to get secret: %s", err.Error()) } } else { secret, err = getServerSecret(generalConf, &tlsConfig) if err != nil { logger.Fatalf("failed to get secret: %s", err.Error()) } } if err := os.WriteFile(generalConf.outFile, secret, 0o600); err != nil { logger.Fatalf("failed to write to file exported secret: %s", err.Error()) } }