Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix load context bug #137

Merged
merged 10 commits into from
Jan 5, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions cmd/arcaflow/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,15 +130,16 @@ Options:
RequiredFileKeyWorkflow: workflowFile,
}

requiredFilesAbsPaths, err := loadfile.ContextAbsFilepaths(dir, requiredFiles)
requiredFilesAbsPaths, err := loadfile.AbsPathsWithContext(dir, requiredFiles)
if err != nil {
flag.Usage()
tempLogger.Errorf("Failed to determine absolute path of arcaflow context directory %s (%v)", dir, err)
os.Exit(ExitCodeInvalidData)
}

var configData any = map[any]any{}
if configFile != "" {
configData, err = loadYamlFile(configFile)
configData, err = loadYamlFile(requiredFilesAbsPaths[RequiredFileKeyConfig])
if err != nil {
tempLogger.Errorf("Failed to load configuration file %s (%v)", configFile, err)
flag.Usage()
Expand All @@ -151,17 +152,20 @@ Options:
flag.Usage()
os.Exit(ExitCodeInvalidData)
}
cfg.Log.Stdout = os.Stderr

// now we are ready to instantiate our main logger
cfg.Log.Stdout = os.Stderr
logger := log.New(cfg.Log).WithLabel("source", "main")

var requiredFilesAbsSlice = make([]string, len(requiredFiles))
var requiredFilesAbsSlice = make([]string, len(requiredFilesAbsPaths))
var j int
for _, f := range requiredFilesAbsPaths {
requiredFilesAbsSlice = append(requiredFilesAbsSlice, f)
requiredFilesAbsSlice[j] = f
j++
}
mfleader marked this conversation as resolved.
Show resolved Hide resolved
dirContext, err := loadfile.LoadContext(requiredFilesAbsSlice)
if err != nil {
logger.Errorf("Failed to load configuration file %s (%v)", configFile, err)
logger.Errorf("Failed to load required files into context (%v)", err)
flag.Usage()
os.Exit(ExitCodeInvalidData)
}
Expand Down
17 changes: 9 additions & 8 deletions loadfile/loadfile.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,33 @@ import (
"path/filepath"
)

// LoadContext reads the contents at each file into a map where the key
// is the absolute filepath and file contents is the value.
// LoadContext reads the content of each file into a map where the key
// is the absolute filepath and the file content is the value.
func LoadContext(neededFilepaths []string) (map[string][]byte, error) {
result := map[string][]byte{}
var err error
for _, filePath := range neededFilepaths {
absPath, err := filepath.Abs(filePath)
if err != nil {
return nil, fmt.Errorf("failed to obtain absolute path of file %s (%w)", filepath.Base(filePath), err)
return nil, fmt.Errorf("error obtaining absolute path of file %s (%w)",
filePath, err)
}
fileData, err := os.ReadFile(absPath) //nolint:gosec
fileData, err := os.ReadFile(filepath.Clean(absPath))
dbutenhof marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return nil, fmt.Errorf("failed to read file from context directory: %s (%w)", absPath, err)
return nil, fmt.Errorf("error reading file %s (%w)", absPath, err)
}
result[absPath] = fileData
}
return result, err
}

// ContextAbsFilepaths creates a map of absolute filepaths. If a required
// AbsPathsWithContext creates a map of absolute filepaths. If a required
// file is not provided with an absolute path, then it is joined with the
// root directory.
func ContextAbsFilepaths(rootDir string, requiredFiles map[string]string) (map[string]string, error) {
func AbsPathsWithContext(rootDir string, requiredFiles map[string]string) (map[string]string, error) {
absDir, err := filepath.Abs(rootDir)
if err != nil {
return nil, err
return nil, fmt.Errorf("error determining context directory absolute path %s (%w)", rootDir, err)
}
requiredFilesAbs := map[string]string{}
for key, f := range requiredFiles {
Expand Down
73 changes: 60 additions & 13 deletions loadfile/loadfile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,39 +22,86 @@ func TestLoadContext(t *testing.T) {

assert.NoError(t, os.MkdirAll(testdir, os.ModePerm))

// create a directory and a file
// create a directory
dirname := "mydir"
dirpath := filepath.Join(testdir, dirname)
filename := "myfile"
assert.NoError(t, os.MkdirAll(dirpath, os.ModePerm))
f, err := os.CreateTemp(testdir, filename)

// create a file
filename := "myfile"
filePath := filepath.Join(testdir, filename)
f, err := os.Create(filepath.Clean(filePath))
assert.NoError(t, err)
tempfilepath := f.Name()
assert.NoError(t, f.Close())

// create symlinks to the above directory and file
// create symlink to the directory
symlinkDirname := dirname + "_sym"
symlinkFilepath := tempfilepath + "_sym"
symlinkDirpath := filepath.Join(testdir, symlinkDirname)
assert.NoError(t, os.Symlink(dirpath, symlinkDirpath))
assert.NoError(t, os.Symlink(tempfilepath, symlinkFilepath))

// create symlink to the file
symlinkFilepath := filePath + "_sym"
assert.NoError(t, os.Symlink(filePath, symlinkFilepath))

neededFiles := []string{
tempfilepath,
filePath,
symlinkFilepath,
}
filemap, err := loadfile.LoadContext(neededFiles)
filemapExp := map[string][]byte{
tempfilepath: {},
symlinkFilepath: {},
}
// assert no error on attempting to read files
// that cannot be read
assert.NoError(t, err)

// assert only the regular file will be loaded
// assert only the regular and symlinked file are loaded
filemapExp := map[string][]byte{
filePath: {},
symlinkFilepath: {},
}
assert.Equals(t, filemap, filemapExp)

// error on loading a directory
neededFiles = []string{
dirpath,
}
_, err = loadfile.LoadContext(neededFiles)
assert.Error(t, err)
mfleader marked this conversation as resolved.
Show resolved Hide resolved

// error on loading a symlink directory
neededFiles = []string{
symlinkDirpath,
}
_, err = loadfile.LoadContext(neededFiles)
assert.Error(t, err)

t.Cleanup(func() {
assert.NoError(t, os.RemoveAll(testdir))
})
}

// This tests AbsPathsWithContext joins relative paths with the
// context (root) directory, and passes through absolute paths
// unmodified.
mfleader marked this conversation as resolved.
Show resolved Hide resolved
func TestContextAbsFilepaths(t *testing.T) {
testdir, err := os.MkdirTemp(os.TempDir(), "")
assert.NoError(t, err)

testFilepaths := map[string]string{
"a": "a.yaml",
"b": "/b.toml",
"c": "../rel/subdir/c.txt",
}

absPathsExp := map[string]string{
"a": filepath.Join(testdir, testFilepaths["a"]),
// since the 'b' file has an absolute path, it should be unmodified
"b": "/b.toml",
mfleader marked this conversation as resolved.
Show resolved Hide resolved
"c": filepath.Join(testdir, testFilepaths["c"]),
}
mfleader marked this conversation as resolved.
Show resolved Hide resolved

absPathsGot, err := loadfile.AbsPathsWithContext(testdir, testFilepaths)
assert.NoError(t, err)
assert.Equals(t, absPathsExp, absPathsGot)

t.Cleanup(func() {
assert.NoError(t, os.RemoveAll(testdir))
})
Expand Down
Loading