会话构建

  • create_session:Session create_session(Graph **graph, int h, int w, int c, int truth_num, char **type*, char **path*)

​ 创建会话实例

graph

计算图

h

输入图像数据的height

w

输入图像数据的width

c

输入图像数据的channel

truth_num

离散标签数据个数

type

运行内核选择CPU/GPU

path

权重文件路径

  • init_session:void init_session(Session **sess*, char **data_path*, char **label_path*)

​ 初始化会话

sess

会话实例

data_path

数据路径

label_path

标签路径

  • set_train_params:void set_train_params(Session **sess*, int epoch, int batch, int subdivision, float learning_rate)

​ 设置训练超参数

sess

会话实例

epoch

训练轮次

batch

随机梯度下降批次大小

subdivision

批次分割大小

learning_rate

步长(学习率)

  • set_detect_params:void set_detect_params(Session **sess*)

​ 设置测试超参数

  • train:void train(Session **sess*)

​ 运行训练

  • detect_classification:void detect_classification(Session **sess*)

​ 运行测试