Skip to content

Commit

Permalink
Fix remaining python test failures
Browse files Browse the repository at this point in the history
  • Loading branch information
keithtensor committed Dec 5, 2024
1 parent 5c2a8b4 commit 82d0b32
Showing 1 changed file with 149 additions and 36 deletions.
185 changes: 149 additions & 36 deletions src/python_bindings.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
use core::str;
use std::borrow::Cow;

use crate::errors::{ConfigurationError, KeyFileError, PasswordError};
use crate::errors::{ConfigurationError, KeyFileError, PasswordError, WalletError};
use crate::keyfile;
use crate::keyfile::Keyfile as RustKeyfile;
use crate::keypair::Keypair as RustKeypair;
use crate::utils::is_valid_ss58_address;
use crate::wallet::Wallet as RustWallet;
use pyo3::exceptions::{PyException, PyValueError};
use pyo3::exceptions::{PyException, PyTypeError, PyValueError};
use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyModule, PyString, PyTuple, PyType};
use pyo3::wrap_pyfunction;
use pyo3::{create_exception, ffi};
use pyo3::{ffi, wrap_pyfunction};

#[pyclass]
#[derive(Clone)]
Expand Down Expand Up @@ -132,9 +130,30 @@ impl PyKeyfile {
.remove_password_from_env()
.map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))
}

#[pyo3(signature = (password=None))]
fn get_keypair(&self, password: Option<String>) -> PyResult<PyKeypair> {
self.inner
.get_keypair(password)
.map(|inner| PyKeypair { inner })
.map_err(|e| PyErr::new::<PyKeyFileError, _>(e))
}

#[pyo3(signature = (keypair, encrypt=true, overwrite=false, password=None))]
fn set_keypair(
&self,
keypair: PyKeypair,
encrypt: bool,
overwrite: bool,
password: Option<String>,
) -> PyResult<()> {
self.inner
.set_keypair(keypair.inner, encrypt, overwrite, password)
.map_err(|e| PyErr::new::<PyKeyFileError, _>(e))
}
}

