package server import ( "crypto/tls" "errors" "strings" "sync" "git.polynom.me/rio/internal/certificates" "git.polynom.me/rio/internal/context" "git.polynom.me/rio/internal/dns" "git.polynom.me/rio/internal/repo" "github.com/go-acme/lego/v4/lego" log "github.com/sirupsen/logrus" ) var ( // To access requestingDomains, first acquire the lock. domainsLock = sync.Mutex{} // Domain -> _. Check if domain is a key here to see if we're already requesting // or renewing a certificate for that domain. workingDomains = make(map[string]bool) ) func lockIfUnlockedDomain(domain string) bool { domainsLock.Lock() defer domainsLock.Unlock() _, found := workingDomains[domain] if !found { workingDomains[domain] = true } return found } func unlockDomain(domain string) { domainsLock.Lock() defer domainsLock.Unlock() delete(workingDomains, domain) } func buildDomainList(domain, pagesDomain string) []string { if domain == pagesDomain || strings.HasSuffix(domain, pagesDomain) { return []string{ pagesDomain, "*." + pagesDomain, } } return []string{domain} } func getDomainKey(domain, pagesDomain string) string { if domain == pagesDomain || strings.HasSuffix(domain, pagesDomain) { return "*." + pagesDomain } return domain } func getUsername(sni, pagesDomain string) (string, error) { if !strings.HasSuffix(sni, pagesDomain) { log.Debugf("'%s' is not a subdomain of '%s'", sni, pagesDomain) // Note: We do not check err here because err != nil // always implies that cname == "", which does not have // pagesDomain as a suffix. query, err := dns.LookupCNAME(sni) if !strings.HasSuffix(query, pagesDomain) { log.Warnf("Got ServerName for Domain %s that we're not responsible for. CNAME '%s', err: %v", sni, query, err) return "", errors.New("CNAME does not resolve to subdomain of pages domain") } return dns.ExtractUsername(pagesDomain, query), nil } return dns.ExtractUsername(pagesDomain, sni), nil } func MakeTlsConfig(pagesDomain, cachePath string, cache *certificates.CertificatesCache, acmeClient *lego.Client, ctx *context.GlobalContext) *tls.Config { return &tls.Config{ GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { // Validate that we should even care about this domain isPagesDomain := info.ServerName == pagesDomain username, err := getUsername(info.ServerName, pagesDomain) if err != nil { log.Warnf("Failed to get username for %s: %v", info.ServerName, err) return cache.FallbackCertificate.TlsCertificate, nil } // Find the correct certificate domainKey := getDomainKey(info.ServerName, pagesDomain) cert, found := cache.Certificates[domainKey] if found { if cert.IsValid() { return cert.TlsCertificate, nil } else { if !isPagesDomain && !repo.CanRequestCertificate(username, ctx) { log.Warnf( "Cannot renew certificate for %s because CanRequestCertificate(%s) returned false", info.ServerName, username, ) return cert.TlsCertificate, nil } // If we're already working on the domain, // return the old certificate if lockIfUnlockedDomain(domainKey) { return cert.TlsCertificate, nil } defer unlockDomain(domainKey) // Renew the certificate log.Infof("Certificate for %s expired, renewing", info.ServerName) newCert, err := certificates.RenewCertificate(&cert, acmeClient) if err != nil { log.Errorf("Failed to renew certificate for %s: %v", info.ServerName, err) return cert.TlsCertificate, nil } log.Info("Successfully renewed certificate!") cache.AddCert(newCert, cachePath) return newCert.TlsCertificate, nil } } else { if !isPagesDomain && !repo.CanRequestCertificate(username, ctx) { log.Warnf( "Cannot request certificate for %s because CanRequestCertificate(%s) returned false", info.ServerName, username, ) return cache.FallbackCertificate.TlsCertificate, nil } // Don't request if we're already requesting. key := getDomainKey(info.ServerName, pagesDomain) if lockIfUnlockedDomain(domainKey) { return cache.FallbackCertificate.TlsCertificate, nil } defer unlockDomain(key) // Request new certificate log.Infof("Obtaining new certificate for %s...", info.ServerName) cert, err := certificates.ObtainNewCertificate( buildDomainList(info.ServerName, pagesDomain), domainKey, acmeClient, ) if err != nil { log.Errorf( "Failed to get certificate for %s: %v", info.ServerName, err, ) return cache.FallbackCertificate.TlsCertificate, nil } // Add to cache and flush log.Info("Successfully obtained new certificate!") cache.AddCert(cert, cachePath) return cert.TlsCertificate, nil } }, NextProtos: []string{ "http/0.9", "http/1.0", "http/1.1", "h2", "h2c", }, MinVersion: tls.VersionTLS12, CipherSuites: []uint16{ tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, }, } }