package server import ( "crypto/tls" "strings" "sync" "git.polynom.me/rio/internal/certificates" "git.polynom.me/rio/internal/dns" "github.com/go-acme/lego/v4/lego" log "github.com/sirupsen/logrus" ) var ( // To access requestingDomains, first acquire the lock. domainsLock = sync.Mutex{} // Domain -> _. Check if domain is a key here to see if we're already requesting // or renewing a certificate for that domain. workingDomains = make(map[string]bool) ) func lockIfUnlockedDomain(domain string) bool { domainsLock.Lock() defer domainsLock.Unlock() _, found := workingDomains[domain] if !found { workingDomains[domain] = true } return found } func unlockDomain(domain string) { domainsLock.Lock() defer domainsLock.Unlock() delete(workingDomains, domain) } func MakeTlsConfig(pagesDomain, cachePath string, cache *certificates.CertificatesCache, acmeClient *lego.Client) *tls.Config { return &tls.Config{ InsecureSkipVerify: true, GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { // Validate that we should even care about this domain if !strings.HasSuffix(info.ServerName, pagesDomain) { // Note: We do not check err here because err != nil // always implies that cname == "", which does not have // pagesDomain as a suffix. cname, _ := dns.LookupCNAME(info.ServerName) if !strings.HasSuffix(cname, pagesDomain) { log.Warnf("Got ServerName for Domain %s that we're not responsible for", info.ServerName) return cache.FallbackCertificate.TlsCertificate, nil } } // If we want to access ., then we can just // use a wildcard certificate. domain := info.ServerName /*if strings.HasSuffix(info.ServerName, pagesDomain) { domain = "*." + pagesDomain }*/ cert, found := cache.Certificates[info.ServerName] if found { if cert.IsValid() { return cert.TlsCertificate, nil } else { // If we're already working on the domain, // return the old certificate if lockIfUnlockedDomain(domain) { return cert.TlsCertificate, nil } defer unlockDomain(domain) // Renew log.Infof("Certificate for %s expired, renewing", domain) newCert, err := certificates.RenewCertificate(&cert, acmeClient) if err != nil { log.Errorf("Failed to renew certificate for %s: %v", domain, err) return cert.TlsCertificate, nil } log.Info("Successfully renewed certificate!") cache.AddCert(newCert, cachePath) return newCert.TlsCertificate, nil } } else { // Don't request if we're already requesting. if lockIfUnlockedDomain(domain) { return cache.FallbackCertificate.TlsCertificate, nil } defer unlockDomain(domain) // Request new certificate log.Infof("Obtaining new certificate for %s...", domain) cert, err := certificates.ObtainNewCertificate( []string{domain}, acmeClient, ) if err != nil { log.Errorf( "Failed to get certificate for %s: %v", domain, err, ) return cache.FallbackCertificate.TlsCertificate, nil } // Add to cache and flush log.Info("Successfully obtained new certificate!") cache.AddCert(cert, cachePath) return cert.TlsCertificate, nil } log.Debugf("TLS ServerName: %s", info.ServerName) return cache.FallbackCertificate.TlsCertificate, nil }, NextProtos: []string{ "http/0.9", "http/1.0", "http/1.1", "h2", "h2c", }, MinVersion: tls.VersionTLS12, CipherSuites: []uint16{ tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, }, } }