diff --git a/README.md b/README.md index a2173bd..b29e994 100644 --- a/README.md +++ b/README.md @@ -67,6 +67,34 @@ Listing the contents of an archive: $ aar list -f archive.aarch ``` +Encrypting an archive: + +```bash +$ aar encrypt -f archive.aarch +Password: +Confirm password: +``` + +Where `` is the password you want to use to encrypt the archive, with a minimum length of 8 characters. +It removes the original _.aarch_ file and creates a new one with the encrypted data, with extension _.aarch.enc_. + +> [!NOTE] +> The encryption is done using the AES-256-GCM algorithm, and it only works for angel archives. + +Decrypting an archive: + +```bash +$ aar decrypt -f archive.aarch.enc +Password: +Confirm password: +``` + +Where `` is the password you used to encrypt the archive. +It removes the encrypted _.aarch.enc_ file and creates a new one with the decrypted data, with extension _.aarch_. + +> [!NOTE] +> The decryption is done using the AES-256-GCM algorithm, and it only works for encrypted angel archives. + ## File Format ### Archive Header diff --git a/archive/archive.go b/archive/archive.go index 0d8bd03..46345ed 100644 --- a/archive/archive.go +++ b/archive/archive.go @@ -1,6 +1,9 @@ package archive -import "io" +import ( + "bytes" + "io" +) // An Archive represents a collection of files stored in a single file. type Archive struct { @@ -20,6 +23,16 @@ func (a *Archive) TotalSize() uint64 { return total } +// GetBytes returns the archive as a byte slice. +func (a *Archive) GetBytes() ([]byte, error) { + data := new(bytes.Buffer) + if err := a.Write(data); err != nil { + return nil, err + } + + return data.Bytes(), nil +} + // Write writes the archive into the provided writer. func (a *Archive) Write(w io.Writer) error { if err := a.Header.Write(w); err != nil { diff --git a/archive/encryption.go b/archive/encryption.go new file mode 100644 index 0000000..aea9088 --- /dev/null +++ b/archive/encryption.go @@ -0,0 +1,150 @@ +package archive + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "crypto/sha256" + "io" + + "golang.org/x/crypto/pbkdf2" +) + +const ( + saltSize = 16 + nonceSize = 12 +) + +// An EncryptedArchive represents an encrypted archive. +type EncryptedArchive struct { + bytes []byte + salt []byte + nonce []byte +} + +// Write writes the encrypted archive into the provided writer. +// The encrypted archive is serialized as follows: +// +// 1. The magic field is serialized as a 4-byte sequence. +// 2. The salt field is serialized as a 16-byte sequence. +// 3. The nonce field is serialized as a sequence of bytes. +// 4. The encrypted data is serialized as a sequence of bytes. +func (a *EncryptedArchive) Write(w io.Writer) error { + // Write the magic (4 bytes) + if _, err := w.Write(encMagic); err != nil { + return err + } + + // Write the salt (16 bytes) + if _, err := w.Write(a.salt); err != nil { + return err + } + + // Write the nonce + if _, err := w.Write(a.nonce); err != nil { + return err + } + + // Write the encrypted data + if _, err := w.Write(a.bytes); err != nil { + return err + } + + return nil +} + +// ReadEncryptedArchive reads an encrypted archive from the provided reader. +func ReadEncryptedArchive(r io.Reader) (*EncryptedArchive, error) { + if err := mustReadEncryptedMagic(r); err != nil { + return nil, err + } + + // Read the salt (16 bytes) + salt := make([]byte, saltSize) + if _, err := io.ReadFull(r, salt); err != nil { + return nil, err + } + + // Read the nonce + nonce := make([]byte, nonceSize) + if _, err := io.ReadFull(r, nonce); err != nil { + return nil, err + } + + // Read the encrypted data + data, err := io.ReadAll(r) + if err != nil { + return nil, err + } + + return &EncryptedArchive{ + bytes: data, + salt: salt, + nonce: nonce, + }, nil +} + +// Encrypt encrypts the archive using AES-GCM with the provided password. +func (a *Archive) Encrypt(password string) (*EncryptedArchive, error) { + // Generate a salt for key derivation (PBKDF2) + salt := make([]byte, saltSize) + if _, err := rand.Read(salt); err != nil { + return nil, err + } + + aesGCM, err := newCipher(password, salt) + if err != nil { + return nil, err + } + + // Generate a nonce for AES-GCM (random IV) + // aesGCM.NonceSize() returns 12 bytes + nonce := make([]byte, nonceSize) + if _, err := rand.Read(nonce); err != nil { + return nil, err + } + + // Get the plaintext data + plaintext, err := a.GetBytes() + if err != nil { + return nil, err + } + + // Encrypt the data using AES-GCM + ciphertext := aesGCM.Seal(nil, nonce, plaintext, nil) + + return &EncryptedArchive{ + bytes: ciphertext, + salt: salt, + nonce: nonce, + }, nil +} + +// Decrypt decrypts the encrypted archive using AES-GCM with the provided password. +// If the password is incorrect, the process will fail as the Archive data will be +// corrupted. +func (a *EncryptedArchive) Decrypt(password string) (*Archive, error) { + aesGCM, err := newCipher(password, a.salt) + if err != nil { + return nil, err + } + + plaintext, err := aesGCM.Open(nil, a.nonce, a.bytes, nil) + if err != nil { + return nil, err + } + + return ReadArchive(bytes.NewReader(plaintext)) +} + +// newCipher creates a new AES-GCM cipher with the provided password and salt. +func newCipher(password string, salt []byte) (cipher.AEAD, error) { + key := pbkdf2.Key([]byte(password), salt, 4096, 32, sha256.New) + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + return cipher.NewGCM(block) +} diff --git a/archive/encryption_test.go b/archive/encryption_test.go new file mode 100644 index 0000000..b7ce7f0 --- /dev/null +++ b/archive/encryption_test.go @@ -0,0 +1,76 @@ +package archive + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestEncryptDecryptArchive(t *testing.T) { + archive := makeTestArchive() + + encrypted, err := archive.Encrypt("password") + assert.Nil(t, err) + + decrypted, err := encrypted.Decrypt("password") + assert.Nil(t, err) + + assert.Equal(t, archive, decrypted) +} + +func TestWriteAndReadEncryptedArchive(t *testing.T) { + var ( + archive = makeTestArchive() + encrypted, _ = archive.Encrypt("password") + w = new(bytes.Buffer) + ) + + err := encrypted.Write(w) + assert.Nil(t, err) + + r := bytes.NewReader(w.Bytes()) + readArchive, err := ReadEncryptedArchive(r) + assert.Nil(t, err) + + assert.Equal(t, encrypted, readArchive) +} + +func makeTestArchive() *Archive { + return &Archive{ + Header: &Header{ + HeaderLength: 46, + Entries: []*HeaderFileEntry{ + { + Name: "file1.txt", + Size: 12, + Offset: 28, + }, + { + Name: "file2.txt", + Size: 16, + Offset: 40, + }, + }, + }, + Files: []*ArchiveFile{ + { + FileName: "file1.txt", + CompressedBytes: []byte{ + 0x78, 0x9c, 0x4b, 0x4c, + 0x4f, 0x49, 0x2d, 0x2e, + 0x01, 0x00, 0x00, 0xff, + }, + }, + { + FileName: "file2.txt", + CompressedBytes: []byte{ + 0x78, 0x9c, 0x4b, 0x4c, + 0x4f, 0x49, 0x2d, 0x2e, + 0x01, 0x00, 0x00, 0xff, + 0x78, 0x9c, 0x4b, 0x4c, + }, + }, + }, + } +} diff --git a/archive/header.go b/archive/header.go index 19d1035..115155d 100644 --- a/archive/header.go +++ b/archive/header.go @@ -6,13 +6,6 @@ import ( "io" ) -// magic is a unique identifier for the archive format. -// It's the ASCII representation of "AAR?". -var magic = []byte{0x41, 0x41, 0x52, 0x3F} - -// magicLen is the length of the magic field in bytes. -const magicLen = uint32(4) - // byteOrder is the byte order used to serialize integers. var byteOrder = binary.LittleEndian diff --git a/archive/magic.go b/archive/magic.go new file mode 100644 index 0000000..b512a56 --- /dev/null +++ b/archive/magic.go @@ -0,0 +1,59 @@ +package archive + +import ( + "bytes" + "fmt" + "io" +) + +// magic is a unique identifier for the archive format. +// It's the ASCII representation of "AAR?". +var magic = []byte{0x41, 0x41, 0x52, 0x3F} + +// encMagic is a unique identifier for the encrypted archive format. +// It's the ASCII representation of "AARX". +var encMagic = []byte{0x41, 0x41, 0x52, 0x58} + +// magicLen is the length of the magic field in bytes. +const magicLen = uint32(4) + +// ErrInvalidMagic is returned when the magic field is not correct. +var ErrInvalidMagic = fmt.Errorf("invalid magic, expected %v", magic) + +// ErrInvalidEncMagic is returned when the magic field is not correct. +var ErrInvalidEncMagic = fmt.Errorf("invalid magic, expected %v", encMagic) + +// mustReadMagic reads the magic field from the provided reader. +// If the magic field is not correct, it returns an error. +func mustReadMagic(r io.Reader) error { + readMagic := make([]byte, 4) + + // Read the magic (4 bytes) + if _, err := io.ReadFull(r, readMagic); err != nil { + return err + } + + // Check if the magic is correct + if !bytes.Equal(magic, readMagic) { + return ErrInvalidMagic + } + + return nil +} + +// mustReadEncryptedMagic reads the magic field from the provided reader. +func mustReadEncryptedMagic(r io.Reader) error { + readMagic := make([]byte, 4) + + // Read the magic (4 bytes) + if _, err := io.ReadFull(r, readMagic); err != nil { + return err + } + + // Check if the magic is correct + if !bytes.Equal(encMagic, readMagic) { + return ErrInvalidEncMagic + } + + return nil +} diff --git a/archive/utils.go b/archive/utils.go deleted file mode 100644 index 5e361e6..0000000 --- a/archive/utils.go +++ /dev/null @@ -1,25 +0,0 @@ -package archive - -import ( - "bytes" - "fmt" - "io" -) - -// mustReadMagic reads the magic field from the provided reader. -// If the magic field is not correct, it returns an error. -func mustReadMagic(r io.Reader) error { - readMagic := make([]byte, 4) - - // Read the magic (4 bytes) - if _, err := io.ReadFull(r, readMagic); err != nil { - return err - } - - // Check if the magic is correct - if !bytes.Equal(magic, readMagic) { - return fmt.Errorf("invalid magic: got %v, expected %v", readMagic, magic) - } - - return nil -} diff --git a/cmd/aar/main.go b/cmd/aar/main.go index c27bd26..b7caf2d 100644 --- a/cmd/aar/main.go +++ b/cmd/aar/main.go @@ -19,6 +19,12 @@ func main() { listCmd = flag.NewFlagSet("list", flag.ExitOnError) listFileNameFlag = listCmd.String("f", "", "Filename of the archive to list") + + encryptCmd = flag.NewFlagSet("encrypt", flag.ExitOnError) + encryptFileNameFlag = encryptCmd.String("f", "", "Filename of the archive to encrypt") + + decryptCmd = flag.NewFlagSet("decrypt", flag.ExitOnError) + decryptFileNameFlag = decryptCmd.String("f", "", "Filename of the archive to decrypt") ) if len(os.Args) < 2 { @@ -48,6 +54,20 @@ func main() { validateFileName(*listFileNameFlag) cmd.ListArchive(*listFileNameFlag) + case "encrypt": + encryptCmd.Parse(os.Args[2:]) + validateFileName(*encryptFileNameFlag) + password := cmd.PromptPassword() + + cmd.EncryptArchive(*encryptFileNameFlag, password) + + case "decrypt": + decryptCmd.Parse(os.Args[2:]) + validateFileName(*decryptFileNameFlag) + password := cmd.PromptPassword() + + cmd.DecryptArchive(*decryptFileNameFlag, password) + default: fmt.Fprintf(os.Stderr, "Usage: aar [options]\n") os.Exit(1) diff --git a/cmd/cypt.go b/cmd/cypt.go new file mode 100644 index 0000000..33bc334 --- /dev/null +++ b/cmd/cypt.go @@ -0,0 +1,107 @@ +package cmd + +import ( + "fmt" + "os" + + "github.com/angelsolaorbaiceta/aar/archive" +) + +func EncryptArchive(fileName, password string) { + // Read the archive + reader, err := os.OpenFile(fileName, os.O_RDONLY, 0) + if err != nil { + fmt.Fprintf(os.Stderr, "Error opening archive file: %v\n", err) + os.Exit(1) + } + + arch, err := archive.ReadArchive(reader) + if err != nil { + fmt.Fprintf(os.Stderr, "Error reading archive: %v\n", err) + os.Exit(1) + } + + // Encrypt the archive + encArch, err := arch.Encrypt(password) + if err != nil { + fmt.Fprintf(os.Stderr, "Error encrypting archive: %v\n", err) + os.Exit(1) + } + + // Write the encrypted archive to disk + encFileName := fileName + ".enc" + encFile, err := os.Create(encFileName) + if err != nil { + fmt.Fprintf(os.Stderr, "Error creating encrypted archive file: %v\n", err) + os.Exit(1) + } + defer encFile.Close() + + if err := encArch.Write(encFile); err != nil { + fmt.Fprintf(os.Stderr, "Error writing encrypted archive file: %v\n", err) + os.Exit(1) + } + + fmt.Fprintf(os.Stderr, "Archive encrypted successfully to %s\n", encFileName) + + // Remove the original archive + if err := os.Remove(fileName); err != nil { + fmt.Fprintf(os.Stderr, "Error removing original archive: %v\n", err) + os.Exit(1) + } +} + +func DecryptArchive(fileName, password string) { + // Read the encrypted archive + reader, err := os.OpenFile(fileName, os.O_RDONLY, 0) + if err != nil { + fmt.Fprintf(os.Stderr, "Error opening encrypted archive file: %v\n", err) + os.Exit(1) + } + + encArch, err := archive.ReadEncryptedArchive(reader) + if err != nil { + fmt.Fprintf(os.Stderr, "Error reading encrypted archive: %v\n", err) + os.Exit(1) + } + + // Decrypt the archive + arch, err := encArch.Decrypt(password) + if err != nil { + fmt.Fprintf(os.Stderr, "Error decrypting archive: %v\n", err) + os.Exit(1) + } + + // Write the decrypted archive to disk + decFileName := decryptFileName(fileName) + decFile, err := os.Create(decFileName) + if err != nil { + fmt.Fprintf(os.Stderr, "Error creating decrypted archive file: %v\n", err) + os.Exit(1) + } + defer decFile.Close() + + if err := arch.Write(decFile); err != nil { + fmt.Fprintf(os.Stderr, "Error writing decrypted archive file: %v\n", err) + os.Exit(1) + } + + fmt.Fprintf(os.Stderr, "Archive decrypted successfully to %s\n", decFileName) + + // Remove the encrypted archive + if err := os.Remove(fileName); err != nil { + fmt.Fprintf(os.Stderr, "Error removing encrypted archive: %v\n", err) + os.Exit(1) + } +} + +// decryptFileName returns the decrypted file name from the encrypted file name. +// If the file name doesn't end with ".enc", it appends ".dec" to the file name. +// Otherwise, it removes the ".enc" extension. +func decryptFileName(fileName string) string { + if len(fileName) < 4 || fileName[len(fileName)-4:] != ".enc" { + return fileName + ".dec" + } + + return fileName[:len(fileName)-4] +} diff --git a/cmd/password.go b/cmd/password.go new file mode 100644 index 0000000..d7ec18a --- /dev/null +++ b/cmd/password.go @@ -0,0 +1,49 @@ +package cmd + +import ( + "fmt" + "os" + "syscall" + + "golang.org/x/term" +) + +func PromptPassword() string { + fmt.Print("Password: ") + passwordBytes, err := term.ReadPassword(int(syscall.Stdin)) + if err != nil { + fmt.Fprintf(os.Stderr, "Error reading password: %v\n", err) + os.Exit(1) + } + + // Move to the next line after password input + fmt.Println() + + password := string(passwordBytes) + validatePassword(password) + + fmt.Print("Confirm password: ") + passwordConfirmationBytes, err := term.ReadPassword(int(syscall.Stdin)) + if err != nil { + fmt.Fprintf(os.Stderr, "Error reading password: %v\n", err) + os.Exit(1) + } + + // Move to the next line after password input + fmt.Println() + + passwordConfirmation := string(passwordConfirmationBytes) + if password != passwordConfirmation { + fmt.Fprintf(os.Stderr, "Passwords do not match.\n") + os.Exit(1) + } + + return password +} + +func validatePassword(password string) { + if len(password) < 8 { + fmt.Fprintf(os.Stderr, "The password must be at least 8 characters long.\n") + os.Exit(1) + } +} diff --git a/go.mod b/go.mod index 743809e..678c409 100644 --- a/go.mod +++ b/go.mod @@ -5,10 +5,13 @@ go 1.22.0 require ( github.com/stretchr/testify v1.9.0 github.com/ulikunitz/xz v0.5.12 + golang.org/x/crypto v0.26.0 + golang.org/x/term v0.23.0 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/sys v0.23.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index b43ca2a..86283c5 100644 --- a/go.sum +++ b/go.sum @@ -6,6 +6,12 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/ulikunitz/xz v0.5.12 h1:37Nm15o69RwBkXM0J6A5OlE67RZTfzUxTj8fB3dfcsc= github.com/ulikunitz/xz v0.5.12/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14= +golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= +golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54= +golang.org/x/sys v0.23.0 h1:YfKFowiIMvtgl1UERQoTPPToxltDeZfbj4H7dVUCwmM= +golang.org/x/sys v0.23.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.23.0 h1:F6D4vR+EHoL9/sWAWgAR1H2DcHr4PareCbAaCo1RpuU= +golang.org/x/term v0.23.0/go.mod h1:DgV24QBUrK6jhZXl+20l6UWznPlwAHm1Q1mGHtydmSk= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=