get_secrets/main.go
2025-01-22 10:57:00 +01:00

294 lines
8 KiB
Go

package main
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"flag"
"fmt"
"net"
"os"
"slices"
"syscall"
"time"
"go.uber.org/zap"
)
const (
MAX_CERT_POOL_SIZE = 1024
HANDSHAKE_TICK = 10 * time.Millisecond
HANDSHAKE_TIMEOUT = 5 * time.Second
)
type config struct {
caFile string
certFile string
keyFile string
peerName string
isClient bool
outFile string
serverAddress string
serverPort int
exportLabel string
exportLen int
}
func parseArgs() (c config, err error) {
caFile := flag.String("ca", "", "Path to the file containing all of the Certification Authorities that should be trusted")
certFile := flag.String("crt", "", "Path to the file containing the certificate to present to the TLS peer")
keyFile := flag.String("key", "", "Path to the file containing the key to use with the TLS peer")
peerName := flag.String("peer", "", "Name of the peer that must be in the certificate returned")
isClient := flag.Bool("client", false, "Whether we should act as a TLS client. If not specified, act as a server")
outFile := flag.String("out", "", "Path to the output file containing the exported secret")
exportLabel := flag.String("label", "", "Label associated with the exported secret")
exportLen := flag.Int("length", 32, "Length of the exported secret, in bytes")
serverAddress := flag.String("addr", "", "IP address to which the TLS server must bind. By default, it binds to all local addresses")
serverPort := flag.Int("port", 443, "Port number of the server")
flag.Parse()
if *caFile == "" {
err = fmt.Errorf("ca argument is mandatory")
return
}
if *certFile == "" {
err = fmt.Errorf("crt argument is mandatory")
return
}
if *keyFile == "" {
err = fmt.Errorf("key argument is mandatory")
return
}
if *peerName == "" {
err = fmt.Errorf("peer argument is mandatory")
return
}
if *outFile == "" {
err = fmt.Errorf("out argument is mandatory")
return
}
if *exportLabel == "" {
err = fmt.Errorf("label argument is mandatory")
}
c = config{
caFile: *caFile,
certFile: *certFile,
keyFile: *keyFile,
peerName: *peerName,
isClient: *isClient,
outFile: *outFile,
exportLabel: *exportLabel,
exportLen: *exportLen,
serverAddress: *serverAddress,
serverPort: *serverPort,
}
return
}
func getCAPool(caFile string) (cp *x509.CertPool, retErr error) {
cp = x509.NewCertPool()
f, err := os.Open(caFile)
if err != nil {
retErr = fmt.Errorf("failed to open CA file: %w", err)
return
}
defer f.Close()
fileInfo, err := f.Stat()
if err != nil {
retErr = fmt.Errorf("failed to stat() CA file: %w", err)
return
}
fileMap, err := syscall.Mmap(int(f.Fd()), 0, int(fileInfo.Size()), syscall.PROT_READ, syscall.MAP_PRIVATE)
if err != nil {
retErr = fmt.Errorf("failed to mmap file content: %w", err)
return
}
defer syscall.Munmap(fileMap)
certPoolSize := 0
fm := fileMap[:]
for len(fm) > 0 {
block, rem := pem.Decode(fm)
if block == nil {
max := 1024
if len(fm) < max {
max = len(fm)
}
retErr = fmt.Errorf("invalid PEM encoding at %q", string(fm[:max]))
return
}
fm = rem
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
retErr = fmt.Errorf("failed to parse certificate: %w", err)
return
}
cp.AddCert(cert)
certPoolSize += 1
if certPoolSize > MAX_CERT_POOL_SIZE {
retErr = fmt.Errorf("cert pool max size exceeded")
return
}
}
return
}
func verifyClientCertificate(verifiedChains [][]*x509.Certificate, peerName string) error {
for _, certChain := range verifiedChains {
if len(certChain) == 0 {
return fmt.Errorf("empty cert chain: should never happen?")
}
cert := *certChain[0]
if !slices.Contains(cert.DNSNames, peerName) {
continue
}
if cert.KeyUsage&x509.KeyUsageDigitalSignature == 0 {
continue
}
if !slices.Contains(cert.ExtKeyUsage, x509.ExtKeyUsageClientAuth) {
continue
}
return nil
}
return fmt.Errorf("failed to found a proper certificate matching the expected values")
}
func getTLSConf(generalConf config) (tc tls.Config, retErr error) {
cert, err := tls.LoadX509KeyPair(generalConf.certFile, generalConf.keyFile)
if err != nil {
retErr = fmt.Errorf("failed to load certificates or private key from files: %w", err)
return
}
caPool, err := getCAPool(generalConf.caFile)
if err != nil {
retErr = fmt.Errorf("failed to load CA certificates: %w", err)
return
}
tc.MinVersion = tls.VersionTLS13
tc.Certificates = []tls.Certificate{cert}
if generalConf.isClient { // TLS Client
tc.RootCAs = caPool
tc.ServerName = generalConf.peerName
} else { // TLS Server
tc.ClientCAs = caPool
tc.ClientAuth = tls.RequireAndVerifyClientCert
tc.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
return verifyClientCertificate(verifiedChains, generalConf.peerName)
}
}
return
}
func handshakeCompleted(ctx context.Context, conn *tls.Conn) <-chan struct{} {
c := make(chan struct{}, 0)
go func() {
tick := time.NewTicker(HANDSHAKE_TICK)
for {
cs := conn.ConnectionState()
if cs.HandshakeComplete {
c <- struct{}{}
return
}
tick.Reset(HANDSHAKE_TICK)
select {
case <-ctx.Done():
return
case <-tick.C:
}
}
}()
return c
}
func getClientSecret(generalConf config, tlsConfig *tls.Config) ([]byte, error) {
conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", generalConf.serverAddress, generalConf.serverPort), tlsConfig)
if err != nil {
return nil, fmt.Errorf("failed to connect to the TLS server: %w", err)
}
if err := conn.Handshake(); err != nil {
return nil, fmt.Errorf("failed to perform handshake: %w", err)
}
ctx, cancelFunc := context.WithTimeout(context.Background(), HANDSHAKE_TIMEOUT)
defer cancelFunc()
select {
case <-handshakeCompleted(ctx, conn):
cancelFunc()
case <-ctx.Done():
return nil, fmt.Errorf("failed to perform handshake during the alloted time")
}
defer conn.Close()
cs := conn.ConnectionState()
keyMat, err := cs.ExportKeyingMaterial(generalConf.exportLabel, nil, generalConf.exportLen)
if err != nil {
return nil, fmt.Errorf("failed to export key material from TLS connection: %w", err)
}
return keyMat, nil
}
func getServerSecret(generalConf config, tlsConfig *tls.Config) ([]byte, error) {
lstn, err := net.Listen("tcp", fmt.Sprintf("%s:%d", generalConf.serverAddress, generalConf.serverPort))
if err != nil {
return nil, fmt.Errorf("failed to create the new server listener: %w", err)
}
conn, err := lstn.Accept()
if err != nil {
return nil, fmt.Errorf("failed to accept connection from the client: %w", err)
}
defer conn.Close()
tlsConn := tls.Server(conn, tlsConfig)
if err := tlsConn.Handshake(); err != nil {
return nil, fmt.Errorf("failed to perform handshake: %w", err)
}
defer tlsConn.Close()
ctx, cancelFunc := context.WithTimeout(context.Background(), HANDSHAKE_TIMEOUT)
defer cancelFunc()
select {
case <-handshakeCompleted(ctx, tlsConn):
cancelFunc()
case <-ctx.Done():
return nil, fmt.Errorf("failed to perform handshake during the alloted time")
}
cs := tlsConn.ConnectionState()
keyMat, _ := cs.ExportKeyingMaterial(generalConf.exportLabel, nil, generalConf.exportLen)
return keyMat, nil
}
func main() {
rawLogger, err := zap.NewProduction()
if err != nil {
os.Exit(1)
}
defer rawLogger.Sync()
logger := rawLogger.Sugar()
generalConf, err := parseArgs()
if err != nil {
logger.Fatalf("failed to parse args: %s", err.Error())
}
tlsConfig, err := getTLSConf(generalConf)
if err != nil {
logger.Fatalf("failed to initialize TLS config: %s", err.Error())
}
var secret []byte
if generalConf.isClient {
secret, err = getClientSecret(generalConf, &tlsConfig)
if err != nil {
logger.Fatalf("failed to get secret: %s", err.Error())
}
} else {
secret, err = getServerSecret(generalConf, &tlsConfig)
if err != nil {
logger.Fatalf("failed to get secret: %s", err.Error())
}
}
if err := os.WriteFile(generalConf.outFile, secret, 0o600); err != nil {
logger.Fatalf("failed to write to file exported secret: %s", err.Error())
}
}