diff --git a/cache/hashmap.go b/cache/hashmap.go index b4a6c12..84cc24b 100644 --- a/cache/hashmap.go +++ b/cache/hashmap.go @@ -90,7 +90,7 @@ func (c *HashMapCache) Store(path []byte, host []byte, instance *commonInstance) } } -func (c *HashMapCache) Get(path []byte, host []byte) (string, io.Reader, int) { +func (c *HashMapCache) Get(path []byte, host []byte) (string, io.Reader, int, bool) { if !c.hosts { if instance, ok := c.data.Get(path); ok { return instance.(*commonInstance).Get(instance.(*commonInstance), nil) @@ -115,7 +115,7 @@ func (c *HashMapCache) Get(path []byte, host []byte) (string, io.Reader, int) { } } - return "", nil, 0 + return "", nil, 0, false } func (c *HashMapCache) Source() source.Source { diff --git a/cache/shared.go b/cache/shared.go index 973db39..83e9512 100644 --- a/cache/shared.go +++ b/cache/shared.go @@ -57,12 +57,12 @@ func index(source source.Source, f source.IndexFunc) (int64, error) { return totalSize, nil } -func load(c Cache) func(*commonInstance, []byte) (string, io.Reader, int) { - return func(instance *commonInstance, host []byte) (string, io.Reader, int) { - hijacker := c.Source().Get(instance.Instance.AbsolutePath, host) +func load(c Cache) func(*commonInstance, []byte) (string, io.Reader, int, bool) { + return func(instance *commonInstance, host []byte) (string, io.Reader, int, bool) { + hijacker, failed := c.Source().Get(instance.Instance.AbsolutePath, host) if hijacker == nil { - return "", nil, 0 + return "", nil, 0, failed } log.Debug().Msgf("Loaded file [%d][%s]: %s", hijacker.Size, hijacker.FileType(), instance.Instance.AbsolutePath) @@ -71,12 +71,12 @@ func load(c Cache) func(*commonInstance, []byte) (string, io.Reader, int) { instance.Instance.LoadTime = time.Now() instance.Instance.Data = hijacker.Buffer instance.Instance.ContentType = hijacker.FileType() - instance.Get = func(cache *commonInstance, _ []byte) (string, io.Reader, int) { - return cache.Instance.ContentType, bytes.NewReader(cache.Instance.Data), len(cache.Instance.Data) + instance.Get = func(cache *commonInstance, _ []byte) (string, io.Reader, int, bool) { + return cache.Instance.ContentType, bytes.NewReader(cache.Instance.Data), len(cache.Instance.Data), false } } - return hijacker.FileType(), hijacker, hijacker.Size + return hijacker.FileType(), hijacker, hijacker.Size, false } } diff --git a/cache/types.go b/cache/types.go index d065ab7..ed517ab 100644 --- a/cache/types.go +++ b/cache/types.go @@ -16,7 +16,7 @@ type CachedInstance struct { type commonInstance struct { Instance *CachedInstance - Get func(instance *commonInstance, host []byte) (string, io.Reader, int) + Get func(instance *commonInstance, host []byte) (string, io.Reader, int, bool) } type KeyValue struct { @@ -26,7 +26,7 @@ type KeyValue struct { type Cache interface { Index() (int64, error) - Get(path []byte, host []byte) (string, io.Reader, int) + Get(path []byte, host []byte) (string, io.Reader, int, bool) Source() source.Source Iter() <-chan KeyValue Store(path []byte, host []byte, instance *commonInstance) diff --git a/server/webserver.go b/server/webserver.go index 0ff987f..3f1426b 100644 --- a/server/webserver.go +++ b/server/webserver.go @@ -58,11 +58,16 @@ type Webserver struct { } func (h *Webserver) HandleFastHTTP(ctx *fasthttp.RequestCtx) { - if fileType, stream, size := h.Cache.Get(ctx.Path(), ctx.Host()); size > 0 { + fileType, stream, size, failed := h.Cache.Get(ctx.Path(), ctx.Host()) + if size > 0 { ctx.SetContentType(fileType) ctx.SetBodyStream(stream, size) } else { - ctx.SetStatusCode(404) + if failed { + ctx.SetStatusCode(500) + } else { + ctx.SetStatusCode(404) + } } } diff --git a/source/local.go b/source/local.go index a9cd07d..7ed2296 100644 --- a/source/local.go +++ b/source/local.go @@ -16,22 +16,22 @@ var _ Source = (*Local)(nil) type Local struct { } -func (l Local) Get(path string, host []byte) *utils.StreamHijacker { +func (l Local) Get(path string, host []byte) (*utils.StreamHijacker, bool) { file, err := os.OpenFile(path, os.O_RDONLY, 0664) if err != nil { log.Err(err).Msg("error reading file") - return nil + return nil, true } stat, err := file.Stat() if err != nil { log.Err(err).Msg("error reading file") - return nil + return nil, true } fileType := mime.TypeByExtension(filepath.Ext(filepath.Base(path))) - return utils.NewStreamHijacker(int(stat.Size()), fileType, file) + return utils.NewStreamHijacker(int(stat.Size()), fileType, file), false } func (l Local) IndexPath(dir string, f IndexFunc) (int64, int64, error) { diff --git a/source/s3.go b/source/s3.go index 864d216..6bedff8 100644 --- a/source/s3.go +++ b/source/s3.go @@ -39,7 +39,7 @@ func NewS3(bucket string, key string, secret string, endpoint string, region str }, nil } -func (s S3) Get(path string, _ []byte) *utils.StreamHijacker { +func (s S3) Get(path string, _ []byte) (*utils.StreamHijacker, bool) { return GetS3(s.S3Client, s.Bucket, path) } diff --git a/source/s3_redis.go b/source/s3_redis.go index 98582a3..57c9605 100644 --- a/source/s3_redis.go +++ b/source/s3_redis.go @@ -46,11 +46,11 @@ func NewS3Redis(network string, address string, username string, password string }, nil } -func (s S3Redis) Get(path string, host []byte) *utils.StreamHijacker { +func (s S3Redis) Get(path string, host []byte) (*utils.StreamHijacker, bool) { var s3Wrapper *S3Wrapper if host == nil { - return nil + return nil, true } if instance, ok := s.CredentialCache.Get(utils.ByteSliceToString(host)); ok { @@ -60,11 +60,11 @@ func (s S3Redis) Get(path string, host []byte) *utils.StreamHijacker { if get.Err() != nil { if errors.Is(get.Err(), redis.Nil) { log.Warn().Str("host", utils.ByteSliceToString(host)).Msg("no credentials found") - return nil + return nil, true } log.Error().Err(get.Err()).Msg("failed to get credentials") - return nil + return nil, true } s3Flat := yeet.GetRootAsS3(utils.UnsafeGetBytes(get.Val()), 0) @@ -79,13 +79,13 @@ func (s S3Redis) Get(path string, host []byte) *utils.StreamHijacker { if err != nil { log.Err(err).Msg("failed to create new S3 session") - return nil + return nil, true } cf, err := cuckoo.Decode(s3Flat.Filter()) if err != nil { log.Err(err).Msg("failed to decode filter") - return nil + return nil, true } s3Wrapper = &S3Wrapper{ @@ -100,7 +100,7 @@ func (s S3Redis) Get(path string, host []byte) *utils.StreamHijacker { return GetS3(s3Wrapper.S3Client, s3Wrapper.Bucket, path) } - return nil + return nil, false } func (s S3Redis) IndexPath(_ string, _ IndexFunc) (int64, int64, error) { diff --git a/source/s3_utils.go b/source/s3_utils.go index 712fce3..dc77f63 100644 --- a/source/s3_utils.go +++ b/source/s3_utils.go @@ -10,7 +10,7 @@ import ( "strings" ) -func GetS3(client *s3.S3, bucket string, path string) *utils.StreamHijacker { +func GetS3(client *s3.S3, bucket string, path string) (*utils.StreamHijacker, bool) { cleanedKey := strings.TrimPrefix(path, "/") object, err := client.GetObject(&s3.GetObjectInput{ @@ -20,10 +20,10 @@ func GetS3(client *s3.S3, bucket string, path string) *utils.StreamHijacker { if err != nil { log.Err(err).Msg("failed to get object") - return nil + return nil, true } fileType := mime.TypeByExtension(filepath.Ext(filepath.Base(path))) - return utils.NewStreamHijacker(int(*object.ContentLength), fileType, object.Body) + return utils.NewStreamHijacker(int(*object.ContentLength), fileType, object.Body), false } diff --git a/source/types.go b/source/types.go index e7c9daa..4f97069 100644 --- a/source/types.go +++ b/source/types.go @@ -7,7 +7,7 @@ import ( type IndexFunc = func(absolutePath string, cleanedPath string) int64 type Source interface { - Get(path string, host []byte) *utils.StreamHijacker + Get(path string, host []byte) (*utils.StreamHijacker, bool) IndexPath(dir string, f IndexFunc) (int64, int64, error) Watch() (<-chan WatchEvent, error) }