package server import ( "crypto/tls" "strings" "sync" "git.polynom.me/rio/internal/certificates" "git.polynom.me/rio/internal/dns" "git.polynom.me/rio/internal/repo" "github.com/go-acme/lego/v4/lego" log "github.com/sirupsen/logrus" ) var ( // To access requestingDomains, first acquire the lock. domainsLock = sync.Mutex{} // Domain -> _. Check if domain is a key here to see if we're already requesting // or renewing a certificate for that domain. workingDomains = make(map[string]bool) ) func lockIfUnlockedDomain(domain string) bool { domainsLock.Lock() defer domainsLock.Unlock() _, found := workingDomains[domain] if !found { workingDomains[domain] = true } return found } func unlockDomain(domain string) { domainsLock.Lock() defer domainsLock.Unlock() delete(workingDomains, domain) } func buildDomainList(domain, pagesDomain string) []string { // TODO: For wildcards, we MUST use DNS01 if domain == pagesDomain || strings.HasSuffix(domain, pagesDomain) { return []string{ pagesDomain, "*." + pagesDomain, } } return []string{domain} } func getDomainKey(domain, pagesDomain string) string { // TODO: For wildcards, we MUST use DNS01 if domain == pagesDomain || strings.HasSuffix(domain, pagesDomain) { return "*." + pagesDomain } return domain } func MakeTlsConfig(pagesDomain, cachePath string, cache *certificates.CertificatesCache, acmeClient *lego.Client, giteaClient *repo.GiteaClient) *tls.Config { return &tls.Config{ GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { // Validate that we should even care about this domain isPagesDomain := info.ServerName == pagesDomain cname := "" if !strings.HasSuffix(info.ServerName, pagesDomain) { // Note: We do not check err here because err != nil // always implies that cname == "", which does not have // pagesDomain as a suffix. cname, _ = dns.LookupCNAME(info.ServerName) if !strings.HasSuffix(cname, pagesDomain) { log.Warnf("Got ServerName for Domain %s that we're not responsible for", info.ServerName) return cache.FallbackCertificate.TlsCertificate, nil } } // Figure out a username for later username checks username := "" if cname == "" { // domain ends on pagesDomain username = strings.Split(info.ServerName, ".")[0] } else { // cname ends on pagesDomain username = strings.Split(cname, ".")[0] } // Find the correct certificate domainKey := getDomainKey(info.ServerName, pagesDomain) cert, found := cache.Certificates[domainKey] if found { if cert.IsValid() { return cert.TlsCertificate, nil } else { if !isPagesDomain && !repo.CanRequestCertificate(username, giteaClient) { log.Warnf( "Cannot renew certificate for %s because CanRequestCertificate(%s) returned false", info.ServerName, username, ) return cert.TlsCertificate, nil } // If we're already working on the domain, // return the old certificate if lockIfUnlockedDomain(domainKey) { return cert.TlsCertificate, nil } defer unlockDomain(domainKey) // Renew the certificate log.Infof("Certificate for %s expired, renewing", info.ServerName) newCert, err := certificates.RenewCertificate(&cert, acmeClient) if err != nil { log.Errorf("Failed to renew certificate for %s: %v", info.ServerName, err) return cert.TlsCertificate, nil } log.Info("Successfully renewed certificate!") cache.AddCert(newCert, cachePath) return newCert.TlsCertificate, nil } } else { if !isPagesDomain && !repo.CanRequestCertificate(username, giteaClient) { log.Warnf( "Cannot request certificate for %s because CanRequestCertificate(%s) returned false", info.ServerName, username, ) return cache.FallbackCertificate.TlsCertificate, nil } // Don't request if we're already requesting. key := getDomainKey(info.ServerName, pagesDomain) if lockIfUnlockedDomain(domainKey) { return cache.FallbackCertificate.TlsCertificate, nil } defer unlockDomain(key) // Request new certificate log.Infof("Obtaining new certificate for %s...", info.ServerName) cert, err := certificates.ObtainNewCertificate( buildDomainList(info.ServerName, pagesDomain), domainKey, acmeClient, ) if err != nil { log.Errorf( "Failed to get certificate for %s: %v", info.ServerName, err, ) return cache.FallbackCertificate.TlsCertificate, nil } // Add to cache and flush log.Info("Successfully obtained new certificate!") cache.AddCert(cert, cachePath) return cert.TlsCertificate, nil } }, NextProtos: []string{ "http/0.9", "http/1.0", "http/1.1", "h2", "h2c", }, MinVersion: tls.VersionTLS12, CipherSuites: []uint16{ tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, }, } }