#[pyclass(name = "Keypair", subclass)]
#[pyclass(name = "Keypair")]
#[derive(Clone)]
pub struct PyKeypair {
inner: RustKeypair,
Expand Down Expand Up @@ -165,6 +184,7 @@ impl PyKeypair {
}

#[staticmethod]
#[pyo3(signature = (n_words=12))]
fn generate_mnemonic(n_words: usize) -> PyResult<String> {
RustKeypair::generate_mnemonic(n_words).map_err(|e| PyErr::new::<PyValueError, _>(e))
}
Expand Down Expand Up @@ -204,15 +224,81 @@ impl PyKeypair {
Ok(PyKeypair { inner: keypair })
}

fn sign(&self, data: Vec<u8>) -> PyResult<Vec<u8>> {
#[pyo3(signature = (data))]
fn sign(&self, data: PyObject, py: Python) -> PyResult<Cow<[u8]>> {
// Convert data to bytes (data can be a string, hex, or bytes)
let data_bytes = if let Ok(s) = data.extract::<String>(py) {
if s.starts_with("0x") {
hex::decode(s.trim_start_matches("0x")).map_err(|e| {
PyErr::new::<PyConfigurationError, _>(format!("Invalid hex string: {}", e))
})?
} else {
s.into_bytes()
}
} else if let Ok(bytes) = data.extract::<Vec<u8>>(py) {
bytes
} else if let Ok(py_scale_bytes) = data.extract::<&PyAny>(py) {
let scale_data: &PyAny = py_scale_bytes.getattr("data")?;
let scale_data_bytes: Vec<u8> = scale_data.extract()?;

scale_data_bytes.to_vec()
} else {
return Err(PyErr::new::<PyConfigurationError, _>(
"Keypair::sign: Unsupported data format. Expected str or bytes.",
));
};

self.inner
.sign(data)
.map_err(|e| PyErr::new::<PyValueError, _>(e))
}
.sign(data_bytes)
.map(Cow::from)
.map_err(|e| PyErr::new::<PyConfigurationError, _>(e))
}

#[pyo3(signature = (data, signature))]
fn verify(&self, data: PyObject, signature: PyObject, py: Python) -> PyResult<bool> {
// Convert data to bytes (data can be a string, hex, or bytes)
let data_bytes = if let Ok(s) = data.extract::<String>(py) {
if s.starts_with("0x") {
hex::decode(s.trim_start_matches("0x")).map_err(|e| {
PyErr::new::<PyValueError, _>(format!("Invalid hex string: {:?}", e))
})?
} else {
s.into_bytes()
}
} else if let Ok(bytes) = data.extract::<Vec<u8>>(py) {
bytes
} else if let Ok(py_scale_bytes) = data.extract::<&PyAny>(py) {
let scale_data: &PyAny = py_scale_bytes.getattr("data")?;
let scale_data_bytes: Vec<u8> = scale_data.extract()?;

scale_data_bytes.to_vec()
} else {
return Err(PyErr::new::<PyConfigurationError, _>(
"Keypair::verify: Unsupported data format. Expected str or bytes.",
));
};

// Convert signature to bytes
let signature_bytes = if let Ok(s) = signature.extract::<String>(py) {
if s.starts_with("0x") {
hex::decode(s.trim_start_matches("0x")).map_err(|e| {
PyErr::new::<PyValueError, _>(format!("Invalid hex string: {:?}", e))
})?
} else {
return Err(PyErr::new::<PyValueError, _>(
"Invalid signature format. Expected hex string.",
));
}
} else if let Ok(bytes) = signature.extract::<Vec<u8>>(py) {
bytes
} else {
return Err(PyErr::new::<PyTypeError, _>(
"Unsupported signature format. Expected str or bytes.",
));
};

fn verify(&self, data: Vec<u8>, signature: Vec<u8>) -> PyResult<bool> {
self.inner
.verify(data, signature)
.verify(data_bytes, signature_bytes)
.map_err(|e| PyErr::new::<PyValueError, _>(e))
}

Expand All @@ -222,9 +308,10 @@ impl PyKeypair {
}

#[getter]
fn public_key(&self) -> PyResult<Option<Vec<u8>>> {
fn public_key(&self) -> PyResult<Option<Cow<[u8]>>> {
self.inner
.public_key()
.map(|opt| opt.map(Cow::from))
.map_err(|e| PyErr::new::<PyValueError, _>(e))
}

Expand Down Expand Up @@ -262,7 +349,7 @@ impl PyKeypair {
}

// Error type bindings
#[pyclass(name = "KeyFileError")]
#[pyclass(name = "KeyFileError", extends = PyException)]
#[derive(Debug)]
pub struct PyKeyFileError {
inner: KeyFileError,
Expand All @@ -282,7 +369,13 @@ impl PyKeyFileError {
}
}

#[pyclass(name = "ConfigurationError")]
impl IntoPy<PyObject> for KeyFileError {
fn into_py(self, py: Python<'_>) -> PyObject {
Py::new(py, PyKeyFileError { inner: self }).unwrap().into_any()
}
}

#[pyclass(name = "ConfigurationError", extends = PyException)]
#[derive(Debug)]
pub struct PyConfigurationError {
inner: ConfigurationError,
Expand All @@ -302,7 +395,7 @@ impl PyConfigurationError {
}
}

#[pyclass(name = "PasswordError")]
#[pyclass(name = "PasswordError", extends = PyException)]
#[derive(Debug)]
pub struct PyPasswordError {
inner: PasswordError,
Expand All @@ -322,7 +415,31 @@ impl PyPasswordError {
}
}

create_exception!(errors, WalletError, PyException);
#[pyclass(name = "WalletError", extends = PyException)]
#[derive(Debug)]
pub struct PyWalletError {
inner: WalletError,
}

#[pymethods]
impl PyWalletError {
#[new]
fn new(msg: String) -> Self {
PyWalletError {
inner: WalletError::InvalidInput(msg),
}
}

fn __str__(&self) -> PyResult<String> {
Ok(self.inner.to_string())
}
}

impl IntoPy<PyObject> for WalletError {
fn into_py(self, py: Python<'_>) -> PyObject {
Py::new(py, PyWalletError { inner: self }).unwrap().into_any()
}
}

// Define the Python module using PyO3
#[pymodule]
Expand Down Expand Up @@ -353,10 +470,7 @@ fn register_config_module(main_module: Bound<'_, PyModule>) -> PyResult<()> {
fn register_errors_module(main_module: Bound<'_, PyModule>) -> PyResult<()> {
let errors_module = PyModule::new_bound(main_module.py(), "errors")?;
// Register the WalletError exception
errors_module.add(
"WalletError",
main_module.py().get_type_bound::<WalletError>(),
)?;
errors_module.add_class::<PyWalletError>()?;
errors_module.add_class::<PyConfigurationError>()?;
errors_module.add_class::<PyKeyFileError>()?;
errors_module.add_class::<PyPasswordError>()?;
Expand All @@ -368,26 +482,26 @@ fn register_errors_module(main_module: Bound<'_, PyModule>) -> PyResult<()> {
fn py_serialized_keypair_to_keyfile_data(py: Python, keypair: &PyKeypair) -> PyResult<PyObject> {
keyfile::serialized_keypair_to_keyfile_data(&keypair.inner)
.map(|bytes| PyBytes::new_bound(py, &bytes).into_py(py))
.map_err(|inner| PyErr::new::<PyKeyFileError, _>(PyKeyFileError { inner }))
.map_err(|e| PyErr::new::<PyKeyFileError, _>(e))
}

#[pyfunction(name = "deserialize_keypair_from_keyfile_data")]
fn py_deserialize_keypair_from_keyfile_data(keyfile_data: &[u8]) -> PyResult<PyKeypair> {
keyfile::deserialize_keypair_from_keyfile_data(keyfile_data)
.map(|inner| PyKeypair { inner })
.map_err(|inner| PyErr::new::<PyKeyFileError, _>(PyKeyFileError { inner }))
.map_err(|e| PyErr::new::<PyKeyFileError, _>(e))
}

#[pyfunction(name = "validate_password")]
fn py_validate_password(password: &str) -> PyResult<bool> {
keyfile::validate_password(password)
.map_err(|inner| PyErr::new::<PyKeyFileError, _>(PyKeyFileError { inner }))
.map_err(|e| PyErr::new::<PyKeyFileError, _>(e))
}

#[pyfunction(name = "ask_password")]
fn py_ask_password(validation_required: bool) -> PyResult<String> {
keyfile::ask_password(validation_required)
.map_err(|inner| PyErr::new::<PyKeyFileError, _>(PyKeyFileError { inner }))
.map_err(|e| PyErr::new::<PyKeyFileError, _>(e))
}

#[pyfunction(name = "legacy_encrypt_keyfile_data")]
Expand All @@ -397,21 +511,21 @@ fn py_legacy_encrypt_keyfile_data(
password: Option<String>,
) -> PyResult<Vec<u8>> {
keyfile::legacy_encrypt_keyfile_data(keyfile_data, password)
.map_err(|inner| PyErr::new::<PyKeyFileError, _>(PyKeyFileError { inner }))
.map_err(|e| PyErr::new::<PyKeyFileError, _>(e))
}

#[pyfunction(name = "get_password_from_environment")]
fn py_get_password_from_environment(env_var_name: String) -> PyResult<Option<String>> {
keyfile::get_password_from_environment(env_var_name)
.map_err(|inner| PyErr::new::<PyKeyFileError, _>(PyKeyFileError { inner }))
.map_err(|e| PyErr::new::<PyKeyFileError, _>(e))
}

#[pyfunction(name = "encrypt_keyfile_data")]
#[pyo3(signature = (keyfile_data, password=None))]
fn py_encrypt_keyfile_data(keyfile_data: &[u8], password: Option<String>) -> PyResult<Cow<[u8]>> {
keyfile::encrypt_keyfile_data(keyfile_data, password)
.map(Cow::from)
.map_err(|inner| PyErr::new::<PyKeyFileError, _>(PyKeyFileError { inner }))
.map_err(|e| PyErr::new::<PyKeyFileError, _>(e))
}

#[pyfunction(name = "decrypt_keyfile_data")]
Expand All @@ -423,7 +537,7 @@ fn py_decrypt_keyfile_data(
) -> PyResult<Cow<[u8]>> {
keyfile::decrypt_keyfile_data(keyfile_data, password, password_env_var)
.map(Cow::from)
.map_err(|inner| PyErr::new::<PyKeyFileError, _>(PyKeyFileError { inner }))
.map_err(|e| PyErr::new::<PyKeyFileError, _>(e))
}

// keyfile module with functions
Expand Down Expand Up @@ -552,7 +666,10 @@ fn py_is_valid_bittensor_address_or_public_key(address: &Bound<'_, PyAny>) -> bo

fn register_utils_module(main_module: Bound<'_, PyModule>) -> PyResult<()> {
let utils_module = PyModule::new_bound(main_module.py(), "utils")?;
utils_module.add_function(wrap_pyfunction!(is_valid_ss58_address, &utils_module)?)?;
utils_module.add_function(wrap_pyfunction!(
crate::utils::is_valid_ss58_address,
&utils_module
)?)?;

utils_module.add_function(wrap_pyfunction!(py_get_ss58_format, &utils_module)?)?;
utils_module.add_function(wrap_pyfunction!(
Expand Down Expand Up @@ -929,9 +1046,7 @@ impl Wallet {
save_coldkey_to_env.unwrap_or(false),
coldkey_password,
)
.map_err(|e| {
PyErr::new::<WalletError, _>(format!("Failed to regenerate coldkey: {:?}", e))
})?;
.map_err(|e| PyErr::new::<PyWalletError, _>(e))?;
self.inner = new_inner_wallet;
Ok(())
}
Expand All @@ -946,9 +1061,7 @@ impl Wallet {
let new_inner_wallet = self
.inner
.regenerate_coldkeypub(ss58_address, public_key, overwrite.unwrap_or(false))
.map_err(|e| {
PyErr::new::<WalletError, _>(format!("Failed to regenerate coldkeypub: {:?}", e))
})?;
.map_err(|e| PyErr::new::<PyWalletError, _>(e))?;
self.inner = new_inner_wallet;
Ok(())
}
Expand Down

0 comments on commit 82d0b32

Please sign in to comment.