feat: First steps towards using wildcard certificates
All checks were successful
ci/woodpecker/push/woodpecker Pipeline was successful

This commit is contained in:
PapaTutuWawa 2024-01-02 18:23:46 +01:00
parent 2d4ecc40cb
commit 3692168346
4 changed files with 131 additions and 35 deletions

View File

@ -38,7 +38,7 @@ func RenewCertificate(old *CertificateWrapper, acmeClient *lego.Client) (Certifi
} }
wrapper := CertificateWrapper{ wrapper := CertificateWrapper{
TlsCertificate: &tlsCert, TlsCertificate: &tlsCert,
Domain: old.Domain, DomainKey: old.DomainKey,
NotAfter: time.Now().Add(time.Hour * 24 * 60), NotAfter: time.Now().Add(time.Hour * 24 * 60),
PrivateKeyEncoded: base64.StdEncoding.EncodeToString(new.PrivateKey), PrivateKeyEncoded: base64.StdEncoding.EncodeToString(new.PrivateKey),
Certificate: new.Certificate, Certificate: new.Certificate,
@ -47,7 +47,7 @@ func RenewCertificate(old *CertificateWrapper, acmeClient *lego.Client) (Certifi
return wrapper, nil return wrapper, nil
} }
func ObtainNewCertificate(domains []string, acmeClient *lego.Client) (CertificateWrapper, error) { func ObtainNewCertificate(domains []string, domainKey string, acmeClient *lego.Client) (CertificateWrapper, error) {
req := certificate.ObtainRequest{ req := certificate.ObtainRequest{
Domains: domains, Domains: domains,
Bundle: true, Bundle: true,
@ -64,7 +64,7 @@ func ObtainNewCertificate(domains []string, acmeClient *lego.Client) (Certificat
wrapper := CertificateWrapper{ wrapper := CertificateWrapper{
TlsCertificate: &tlsCert, TlsCertificate: &tlsCert,
Domain: cert.Domain, DomainKey: domainKey,
//NotAfter: tlsCert.Leaf.NotAfter, //NotAfter: tlsCert.Leaf.NotAfter,
NotAfter: time.Now().Add(time.Hour * 24 * 60), NotAfter: time.Now().Add(time.Hour * 24 * 60),
PrivateKeyEncoded: base64.StdEncoding.EncodeToString(cert.PrivateKey), PrivateKeyEncoded: base64.StdEncoding.EncodeToString(cert.PrivateKey),
@ -127,7 +127,7 @@ func MakeFallbackCertificate(pagesDomain string) (*CertificateWrapper, error) {
} }
return &CertificateWrapper{ return &CertificateWrapper{
TlsCertificate: &tlsCertificate, TlsCertificate: &tlsCertificate,
Domain: pagesDomain, DomainKey: "*." + pagesDomain,
NotAfter: notAfter, NotAfter: notAfter,
PrivateKeyEncoded: base64.StdEncoding.EncodeToString(certcrypto.PEMEncode(key)), PrivateKeyEncoded: base64.StdEncoding.EncodeToString(certcrypto.PEMEncode(key)),
Certificate: outBytes, Certificate: outBytes,

View File

@ -14,11 +14,22 @@ import (
// A convenience wrapper around a TLS certificate // A convenience wrapper around a TLS certificate
type CertificateWrapper struct { type CertificateWrapper struct {
// The parsed TLS certificate we can pass to the tls listener
TlsCertificate *tls.Certificate `json:"-"` TlsCertificate *tls.Certificate `json:"-"`
Domain string `json:"domain"`
// Key identifying for which domain(s) this certificate is valid.
DomainKey string `json:"domain"`
// Indicates at which point in time this certificate is no longer valid.
NotAfter time.Time `json:"not_after"` NotAfter time.Time `json:"not_after"`
// The encoded private key.
PrivateKeyEncoded string `json:"private_key"` PrivateKeyEncoded string `json:"private_key"`
// The PEM-encoded certificate.
Certificate []byte `json:"certificate"` Certificate []byte `json:"certificate"`
// The CSR provided when we requested the certificate.
CSR []byte `json:"csr"` CSR []byte `json:"csr"`
} }
@ -27,7 +38,7 @@ type CertificatesCache struct {
// The certificate to use as a fallback if all else fails. // The certificate to use as a fallback if all else fails.
FallbackCertificate *CertificateWrapper FallbackCertificate *CertificateWrapper
// Mapping of domain name to certificate. // Mapping of a domain's domain key to the certificate.
Certificates map[string]CertificateWrapper Certificates map[string]CertificateWrapper
} }
@ -83,7 +94,7 @@ func (c *CertificatesCache) FlushToDisk(path string) {
} }
func (c *CertificatesCache) AddCert(cert CertificateWrapper, path string) { func (c *CertificatesCache) AddCert(cert CertificateWrapper, path string) {
c.Certificates[cert.Domain] = cert c.Certificates[cert.DomainKey] = cert
c.FlushToDisk(path) c.FlushToDisk(path)
} }
@ -105,7 +116,7 @@ func CertificateCacheFromFile(path string) (CertificatesCache, error) {
certs := make(map[string]CertificateWrapper) certs := make(map[string]CertificateWrapper)
for _, cert := range store.Certificates { for _, cert := range store.Certificates {
cert.initTlsCertificate() cert.initTlsCertificate()
certs[cert.Domain] = cert certs[cert.DomainKey] = cert
} }
cache.Certificates = certs cache.Certificates = certs

View File

@ -43,10 +43,30 @@ func unlockDomain(domain string) {
delete(workingDomains, domain) delete(workingDomains, domain)
} }
func buildDomainList(domain, pagesDomain string) []string {
if domain == pagesDomain || strings.HasSuffix(domain, pagesDomain) {
return []string{
pagesDomain,
"*." + pagesDomain,
}
}
return []string{domain}
}
func getDomainKey(domain, pagesDomain string) string {
if domain == pagesDomain || strings.HasSuffix(domain, pagesDomain) {
return "*." + pagesDomain
}
return domain
}
func MakeTlsConfig(pagesDomain, cachePath string, cache *certificates.CertificatesCache, acmeClient *lego.Client, giteaClient *gitea.Client) *tls.Config { func MakeTlsConfig(pagesDomain, cachePath string, cache *certificates.CertificatesCache, acmeClient *lego.Client, giteaClient *gitea.Client) *tls.Config {
return &tls.Config{ return &tls.Config{
GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
// Validate that we should even care about this domain // Validate that we should even care about this domain
isPagesDomain := info.ServerName == pagesDomain
cname := "" cname := ""
if !strings.HasSuffix(info.ServerName, pagesDomain) { if !strings.HasSuffix(info.ServerName, pagesDomain) {
// Note: We do not check err here because err != nil // Note: We do not check err here because err != nil
@ -59,33 +79,27 @@ func MakeTlsConfig(pagesDomain, cachePath string, cache *certificates.Certificat
} }
} }
// If we want to access <user>.<pages domain>, then we can just
// use a wildcard certificate.
domain := info.ServerName
/*if strings.HasSuffix(info.ServerName, pagesDomain) {
domain = "*." + pagesDomain
}*/
// Figure out a username for later username checks // Figure out a username for later username checks
username := "" username := ""
if cname == "" { if cname == "" {
// domain ends on pagesDomain // domain ends on pagesDomain
username = strings.Split(domain, ".")[0] username = strings.Split(info.ServerName, ".")[0]
} else { } else {
// cname ends on pagesDomain // cname ends on pagesDomain
username = strings.Split(cname, ".")[0] username = strings.Split(cname, ".")[0]
} }
// Find the correct certificate // Find the correct certificate
cert, found := cache.Certificates[info.ServerName] domainKey := getDomainKey(info.ServerName, pagesDomain)
cert, found := cache.Certificates[domainKey]
if found { if found {
if cert.IsValid() { if cert.IsValid() {
return cert.TlsCertificate, nil return cert.TlsCertificate, nil
} else { } else {
if !repo.CanRequestCertificate(username, giteaClient) { if !isPagesDomain && !repo.CanRequestCertificate(username, giteaClient) {
log.Warnf( log.Warnf(
"Cannot renew certificate for %s because CanRequestCertificate(%s) returned false", "Cannot renew certificate for %s because CanRequestCertificate(%s) returned false",
domain, info.ServerName,
username, username,
) )
return cert.TlsCertificate, nil return cert.TlsCertificate, nil
@ -93,16 +107,16 @@ func MakeTlsConfig(pagesDomain, cachePath string, cache *certificates.Certificat
// If we're already working on the domain, // If we're already working on the domain,
// return the old certificate // return the old certificate
if lockIfUnlockedDomain(domain) { if lockIfUnlockedDomain(domainKey) {
return cert.TlsCertificate, nil return cert.TlsCertificate, nil
} }
defer unlockDomain(domain) defer unlockDomain(domainKey)
// Renew the certificate // Renew the certificate
log.Infof("Certificate for %s expired, renewing", domain) log.Infof("Certificate for %s expired, renewing", info.ServerName)
newCert, err := certificates.RenewCertificate(&cert, acmeClient) newCert, err := certificates.RenewCertificate(&cert, acmeClient)
if err != nil { if err != nil {
log.Errorf("Failed to renew certificate for %s: %v", domain, err) log.Errorf("Failed to renew certificate for %s: %v", info.ServerName, err)
return cert.TlsCertificate, nil return cert.TlsCertificate, nil
} }
@ -111,31 +125,33 @@ func MakeTlsConfig(pagesDomain, cachePath string, cache *certificates.Certificat
return newCert.TlsCertificate, nil return newCert.TlsCertificate, nil
} }
} else { } else {
if !repo.CanRequestCertificate(username, giteaClient) { if !isPagesDomain && !repo.CanRequestCertificate(username, giteaClient) {
log.Warnf( log.Warnf(
"Cannot request certificate for %s because CanRequestCertificate(%s) returned false", "Cannot request certificate for %s because CanRequestCertificate(%s) returned false",
domain, info.ServerName,
username, username,
) )
return cache.FallbackCertificate.TlsCertificate, nil return cache.FallbackCertificate.TlsCertificate, nil
} }
// Don't request if we're already requesting. // Don't request if we're already requesting.
if lockIfUnlockedDomain(domain) { key := getDomainKey(info.ServerName, pagesDomain)
if lockIfUnlockedDomain(domainKey) {
return cache.FallbackCertificate.TlsCertificate, nil return cache.FallbackCertificate.TlsCertificate, nil
} }
defer unlockDomain(domain) defer unlockDomain(key)
// Request new certificate // Request new certificate
log.Infof("Obtaining new certificate for %s...", domain) log.Infof("Obtaining new certificate for %s...", info.ServerName)
cert, err := certificates.ObtainNewCertificate( cert, err := certificates.ObtainNewCertificate(
[]string{domain}, buildDomainList(info.ServerName, pagesDomain),
domainKey,
acmeClient, acmeClient,
) )
if err != nil { if err != nil {
log.Errorf( log.Errorf(
"Failed to get certificate for %s: %v", "Failed to get certificate for %s: %v",
domain, info.ServerName,
err, err,
) )
return cache.FallbackCertificate.TlsCertificate, nil return cache.FallbackCertificate.TlsCertificate, nil

View File

@ -0,0 +1,69 @@
package server
import (
"testing"
)
const (
pagesDomain = "pages.local"
pagesDomainWildcard = "*.pages.local"
)
func equals(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i, _ := range a {
if a[i] != b[i] {
return false
}
}
return true
}
func TestDomainListBare(t *testing.T) {
expect := []string{pagesDomain, pagesDomainWildcard}
res := buildDomainList(pagesDomain, pagesDomain)
if !equals(res, expect) {
t.Fatalf("%v != %v", res, expect)
}
}
func TestDomainListSubdomain(t *testing.T) {
expect := []string{pagesDomain, pagesDomainWildcard}
res := buildDomainList("user."+pagesDomain, pagesDomain)
if !equals(res, expect) {
t.Fatalf("%v != %v", res, expect)
}
}
func TestDomainListCNAME(t *testing.T) {
expect := []string{"testdomain.example"}
res := buildDomainList("testdomain.example", pagesDomain)
if !equals(res, expect) {
t.Fatalf("%v != %v", res, expect)
}
}
func TestDomainKeyBare(t *testing.T) {
res := getDomainKey(pagesDomain, pagesDomain)
if res != pagesDomainWildcard {
t.Fatalf("%s != %s", res, pagesDomainWildcard)
}
}
func TestDomainKeySubdomain(t *testing.T) {
res := getDomainKey("user."+pagesDomain, pagesDomain)
if res != pagesDomainWildcard {
t.Fatalf("%s != %s", res, pagesDomainWildcard)
}
}
func TestDomainKeyCNAME(t *testing.T) {
res := getDomainKey("testdomain.example", pagesDomain)
if res != "testdomain.example" {
t.Fatalf("%s != %s", res, "testdomain.example")
}
}