package main import ( "bytes" "crypto/rand" "crypto/rsa" "crypto/tls" "crypto/x509" "crypto/x509/pkix" "encoding/base64" "encoding/json" "encoding/pem" "io/ioutil" "math/big" "net/http" "strings" "sync" "time" "github.com/go-acme/lego/v4/certcrypto" "github.com/go-acme/lego/v4/certificate" "github.com/go-acme/lego/v4/lego" log "github.com/sirupsen/logrus" ) const ( AcmeChallengePathPrefix = "/.well-known/acme-challenge/" ) var ( // Well-known -> Challenge solution runningChallenges = make(map[string]string) Certificates = CertificatesCache{ Certificates: make(map[string]CertificateWrapper), } // To access requestingDomains, first acquire the lock. requestingLock = sync.Mutex{} // Domain -> _. Check if domain is a key here to see if we're already requeting // a certificate for it. requestingDomains = make(map[string]bool) ) func lockIfUnlockedDomain(domain string) bool { requestingLock.Lock() defer requestingLock.Unlock() _, found := requestingDomains[domain] if !found { requestingDomains[domain] = true } return found } func unlockDomain(domain string) { requestingLock.Lock() defer requestingLock.Unlock() delete(requestingDomains, domain) } type CertificateWrapper struct { TlsCertificate *tls.Certificate `json:"-"` Domain string `json:"domain"` NotAfter time.Time `json:"not_after"` PrivateKeyEncoded string `json:"private_key"` Certificate []byte `json:"certificate"` IssuerCertificate []byte `json:"issuer_certificate"` CertificateUrl string `json:"certificate_url"` } func (c *CertificateWrapper) GetPrivateKey() *rsa.PrivateKey { data, _ := base64.StdEncoding.DecodeString(c.PrivateKeyEncoded) pk, _ := certcrypto.ParsePEMPrivateKey(data) return pk.(*rsa.PrivateKey) } type CertificatesCache struct { FallbackCertificate *CertificateWrapper Certificates map[string]CertificateWrapper } type CertificatesStore struct { FallbackCertificate CertificateWrapper `json:"fallback"` Certificates []CertificateWrapper `json:"certificates"` } func (c *CertificatesCache) toStoreData() string { certs := make([]CertificateWrapper, 0) for _, cert := range c.Certificates { certs = append(certs, cert) } result, err := json.Marshal(CertificatesStore{ FallbackCertificate: *c.FallbackCertificate, Certificates: certs, }) if err != nil { log.Errorf("Failed to Marshal cache: %v", err) } return string(result) } func (c *CertificateWrapper) initTlsCertificate() { pk, _ := base64.StdEncoding.DecodeString(c.PrivateKeyEncoded) tlsCert, _ := tls.X509KeyPair( c.Certificate, pk, ) c.TlsCertificate = &tlsCert } func CertificateFromStoreData(rawJson string) CertificatesCache { var store CertificatesStore _ = json.Unmarshal([]byte(rawJson), &store) store.FallbackCertificate.initTlsCertificate() cache := CertificatesCache{ FallbackCertificate: &store.FallbackCertificate, } certs := make(map[string]CertificateWrapper) for _, cert := range store.Certificates { cert.initTlsCertificate() certs[cert.Domain] = cert } cache.Certificates = certs return cache } func LoadCertificateStoreFromFile(path string) error { content, err := ioutil.ReadFile(path) if err != nil { return err } Certificates = CertificateFromStoreData(string(content)) return nil } func FlushCertificateStoreToFile(path string) { data := Certificates.toStoreData() ioutil.WriteFile(path, []byte(data), 0600) } func InitialiseFallbackCert(pagesDomain string) error { cert, err := fallbackCert(pagesDomain) Certificates.FallbackCertificate = cert return err } func fallbackCert(pagesDomain string) (*CertificateWrapper, error) { key, err := certcrypto.GeneratePrivateKey(certcrypto.RSA2048) if err != nil { return nil, err } notAfter := time.Now().Add(time.Hour * 24 * 7) cert := x509.Certificate{ SerialNumber: big.NewInt(1), Subject: pkix.Name{ CommonName: pagesDomain, Organization: []string{"Pages Server"}, }, NotAfter: notAfter, NotBefore: time.Now(), KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, BasicConstraintsValid: true, } certBytes, err := x509.CreateCertificate( rand.Reader, &cert, &cert, &key.(*rsa.PrivateKey).PublicKey, key, ) if err != nil { return nil, err } out := &bytes.Buffer{} err = pem.Encode(out, &pem.Block{ Bytes: certBytes, Type: "CERTIFICATE", }) if err != nil { return nil, err } outBytes := out.Bytes() res := &certificate.Resource{ PrivateKey: certcrypto.PEMEncode(key), Certificate: outBytes, IssuerCertificate: outBytes, Domain: pagesDomain, } tlsCertificate, err := tls.X509KeyPair(res.Certificate, res.PrivateKey) if err != nil { return nil, err } return &CertificateWrapper{ TlsCertificate: &tlsCertificate, Domain: pagesDomain, NotAfter: notAfter, PrivateKeyEncoded: base64.StdEncoding.EncodeToString(certcrypto.PEMEncode(key)), Certificate: outBytes, IssuerCertificate: outBytes, CertificateUrl: "localhost", }, nil } func isCertStillValid(cert CertificateWrapper) bool { return time.Now().Compare(cert.NotAfter) <= -1 } func makeTlsConfig(pagesDomain, path string, acmeClient *lego.Client) *tls.Config { return &tls.Config{ InsecureSkipVerify: true, GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { // Validate that we should even care about this domain if !strings.HasSuffix(info.ServerName, pagesDomain) { // Note: We do not check err here because err != nil // always implies that cname == "", which does not have // pagesDomain as a suffix. cname, _ := lookupCNAME(info.ServerName) if !strings.HasSuffix(cname, pagesDomain) { log.Warnf("Got ServerName for Domain %s that we're not responsible for", info.ServerName) return Certificates.FallbackCertificate.TlsCertificate, nil } } // If we want to access ., then we can just // use a wildcard certificate. domain := info.ServerName /*if strings.HasSuffix(info.ServerName, pagesDomain) { domain = "*." + pagesDomain }*/ cert, found := Certificates.Certificates[info.ServerName] if found { if isCertStillValid(cert) { return cert.TlsCertificate, nil } else { // If we're already working on the domain, // return the old certificate if lockIfUnlockedDomain(domain) { return cert.TlsCertificate, nil } defer unlockDomain(domain) // TODO: Renew log.Debugf("Certificate for %s expired, renewing", domain) } } else { // Don't request if we're already requesting. if lockIfUnlockedDomain(domain) { return Certificates.FallbackCertificate.TlsCertificate, nil } defer unlockDomain(domain) // Request new certificate log.Debugf("Obtaining new certificate for %s...", domain) err := ObtainNewCertificate( []string{domain}, path, acmeClient, ) if err != nil { log.Errorf( "Failed to get certificate for %s: %v", domain, err, ) return Certificates.FallbackCertificate.TlsCertificate, nil } cert, _ = Certificates.Certificates[domain] return cert.TlsCertificate, nil } log.Debugf("TLS ServerName: %s", info.ServerName) return Certificates.FallbackCertificate.TlsCertificate, nil }, NextProtos: []string{ "http/0.9", "http/1.0", "http/1.1", "h2", "h2c", }, MinVersion: tls.VersionTLS12, CipherSuites: []uint16{ tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, }, } } func addChallenge(id, token string) { runningChallenges[id] = token } func removeChallenge(id string) { delete(runningChallenges, id) } func getChallenge(id string) string { if value, found := runningChallenges[id]; found { return value } return "" } func handleLetsEncryptChallenge(w http.ResponseWriter, req *http.Request) bool { if !strings.HasPrefix(req.URL.Path, AcmeChallengePathPrefix) { return false } log.Debug("Handling ACME challenge path") id := strings.TrimPrefix(req.URL.Path, AcmeChallengePathPrefix) challenge := getChallenge(id) if id == "" { w.WriteHeader(404) return true } w.WriteHeader(200) w.Write([]byte(challenge)) removeChallenge(id) return true }