Skip to content

Commit

Permalink
Refactor codes
Browse files Browse the repository at this point in the history
  • Loading branch information
minhduc0711 committed Mar 20, 2019
1 parent b80997c commit 76a0d00
Show file tree
Hide file tree
Showing 18 changed files with 150 additions and 145 deletions.
Binary file modified .idea/caches/build_file_checksums.ser
Binary file not shown.
9 changes: 3 additions & 6 deletions .idea/gradle.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

29 changes: 0 additions & 29 deletions .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Binary file added app/src/main/assets/decision_tree_x.pmml.ser
Binary file not shown.
Binary file added app/src/main/assets/decision_tree_y.pmml.ser
Binary file not shown.
Binary file removed app/src/main/assets/model.pb
Binary file not shown.
Binary file added app/src/main/assets/neural_net.pb
Binary file not shown.
Binary file removed app/src/main/assets/poly_x.pmml.ser
Binary file not shown.
Binary file removed app/src/main/assets/poly_y.pmml.ser
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,16 @@
import android.os.Bundle;
import android.support.v4.app.ActivityCompat;
import android.support.v7.app.AppCompatActivity;
import android.view.View;
import android.widget.AdapterView;
import android.widget.Spinner;

public class MainActivity extends AppCompatActivity implements SensorEventListener {
import com.example.minhduc0711.indoorlocalization.models.PredictiveModel;
import com.example.minhduc0711.indoorlocalization.models.ScikitModel;
import com.example.minhduc0711.indoorlocalization.models.TensorflowModel;

public class MainActivity extends AppCompatActivity implements SensorEventListener, AdapterView.OnItemSelectedListener {
private Spinner mModelSpinner;
private MapView mMapView;

private SensorManager mSensorManager;
Expand All @@ -32,8 +40,36 @@ protected void onCreate(Bundle savedInstanceState) {
mMapView.getHolder().setFormat(PixelFormat.TRANSPARENT);

mSensorManager = (SensorManager) getSystemService(SENSOR_SERVICE);

// mModelSpinner = findViewById(R.id.model_spinner);
// ArrayAdapter<CharSequence> adapter = ArrayAdapter.createFromResource(this,
// R.array.models_array, android.R.layout.simple_spinner_item);
// adapter.setDropDownViewResource(android.R.layout.simple_spinner_dropdown_item);
// mModelSpinner.setAdapter(adapter);
// mModelSpinner.setOnItemSelectedListener(this);
// mModelSpinner.setSelection(0);
}

@Override
public void onItemSelected(AdapterView<?> parent, View view, int position, long id) {
String selected = (String) parent.getItemAtPosition(position);
PredictiveModel newModel;
switch (selected) {
case "Neural network":
newModel = new TensorflowModel("neural_net.pb", this);
break;
case "Decision tree":
newModel = new ScikitModel("decision_tree_x.pmml.ser", "decision_tree_y.pmml.ser", this);
break;
default:
throw new RuntimeException();
}
mMapView.initializeModel(newModel);
}

@Override
public void onNothingSelected(AdapterView<?> parent) {}

@Override
protected void onResume() {
super.onResume();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,17 @@
import android.view.SurfaceView;
import android.widget.Toast;

import com.example.minhduc0711.indoorlocalization.models.PredictiveModel;
import com.example.minhduc0711.indoorlocalization.models.ScikitModel;
import com.example.minhduc0711.indoorlocalization.models.TensorflowModel;

import org.json.JSONException;
import org.json.JSONObject;

import java.util.List;

public class MapView extends SurfaceView implements SurfaceHolder.Callback {
private static final String TRAIN_IDX_DICT_PATH = "cleaned_idx_dict.json";
private static final String MODEL_NAME = "model";

private final static float X_MAX = 20;
private final static float Y_MAX = 40;
Expand Down Expand Up @@ -48,18 +51,23 @@ public MapView(Context context, AttributeSet attrs) {

mPositionIndicator = new PositionIndicator();

mWifiManager = (WifiManager) context.getSystemService(Context.WIFI_SERVICE);
mWifiManager = (WifiManager) context.getApplicationContext().getSystemService(Context.WIFI_SERVICE);
if (mWifiManager.getWifiState() == WifiManager.WIFI_STATE_DISABLED) {
Toast.makeText(getContext(), "Enabling Wifi...", Toast.LENGTH_LONG).show();
mWifiManager.setWifiEnabled(true);
}

mPredictiveModel = new PredictiveModel(context, PredictiveModel.PB_FILE_EXT, MODEL_NAME);
initializeModel(new TensorflowModel("neural_net.pb", context));
// initializeModel(new ScikitModel("decision_tree_x.pmml.ser", "decision_tree_y.pmml.ser", context));
trainIndexDict = Utils.loadJSONFromAsset(TRAIN_IDX_DICT_PATH, context);

mDrawThread = new DrawThread();
}

public void initializeModel(PredictiveModel newModel) {
mPredictiveModel = newModel;
}

/**
* Converts the wifi status to a feature vector
*/
Expand Down Expand Up @@ -89,18 +97,9 @@ private void updatePositionIndicator() {
if (mWifiManager.getWifiState() == WifiManager.WIFI_STATE_ENABLED) {
List<ScanResult> wifiResults = mWifiManager.getScanResults();

long startTime = System.nanoTime();
float[] input = toFeatureVector(wifiResults);
float[] output = mPredictiveModel.predict(input);
long endTime = System.nanoTime();

long duration = (endTime - startTime);
// Log.d("output", Arrays.toString(output));

// Random r = new Random();
// int x = r.nextInt((int) X_MAX);
// Random r1 = new Random();
// int y = r1.nextInt((int) Y_MAX);
mPositionIndicator.update(Math.round(output[0]), Math.round(output[1]));
}
}
Expand All @@ -112,13 +111,6 @@ public void updateOrientation(int angle) {
mDeviceOrientation = angle;
}

/**
* Performs the actual drawing on the SurfaceView
*/
private void doDraw() {

}

@Override
public void surfaceCreated(SurfaceHolder holder) {
threadRunning = true;
Expand Down Expand Up @@ -160,7 +152,7 @@ public void run() {
long endTime = System.currentTimeMillis();
long deltaTime = endTime - startTime;

if(deltaTime < 200)
if (deltaTime < 200)
{
try {
Thread.sleep(200 - deltaTime);
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,6 @@
import java.io.InputStream;

public class Utils {
public static Bitmap drawableToBitmap(Drawable drawable) {
if (drawable instanceof BitmapDrawable) {
return ((BitmapDrawable) drawable).getBitmap();
}

Bitmap bitmap = Bitmap.createBitmap(drawable.getIntrinsicWidth(), drawable.getIntrinsicHeight(), Bitmap.Config.ARGB_8888);
Canvas canvas = new Canvas(bitmap);
drawable.setBounds(0, 0, canvas.getWidth(), canvas.getHeight());
drawable.draw(canvas);

return bitmap;
}

public static JSONObject loadJSONFromAsset(String fname, Context context) {
String json;
try {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package com.example.minhduc0711.indoorlocalization.models;

public interface PredictiveModel {
float[] predict(float[] input);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package com.example.minhduc0711.indoorlocalization.models;

import android.content.Context;

import org.dmg.pmml.FieldName;
import org.jpmml.android.EvaluatorUtil;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.InputField;
import org.jpmml.evaluator.ModelEvaluator;

import java.io.InputStream;
import java.util.HashMap;
import java.util.Map;

public class ScikitModel implements PredictiveModel {
private ModelEvaluator<?> mEvaluatorX;
private ModelEvaluator<?> mEvaluatorY;

public ScikitModel(String fileNameX, String fileNameY, Context context) {
try {
InputStream is = context.getAssets().open(fileNameX);
mEvaluatorX = (ModelEvaluator<?>) EvaluatorUtil.createEvaluator(is);
is = context.getAssets().open(fileNameY);
mEvaluatorY = (ModelEvaluator<?>) EvaluatorUtil.createEvaluator(is);
} catch (Exception e) {
e.printStackTrace();
}
}

@Override
public float[] predict(float[] input) {
float[] output = new float[2];

int i = 0;
Map<FieldName, FieldValue> inputMap = new HashMap<>();
for (InputField inputField : mEvaluatorX.getInputFields()) {
FieldName fieldName = inputField.getName();
Object rawValue = input[i++];
FieldValue fieldValue = inputField.prepare(rawValue);
inputMap.put(fieldName, fieldValue);
}
Map<String, ?> resultsX = org.jpmml.evaluator.EvaluatorUtil.decode(mEvaluatorX.evaluate(inputMap));
Map<String, ?> resultsY = org.jpmml.evaluator.EvaluatorUtil.decode(mEvaluatorY.evaluate(inputMap));
output[0] = ((Double) resultsX.get("y")).floatValue();
output[1] = ((Double) resultsY.get("y")).floatValue();

return output;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package com.example.minhduc0711.indoorlocalization.models;

import android.content.Context;

import org.tensorflow.contrib.android.TensorFlowInferenceInterface;

public class TensorflowModel implements PredictiveModel {
private TensorFlowInferenceInterface mInferenceInterface;

public TensorflowModel(String fileName, Context context) {
mInferenceInterface = new TensorFlowInferenceInterface(context.getAssets(), fileName);
}

@Override
public float[] predict(float[] input) {
float[] output = new float[2];

String[] outputNames = {"output_node0"};
mInferenceInterface.feed("dense_55_input", input, 1, input.length);
mInferenceInterface.run(outputNames);
mInferenceInterface.fetch(outputNames[0], output);

return output;
}
}
17 changes: 17 additions & 0 deletions app/src/main/res/layout/activity_main.xml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,23 @@
android:layout_height="match_parent"
xmlns:android="http://schemas.android.com/apk/res/android">

<!--<LinearLayout-->
<!--android:layout_width="match_parent"-->
<!--android:layout_height="wrap_content"-->
<!--android:orientation="vertical">-->

<!--<TextView-->
<!--android:layout_width="wrap_content"-->
<!--android:layout_height="wrap_content"-->
<!--android:text="Select model"/>-->

<!--<Spinner-->
<!--android:id="@+id/model_spinner"-->
<!--android:layout_width="match_parent"-->
<!--android:layout_height="wrap_content" />-->

<!--</LinearLayout>-->

<FrameLayout android:layout_height="match_parent"
android:layout_width="match_parent"
android:id="@+id/frame_layout_main"
Expand Down
Loading

0 comments on commit 76a0d00

Please sign in to comment.