diff --git a/README.md b/README.md index e69de29..96ea66f 100644 --- a/README.md +++ b/README.md @@ -0,0 +1,16 @@ +# Testing + +With the go-get command + +```sh +$ go get -v -u jbowen.dev/cereal +get "jbowen.dev/cereal": found meta tag get.metaImport{Prefix:"jbowen.dev/cereal", VCS:"git", RepoRoot:"https://github.com/jamesbo13/cereal"} at //jbowen.dev/cereal?go-get=1 +jbowen.dev/cereal (download) +``` + +With the httpie command + +```sh +$ http --body "https://gopkg.in/yaml.v3?go-get=1" +... +``` diff --git a/app.go b/app.go new file mode 100644 index 0000000..5043e9d --- /dev/null +++ b/app.go @@ -0,0 +1,30 @@ +package main + +import ( + "fmt" + "net/http" +) + +type Application struct { + Config *Config + Mux *http.ServeMux + Log *LogFile +} + +func (app *Application) routes() { + m := http.NewServeMux() + + m.HandleFunc("/healthcheck", healthcheck) + m.HandleFunc("/reloadRules", reloadRules(app.Config)) + m.HandleFunc("/", serveLogger(app.Log)(serveRules(app.Config))) + + app.Mux = m +} + +func (app *Application) Setup(port string) *http.Server { + app.routes() + return &http.Server{ + Addr: fmt.Sprintf(":%s", port), + Handler: app.Mux, + } +} diff --git a/conf.go b/conf.go new file mode 100644 index 0000000..dd1be0a --- /dev/null +++ b/conf.go @@ -0,0 +1,62 @@ +package main + +import ( + "encoding/json" + "fmt" + "os" +) + +type Config struct { + MappingFilePath string + MappingRules ImportRulesMappings +} + +type ImportRulesMappings struct { + Mappings []struct { + Protocol string `json:"protocol"` + VanityUrl string `json:"vanity_url"` + RealUrl string `json:"real_url"` + } `json:"mappings"` +} + +type ImportRuleStruct struct { + VanityUrl string + Proto string + RepoUrl string +} + +// isFile - check if fp is a valid file +func isFile(fp string) bool { + info, err := os.Stat(fp) + if os.IsNotExist(err) || !info.Mode().IsRegular() { + return false + } + return true +} + +// load mapping file +func (c *Config) LoadMappingFile(fp string) error { + var mapping ImportRulesMappings + mappingFilePath := fp + if len(c.MappingFilePath) == 0 { + ok := isFile(mappingFilePath) + if !ok { + return fmt.Errorf("%s is not found", mappingFilePath) + } + } + mappingFile, err := os.Open(mappingFilePath) + if err != nil { + return err + } + defer mappingFile.Close() + + err = json.NewDecoder(mappingFile).Decode(&mapping) + if err != nil { + return err + } + c.MappingRules = mapping + if len(c.MappingFilePath) == 0 { + c.MappingFilePath = mappingFilePath + } + return nil +} diff --git a/embed.go b/embed.go new file mode 100644 index 0000000..8a32e90 --- /dev/null +++ b/embed.go @@ -0,0 +1,19 @@ +package main + +import ( + "embed" + "html/template" +) + +// go:embed templates/* +var tmpls embed.FS + +func GetServeHtml() *template.Template { + data, _ := tmpls.ReadFile("success.html") + return template.Must(template.New("main").Parse(string(data))) +} + +func GetDefaultHtml() []byte { + data, _ := tmpls.ReadFile("default.html") + return data +} diff --git a/handlers.go b/handlers.go new file mode 100644 index 0000000..139474c --- /dev/null +++ b/handlers.go @@ -0,0 +1,71 @@ +package main + +import ( + "bytes" + "net/http" + "strings" +) + +func healthcheck(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("ok")) +} + +func reloadRules(c *Config) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + err := c.LoadMappingFile("") + if err != nil { + errorLog.Printf("Cannot reload rules: %+v", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.Write([]byte("ok")) + } +} + +func serveRules(c *Config) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusInternalServerError) + return + } + + // if go-get param is absent, return nothing + if r.FormValue("go-get") != "1" { + w.Write(GetDefaultHtml()) + return + } + + nameOfPkg := r.Host + r.URL.Path + + var vanityUrl, proto, repoUrl string + for _, rule := range c.MappingRules.Mappings { + if strings.HasPrefix(strings.ToLower(nameOfPkg), strings.ToLower(rule.VanityUrl+"/")) { + repo := strings.Replace(strings.ToLower(nameOfPkg), strings.ToLower(rule.VanityUrl), "", -1) + repo = strings.Split(repo, "/")[1] + + vanityUrl = rule.VanityUrl + "/" + repo + repoUrl = rule.RealUrl + "/" + repo + proto = rule.Protocol + + break + } + } + + d := ImportRuleStruct{ + VanityUrl: vanityUrl, + Proto: proto, + RepoUrl: repoUrl, + } + tmpl := GetServeHtml() + + var buf bytes.Buffer + err := tmpl.Execute(&buf, &d) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + w.Header().Set("Cache-Control", "public, max-age=500") + w.Write(buf.Bytes()) + } +} diff --git a/logger.go b/logger.go new file mode 100644 index 0000000..b44afcb --- /dev/null +++ b/logger.go @@ -0,0 +1,148 @@ +package main + +import ( + "encoding/json" + "log" + "net/http" + "os" + "strings" +) + +// some headers not worth logging +var ( + hdrsToNotLog = []string{ + "Accept-Language", + "Cache-Control", + "Cf-Ray", + "CF-Visitor", + "CF-Connecting-IP", + "Cdn-Loop", + "Cookie", + "Connection", + "Dnt", + "If-Modified-Since", + "Sec-Fetch-Dest", + "Sec-Ch-Ua-Mobile", + // "Sec-Ch-Ua", + "Sec-Ch-Ua-Platform", + "Sec-Fetch-Site", + "Sec-Fetch-Mode", + "Sec-Fetch-User", + "Upgrade-Insecure-Requests", + "X-Request-Start", + "X-Forwarded-For", + "X-Forwarded-Proto", + "X-Forwarded-Host", + } + hdrsToNotLogMap map[string]bool +) + +type LogFile struct { + handle *os.File + logger *log.Logger + path string +} + +type LogFileRec struct { + Method string `json:"method"` + IpAddr string `json:"ipAddr"` + Url string `json:"url"` +} + +func newFileLogger(path string) (*LogFile, error) { + f, err := os.OpenFile(path, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0666) + if err != nil { + return nil, err + } + return &LogFile{ + handle: f, + logger: log.New(f, "", 0), + path: path, + }, nil +} + +func (f *LogFile) Close() error { + if f == nil { + return nil + } + err := f.handle.Close() + f.handle = nil + return err +} + +func extractFirstFragment(header *http.Header, headerName string) string { + s := header.Get(headerName) + if len(strings.TrimSpace(s)) == 0 { + return s + } + fragments := strings.Split(s, ",") + return strings.TrimSpace(fragments[0]) +} + +// Get Ip Address of the client +func extractIpAddress(r *http.Request) string { + var ipAddr string + if r == nil { + return "" + } + possibleIpHeaders := []string{"CF-Connecting-IP", "X-Real-Ip", "X-Forwarded-For"} + for _, header := range possibleIpHeaders { + ipAddr = extractFirstFragment(&r.Header, header) + if len(strings.TrimSpace(ipAddr)) != 0 { + return ipAddr + } + } + // pull ip from Request.RemoteAddr + if len(strings.TrimSpace(r.RemoteAddr)) != 0 { + index := strings.LastIndex(r.RemoteAddr, ";") + if index == -1 { + return r.RemoteAddr + } + ipAddr = r.RemoteAddr[:index] + } + return ipAddr +} + +func canSkipExtraHeaders(r *http.Request) bool { + ref := r.Header.Get("Referer") + if len(strings.TrimSpace(ref)) == 0 { + return false + } + return strings.Contains(ref, r.Host) +} + +func shouldLogHeader(s string) bool { + if hdrsToNotLogMap == nil { + hdrsToNotLogMap = map[string]bool{} + for _, h := range hdrsToNotLog { + h = strings.ToLower(h) + hdrsToNotLogMap[h] = true + } + } + s = strings.ToLower(s) + return !hdrsToNotLogMap[s] +} + +func (f *LogFile) WriteLog(r *http.Request) error { + if f == nil { + return nil + } + var rec = make(map[string]string) + rec["method"] = r.Method + rec["requestUri"] = r.RequestURI + rec["Host"] = r.Host + rec["ipAddr"] = extractIpAddress(r) + if !canSkipExtraHeaders(r) { + for key, val := range r.Header { + if shouldLogHeader(key) && len(val) > 0 { + rec[key] = val[0] + } + } + } + b, err := json.Marshal(rec) + if err != nil { + return err + } + f.logger.Println(string(b)) + return nil +} diff --git a/main.go b/main.go index 7905807..288ca4a 100644 --- a/main.go +++ b/main.go @@ -1,5 +1,115 @@ package main -func main() { +import ( + "context" + "errors" + "flag" + "fmt" + "log" + "net/http" + "os" + "os/signal" + "syscall" + "time" +) +var errorLog *log.Logger = log.New(os.Stderr, "", log.LstdFlags) + +const ( + DEFAULT_RULES_FILE string = "/var/lib/gocustomurls/rules.json" + DEFAULT_LOG_FILE string = "/var/log/gocustomurls/app.log" +) + +// flagsSet returns a set of all the flags what were actually set on the +// command line. +func flagsSet(flags *flag.FlagSet) map[string]bool { + s := make(map[string]bool) + flags.Visit(func(f *flag.Flag) { + s[f.Name] = true + }) + return s +} + +func main() { + programName := os.Args[0] + // errorLog = log.New(os.Stderr, "", log.LstdFlags) + + flags := flag.NewFlagSet(os.Args[0], flag.ExitOnError) + flags.Usage = func() { + out := flags.Output() + fmt.Fprintf(out, "Usage: %v [flags]\n\n", programName) + fmt.Fprint(out, " This utility serves vanity urls for the go get/install command.\n") + fmt.Fprint(out, " By default, the server listens on localhost:7070.\n") + flags.PrintDefaults() + } + + portFlag := flags.String("port", "7070", "port to listen to") + rulesFileFlag := flags.String("config", DEFAULT_RULES_FILE, "contains go-import mapping") + logFileFlag := flags.String("logfile", DEFAULT_LOG_FILE, "default log file") + flags.Parse(os.Args[1:]) + + if len(flags.Args()) > 1 { + errorLog.Println("Error: too many command-line arguments") + flags.Usage() + os.Exit(1) + } + + allSetFlags := flagsSet(flags) + + var port string + if allSetFlags["port"] { + port = *portFlag + } + + var rulesFile string + if allSetFlags["config"] { + rulesFile = *rulesFileFlag + } + + var logFile string + if allSetFlags["logFile"] { + logFile = *logFileFlag + } + + // load rules mapping + c := &Config{} + err := c.LoadMappingFile(rulesFile) + if err != nil { + errorLog.Println(err) + os.Exit(1) + } + l, err := newFileLogger(logFile) + if err != nil { + errorLog.Println(err) + os.Exit(1) + } + + app := &Application{ + Config: c, + Log: l, + } + srv := app.Setup(port) + + // For graceful shutdowns + go func() { + err := srv.ListenAndServe() + if !errors.Is(err, http.ErrServerClosed) { + errorLog.Printf("HTTP Server error: %+v\n", err) + os.Exit(1) + } + app.Log.logger.Println("Stopped serving new connections.") + }() + + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + <-sigChan + + shutdownCtx, shutdownRelease := context.WithTimeout(context.Background(), 10*time.Second) + defer shutdownRelease() + + if err := srv.Shutdown(shutdownCtx); err != nil { + errorLog.Printf("HTTP shutdown error: %+v\n", err) + os.Exit(1) + } + app.Log.logger.Println("Graceful shutdown complete.") } diff --git a/middleware.go b/middleware.go new file mode 100644 index 0000000..d4ed069 --- /dev/null +++ b/middleware.go @@ -0,0 +1,23 @@ +package main + +import ( + "net/http" +) + +// serveLogger is a logging middleware for serving. It generates logs for +// requests sent to the server. +func serveLogger(l *LogFile) func(http.HandlerFunc) http.HandlerFunc { + return func(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + l.WriteLog(r) + next(w, r) + } + } +} + +// func serveLogger(logger *LogFile, next http.Handler) http.Handler { +// return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +// logger.WriteLog(r) +// next.ServeHTTP(w, r) +// }) +// } diff --git a/templates/default.html b/templates/default.html new file mode 100644 index 0000000..76bdefa --- /dev/null +++ b/templates/default.html @@ -0,0 +1,6 @@ + +
+ +