モデルデータを読込 tensorflow1.0.1

Androidアプリでモデルデータを読み込む
Tensorflowのversionが1.0.1だとAPIの仕様が変更されているのでその2の記事の方は動かない
以下がversion1.0.1仕様のMainActivity

package com.example.yamikachan.irisdetector;

import android.content.DialogInterface;
import android.content.res.AssetManager;
import android.support.v7.app.AlertDialog;
import android.support.v7.app.AppCompatActivity;
import android.os.Bundle;
import android.view.View;
import android.widget.Button;
import android.widget.EditText;
import android.widget.TextView;

import org.tensorflow.contrib.android.TensorFlowInferenceInterface;

public class MainActivity extends AppCompatActivity {
    private TextView ansView;
    private EditText editIrisFeature1;
    private EditText editIrisFeature2;
    private EditText editIrisFeature3;
    private EditText editIrisFeature4;
    private Button detectButton;
    static {
        System.loadLibrary("tensorflow_inference");
    }

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);

        ansView = (TextView) findViewById(R.id.answer_text);

        editIrisFeature1 = (EditText) findViewById(R.id.edit_Iris_feature1);
        editIrisFeature2 = (EditText) findViewById(R.id.edit_Iris_feature2);
        editIrisFeature3 = (EditText) findViewById(R.id.edit_Iris_feature3);
        editIrisFeature4 = (EditText) findViewById(R.id.edit_Iris_feature4);

        detectButton = (Button) findViewById(R.id.detect_Button);
        detectButton.setOnClickListener(new View.OnClickListener(){
            @Override
            public void onClick(View v) {
                float[] features = new float[4];
                String value1 = null;
                String value2 = null;
                String value3 = null;
                String value4 = null;
                Boolean flag = false;
                try{
                    value1 = editIrisFeature1.getText().toString();
                    value2 = editIrisFeature2.getText().toString();
                    value3 = editIrisFeature3.getText().toString();
                    value4 = editIrisFeature4.getText().toString();

                    features[0] = Float.valueOf(value1);
                    features[1] = Float.valueOf(value2);
                    features[2] = Float.valueOf(value3);
                    features[3] = Float.valueOf(value4);
                }catch (java.lang.NumberFormatException e){
                    flag = true;
                    ansView.setText("");
                    AlertDialog.Builder alertDialog = new AlertDialog.Builder(MainActivity.this);

                    alertDialog.setTitle("input error");
                    alertDialog.setMessage("Please input values");
                    alertDialog.setPositiveButton("OK", new DialogInterface.OnClickListener() {
                        public void onClick(DialogInterface dialog,int which) {
                        }
                    });

                    alertDialog.create();
                    alertDialog.show();
                }

                if(!flag){
                    onDetectClicked(features);
                }

            }
        });


    }
    private void onDetectClicked(float[] f) {
        AssetManager mAssetManager = getAssets();
        TensorFlowInferenceInterface mTensorFlowIF = new TensorFlowInferenceInterface(mAssetManager, "file:///android_asset/frozen_model.pb");

        ansView.setText("");

        //入力データを入れる
        //第一引数にはモデル作成時に指定した入力データを格納するPlaceholderのnameを指定する
        //第二引数には入力データ
        //第三引数にはshapeを指定する 今回は一つのサンプルデータを渡してテストするので(1,4)
        mTensorFlowIF.feed("input:0",f,new long[] {1,4});
        //判定結果を格納する配列
        float[] result_value = new float[3];
        //判定を行うこの時にモデル作成時の出力のnameを指定する
        mTensorFlowIF.run(new String[] {"output:0"});
        //result_valueに結果を代入する
        mTensorFlowIF.fetch("output:0", result_value);

        int ansIndex = getAnswer(result_value);
        switch (ansIndex){
            case 0:
                ansView.setText("Detected : Iris-Setosa");
                break;

            case 1:
                ansView.setText("Detected : Iris-versicolor");
                break;

            case 2:
                ansView.setText("Detected : Iris-virginica");
                break;
        }
    }

    private int getAnswer(float[] f){
        int argmax = 0;
        float max = f[0];
        for(int i=0;i<f.length;i++){
            if(max < f[i]){
                max = f[i];
                argmax = i;
            }
        }

        return argmax;
    }
}

activity_main.xml

<?xml version="1.0" encoding="utf-8"?>
<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"
    xmlns:tools="http://schemas.android.com/tools"
    android:id="@+id/activity_main"
    android:layout_width="match_parent"
    android:layout_height="match_parent"
    android:paddingBottom="@dimen/activity_vertical_margin"
    android:paddingLeft="@dimen/activity_horizontal_margin"
    android:paddingRight="@dimen/activity_horizontal_margin"
    android:paddingTop="@dimen/activity_vertical_margin"
    tools:context="各自のproject名">

    <LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
        android:orientation="horizontal"
        android:layout_width="match_parent"
        android:layout_height="match_parent"
        android:paddingTop="16dp">
        <EditText
            android:id="@+id/edit_Iris_feature1"
            android:inputType="numberDecimal"
            android:layout_width="80dp"
            android:layout_height="wrap_content"
            android:background="#ffffff"
            android:layout_marginLeft="5dp"
            android:layout_marginRight="5dp" />
        <EditText
            android:id="@+id/edit_Iris_feature2"
            android:inputType="numberDecimal"
            android:layout_width="80dp"
            android:layout_height="wrap_content"
            android:background="#ffffff"
            android:layout_marginLeft="5dp"
            android:layout_marginRight="5dp" />
        <EditText
            android:id="@+id/edit_Iris_feature3"
            android:inputType="numberDecimal"
            android:layout_width="80dp"
            android:layout_height="wrap_content"
            android:background="#ffffff"
            android:layout_marginLeft="5dp"
            android:layout_marginRight="5dp" />
        <EditText
            android:id="@+id/edit_Iris_feature4"
            android:inputType="numberDecimal"
            android:layout_width="80dp"
            android:layout_height="wrap_content"
            android:background="#ffffff"
            android:layout_marginLeft="5dp"
            android:layout_marginRight="5dp" />
    </LinearLayout>

    <Button
        android:id="@+id/detect_Button"
        android:text="Button"
        android:layout_width="wrap_content"
        android:layout_height="wrap_content"
        android:layout_centerHorizontal="true"
        android:layout_marginTop="70dp" />

    <TextView
        android:id="@+id/answer_text"
        android:layout_width="match_parent"
        android:layout_height="wrap_content"
        android:layout_below="@+id/detect_Button"
        android:textSize="30sp"
        android:layout_margin="30dp"
        android:gravity="center" />

</RelativeLayout>

build.gradle androidの所に下記を追加しておく

sourceSets {
    main {
        jniLibs.srcDirs = ['libs']
        assets.srcDirs = ['assets']
    }
}


データセットからサンプルを選び、正解ラベルが返ってきたら成功

results matching ""

    No results matching ""