diff --git a/internal/cli/daemon.go b/internal/cli/daemon.go index c5554448b54..f8005e5b21b 100644 --- a/internal/cli/daemon.go +++ b/internal/cli/daemon.go @@ -27,6 +27,9 @@ func startDaemon(store *storage.Storage) { signal.Notify(stop, os.Interrupt) signal.Notify(stop, syscall.SIGTERM) + reload := make(chan os.Signal, 1) + signal.Notify(reload, syscall.SIGHUP) + pool := worker.NewPool(store, config.Opts.WorkerPoolSize()) if config.Opts.HasSchedulerService() && !config.Opts.HasMaintenanceMode() { @@ -34,8 +37,9 @@ func startDaemon(store *storage.Storage) { } var httpServers []*http.Server + var certReloadFn func() if config.Opts.HasHTTPService() { - httpServers = server.StartWebServer(store, pool) + httpServers, certReloadFn = server.StartWebServer(store, pool) } metricsCtx, cancelMetrics := context.WithCancel(context.Background()) @@ -74,29 +78,40 @@ func startDaemon(store *storage.Storage) { } } - <-stop - slog.Debug("Shutting down the process") - cancelMetrics() - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - if len(httpServers) > 0 { - slog.Debug("Shutting down HTTP servers...") - for _, server := range httpServers { - if server != nil { - if err := server.Shutdown(ctx); err != nil { - slog.Error("HTTP server shutdown error", slog.Any("error", err), slog.String("addr", server.Addr)) + for { + select { + case <-stop: + slog.Debug("Shutting down the process") + cancelMetrics() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if len(httpServers) > 0 { + slog.Debug("Shutting down HTTP servers...") + for _, srv := range httpServers { + if srv != nil { + if err := srv.Shutdown(ctx); err != nil { + slog.Error("HTTP server shutdown error", slog.Any("error", err), slog.String("addr", srv.Addr)) + } + } } + slog.Debug("All HTTP servers shut down.") + } else { + slog.Debug("No HTTP servers to shut down.") } - } - slog.Debug("All HTTP servers shut down.") - } else { - slog.Debug("No HTTP servers to shut down.") - } - slog.Debug("Shutting down worker pool...") - pool.Shutdown() - slog.Debug("Worker pool shut down.") + slog.Debug("Shutting down worker pool...") + pool.Shutdown() + slog.Debug("Worker pool shut down.") - slog.Debug("Process gracefully stopped") + slog.Debug("Process gracefully stopped") + return + + case <-reload: + slog.Info("Received SIGHUP, reloading TLS certificates") + if certReloadFn != nil { + certReloadFn() + } + } + } } diff --git a/internal/http/server/server.go b/internal/http/server/server.go index f8fee0e5298..54f12b8ad87 100644 --- a/internal/http/server/server.go +++ b/internal/http/server/server.go @@ -21,7 +21,7 @@ import ( "golang.org/x/crypto/acme/autocert" ) -func StartWebServer(store *storage.Storage, pool *worker.Pool) []*http.Server { +func StartWebServer(store *storage.Storage, pool *worker.Pool) ([]*http.Server, func()) { var servers []*http.Server autocertTLSConfig, challengeServer := setupAutocert(store) @@ -39,6 +39,26 @@ func StartWebServer(store *storage.Storage, pool *worker.Pool) []*http.Server { config.Opts.SetHTTPSValue(true) } + // Create a single certificate loader shared by all TLS servers + // that use the same cert/key pair. + var certLoader *certificateLoader + if certFile != "" && keyFile != "" { + hasTLSTarget := false + for _, t := range targets { + if t.mode == modeTLS || t.mode == modeUnixSocketTLS { + hasTLSTarget = true + break + } + } + if hasTLSTarget { + var err error + certLoader, err = newCertificateLoader(certFile, keyFile) + if err != nil { + printErrorAndExit("Unable to load TLS certificate from %s / %s: %v", certFile, keyFile, err) + } + } + } + for _, t := range targets { srv := &http.Server{ Addr: t.address, @@ -55,11 +75,11 @@ func StartWebServer(store *storage.Storage, pool *worker.Pool) []*http.Server { case modeUnixSocket: startUnixSocketServer(srv, t.address) case modeUnixSocketTLS: - startUnixSocketTLSServer(srv, t.address, t.certFile, t.keyFile) + startUnixSocketTLSServer(srv, t.address, certLoader) case modeAutocertTLS: startAutoCertTLSServer(srv, autocertTLSConfig) case modeTLS: - startTLSServer(srv, t.certFile, t.keyFile) + startTLSServer(srv, certLoader) default: startHTTPServer(srv) } @@ -67,7 +87,11 @@ func StartWebServer(store *storage.Storage, pool *worker.Pool) []*http.Server { servers = append(servers, srv) } - return servers + certReloadFn := func() {} + if certLoader != nil { + certReloadFn = certLoader.Reload + } + return servers, certReloadFn } type listenerMode int @@ -82,10 +106,8 @@ const ( ) type listenTarget struct { - address string - mode listenerMode - certFile string - keyFile string + address string + mode listenerMode } func determineListenTargets(addresses []string, certDomain, certFile, keyFile string) []listenTarget { @@ -111,13 +133,13 @@ func determineListenTargets(addresses []string, certDomain, certFile, keyFile st switch { case isUnix && hasCertFiles: - targets = append(targets, listenTarget{address: addr, mode: modeUnixSocketTLS, certFile: certFile, keyFile: keyFile}) + targets = append(targets, listenTarget{address: addr, mode: modeUnixSocketTLS}) case isUnix: targets = append(targets, listenTarget{address: addr, mode: modeUnixSocket}) case hasAutocert && (addr == ":https" || (i == 0 && strings.Contains(addr, ":"))): targets = append(targets, listenTarget{address: addr, mode: modeAutocertTLS}) case hasCertFiles: - targets = append(targets, listenTarget{address: addr, mode: modeTLS, certFile: certFile, keyFile: keyFile}) + targets = append(targets, listenTarget{address: addr, mode: modeTLS}) default: targets = append(targets, listenTarget{address: addr, mode: modeHTTP}) } @@ -195,16 +217,20 @@ func startUnixSocketServer(server *http.Server, socketFile string) { }() } -func startUnixSocketTLSServer(server *http.Server, socketFile, certFile, keyFile string) { +func startUnixSocketTLSServer(server *http.Server, socketFile string, certLoader *certificateLoader) { + config := &tls.Config{ + GetCertificate: certLoader.getCertificate, + NextProtos: []string{"h2", "http/1.1"}, + } + listener := createUnixSocketListener(socketFile) go func() { slog.Info("Starting TLS server using a Unix socket", slog.String("socket", socketFile), - slog.String("cert_file", certFile), - slog.String("key_file", keyFile), ) - if err := server.ServeTLS(listener, certFile, keyFile); err != http.ErrServerClosed { + tlsListener := tls.NewListener(listener, config) + if err := server.Serve(tlsListener); err != http.ErrServerClosed { printErrorAndExit("TLS Unix socket server failed to start on %s: %v", socketFile, err) } }() @@ -243,14 +269,23 @@ func startAutoCertTLSServer(server *http.Server, autoTLSConfig *tls.Config) { }() } -func startTLSServer(server *http.Server, certFile, keyFile string) { +func startTLSServer(server *http.Server, certLoader *certificateLoader) { + config := &tls.Config{ + GetCertificate: certLoader.getCertificate, + NextProtos: []string{"h2", "http/1.1"}, + } + + listener, err := net.Listen("tcp", server.Addr) + if err != nil { + printErrorAndExit("TLS server failed to listen on %s: %v", server.Addr, err) + } + go func() { slog.Info("Starting TLS server using a certificate", slog.String("listen_address", server.Addr), - slog.String("cert_file", certFile), - slog.String("key_file", keyFile), ) - if err := server.ListenAndServeTLS(certFile, keyFile); err != http.ErrServerClosed { + tlsListener := tls.NewListener(listener, config) + if err := server.Serve(tlsListener); err != http.ErrServerClosed { printErrorAndExit("TLS server failed to start on %s: %v", server.Addr, err) } }() diff --git a/internal/http/server/server_test.go b/internal/http/server/server_test.go index d4674fb16b2..b2471d89e41 100644 --- a/internal/http/server/server_test.go +++ b/internal/http/server/server_test.go @@ -37,7 +37,7 @@ func TestDetermineListenTargets(t *testing.T) { certFile: "/path/to/cert.pem", keyFile: "/path/to/key.pem", expected: []listenTarget{ - {address: ":443", mode: modeTLS, certFile: "/path/to/cert.pem", keyFile: "/path/to/key.pem"}, + {address: ":443", mode: modeTLS}, }, }, { @@ -94,7 +94,7 @@ func TestDetermineListenTargets(t *testing.T) { certFile: "/path/to/cert.pem", keyFile: "/path/to/key.pem", expected: []listenTarget{ - {address: "/var/run/miniflux.sock", mode: modeUnixSocketTLS, certFile: "/path/to/cert.pem", keyFile: "/path/to/key.pem"}, + {address: "/var/run/miniflux.sock", mode: modeUnixSocketTLS}, }, }, { @@ -103,8 +103,8 @@ func TestDetermineListenTargets(t *testing.T) { certFile: "/path/to/cert.pem", keyFile: "/path/to/key.pem", expected: []listenTarget{ - {address: "/var/run/miniflux.sock", mode: modeUnixSocketTLS, certFile: "/path/to/cert.pem", keyFile: "/path/to/key.pem"}, - {address: ":8080", mode: modeTLS, certFile: "/path/to/cert.pem", keyFile: "/path/to/key.pem"}, + {address: "/var/run/miniflux.sock", mode: modeUnixSocketTLS}, + {address: ":8080", mode: modeTLS}, }, }, { diff --git a/internal/http/server/tls_certificate_loader.go b/internal/http/server/tls_certificate_loader.go new file mode 100644 index 00000000000..befcf6d004b --- /dev/null +++ b/internal/http/server/tls_certificate_loader.go @@ -0,0 +1,73 @@ +// SPDX-FileCopyrightText: Copyright The Miniflux Authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package server + +import ( + "crypto/tls" + "log/slog" + "path/filepath" + "sync" +) + +// certificateLoader loads and caches a TLS certificate from disk, and +// provides a reload method that can be triggered on SIGHUP. +type certificateLoader struct { + mu sync.RWMutex + cert *tls.Certificate + certFile string + keyFile string +} + +func newCertificateLoader(certFile, keyFile string) (*certificateLoader, error) { + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return nil, err + } + + loader := &certificateLoader{ + cert: &cert, + certFile: filepath.Clean(certFile), + keyFile: filepath.Clean(keyFile), + } + + slog.Info("TLS certificate loaded", + slog.String("cert_file", loader.certFile), + slog.String("key_file", loader.keyFile), + ) + + return loader, nil +} + +// getCertificate returns the currently cached TLS certificate. It satisfies +// the tls.Config.GetCertificate callback and is called by the TLS layer on +// every handshake. +func (cl *certificateLoader) getCertificate(*tls.ClientHelloInfo) (*tls.Certificate, error) { + cl.mu.RLock() + defer cl.mu.RUnlock() + return cl.cert, nil +} + +// Reload loads the certificate and key from disk and replaces the cached +// copy. If loading fails, the existing certificate is kept and the error +// is logged. +func (cl *certificateLoader) Reload() { + cert, err := tls.LoadX509KeyPair(cl.certFile, cl.keyFile) + if err != nil { + slog.Error("Unable to reload TLS certificate", + slog.String("cert_file", cl.certFile), + slog.String("key_file", cl.keyFile), + slog.Any("error", err), + ) + return + } + + cl.mu.Lock() + cl.cert = &cert + cl.mu.Unlock() + + slog.Info("TLS certificate reloaded successfully", + slog.String("cert_file", cl.certFile), + slog.String("key_file", cl.keyFile), + ) +} diff --git a/internal/http/server/tls_certificate_loader_test.go b/internal/http/server/tls_certificate_loader_test.go new file mode 100644 index 00000000000..073b8642b27 --- /dev/null +++ b/internal/http/server/tls_certificate_loader_test.go @@ -0,0 +1,212 @@ +// SPDX-FileCopyrightText: Copyright The Miniflux Authors. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package server + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "net" + "os" + "path/filepath" + "testing" + "time" +) + +// generateTestCert creates a self-signed certificate and key and writes them +// to PEM files in the given directory. Returns cert and key file paths. +func generateTestCert(t *testing.T, dir, prefix string) (string, string) { + t.Helper() + + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("failed to generate RSA key: %v", err) + } + + serial := big.NewInt(time.Now().UnixNano()) + tmpl := &x509.Certificate{ + SerialNumber: serial, + Subject: pkix.Name{ + CommonName: prefix + ".example.com", + }, + NotBefore: time.Now().Add(-1 * time.Hour), + NotAfter: time.Now().Add(1 * time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + } + + certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key) + if err != nil { + t.Fatalf("failed to create certificate: %v", err) + } + + certFile := filepath.Join(dir, prefix+".pem") + keyFile := filepath.Join(dir, prefix+"-key.pem") + + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + keyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(key), + }) + + if err := os.WriteFile(certFile, certPEM, 0600); err != nil { + t.Fatalf("failed to write cert file: %v", err) + } + if err := os.WriteFile(keyFile, keyPEM, 0600); err != nil { + t.Fatalf("failed to write key file: %v", err) + } + + return certFile, keyFile +} + +// certLoaderSerial extracts the serial number of the first certificate +// returned by the loader's getCertificate callback. +func certLoaderSerial(cl *certificateLoader) *big.Int { + cert, err := cl.getCertificate(nil) + if err != nil || cert == nil { + return nil + } + x509Cert, err := x509.ParseCertificate(cert.Certificate[0]) + if err != nil { + return nil + } + return x509Cert.SerialNumber +} + +// TestCertificateLoaderInitialLoad verifies that a new certificateLoader +// loads the certificate successfully and serves it via getCertificate. +func TestCertificateLoaderInitialLoad(t *testing.T) { + dir := t.TempDir() + certFile, keyFile := generateTestCert(t, dir, "initial") + + cl, err := newCertificateLoader(certFile, keyFile) + if err != nil { + t.Fatalf("newCertificateLoader failed: %v", err) + } + + cert, err := cl.getCertificate(nil) + if err != nil { + t.Fatalf("getCertificate failed: %v", err) + } + if cert == nil { + t.Fatal("getCertificate returned nil certificate") + } + if len(cert.Certificate) == 0 { + t.Fatal("certificate chain is empty") + } + if certLoaderSerial(cl) == nil { + t.Fatal("unable to parse certificate serial") + } +} + +// TestCertificateLoaderReload verifies that Reload picks up a new certificate +// written to disk. +func TestCertificateLoaderReload(t *testing.T) { + dir := t.TempDir() + certFile, keyFile := generateTestCert(t, dir, "reload") + + cl, err := newCertificateLoader(certFile, keyFile) + if err != nil { + t.Fatalf("newCertificateLoader failed: %v", err) + } + + origSerial := certLoaderSerial(cl) + if origSerial == nil { + t.Fatal("unable to parse original certificate serial") + } + + // Write a new certificate to the same file paths. + generateTestCert(t, dir, "reload") + + cl.Reload() + + newSerial := certLoaderSerial(cl) + if newSerial == nil { + t.Fatal("unable to parse certificate serial after reload") + } + if origSerial.Cmp(newSerial) == 0 { + t.Fatal("certificate serial did not change after reload") + } +} + +// TestCertificateLoaderReloadFailureKeepsOldCert verifies that if Reload fails +// the old certificate is preserved. +func TestCertificateLoaderReloadFailureKeepsOldCert(t *testing.T) { + dir := t.TempDir() + certFile, keyFile := generateTestCert(t, dir, "keep-old") + + cl, err := newCertificateLoader(certFile, keyFile) + if err != nil { + t.Fatalf("newCertificateLoader failed: %v", err) + } + + origSerial := certLoaderSerial(cl) + if origSerial == nil { + t.Fatal("unable to parse original certificate serial") + } + + // Corrupt the key file. + if err := os.WriteFile(keyFile, []byte("not a valid PEM key"), 0600); err != nil { + t.Fatalf("failed to write corrupted key file: %v", err) + } + + cl.Reload() + + curSerial := certLoaderSerial(cl) + if curSerial == nil { + t.Fatal("unable to parse certificate serial after failed reload") + } + if origSerial.Cmp(curSerial) != 0 { + t.Fatal("certificate changed after a failed reload") + } +} + +// TestCertificateLoaderNilClientHello verifies getCertificate handles a nil +// *tls.ClientHelloInfo argument. +func TestCertificateLoaderNilClientHello(t *testing.T) { + dir := t.TempDir() + certFile, keyFile := generateTestCert(t, dir, "nil-hello") + + cl, err := newCertificateLoader(certFile, keyFile) + if err != nil { + t.Fatalf("newCertificateLoader failed: %v", err) + } + + cert, err := cl.getCertificate(nil) + if err != nil { + t.Fatalf("getCertificate(nil) returned error: %v", err) + } + if cert == nil { + t.Fatal("getCertificate(nil) returned nil") + } +} + +// TestCertificateLoaderClientHelloInfo verifies that getCertificate works +// when called with a real *tls.ClientHelloInfo. +func TestCertificateLoaderClientHelloInfo(t *testing.T) { + dir := t.TempDir() + certFile, keyFile := generateTestCert(t, dir, "sni") + + cl, err := newCertificateLoader(certFile, keyFile) + if err != nil { + t.Fatalf("newCertificateLoader failed: %v", err) + } + + hello := &tls.ClientHelloInfo{ + ServerName: "sni.example.com", + } + cert, err := cl.getCertificate(hello) + if err != nil { + t.Fatalf("getCertificate with ClientHelloInfo failed: %v", err) + } + if cert == nil { + t.Fatal("getCertificate with ClientHelloInfo returned nil") + } +}