Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 37 additions & 22 deletions internal/cli/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,19 @@ func startDaemon(store *storage.Storage) {
signal.Notify(stop, os.Interrupt)
signal.Notify(stop, syscall.SIGTERM)

reload := make(chan os.Signal, 1)
signal.Notify(reload, syscall.SIGHUP)

pool := worker.NewPool(store, config.Opts.WorkerPoolSize())

if config.Opts.HasSchedulerService() && !config.Opts.HasMaintenanceMode() {
runScheduler(store, pool)
}

var httpServers []*http.Server
var certReloadFn func()
if config.Opts.HasHTTPService() {
httpServers = server.StartWebServer(store, pool)
httpServers, certReloadFn = server.StartWebServer(store, pool)
}

metricsCtx, cancelMetrics := context.WithCancel(context.Background())
Expand Down Expand Up @@ -74,29 +78,40 @@ func startDaemon(store *storage.Storage) {
}
}

<-stop
slog.Debug("Shutting down the process")
cancelMetrics()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

if len(httpServers) > 0 {
slog.Debug("Shutting down HTTP servers...")
for _, server := range httpServers {
if server != nil {
if err := server.Shutdown(ctx); err != nil {
slog.Error("HTTP server shutdown error", slog.Any("error", err), slog.String("addr", server.Addr))
for {
select {
case <-stop:
slog.Debug("Shutting down the process")
cancelMetrics()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

if len(httpServers) > 0 {
slog.Debug("Shutting down HTTP servers...")
for _, srv := range httpServers {
if srv != nil {
if err := srv.Shutdown(ctx); err != nil {
slog.Error("HTTP server shutdown error", slog.Any("error", err), slog.String("addr", srv.Addr))
}
}
}
slog.Debug("All HTTP servers shut down.")
} else {
slog.Debug("No HTTP servers to shut down.")
}
}
slog.Debug("All HTTP servers shut down.")
} else {
slog.Debug("No HTTP servers to shut down.")
}

slog.Debug("Shutting down worker pool...")
pool.Shutdown()
slog.Debug("Worker pool shut down.")
slog.Debug("Shutting down worker pool...")
pool.Shutdown()
slog.Debug("Worker pool shut down.")

slog.Debug("Process gracefully stopped")
slog.Debug("Process gracefully stopped")
return

case <-reload:
slog.Info("Received SIGHUP, reloading TLS certificates")
if certReloadFn != nil {
certReloadFn()
}
}
}
}
71 changes: 53 additions & 18 deletions internal/http/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import (
"golang.org/x/crypto/acme/autocert"
)

func StartWebServer(store *storage.Storage, pool *worker.Pool) []*http.Server {
func StartWebServer(store *storage.Storage, pool *worker.Pool) ([]*http.Server, func()) {
var servers []*http.Server

autocertTLSConfig, challengeServer := setupAutocert(store)
Expand All @@ -39,6 +39,26 @@ func StartWebServer(store *storage.Storage, pool *worker.Pool) []*http.Server {
config.Opts.SetHTTPSValue(true)
}

// Create a single certificate loader shared by all TLS servers
// that use the same cert/key pair.
var certLoader *certificateLoader
if certFile != "" && keyFile != "" {
hasTLSTarget := false
for _, t := range targets {
if t.mode == modeTLS || t.mode == modeUnixSocketTLS {
hasTLSTarget = true
break
}
}
if hasTLSTarget {
var err error
certLoader, err = newCertificateLoader(certFile, keyFile)
if err != nil {
printErrorAndExit("Unable to load TLS certificate from %s / %s: %v", certFile, keyFile, err)
}
}
}

for _, t := range targets {
srv := &http.Server{
Addr: t.address,
Expand All @@ -55,19 +75,23 @@ func StartWebServer(store *storage.Storage, pool *worker.Pool) []*http.Server {
case modeUnixSocket:
startUnixSocketServer(srv, t.address)
case modeUnixSocketTLS:
startUnixSocketTLSServer(srv, t.address, t.certFile, t.keyFile)
startUnixSocketTLSServer(srv, t.address, certLoader)
case modeAutocertTLS:
startAutoCertTLSServer(srv, autocertTLSConfig)
case modeTLS:
startTLSServer(srv, t.certFile, t.keyFile)
startTLSServer(srv, certLoader)
default:
startHTTPServer(srv)
}

servers = append(servers, srv)
}

return servers
certReloadFn := func() {}
if certLoader != nil {
certReloadFn = certLoader.Reload
}
return servers, certReloadFn
}

type listenerMode int
Expand All @@ -82,10 +106,8 @@ const (
)

type listenTarget struct {
address string
mode listenerMode
certFile string
keyFile string
address string
mode listenerMode
}

func determineListenTargets(addresses []string, certDomain, certFile, keyFile string) []listenTarget {
Expand All @@ -111,13 +133,13 @@ func determineListenTargets(addresses []string, certDomain, certFile, keyFile st

switch {
case isUnix && hasCertFiles:
targets = append(targets, listenTarget{address: addr, mode: modeUnixSocketTLS, certFile: certFile, keyFile: keyFile})
targets = append(targets, listenTarget{address: addr, mode: modeUnixSocketTLS})
case isUnix:
targets = append(targets, listenTarget{address: addr, mode: modeUnixSocket})
case hasAutocert && (addr == ":https" || (i == 0 && strings.Contains(addr, ":"))):
targets = append(targets, listenTarget{address: addr, mode: modeAutocertTLS})
case hasCertFiles:
targets = append(targets, listenTarget{address: addr, mode: modeTLS, certFile: certFile, keyFile: keyFile})
targets = append(targets, listenTarget{address: addr, mode: modeTLS})
default:
targets = append(targets, listenTarget{address: addr, mode: modeHTTP})
}
Expand Down Expand Up @@ -195,16 +217,20 @@ func startUnixSocketServer(server *http.Server, socketFile string) {
}()
}

func startUnixSocketTLSServer(server *http.Server, socketFile, certFile, keyFile string) {
func startUnixSocketTLSServer(server *http.Server, socketFile string, certLoader *certificateLoader) {
config := &tls.Config{
GetCertificate: certLoader.getCertificate,
NextProtos: []string{"h2", "http/1.1"},
}

listener := createUnixSocketListener(socketFile)

go func() {
slog.Info("Starting TLS server using a Unix socket",
slog.String("socket", socketFile),
slog.String("cert_file", certFile),
slog.String("key_file", keyFile),
)
if err := server.ServeTLS(listener, certFile, keyFile); err != http.ErrServerClosed {
tlsListener := tls.NewListener(listener, config)
if err := server.Serve(tlsListener); err != http.ErrServerClosed {
printErrorAndExit("TLS Unix socket server failed to start on %s: %v", socketFile, err)
}
}()
Expand Down Expand Up @@ -243,14 +269,23 @@ func startAutoCertTLSServer(server *http.Server, autoTLSConfig *tls.Config) {
}()
}

func startTLSServer(server *http.Server, certFile, keyFile string) {
func startTLSServer(server *http.Server, certLoader *certificateLoader) {
config := &tls.Config{
GetCertificate: certLoader.getCertificate,
NextProtos: []string{"h2", "http/1.1"},
}

listener, err := net.Listen("tcp", server.Addr)
if err != nil {
printErrorAndExit("TLS server failed to listen on %s: %v", server.Addr, err)
}

go func() {
slog.Info("Starting TLS server using a certificate",
slog.String("listen_address", server.Addr),
slog.String("cert_file", certFile),
slog.String("key_file", keyFile),
)
if err := server.ListenAndServeTLS(certFile, keyFile); err != http.ErrServerClosed {
tlsListener := tls.NewListener(listener, config)
if err := server.Serve(tlsListener); err != http.ErrServerClosed {
printErrorAndExit("TLS server failed to start on %s: %v", server.Addr, err)
}
}()
Expand Down
8 changes: 4 additions & 4 deletions internal/http/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func TestDetermineListenTargets(t *testing.T) {
certFile: "/path/to/cert.pem",
keyFile: "/path/to/key.pem",
expected: []listenTarget{
{address: ":443", mode: modeTLS, certFile: "/path/to/cert.pem", keyFile: "/path/to/key.pem"},
{address: ":443", mode: modeTLS},
},
},
{
Expand Down Expand Up @@ -94,7 +94,7 @@ func TestDetermineListenTargets(t *testing.T) {
certFile: "/path/to/cert.pem",
keyFile: "/path/to/key.pem",
expected: []listenTarget{
{address: "/var/run/miniflux.sock", mode: modeUnixSocketTLS, certFile: "/path/to/cert.pem", keyFile: "/path/to/key.pem"},
{address: "/var/run/miniflux.sock", mode: modeUnixSocketTLS},
},
},
{
Expand All @@ -103,8 +103,8 @@ func TestDetermineListenTargets(t *testing.T) {
certFile: "/path/to/cert.pem",
keyFile: "/path/to/key.pem",
expected: []listenTarget{
{address: "/var/run/miniflux.sock", mode: modeUnixSocketTLS, certFile: "/path/to/cert.pem", keyFile: "/path/to/key.pem"},
{address: ":8080", mode: modeTLS, certFile: "/path/to/cert.pem", keyFile: "/path/to/key.pem"},
{address: "/var/run/miniflux.sock", mode: modeUnixSocketTLS},
{address: ":8080", mode: modeTLS},
},
},
{
Expand Down
73 changes: 73 additions & 0 deletions internal/http/server/tls_certificate_loader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// SPDX-FileCopyrightText: Copyright The Miniflux Authors. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package server

import (
"crypto/tls"
"log/slog"
"path/filepath"
"sync"
)

// certificateLoader loads and caches a TLS certificate from disk, and
// provides a reload method that can be triggered on SIGHUP.
type certificateLoader struct {
mu sync.RWMutex
cert *tls.Certificate
certFile string
keyFile string
}

func newCertificateLoader(certFile, keyFile string) (*certificateLoader, error) {
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, err
}

loader := &certificateLoader{
cert: &cert,
certFile: filepath.Clean(certFile),
keyFile: filepath.Clean(keyFile),
}

slog.Info("TLS certificate loaded",
slog.String("cert_file", loader.certFile),
slog.String("key_file", loader.keyFile),
)

return loader, nil
}

// getCertificate returns the currently cached TLS certificate. It satisfies
// the tls.Config.GetCertificate callback and is called by the TLS layer on
// every handshake.
func (cl *certificateLoader) getCertificate(*tls.ClientHelloInfo) (*tls.Certificate, error) {
cl.mu.RLock()
defer cl.mu.RUnlock()
return cl.cert, nil
}

// Reload loads the certificate and key from disk and replaces the cached
// copy. If loading fails, the existing certificate is kept and the error
// is logged.
func (cl *certificateLoader) Reload() {
cert, err := tls.LoadX509KeyPair(cl.certFile, cl.keyFile)
if err != nil {
slog.Error("Unable to reload TLS certificate",
slog.String("cert_file", cl.certFile),
slog.String("key_file", cl.keyFile),
slog.Any("error", err),
)
return
}

cl.mu.Lock()
cl.cert = &cert
cl.mu.Unlock()

slog.Info("TLS certificate reloaded successfully",
slog.String("cert_file", cl.certFile),
slog.String("key_file", cl.keyFile),
)
}
Loading