rio/internal/server/tls.go

185 lines
5.1 KiB
Go
Raw Normal View History

2024-01-01 13:19:19 +00:00
package server
import (
"crypto/tls"
"strings"
"sync"
"git.polynom.me/rio/internal/certificates"
"git.polynom.me/rio/internal/dns"
"git.polynom.me/rio/internal/repo"
2024-01-01 13:19:19 +00:00
"github.com/go-acme/lego/v4/lego"
log "github.com/sirupsen/logrus"
)
var (
// To access requestingDomains, first acquire the lock.
domainsLock = sync.Mutex{}
2024-01-01 13:19:19 +00:00
// Domain -> _. Check if domain is a key here to see if we're already requesting
// or renewing a certificate for that domain.
workingDomains = make(map[string]bool)
2024-01-01 13:19:19 +00:00
)
func lockIfUnlockedDomain(domain string) bool {
domainsLock.Lock()
defer domainsLock.Unlock()
2024-01-01 13:19:19 +00:00
_, found := workingDomains[domain]
2024-01-01 13:19:19 +00:00
if !found {
workingDomains[domain] = true
2024-01-01 13:19:19 +00:00
}
return found
}
func unlockDomain(domain string) {
domainsLock.Lock()
defer domainsLock.Unlock()
2024-01-01 13:19:19 +00:00
delete(workingDomains, domain)
2024-01-01 13:19:19 +00:00
}
func buildDomainList(domain, pagesDomain string) []string {
2024-01-06 14:06:38 +00:00
// TODO: For wildcards, we MUST use DNS01
/*if domain == pagesDomain || strings.HasSuffix(domain, pagesDomain) {
return []string{
pagesDomain,
"*." + pagesDomain,
}
2024-01-06 14:06:38 +00:00
}*/
return []string{domain}
}
func getDomainKey(domain, pagesDomain string) string {
2024-01-06 14:06:38 +00:00
// TODO: For wildcards, we MUST use DNS01
/*if domain == pagesDomain || strings.HasSuffix(domain, pagesDomain) {
return "*." + pagesDomain
2024-01-06 14:06:38 +00:00
}*/
return domain
}
func MakeTlsConfig(pagesDomain, cachePath string, cache *certificates.CertificatesCache, acmeClient *lego.Client, giteaClient *repo.GiteaClient) *tls.Config {
2024-01-01 13:19:19 +00:00
return &tls.Config{
GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
// Validate that we should even care about this domain
isPagesDomain := info.ServerName == pagesDomain
cname := ""
2024-01-01 13:19:19 +00:00
if !strings.HasSuffix(info.ServerName, pagesDomain) {
// Note: We do not check err here because err != nil
// always implies that cname == "", which does not have
// pagesDomain as a suffix.
cname, _ = dns.LookupCNAME(info.ServerName)
2024-01-01 13:19:19 +00:00
if !strings.HasSuffix(cname, pagesDomain) {
log.Warnf("Got ServerName for Domain %s that we're not responsible for", info.ServerName)
return cache.FallbackCertificate.TlsCertificate, nil
}
}
// Figure out a username for later username checks
username := ""
if cname == "" {
// domain ends on pagesDomain
username = strings.Split(info.ServerName, ".")[0]
} else {
// cname ends on pagesDomain
username = strings.Split(cname, ".")[0]
}
// Find the correct certificate
domainKey := getDomainKey(info.ServerName, pagesDomain)
cert, found := cache.Certificates[domainKey]
2024-01-01 13:19:19 +00:00
if found {
if cert.IsValid() {
return cert.TlsCertificate, nil
} else {
if !isPagesDomain && !repo.CanRequestCertificate(username, giteaClient) {
log.Warnf(
"Cannot renew certificate for %s because CanRequestCertificate(%s) returned false",
info.ServerName,
username,
)
return cert.TlsCertificate, nil
}
2024-01-01 13:19:19 +00:00
// If we're already working on the domain,
// return the old certificate
if lockIfUnlockedDomain(domainKey) {
2024-01-01 13:19:19 +00:00
return cert.TlsCertificate, nil
}
defer unlockDomain(domainKey)
2024-01-01 13:19:19 +00:00
// Renew the certificate
log.Infof("Certificate for %s expired, renewing", info.ServerName)
newCert, err := certificates.RenewCertificate(&cert, acmeClient)
if err != nil {
log.Errorf("Failed to renew certificate for %s: %v", info.ServerName, err)
return cert.TlsCertificate, nil
}
log.Info("Successfully renewed certificate!")
cache.AddCert(newCert, cachePath)
return newCert.TlsCertificate, nil
2024-01-01 13:19:19 +00:00
}
} else {
if !isPagesDomain && !repo.CanRequestCertificate(username, giteaClient) {
log.Warnf(
"Cannot request certificate for %s because CanRequestCertificate(%s) returned false",
info.ServerName,
username,
)
return cache.FallbackCertificate.TlsCertificate, nil
}
2024-01-01 13:19:19 +00:00
// Don't request if we're already requesting.
key := getDomainKey(info.ServerName, pagesDomain)
if lockIfUnlockedDomain(domainKey) {
2024-01-01 13:19:19 +00:00
return cache.FallbackCertificate.TlsCertificate, nil
}
defer unlockDomain(key)
2024-01-01 13:19:19 +00:00
// Request new certificate
log.Infof("Obtaining new certificate for %s...", info.ServerName)
2024-01-01 13:19:19 +00:00
cert, err := certificates.ObtainNewCertificate(
buildDomainList(info.ServerName, pagesDomain),
domainKey,
2024-01-01 13:19:19 +00:00
acmeClient,
)
if err != nil {
log.Errorf(
"Failed to get certificate for %s: %v",
info.ServerName,
2024-01-01 13:19:19 +00:00
err,
)
return cache.FallbackCertificate.TlsCertificate, nil
}
// Add to cache and flush
log.Info("Successfully obtained new certificate!")
2024-01-01 13:19:19 +00:00
cache.AddCert(cert, cachePath)
return cert.TlsCertificate, nil
}
},
NextProtos: []string{
"http/0.9",
"http/1.0",
"http/1.1",
"h2",
"h2c",
},
MinVersion: tls.VersionTLS12,
CipherSuites: []uint16{
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
},
}
}