294 lines
8 KiB
Go
294 lines
8 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"encoding/pem"
|
|
"flag"
|
|
"fmt"
|
|
"net"
|
|
"os"
|
|
"slices"
|
|
"syscall"
|
|
"time"
|
|
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
const (
|
|
MAX_CERT_POOL_SIZE = 1024
|
|
HANDSHAKE_TICK = 10 * time.Millisecond
|
|
HANDSHAKE_TIMEOUT = 5 * time.Second
|
|
)
|
|
|
|
type config struct {
|
|
caFile string
|
|
certFile string
|
|
keyFile string
|
|
peerName string
|
|
isClient bool
|
|
outFile string
|
|
serverAddress string
|
|
serverPort int
|
|
exportLabel string
|
|
exportLen int
|
|
}
|
|
|
|
func parseArgs() (c config, err error) {
|
|
caFile := flag.String("ca", "", "Path to the file containing all of the Certification Authorities that should be trusted")
|
|
certFile := flag.String("crt", "", "Path to the file containing the certificate to present to the TLS peer")
|
|
keyFile := flag.String("key", "", "Path to the file containing the key to use with the TLS peer")
|
|
peerName := flag.String("peer", "", "Name of the peer that must be in the certificate returned")
|
|
isClient := flag.Bool("client", false, "Whether we should act as a TLS client. If not specified, act as a server")
|
|
outFile := flag.String("out", "", "Path to the output file containing the exported secret")
|
|
exportLabel := flag.String("label", "", "Label associated with the exported secret")
|
|
exportLen := flag.Int("length", 32, "Length of the exported secret, in bytes")
|
|
serverAddress := flag.String("addr", "", "IP address to which the TLS server must bind. By default, it binds to all local addresses")
|
|
serverPort := flag.Int("port", 443, "Port number of the server")
|
|
flag.Parse()
|
|
|
|
if *caFile == "" {
|
|
err = fmt.Errorf("ca argument is mandatory")
|
|
return
|
|
}
|
|
if *certFile == "" {
|
|
err = fmt.Errorf("crt argument is mandatory")
|
|
return
|
|
}
|
|
if *keyFile == "" {
|
|
err = fmt.Errorf("key argument is mandatory")
|
|
return
|
|
}
|
|
if *peerName == "" {
|
|
err = fmt.Errorf("peer argument is mandatory")
|
|
return
|
|
}
|
|
if *outFile == "" {
|
|
err = fmt.Errorf("out argument is mandatory")
|
|
return
|
|
}
|
|
if *exportLabel == "" {
|
|
err = fmt.Errorf("label argument is mandatory")
|
|
}
|
|
c = config{
|
|
caFile: *caFile,
|
|
certFile: *certFile,
|
|
keyFile: *keyFile,
|
|
peerName: *peerName,
|
|
isClient: *isClient,
|
|
outFile: *outFile,
|
|
exportLabel: *exportLabel,
|
|
exportLen: *exportLen,
|
|
serverAddress: *serverAddress,
|
|
serverPort: *serverPort,
|
|
}
|
|
return
|
|
}
|
|
|
|
func getCAPool(caFile string) (cp *x509.CertPool, retErr error) {
|
|
cp = x509.NewCertPool()
|
|
f, err := os.Open(caFile)
|
|
if err != nil {
|
|
retErr = fmt.Errorf("failed to open CA file: %w", err)
|
|
return
|
|
}
|
|
defer f.Close()
|
|
fileInfo, err := f.Stat()
|
|
if err != nil {
|
|
retErr = fmt.Errorf("failed to stat() CA file: %w", err)
|
|
return
|
|
}
|
|
fileMap, err := syscall.Mmap(int(f.Fd()), 0, int(fileInfo.Size()), syscall.PROT_READ, syscall.MAP_PRIVATE)
|
|
if err != nil {
|
|
retErr = fmt.Errorf("failed to mmap file content: %w", err)
|
|
return
|
|
}
|
|
defer syscall.Munmap(fileMap)
|
|
|
|
certPoolSize := 0
|
|
fm := fileMap[:]
|
|
for len(fm) > 0 {
|
|
block, rem := pem.Decode(fm)
|
|
if block == nil {
|
|
max := 1024
|
|
if len(fm) < max {
|
|
max = len(fm)
|
|
}
|
|
retErr = fmt.Errorf("invalid PEM encoding at %q", string(fm[:max]))
|
|
return
|
|
}
|
|
fm = rem
|
|
cert, err := x509.ParseCertificate(block.Bytes)
|
|
if err != nil {
|
|
retErr = fmt.Errorf("failed to parse certificate: %w", err)
|
|
return
|
|
}
|
|
cp.AddCert(cert)
|
|
certPoolSize += 1
|
|
if certPoolSize > MAX_CERT_POOL_SIZE {
|
|
retErr = fmt.Errorf("cert pool max size exceeded")
|
|
return
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
func verifyClientCertificate(verifiedChains [][]*x509.Certificate, peerName string) error {
|
|
for _, certChain := range verifiedChains {
|
|
if len(certChain) == 0 {
|
|
return fmt.Errorf("empty cert chain: should never happen?")
|
|
}
|
|
cert := *certChain[0]
|
|
if !slices.Contains(cert.DNSNames, peerName) {
|
|
continue
|
|
}
|
|
if cert.KeyUsage&x509.KeyUsageDigitalSignature == 0 {
|
|
continue
|
|
}
|
|
if !slices.Contains(cert.ExtKeyUsage, x509.ExtKeyUsageClientAuth) {
|
|
continue
|
|
}
|
|
return nil
|
|
}
|
|
return fmt.Errorf("failed to found a proper certificate matching the expected values")
|
|
}
|
|
|
|
func getTLSConf(generalConf config) (tc tls.Config, retErr error) {
|
|
cert, err := tls.LoadX509KeyPair(generalConf.certFile, generalConf.keyFile)
|
|
if err != nil {
|
|
retErr = fmt.Errorf("failed to load certificates or private key from files: %w", err)
|
|
return
|
|
}
|
|
caPool, err := getCAPool(generalConf.caFile)
|
|
if err != nil {
|
|
retErr = fmt.Errorf("failed to load CA certificates: %w", err)
|
|
return
|
|
}
|
|
tc.MinVersion = tls.VersionTLS13
|
|
tc.Certificates = []tls.Certificate{cert}
|
|
|
|
if generalConf.isClient { // TLS Client
|
|
tc.RootCAs = caPool
|
|
tc.ServerName = generalConf.peerName
|
|
} else { // TLS Server
|
|
tc.ClientCAs = caPool
|
|
tc.ClientAuth = tls.RequireAndVerifyClientCert
|
|
tc.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
|
|
return verifyClientCertificate(verifiedChains, generalConf.peerName)
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
func handshakeCompleted(ctx context.Context, conn *tls.Conn) <-chan struct{} {
|
|
c := make(chan struct{}, 0)
|
|
go func() {
|
|
tick := time.NewTicker(HANDSHAKE_TICK)
|
|
for {
|
|
cs := conn.ConnectionState()
|
|
if cs.HandshakeComplete {
|
|
c <- struct{}{}
|
|
return
|
|
}
|
|
tick.Reset(HANDSHAKE_TICK)
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-tick.C:
|
|
}
|
|
}
|
|
}()
|
|
return c
|
|
}
|
|
|
|
func getClientSecret(generalConf config, tlsConfig *tls.Config) ([]byte, error) {
|
|
conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", generalConf.serverAddress, generalConf.serverPort), tlsConfig)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to connect to the TLS server: %w", err)
|
|
}
|
|
if err := conn.Handshake(); err != nil {
|
|
return nil, fmt.Errorf("failed to perform handshake: %w", err)
|
|
}
|
|
ctx, cancelFunc := context.WithTimeout(context.Background(), HANDSHAKE_TIMEOUT)
|
|
defer cancelFunc()
|
|
select {
|
|
case <-handshakeCompleted(ctx, conn):
|
|
cancelFunc()
|
|
case <-ctx.Done():
|
|
return nil, fmt.Errorf("failed to perform handshake during the alloted time")
|
|
}
|
|
defer conn.Close()
|
|
cs := conn.ConnectionState()
|
|
keyMat, err := cs.ExportKeyingMaterial(generalConf.exportLabel, nil, generalConf.exportLen)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to export key material from TLS connection: %w", err)
|
|
}
|
|
return keyMat, nil
|
|
}
|
|
|
|
func getServerSecret(generalConf config, tlsConfig *tls.Config) ([]byte, error) {
|
|
lstn, err := net.Listen("tcp", fmt.Sprintf("%s:%d", generalConf.serverAddress, generalConf.serverPort))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create the new server listener: %w", err)
|
|
}
|
|
conn, err := lstn.Accept()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to accept connection from the client: %w", err)
|
|
}
|
|
defer conn.Close()
|
|
|
|
tlsConn := tls.Server(conn, tlsConfig)
|
|
if err := tlsConn.Handshake(); err != nil {
|
|
return nil, fmt.Errorf("failed to perform handshake: %w", err)
|
|
}
|
|
defer tlsConn.Close()
|
|
|
|
ctx, cancelFunc := context.WithTimeout(context.Background(), HANDSHAKE_TIMEOUT)
|
|
defer cancelFunc()
|
|
select {
|
|
case <-handshakeCompleted(ctx, tlsConn):
|
|
cancelFunc()
|
|
case <-ctx.Done():
|
|
return nil, fmt.Errorf("failed to perform handshake during the alloted time")
|
|
}
|
|
|
|
cs := tlsConn.ConnectionState()
|
|
keyMat, _ := cs.ExportKeyingMaterial(generalConf.exportLabel, nil, generalConf.exportLen)
|
|
return keyMat, nil
|
|
}
|
|
|
|
func main() {
|
|
rawLogger, err := zap.NewProduction()
|
|
if err != nil {
|
|
os.Exit(1)
|
|
}
|
|
defer rawLogger.Sync()
|
|
logger := rawLogger.Sugar()
|
|
|
|
generalConf, err := parseArgs()
|
|
if err != nil {
|
|
logger.Fatalf("failed to parse args: %s", err.Error())
|
|
}
|
|
|
|
tlsConfig, err := getTLSConf(generalConf)
|
|
if err != nil {
|
|
logger.Fatalf("failed to initialize TLS config: %s", err.Error())
|
|
}
|
|
|
|
var secret []byte
|
|
if generalConf.isClient {
|
|
secret, err = getClientSecret(generalConf, &tlsConfig)
|
|
if err != nil {
|
|
logger.Fatalf("failed to get secret: %s", err.Error())
|
|
}
|
|
} else {
|
|
secret, err = getServerSecret(generalConf, &tlsConfig)
|
|
if err != nil {
|
|
logger.Fatalf("failed to get secret: %s", err.Error())
|
|
}
|
|
}
|
|
if err := os.WriteFile(generalConf.outFile, secret, 0o600); err != nil {
|
|
logger.Fatalf("failed to write to file exported secret: %s", err.Error())
|
|
}
|
|
}
|