会话构建
create_session:Session create_session(Graph **graph, int h, int w, int c, int truth_num, char **type*, char **path*)
创建会话实例
graph |
计算图 |
h |
输入图像数据的height |
w |
输入图像数据的width |
c |
输入图像数据的channel |
truth_num |
离散标签数据个数 |
type |
运行内核选择CPU/GPU |
path |
权重文件路径 |
初始化会话
sess |
会话实例 |
data_path |
数据路径 |
label_path |
标签路径 |
set_train_params:void set_train_params(Session **sess*, int epoch, int batch, int subdivision, float learning_rate)
设置训练超参数
sess |
会话实例 |
epoch |
训练轮次 |
batch |
随机梯度下降批次大小 |
subdivision |
批次分割大小 |
learning_rate |
步长(学习率) |
set_detect_params:void set_detect_params(Session **sess*)
设置测试超参数
train:void train(Session **sess*)
运行训练
detect_classification:void detect_classification(Session **sess*)
运行测试