From 369216834626a1b444ecc7dca979f458628c2526 Mon Sep 17 00:00:00 2001 From: "Alexander \"PapaTutuWawa" Date: Tue, 2 Jan 2024 18:23:46 +0100 Subject: [PATCH] feat: First steps towards using wildcard certificates --- internal/certificates/certificate.go | 8 ++-- internal/certificates/store.go | 29 ++++++++---- internal/server/tls.go | 60 +++++++++++++++--------- internal/server/tls_test.go | 69 ++++++++++++++++++++++++++++ 4 files changed, 131 insertions(+), 35 deletions(-) create mode 100644 internal/server/tls_test.go diff --git a/internal/certificates/certificate.go b/internal/certificates/certificate.go index 13de687..2267e1b 100644 --- a/internal/certificates/certificate.go +++ b/internal/certificates/certificate.go @@ -38,7 +38,7 @@ func RenewCertificate(old *CertificateWrapper, acmeClient *lego.Client) (Certifi } wrapper := CertificateWrapper{ TlsCertificate: &tlsCert, - Domain: old.Domain, + DomainKey: old.DomainKey, NotAfter: time.Now().Add(time.Hour * 24 * 60), PrivateKeyEncoded: base64.StdEncoding.EncodeToString(new.PrivateKey), Certificate: new.Certificate, @@ -47,7 +47,7 @@ func RenewCertificate(old *CertificateWrapper, acmeClient *lego.Client) (Certifi return wrapper, nil } -func ObtainNewCertificate(domains []string, acmeClient *lego.Client) (CertificateWrapper, error) { +func ObtainNewCertificate(domains []string, domainKey string, acmeClient *lego.Client) (CertificateWrapper, error) { req := certificate.ObtainRequest{ Domains: domains, Bundle: true, @@ -64,7 +64,7 @@ func ObtainNewCertificate(domains []string, acmeClient *lego.Client) (Certificat wrapper := CertificateWrapper{ TlsCertificate: &tlsCert, - Domain: cert.Domain, + DomainKey: domainKey, //NotAfter: tlsCert.Leaf.NotAfter, NotAfter: time.Now().Add(time.Hour * 24 * 60), PrivateKeyEncoded: base64.StdEncoding.EncodeToString(cert.PrivateKey), @@ -127,7 +127,7 @@ func MakeFallbackCertificate(pagesDomain string) (*CertificateWrapper, error) { } return &CertificateWrapper{ TlsCertificate: &tlsCertificate, - Domain: pagesDomain, + DomainKey: "*." + pagesDomain, NotAfter: notAfter, PrivateKeyEncoded: base64.StdEncoding.EncodeToString(certcrypto.PEMEncode(key)), Certificate: outBytes, diff --git a/internal/certificates/store.go b/internal/certificates/store.go index 465b4ce..820beb9 100644 --- a/internal/certificates/store.go +++ b/internal/certificates/store.go @@ -14,12 +14,23 @@ import ( // A convenience wrapper around a TLS certificate type CertificateWrapper struct { - TlsCertificate *tls.Certificate `json:"-"` - Domain string `json:"domain"` - NotAfter time.Time `json:"not_after"` - PrivateKeyEncoded string `json:"private_key"` - Certificate []byte `json:"certificate"` - CSR []byte `json:"csr"` + // The parsed TLS certificate we can pass to the tls listener + TlsCertificate *tls.Certificate `json:"-"` + + // Key identifying for which domain(s) this certificate is valid. + DomainKey string `json:"domain"` + + // Indicates at which point in time this certificate is no longer valid. + NotAfter time.Time `json:"not_after"` + + // The encoded private key. + PrivateKeyEncoded string `json:"private_key"` + + // The PEM-encoded certificate. + Certificate []byte `json:"certificate"` + + // The CSR provided when we requested the certificate. + CSR []byte `json:"csr"` } // A structure to store all the certificates we know of in. @@ -27,7 +38,7 @@ type CertificatesCache struct { // The certificate to use as a fallback if all else fails. FallbackCertificate *CertificateWrapper - // Mapping of domain name to certificate. + // Mapping of a domain's domain key to the certificate. Certificates map[string]CertificateWrapper } @@ -83,7 +94,7 @@ func (c *CertificatesCache) FlushToDisk(path string) { } func (c *CertificatesCache) AddCert(cert CertificateWrapper, path string) { - c.Certificates[cert.Domain] = cert + c.Certificates[cert.DomainKey] = cert c.FlushToDisk(path) } @@ -105,7 +116,7 @@ func CertificateCacheFromFile(path string) (CertificatesCache, error) { certs := make(map[string]CertificateWrapper) for _, cert := range store.Certificates { cert.initTlsCertificate() - certs[cert.Domain] = cert + certs[cert.DomainKey] = cert } cache.Certificates = certs diff --git a/internal/server/tls.go b/internal/server/tls.go index 88fbafa..67298a6 100644 --- a/internal/server/tls.go +++ b/internal/server/tls.go @@ -43,10 +43,30 @@ func unlockDomain(domain string) { delete(workingDomains, domain) } +func buildDomainList(domain, pagesDomain string) []string { + if domain == pagesDomain || strings.HasSuffix(domain, pagesDomain) { + return []string{ + pagesDomain, + "*." + pagesDomain, + } + } + + return []string{domain} +} + +func getDomainKey(domain, pagesDomain string) string { + if domain == pagesDomain || strings.HasSuffix(domain, pagesDomain) { + return "*." + pagesDomain + } + + return domain +} + func MakeTlsConfig(pagesDomain, cachePath string, cache *certificates.CertificatesCache, acmeClient *lego.Client, giteaClient *gitea.Client) *tls.Config { return &tls.Config{ GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { // Validate that we should even care about this domain + isPagesDomain := info.ServerName == pagesDomain cname := "" if !strings.HasSuffix(info.ServerName, pagesDomain) { // Note: We do not check err here because err != nil @@ -59,33 +79,27 @@ func MakeTlsConfig(pagesDomain, cachePath string, cache *certificates.Certificat } } - // If we want to access ., then we can just - // use a wildcard certificate. - domain := info.ServerName - /*if strings.HasSuffix(info.ServerName, pagesDomain) { - domain = "*." + pagesDomain - }*/ - // Figure out a username for later username checks username := "" if cname == "" { // domain ends on pagesDomain - username = strings.Split(domain, ".")[0] + username = strings.Split(info.ServerName, ".")[0] } else { // cname ends on pagesDomain username = strings.Split(cname, ".")[0] } // Find the correct certificate - cert, found := cache.Certificates[info.ServerName] + domainKey := getDomainKey(info.ServerName, pagesDomain) + cert, found := cache.Certificates[domainKey] if found { if cert.IsValid() { return cert.TlsCertificate, nil } else { - if !repo.CanRequestCertificate(username, giteaClient) { + if !isPagesDomain && !repo.CanRequestCertificate(username, giteaClient) { log.Warnf( "Cannot renew certificate for %s because CanRequestCertificate(%s) returned false", - domain, + info.ServerName, username, ) return cert.TlsCertificate, nil @@ -93,16 +107,16 @@ func MakeTlsConfig(pagesDomain, cachePath string, cache *certificates.Certificat // If we're already working on the domain, // return the old certificate - if lockIfUnlockedDomain(domain) { + if lockIfUnlockedDomain(domainKey) { return cert.TlsCertificate, nil } - defer unlockDomain(domain) + defer unlockDomain(domainKey) // Renew the certificate - log.Infof("Certificate for %s expired, renewing", domain) + log.Infof("Certificate for %s expired, renewing", info.ServerName) newCert, err := certificates.RenewCertificate(&cert, acmeClient) if err != nil { - log.Errorf("Failed to renew certificate for %s: %v", domain, err) + log.Errorf("Failed to renew certificate for %s: %v", info.ServerName, err) return cert.TlsCertificate, nil } @@ -111,31 +125,33 @@ func MakeTlsConfig(pagesDomain, cachePath string, cache *certificates.Certificat return newCert.TlsCertificate, nil } } else { - if !repo.CanRequestCertificate(username, giteaClient) { + if !isPagesDomain && !repo.CanRequestCertificate(username, giteaClient) { log.Warnf( "Cannot request certificate for %s because CanRequestCertificate(%s) returned false", - domain, + info.ServerName, username, ) return cache.FallbackCertificate.TlsCertificate, nil } // Don't request if we're already requesting. - if lockIfUnlockedDomain(domain) { + key := getDomainKey(info.ServerName, pagesDomain) + if lockIfUnlockedDomain(domainKey) { return cache.FallbackCertificate.TlsCertificate, nil } - defer unlockDomain(domain) + defer unlockDomain(key) // Request new certificate - log.Infof("Obtaining new certificate for %s...", domain) + log.Infof("Obtaining new certificate for %s...", info.ServerName) cert, err := certificates.ObtainNewCertificate( - []string{domain}, + buildDomainList(info.ServerName, pagesDomain), + domainKey, acmeClient, ) if err != nil { log.Errorf( "Failed to get certificate for %s: %v", - domain, + info.ServerName, err, ) return cache.FallbackCertificate.TlsCertificate, nil diff --git a/internal/server/tls_test.go b/internal/server/tls_test.go new file mode 100644 index 0000000..8291ac8 --- /dev/null +++ b/internal/server/tls_test.go @@ -0,0 +1,69 @@ +package server + +import ( + "testing" +) + +const ( + pagesDomain = "pages.local" + pagesDomainWildcard = "*.pages.local" +) + +func equals(a, b []string) bool { + if len(a) != len(b) { + return false + } + + for i, _ := range a { + if a[i] != b[i] { + return false + } + } + + return true +} + +func TestDomainListBare(t *testing.T) { + expect := []string{pagesDomain, pagesDomainWildcard} + res := buildDomainList(pagesDomain, pagesDomain) + if !equals(res, expect) { + t.Fatalf("%v != %v", res, expect) + } +} + +func TestDomainListSubdomain(t *testing.T) { + expect := []string{pagesDomain, pagesDomainWildcard} + res := buildDomainList("user."+pagesDomain, pagesDomain) + if !equals(res, expect) { + t.Fatalf("%v != %v", res, expect) + } +} + +func TestDomainListCNAME(t *testing.T) { + expect := []string{"testdomain.example"} + res := buildDomainList("testdomain.example", pagesDomain) + if !equals(res, expect) { + t.Fatalf("%v != %v", res, expect) + } +} + +func TestDomainKeyBare(t *testing.T) { + res := getDomainKey(pagesDomain, pagesDomain) + if res != pagesDomainWildcard { + t.Fatalf("%s != %s", res, pagesDomainWildcard) + } +} + +func TestDomainKeySubdomain(t *testing.T) { + res := getDomainKey("user."+pagesDomain, pagesDomain) + if res != pagesDomainWildcard { + t.Fatalf("%s != %s", res, pagesDomainWildcard) + } +} + +func TestDomainKeyCNAME(t *testing.T) { + res := getDomainKey("testdomain.example", pagesDomain) + if res != "testdomain.example" { + t.Fatalf("%s != %s", res, "testdomain.example") + } +}