From b9cc7f30e8301a3a730e7bb5b52675a5b2fe6f72 Mon Sep 17 00:00:00 2001 From: "Alexander \"PapaTutuWawa" Date: Sat, 3 Feb 2024 16:42:08 +0100 Subject: [PATCH] feat: Fix and test header parsing --- internal/repo/repo.go | 19 ++++-- internal/repo/repo_test.go | 123 ++++++++++++++++++++++++++++++++++++- 2 files changed, 137 insertions(+), 5 deletions(-) diff --git a/internal/repo/repo.go b/internal/repo/repo.go index 40feec7..c7423f4 100644 --- a/internal/repo/repo.go +++ b/internal/repo/repo.go @@ -145,7 +145,7 @@ func CanRequestCertificate(username string, ctx *context.GlobalContext) bool { return hasUser } -func filterHeaders(headers map[string]string) map[string]string { +func filterHeaders(headers map[string]interface{}) map[string]string { newHeaders := make(map[string]string) for key, value := range headers { @@ -153,7 +153,10 @@ func filterHeaders(headers map[string]string) map[string]string { continue } - newHeaders[key] = value + switch value.(type) { + case string: + newHeaders[key] = value.(string) + } } return newHeaders @@ -187,7 +190,15 @@ func GetRepositoryInformation(owner, repoName string, ctx *context.GlobalContext headers, found := payload["headers"] if !found { log.Warnf("Did not find headers key in rio.json for %s/%s", owner, repoName) - headers = make(map[string]string) + headers = make(map[string]interface{}) + } else { + switch headers.(type) { + case map[string]interface{}: + // NOOP + default: + log.Warn("headers attribute has invalid data type") + headers = make(map[string]string) + } } cname, found := payload["CNAME"] @@ -204,7 +215,7 @@ func GetRepositoryInformation(owner, repoName string, ctx *context.GlobalContext } info := context.RepositoryInformation{ - Headers: filterHeaders(headers.(map[string]string)), + Headers: filterHeaders(headers.(map[string]interface{})), CNAME: cname.(string), } ctx.Cache.SetRepositoryInformation(owner, repoName, info) diff --git a/internal/repo/repo_test.go b/internal/repo/repo_test.go index 2153ff9..f3de300 100644 --- a/internal/repo/repo_test.go +++ b/internal/repo/repo_test.go @@ -14,7 +14,7 @@ import ( func TestHeaderFilter(t *testing.T) { map1 := filterHeaders( - map[string]string{ + map[string]interface{}{ "Content-Type": "hallo", "content-Type": "welt", "content-type": "uwu", @@ -427,3 +427,124 @@ func TestPickingRepositoryValidCNAMEWithTXTLookupAndSubdirectory(t *testing.T) { t.Fatalf("Invalid repository name returned: %s", repo.Name) } } + +func TestHeaderParsingEmpty(t *testing.T) { + // Test that we are correctly handling a repository with no headers. + log.SetLevel(log.DebugLevel) + client := gitea.GiteaClient{ + GetRepository: func(username, repositoryName string) (gitea.Repository, error) { + if username == "example-user" && repositoryName == "some-different-repository" { + return gitea.Repository{ + Name: "some-different-repository", + }, nil + } + + return gitea.Repository{}, errors.New("Unknown repository") + }, + HasBranch: func(username, repositoryName, branchName string) bool { + if username == "example-user" && repositoryName == "some-different-repository" && branchName == "pages" { + return true + } + + return false + }, + GetFile: func(username, repositoryName, branch, path string, since *time.Time) ([]byte, bool, error) { + if username == "example-user" && repositoryName == "some-different-repository" && branch == "pages" && path == "rio.json" { + return []byte("{\"CNAME\": \"example-user.local\"}"), true, nil + } + + t.Fatalf("Invalid file requested: %s/%s@%s:%s", username, repositoryName, branch, path) + return []byte{}, true, nil + }, + LookupCNAME: func(domain string) (string, error) { + return "", errors.New("No CNAME") + }, + LookupRepoTXT: func(domain string) (string, error) { + if domain == "example-user.local" { + return "some-different-repository", nil + } + return "", nil + }, + } + ctx := &context.GlobalContext{ + Gitea: &client, + Cache: &context.CacheContext{ + RepositoryInformationCache: context.MakeRepoInfoCache(), + RepositoryPathCache: context.MakeRepoPathCache(), + }, + } + + info := GetRepositoryInformation("example-user", "some-different-repository", ctx) + if info == nil { + t.Fatalf("No repository information returned") + } + + if len(info.Headers) > 0 { + t.Fatalf("Headers returned: %v", info.Headers) + } +} + +func TestHeaderParsing(t *testing.T) { + // Test that we are correctly handling a repository with no headers. + log.SetLevel(log.DebugLevel) + client := gitea.GiteaClient{ + GetRepository: func(username, repositoryName string) (gitea.Repository, error) { + if username == "example-user" && repositoryName == "some-different-repository" { + return gitea.Repository{ + Name: "some-different-repository", + }, nil + } + + return gitea.Repository{}, errors.New("Unknown repository") + }, + HasBranch: func(username, repositoryName, branchName string) bool { + if username == "example-user" && repositoryName == "some-different-repository" && branchName == "pages" { + return true + } + + return false + }, + GetFile: func(username, repositoryName, branch, path string, since *time.Time) ([]byte, bool, error) { + if username == "example-user" && repositoryName == "some-different-repository" && branch == "pages" && path == "rio.json" { + return []byte("{\"CNAME\": \"example-user.local\", \"headers\": {\"X-Cool-Header\": \"Very nice!\"}}"), true, nil + } + + t.Fatalf("Invalid file requested: %s/%s@%s:%s", username, repositoryName, branch, path) + return []byte{}, true, nil + }, + LookupCNAME: func(domain string) (string, error) { + return "", errors.New("No CNAME") + }, + LookupRepoTXT: func(domain string) (string, error) { + if domain == "example-user.local" { + return "some-different-repository", nil + } + return "", nil + }, + } + ctx := &context.GlobalContext{ + Gitea: &client, + Cache: &context.CacheContext{ + RepositoryInformationCache: context.MakeRepoInfoCache(), + RepositoryPathCache: context.MakeRepoPathCache(), + }, + } + + info := GetRepositoryInformation("example-user", "some-different-repository", ctx) + if info == nil { + t.Fatalf("No repository information returned") + } + + if len(info.Headers) != 1 { + t.Fatalf("len(info.Headers) != 1: %v", info.Headers) + } + + header, found := info.Headers["X-Cool-Header"] + if !found { + t.Fatal("Header X-Cool-Header not found") + } + + if header != "Very nice!" { + t.Fatalf("Invalid header value for X-Cool-Header: \"%s\"", header) + } +}