summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorShulhan <m.shulhan@gmail.com>2020-07-26 22:11:24 +0700
committerShulhan <m.shulhan@gmail.com>2020-07-26 22:38:03 +0700
commit73666b02edc55997ab187d0eefea48b662c9de6d (patch)
tree7a5c4146cd73ae81f515165116b12378f588c9cc
parent6cbf38421d453c070cb79ddffde388ad395f0505 (diff)
downloadrescached-73666b02edc55997ab187d0eefea48b662c9de6d.tar.xz
httpd: fetch the hosts block file if its not exist when enabled
While at it, return an error when doing update on hosts block.
-rw-r--r--hostsblock.go21
-rw-r--r--httpd.go18
2 files changed, 23 insertions, 16 deletions
diff --git a/hostsblock.go b/hostsblock.go
index c83cffd..b879d05 100644
--- a/hostsblock.go
+++ b/hostsblock.go
@@ -52,21 +52,16 @@ func (hb *hostsBlock) init(sources []string) {
hb.initLastUpdated()
}
-func (hb *hostsBlock) update() bool {
- if !hb.IsEnabled {
- return false
- }
-
+func (hb *hostsBlock) update() (err error) {
if !hb.isOld() {
- return false
+ return nil
}
fmt.Printf("hostsBlock %s: updating ...\n", hb.Name)
res, err := http.Get(hb.URL)
if err != nil {
- log.Printf("hostsBlock.update %q: %s", hb.Name, err)
- return false
+ return fmt.Errorf("hostsBlock.update %q: %w", hb.Name, err)
}
defer func() {
err := res.Body.Close()
@@ -77,19 +72,19 @@ func (hb *hostsBlock) update() bool {
body, err := ioutil.ReadAll(res.Body)
if err != nil {
- log.Printf("hostsBlock.update %q: %s", hb.Name, err)
- return false
+ return fmt.Errorf("hostsBlock.update %q: %w", hb.Name, err)
}
body = bytes.ReplaceAll(body, []byte("\r\n"), []byte("\n"))
err = ioutil.WriteFile(hb.file, body, 0644)
if err != nil {
- log.Printf("hostsBlock.update %q: %s", hb.Name, err)
- return false
+ return fmt.Errorf("hostsBlock.update %q: %w", hb.Name, err)
}
- return true
+ hb.initLastUpdated()
+
+ return nil
}
func (hb *hostsBlock) hide() (err error) {
diff --git a/httpd.go b/httpd.go
index 7dd04a1..5f422c4 100644
--- a/httpd.go
+++ b/httpd.go
@@ -6,6 +6,7 @@ package rescached
import (
"encoding/json"
+ "errors"
"fmt"
"log"
stdhttp "net/http"
@@ -266,9 +267,18 @@ func (srv *Server) apiHostsBlockUpdate(
}
func (srv *Server) hostsBlockEnable(hb *hostsBlock) (err error) {
+ hb.IsEnabled = true
+
err = hb.unhide()
if err != nil {
- return err
+ if !errors.Is(err, os.ErrNotExist) {
+ return err
+ }
+ // File not exist, fetch new from serfer.
+ err = hb.update()
+ if err != nil {
+ return err
+ }
}
hfile, err := dns.ParseHostsFile(filepath.Join(dirHosts, hb.Name))
@@ -278,8 +288,10 @@ func (srv *Server) hostsBlockEnable(hb *hostsBlock) (err error) {
srv.dns.PopulateCaches(hfile.Messages)
- hb.IsEnabled = true
- hb.update()
+ err = hb.update()
+ if err != nil {
+ return err
+ }
srv.env.HostsFiles = append(srv.env.HostsFiles, convertHostsFile(hfile))
return nil