Skip to content

Commit

Permalink
Merge branch 'main' into metal_relu
Browse files Browse the repository at this point in the history
  • Loading branch information
jbochi authored Dec 31, 2023
2 parents 73156c3 + 1fb2dd9 commit b459774
Show file tree
Hide file tree
Showing 13 changed files with 527 additions and 184 deletions.
55 changes: 55 additions & 0 deletions candle-core/src/metal_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,11 @@ impl BackendStorage for MetalStorage {
(ReduceOp::Max, DType::I64) => ("fast_max_i64_strided", true, false),
(ReduceOp::ArgMin, DType::I64) => ("fast_argmin_i64_strided", true, true),
(ReduceOp::ArgMax, DType::I64) => ("fast_argmax_i64_strided", true, true),
(ReduceOp::Sum, DType::U8) => ("fast_sum_u8_strided", false, false),
(ReduceOp::Min, DType::U8) => ("fast_min_u8_strided", true, false),
(ReduceOp::Max, DType::U8) => ("fast_max_u8_strided", true, false),
(ReduceOp::ArgMin, DType::U8) => ("fast_argmin_u8_strided", true, true),
(ReduceOp::ArgMax, DType::U8) => ("fast_argmax_u8_strided", true, true),
(k, dtype) => crate::bail!("Metal reduce op {k:?} {dtype:?} not implemented"),
};
if check_empty && layout.shape().elem_count() == 0 {
Expand Down Expand Up @@ -660,6 +665,7 @@ impl BackendStorage for MetalStorage {
("ugelu", DType::F32) => contiguous::gelu::FLOAT,
("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT,
("uerf", DType::F32) => contiguous::erf::FLOAT,
("uabs", DType::F32) => contiguous::abs::FLOAT,
("uceil", DType::F32) => contiguous::ceil::FLOAT,
("ufloor", DType::F32) => contiguous::floor::FLOAT,
("uround", DType::F32) => contiguous::round::FLOAT,
Expand All @@ -676,6 +682,7 @@ impl BackendStorage for MetalStorage {
("ugelu", DType::F16) => contiguous::gelu::HALF,
("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF,
("uerf", DType::F16) => contiguous::erf::HALF,
("uabs", DType::F16) => contiguous::abs::HALF,
("uceil", DType::F16) => contiguous::ceil::HALF,
("ufloor", DType::F16) => contiguous::floor::HALF,
("uround", DType::F16) => contiguous::round::HALF,
Expand Down Expand Up @@ -709,6 +716,7 @@ impl BackendStorage for MetalStorage {
("ugelu", DType::F32) => strided::gelu::FLOAT,
("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT,
("uerf", DType::F32) => strided::erf::FLOAT,
("uabs", DType::F32) => strided::abs::FLOAT,
("uceil", DType::F32) => strided::ceil::FLOAT,
("ufloor", DType::F32) => strided::floor::FLOAT,
("urelu", DType::F32) => strided::relu::FLOAT,
Expand All @@ -723,6 +731,7 @@ impl BackendStorage for MetalStorage {
("ugelu", DType::F16) => strided::gelu::HALF,
("ugelu_erf", DType::F16) => strided::gelu_erf::HALF,
("uerf", DType::F16) => strided::erf::HALF,
("uabs", DType::F16) => strided::abs::HALF,
("uceil", DType::F16) => strided::ceil::HALF,
("ufloor", DType::F16) => strided::floor::HALF,
("urelu", DType::F16) => strided::relu::HALF,
Expand Down Expand Up @@ -783,6 +792,8 @@ impl BackendStorage for MetalStorage {
(DType::U8, DType::F32) => "where_u8_f32",
(DType::U8, DType::F16) => "where_u8_f16",
(DType::U8, DType::I64) => "where_u8_i64",
(DType::U8, DType::U32) => "where_u8_u32",
(DType::U8, DType::U8) => "where_u8_u8",
(left, right) => crate::bail!("Metal where_cond {left:?} {right:?} not implemented"),
};
candle_metal_kernels::call_where_cond_strided(
Expand Down Expand Up @@ -1327,6 +1338,26 @@ impl MetalStorage {
("lt", DType::I64) => (contiguous::lt::I64, DType::U8),
("ge", DType::I64) => (contiguous::ge::I64, DType::U8),
("gt", DType::I64) => (contiguous::gt::I64, DType::U8),
("add", DType::U32) => (contiguous::add::U32, self.dtype),
("sub", DType::U32) => (contiguous::sub::U32, self.dtype),
("mul", DType::U32) => (contiguous::mul::U32, self.dtype),
("div", DType::U32) => (contiguous::div::U32, self.dtype),
("eq", DType::U32) => (contiguous::eq::U32, DType::U8),
("ne", DType::U32) => (contiguous::ne::U32, DType::U8),
("le", DType::U32) => (contiguous::le::U32, DType::U8),
("lt", DType::U32) => (contiguous::lt::U32, DType::U8),
("ge", DType::U32) => (contiguous::ge::U32, DType::U8),
("gt", DType::U32) => (contiguous::gt::U32, DType::U8),
("add", DType::U8) => (contiguous::add::U8, self.dtype),
("sub", DType::U8) => (contiguous::sub::U8, self.dtype),
("mul", DType::U8) => (contiguous::mul::U8, self.dtype),
("div", DType::U8) => (contiguous::div::U8, self.dtype),
("eq", DType::U8) => (contiguous::eq::U8, DType::U8),
("ne", DType::U8) => (contiguous::ne::U8, DType::U8),
("le", DType::U8) => (contiguous::le::U8, DType::U8),
("lt", DType::U8) => (contiguous::lt::U8, DType::U8),
("ge", DType::U8) => (contiguous::ge::U8, DType::U8),
("gt", DType::U8) => (contiguous::gt::U8, DType::U8),
(name, dtype) => {
crate::bail!("Metal contiguous binary {name} {dtype:?} not implemented")
}
Expand Down Expand Up @@ -1384,6 +1415,30 @@ impl MetalStorage {
("lt", DType::I64) => (strided::lt::I64, DType::U8),
("ge", DType::I64) => (strided::ge::I64, DType::U8),
("gt", DType::I64) => (strided::gt::I64, DType::U8),
("badd", DType::U32) => (strided::add::U32, self.dtype),
("bsub", DType::U32) => (strided::sub::U32, self.dtype),
("bmul", DType::U32) => (strided::mul::U32, self.dtype),
("bdiv", DType::U32) => (strided::div::U32, self.dtype),
("bminimum", DType::U32) => (strided::min::U32, self.dtype),
("bmaximum", DType::U32) => (strided::max::U32, self.dtype),
("eq", DType::U32) => (strided::eq::U32, DType::U8),
("ne", DType::U32) => (strided::ne::U32, DType::U8),
("le", DType::U32) => (strided::le::U32, DType::U8),
("lt", DType::U32) => (strided::lt::U32, DType::U8),
("ge", DType::U32) => (strided::ge::U32, DType::U8),
("gt", DType::U32) => (strided::gt::U32, DType::U8),
("badd", DType::U8) => (strided::add::U8, self.dtype),
("bsub", DType::U8) => (strided::sub::U8, self.dtype),
("bmul", DType::U8) => (strided::mul::U8, self.dtype),
("bdiv", DType::U8) => (strided::div::U8, self.dtype),
("bminimum", DType::U8) => (strided::min::U8, self.dtype),
("bmaximum", DType::U8) => (strided::max::U8, self.dtype),
("eq", DType::U8) => (strided::eq::U8, DType::U8),
("ne", DType::U8) => (strided::ne::U8, DType::U8),
("le", DType::U8) => (strided::le::U8, DType::U8),
("lt", DType::U8) => (strided::lt::U8, DType::U8),
("ge", DType::U8) => (strided::ge::U8, DType::U8),
("gt", DType::U8) => (strided::gt::U8, DType::U8),
(name, dtype) => {
crate::bail!("Metal strided binary {name} {dtype:?} not implemented")
}
Expand Down
11 changes: 9 additions & 2 deletions candle-examples/examples/llama/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ enum Which {
V2,
#[value(name = "solar-10.7b")]
Solar10_7B,
#[value(name = "tiny-llama-1.1b-chat")]
TinyLlama1_1BChat,
}

#[derive(Parser, Debug)]
Expand Down Expand Up @@ -124,6 +126,7 @@ fn main() -> Result<()> {
Which::V1 => "Narsil/amall-7b".to_string(),
Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(),
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
});
println!("loading the model weights from {model_id}");
let revision = args.revision.unwrap_or("main".to_string());
Expand All @@ -134,8 +137,12 @@ fn main() -> Result<()> {
let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let config = config.into_config(args.use_flash_attn);

let filenames =
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?;
let filenames = match args.which {
Which::V1 | Which::V2 | Which::Solar10_7B => {
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
}
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
};
println!("building the model");
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;

Expand Down
11 changes: 9 additions & 2 deletions candle-examples/examples/reinforcement-learning/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,16 @@ Python package with:
pip install "gymnasium[accept-rom-license]"
```

In order to run the example, use the following command. Note the additional
In order to run the examples, use the following commands. Note the additional
`--package` flag to ensure that there is no conflict with the `candle-pyo3`
crate.

For the Policy Gradient example:
```bash
cargo run --example reinforcement-learning --features=pyo3 --package candle-examples -- pg
```

For the Deep Deterministic Policy Gradient example:
```bash
cargo run --example reinforcement-learning --features=pyo3 --package candle-examples
cargo run --example reinforcement-learning --features=pyo3 --package candle-examples -- ddpg
```
105 changes: 105 additions & 0 deletions candle-examples/examples/reinforcement-learning/ddpg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ use candle_nn::{
};
use rand::{distributions::Uniform, thread_rng, Rng};

use super::gym_env::GymEnv;

pub struct OuNoise {
mu: f64,
theta: f64,
Expand Down Expand Up @@ -449,3 +451,106 @@ impl DDPG<'_> {
Ok(())
}
}

// The impact of the q value of the next state on the current state's q value.
const GAMMA: f64 = 0.99;
// The weight for updating the target networks.
const TAU: f64 = 0.005;
// The capacity of the replay buffer used for sampling training data.
const REPLAY_BUFFER_CAPACITY: usize = 100_000;
// The training batch size for each training iteration.
const TRAINING_BATCH_SIZE: usize = 100;
// The total number of episodes.
const MAX_EPISODES: usize = 100;
// The maximum length of an episode.
const EPISODE_LENGTH: usize = 200;
// The number of training iterations after one episode finishes.
const TRAINING_ITERATIONS: usize = 200;

// Ornstein-Uhlenbeck process parameters.
const MU: f64 = 0.0;
const THETA: f64 = 0.15;
const SIGMA: f64 = 0.1;

const ACTOR_LEARNING_RATE: f64 = 1e-4;
const CRITIC_LEARNING_RATE: f64 = 1e-3;

pub fn run() -> Result<()> {
let env = GymEnv::new("Pendulum-v1")?;
println!("action space: {}", env.action_space());
println!("observation space: {:?}", env.observation_space());

let size_state = env.observation_space().iter().product::<usize>();
let size_action = env.action_space();

let mut agent = DDPG::new(
&Device::Cpu,
size_state,
size_action,
true,
ACTOR_LEARNING_RATE,
CRITIC_LEARNING_RATE,
GAMMA,
TAU,
REPLAY_BUFFER_CAPACITY,
OuNoise::new(MU, THETA, SIGMA, size_action)?,
)?;

let mut rng = rand::thread_rng();

for episode in 0..MAX_EPISODES {
// let mut state = env.reset(episode as u64)?;
let mut state = env.reset(rng.gen::<u64>())?;

let mut total_reward = 0.0;
for _ in 0..EPISODE_LENGTH {
let mut action = 2.0 * agent.actions(&state)?;
action = action.clamp(-2.0, 2.0);

let step = env.step(vec![action])?;
total_reward += step.reward;

agent.remember(
&state,
&Tensor::new(vec![action], &Device::Cpu)?,
&Tensor::new(vec![step.reward as f32], &Device::Cpu)?,
&step.state,
step.terminated,
step.truncated,
);

if step.terminated || step.truncated {
break;
}
state = step.state;
}

println!("episode {episode} with total reward of {total_reward}");

for _ in 0..TRAINING_ITERATIONS {
agent.train(TRAINING_BATCH_SIZE)?;
}
}

println!("Testing...");
agent.train = false;
for episode in 0..10 {
// let mut state = env.reset(episode as u64)?;
let mut state = env.reset(rng.gen::<u64>())?;
let mut total_reward = 0.0;
for _ in 0..EPISODE_LENGTH {
let mut action = 2.0 * agent.actions(&state)?;
action = action.clamp(-2.0, 2.0);

let step = env.step(vec![action])?;
total_reward += step.reward;

if step.terminated || step.truncated {
break;
}
state = step.state;
}
println!("episode {episode} with total reward of {total_reward}");
}
Ok(())
}
Loading

0 comments on commit b459774

Please sign in to comment.