feat: First steps towards using wildcard certificates
All checks were successful
ci/woodpecker/push/woodpecker Pipeline was successful
All checks were successful
ci/woodpecker/push/woodpecker Pipeline was successful
This commit is contained in:
parent
2d4ecc40cb
commit
3692168346
@ -38,7 +38,7 @@ func RenewCertificate(old *CertificateWrapper, acmeClient *lego.Client) (Certifi
|
|||||||
}
|
}
|
||||||
wrapper := CertificateWrapper{
|
wrapper := CertificateWrapper{
|
||||||
TlsCertificate: &tlsCert,
|
TlsCertificate: &tlsCert,
|
||||||
Domain: old.Domain,
|
DomainKey: old.DomainKey,
|
||||||
NotAfter: time.Now().Add(time.Hour * 24 * 60),
|
NotAfter: time.Now().Add(time.Hour * 24 * 60),
|
||||||
PrivateKeyEncoded: base64.StdEncoding.EncodeToString(new.PrivateKey),
|
PrivateKeyEncoded: base64.StdEncoding.EncodeToString(new.PrivateKey),
|
||||||
Certificate: new.Certificate,
|
Certificate: new.Certificate,
|
||||||
@ -47,7 +47,7 @@ func RenewCertificate(old *CertificateWrapper, acmeClient *lego.Client) (Certifi
|
|||||||
return wrapper, nil
|
return wrapper, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ObtainNewCertificate(domains []string, acmeClient *lego.Client) (CertificateWrapper, error) {
|
func ObtainNewCertificate(domains []string, domainKey string, acmeClient *lego.Client) (CertificateWrapper, error) {
|
||||||
req := certificate.ObtainRequest{
|
req := certificate.ObtainRequest{
|
||||||
Domains: domains,
|
Domains: domains,
|
||||||
Bundle: true,
|
Bundle: true,
|
||||||
@ -64,7 +64,7 @@ func ObtainNewCertificate(domains []string, acmeClient *lego.Client) (Certificat
|
|||||||
|
|
||||||
wrapper := CertificateWrapper{
|
wrapper := CertificateWrapper{
|
||||||
TlsCertificate: &tlsCert,
|
TlsCertificate: &tlsCert,
|
||||||
Domain: cert.Domain,
|
DomainKey: domainKey,
|
||||||
//NotAfter: tlsCert.Leaf.NotAfter,
|
//NotAfter: tlsCert.Leaf.NotAfter,
|
||||||
NotAfter: time.Now().Add(time.Hour * 24 * 60),
|
NotAfter: time.Now().Add(time.Hour * 24 * 60),
|
||||||
PrivateKeyEncoded: base64.StdEncoding.EncodeToString(cert.PrivateKey),
|
PrivateKeyEncoded: base64.StdEncoding.EncodeToString(cert.PrivateKey),
|
||||||
@ -127,7 +127,7 @@ func MakeFallbackCertificate(pagesDomain string) (*CertificateWrapper, error) {
|
|||||||
}
|
}
|
||||||
return &CertificateWrapper{
|
return &CertificateWrapper{
|
||||||
TlsCertificate: &tlsCertificate,
|
TlsCertificate: &tlsCertificate,
|
||||||
Domain: pagesDomain,
|
DomainKey: "*." + pagesDomain,
|
||||||
NotAfter: notAfter,
|
NotAfter: notAfter,
|
||||||
PrivateKeyEncoded: base64.StdEncoding.EncodeToString(certcrypto.PEMEncode(key)),
|
PrivateKeyEncoded: base64.StdEncoding.EncodeToString(certcrypto.PEMEncode(key)),
|
||||||
Certificate: outBytes,
|
Certificate: outBytes,
|
||||||
|
@ -14,12 +14,23 @@ import (
|
|||||||
|
|
||||||
// A convenience wrapper around a TLS certificate
|
// A convenience wrapper around a TLS certificate
|
||||||
type CertificateWrapper struct {
|
type CertificateWrapper struct {
|
||||||
TlsCertificate *tls.Certificate `json:"-"`
|
// The parsed TLS certificate we can pass to the tls listener
|
||||||
Domain string `json:"domain"`
|
TlsCertificate *tls.Certificate `json:"-"`
|
||||||
NotAfter time.Time `json:"not_after"`
|
|
||||||
PrivateKeyEncoded string `json:"private_key"`
|
// Key identifying for which domain(s) this certificate is valid.
|
||||||
Certificate []byte `json:"certificate"`
|
DomainKey string `json:"domain"`
|
||||||
CSR []byte `json:"csr"`
|
|
||||||
|
// Indicates at which point in time this certificate is no longer valid.
|
||||||
|
NotAfter time.Time `json:"not_after"`
|
||||||
|
|
||||||
|
// The encoded private key.
|
||||||
|
PrivateKeyEncoded string `json:"private_key"`
|
||||||
|
|
||||||
|
// The PEM-encoded certificate.
|
||||||
|
Certificate []byte `json:"certificate"`
|
||||||
|
|
||||||
|
// The CSR provided when we requested the certificate.
|
||||||
|
CSR []byte `json:"csr"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// A structure to store all the certificates we know of in.
|
// A structure to store all the certificates we know of in.
|
||||||
@ -27,7 +38,7 @@ type CertificatesCache struct {
|
|||||||
// The certificate to use as a fallback if all else fails.
|
// The certificate to use as a fallback if all else fails.
|
||||||
FallbackCertificate *CertificateWrapper
|
FallbackCertificate *CertificateWrapper
|
||||||
|
|
||||||
// Mapping of domain name to certificate.
|
// Mapping of a domain's domain key to the certificate.
|
||||||
Certificates map[string]CertificateWrapper
|
Certificates map[string]CertificateWrapper
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -83,7 +94,7 @@ func (c *CertificatesCache) FlushToDisk(path string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *CertificatesCache) AddCert(cert CertificateWrapper, path string) {
|
func (c *CertificatesCache) AddCert(cert CertificateWrapper, path string) {
|
||||||
c.Certificates[cert.Domain] = cert
|
c.Certificates[cert.DomainKey] = cert
|
||||||
c.FlushToDisk(path)
|
c.FlushToDisk(path)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -105,7 +116,7 @@ func CertificateCacheFromFile(path string) (CertificatesCache, error) {
|
|||||||
certs := make(map[string]CertificateWrapper)
|
certs := make(map[string]CertificateWrapper)
|
||||||
for _, cert := range store.Certificates {
|
for _, cert := range store.Certificates {
|
||||||
cert.initTlsCertificate()
|
cert.initTlsCertificate()
|
||||||
certs[cert.Domain] = cert
|
certs[cert.DomainKey] = cert
|
||||||
}
|
}
|
||||||
cache.Certificates = certs
|
cache.Certificates = certs
|
||||||
|
|
||||||
|
@ -43,10 +43,30 @@ func unlockDomain(domain string) {
|
|||||||
delete(workingDomains, domain)
|
delete(workingDomains, domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func buildDomainList(domain, pagesDomain string) []string {
|
||||||
|
if domain == pagesDomain || strings.HasSuffix(domain, pagesDomain) {
|
||||||
|
return []string{
|
||||||
|
pagesDomain,
|
||||||
|
"*." + pagesDomain,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return []string{domain}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getDomainKey(domain, pagesDomain string) string {
|
||||||
|
if domain == pagesDomain || strings.HasSuffix(domain, pagesDomain) {
|
||||||
|
return "*." + pagesDomain
|
||||||
|
}
|
||||||
|
|
||||||
|
return domain
|
||||||
|
}
|
||||||
|
|
||||||
func MakeTlsConfig(pagesDomain, cachePath string, cache *certificates.CertificatesCache, acmeClient *lego.Client, giteaClient *gitea.Client) *tls.Config {
|
func MakeTlsConfig(pagesDomain, cachePath string, cache *certificates.CertificatesCache, acmeClient *lego.Client, giteaClient *gitea.Client) *tls.Config {
|
||||||
return &tls.Config{
|
return &tls.Config{
|
||||||
GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||||
// Validate that we should even care about this domain
|
// Validate that we should even care about this domain
|
||||||
|
isPagesDomain := info.ServerName == pagesDomain
|
||||||
cname := ""
|
cname := ""
|
||||||
if !strings.HasSuffix(info.ServerName, pagesDomain) {
|
if !strings.HasSuffix(info.ServerName, pagesDomain) {
|
||||||
// Note: We do not check err here because err != nil
|
// Note: We do not check err here because err != nil
|
||||||
@ -59,33 +79,27 @@ func MakeTlsConfig(pagesDomain, cachePath string, cache *certificates.Certificat
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we want to access <user>.<pages domain>, then we can just
|
|
||||||
// use a wildcard certificate.
|
|
||||||
domain := info.ServerName
|
|
||||||
/*if strings.HasSuffix(info.ServerName, pagesDomain) {
|
|
||||||
domain = "*." + pagesDomain
|
|
||||||
}*/
|
|
||||||
|
|
||||||
// Figure out a username for later username checks
|
// Figure out a username for later username checks
|
||||||
username := ""
|
username := ""
|
||||||
if cname == "" {
|
if cname == "" {
|
||||||
// domain ends on pagesDomain
|
// domain ends on pagesDomain
|
||||||
username = strings.Split(domain, ".")[0]
|
username = strings.Split(info.ServerName, ".")[0]
|
||||||
} else {
|
} else {
|
||||||
// cname ends on pagesDomain
|
// cname ends on pagesDomain
|
||||||
username = strings.Split(cname, ".")[0]
|
username = strings.Split(cname, ".")[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find the correct certificate
|
// Find the correct certificate
|
||||||
cert, found := cache.Certificates[info.ServerName]
|
domainKey := getDomainKey(info.ServerName, pagesDomain)
|
||||||
|
cert, found := cache.Certificates[domainKey]
|
||||||
if found {
|
if found {
|
||||||
if cert.IsValid() {
|
if cert.IsValid() {
|
||||||
return cert.TlsCertificate, nil
|
return cert.TlsCertificate, nil
|
||||||
} else {
|
} else {
|
||||||
if !repo.CanRequestCertificate(username, giteaClient) {
|
if !isPagesDomain && !repo.CanRequestCertificate(username, giteaClient) {
|
||||||
log.Warnf(
|
log.Warnf(
|
||||||
"Cannot renew certificate for %s because CanRequestCertificate(%s) returned false",
|
"Cannot renew certificate for %s because CanRequestCertificate(%s) returned false",
|
||||||
domain,
|
info.ServerName,
|
||||||
username,
|
username,
|
||||||
)
|
)
|
||||||
return cert.TlsCertificate, nil
|
return cert.TlsCertificate, nil
|
||||||
@ -93,16 +107,16 @@ func MakeTlsConfig(pagesDomain, cachePath string, cache *certificates.Certificat
|
|||||||
|
|
||||||
// If we're already working on the domain,
|
// If we're already working on the domain,
|
||||||
// return the old certificate
|
// return the old certificate
|
||||||
if lockIfUnlockedDomain(domain) {
|
if lockIfUnlockedDomain(domainKey) {
|
||||||
return cert.TlsCertificate, nil
|
return cert.TlsCertificate, nil
|
||||||
}
|
}
|
||||||
defer unlockDomain(domain)
|
defer unlockDomain(domainKey)
|
||||||
|
|
||||||
// Renew the certificate
|
// Renew the certificate
|
||||||
log.Infof("Certificate for %s expired, renewing", domain)
|
log.Infof("Certificate for %s expired, renewing", info.ServerName)
|
||||||
newCert, err := certificates.RenewCertificate(&cert, acmeClient)
|
newCert, err := certificates.RenewCertificate(&cert, acmeClient)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed to renew certificate for %s: %v", domain, err)
|
log.Errorf("Failed to renew certificate for %s: %v", info.ServerName, err)
|
||||||
return cert.TlsCertificate, nil
|
return cert.TlsCertificate, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -111,31 +125,33 @@ func MakeTlsConfig(pagesDomain, cachePath string, cache *certificates.Certificat
|
|||||||
return newCert.TlsCertificate, nil
|
return newCert.TlsCertificate, nil
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if !repo.CanRequestCertificate(username, giteaClient) {
|
if !isPagesDomain && !repo.CanRequestCertificate(username, giteaClient) {
|
||||||
log.Warnf(
|
log.Warnf(
|
||||||
"Cannot request certificate for %s because CanRequestCertificate(%s) returned false",
|
"Cannot request certificate for %s because CanRequestCertificate(%s) returned false",
|
||||||
domain,
|
info.ServerName,
|
||||||
username,
|
username,
|
||||||
)
|
)
|
||||||
return cache.FallbackCertificate.TlsCertificate, nil
|
return cache.FallbackCertificate.TlsCertificate, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Don't request if we're already requesting.
|
// Don't request if we're already requesting.
|
||||||
if lockIfUnlockedDomain(domain) {
|
key := getDomainKey(info.ServerName, pagesDomain)
|
||||||
|
if lockIfUnlockedDomain(domainKey) {
|
||||||
return cache.FallbackCertificate.TlsCertificate, nil
|
return cache.FallbackCertificate.TlsCertificate, nil
|
||||||
}
|
}
|
||||||
defer unlockDomain(domain)
|
defer unlockDomain(key)
|
||||||
|
|
||||||
// Request new certificate
|
// Request new certificate
|
||||||
log.Infof("Obtaining new certificate for %s...", domain)
|
log.Infof("Obtaining new certificate for %s...", info.ServerName)
|
||||||
cert, err := certificates.ObtainNewCertificate(
|
cert, err := certificates.ObtainNewCertificate(
|
||||||
[]string{domain},
|
buildDomainList(info.ServerName, pagesDomain),
|
||||||
|
domainKey,
|
||||||
acmeClient,
|
acmeClient,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf(
|
log.Errorf(
|
||||||
"Failed to get certificate for %s: %v",
|
"Failed to get certificate for %s: %v",
|
||||||
domain,
|
info.ServerName,
|
||||||
err,
|
err,
|
||||||
)
|
)
|
||||||
return cache.FallbackCertificate.TlsCertificate, nil
|
return cache.FallbackCertificate.TlsCertificate, nil
|
||||||
|
69
internal/server/tls_test.go
Normal file
69
internal/server/tls_test.go
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
pagesDomain = "pages.local"
|
||||||
|
pagesDomainWildcard = "*.pages.local"
|
||||||
|
)
|
||||||
|
|
||||||
|
func equals(a, b []string) bool {
|
||||||
|
if len(a) != len(b) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, _ := range a {
|
||||||
|
if a[i] != b[i] {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDomainListBare(t *testing.T) {
|
||||||
|
expect := []string{pagesDomain, pagesDomainWildcard}
|
||||||
|
res := buildDomainList(pagesDomain, pagesDomain)
|
||||||
|
if !equals(res, expect) {
|
||||||
|
t.Fatalf("%v != %v", res, expect)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDomainListSubdomain(t *testing.T) {
|
||||||
|
expect := []string{pagesDomain, pagesDomainWildcard}
|
||||||
|
res := buildDomainList("user."+pagesDomain, pagesDomain)
|
||||||
|
if !equals(res, expect) {
|
||||||
|
t.Fatalf("%v != %v", res, expect)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDomainListCNAME(t *testing.T) {
|
||||||
|
expect := []string{"testdomain.example"}
|
||||||
|
res := buildDomainList("testdomain.example", pagesDomain)
|
||||||
|
if !equals(res, expect) {
|
||||||
|
t.Fatalf("%v != %v", res, expect)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDomainKeyBare(t *testing.T) {
|
||||||
|
res := getDomainKey(pagesDomain, pagesDomain)
|
||||||
|
if res != pagesDomainWildcard {
|
||||||
|
t.Fatalf("%s != %s", res, pagesDomainWildcard)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDomainKeySubdomain(t *testing.T) {
|
||||||
|
res := getDomainKey("user."+pagesDomain, pagesDomain)
|
||||||
|
if res != pagesDomainWildcard {
|
||||||
|
t.Fatalf("%s != %s", res, pagesDomainWildcard)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDomainKeyCNAME(t *testing.T) {
|
||||||
|
res := getDomainKey("testdomain.example", pagesDomain)
|
||||||
|
if res != "testdomain.example" {
|
||||||
|
t.Fatalf("%s != %s", res, "testdomain.example")
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user