决策树处理医疗数据

决策树处理医疗数据


一、目标

将所给医疗数据进行预处理,根据病人的四个属性——性别、年龄、病程及术中责任血管,对其与即刻面抽之间的联系进行分析。利用决策树分类方法编程实现数据的分类并绘图。

二、算法思想

在数据预处理中进行如下处理:

性别(gender):F为1,M为2。年龄(age):0-40岁为1,40-50岁为2,50-60岁为3,60-70岁为4。病程(course):0-3年为1,3-6年为2,6-9年为3,9-12年为4,12年以上为5。术中责任血管(vessel):AICA为 1,PICA为2,AICA、PICA为3,AICA、VA为4,PICA、CA为5,其余为0。是否即刻面抽:是为1,否为0。

在决策树 ID3 算法中,通过求概率进而求信息熵,通过求信息熵进而求信息增益,以信息增益为标准挑选信息增益最大的因素作为节点。在创建树的过程中进行遍历,如果类别完全相同则停止划分,如果遍历完所有特征值则选取出现次数最多的类标签。最后以字典类型输出数据。

在决策树绘图中,首先计算出树中叶子节点的个数和树的深度,然后对节点和节点间部分的绘制分别进行函数定义。最后在整个决策树的绘制过程中,利用深度和叶子节点个数对图像区域进行划分,以防止出现绘图中节点过密的情况产生。

三、核心代码

1、决策树算法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import operator
import math as mh

def entropy(dataset):
num = len(dataset)
label = {}
for vec in dataset:
labelC = vec[-1]
if labelC not in label.keys():
label[labelC] = 0
label[labelC] += 1
ent = 0.0
for key in label:
pro = float(label[key]) / num
ent -= pro * mh.log(pro, 2)
return ent

def split(dataset, temp, value):
dataAf = []
for vec in dataset:
if vec[temp] == value:
vecAf = vec[: temp]
vecAf.extend(vec[temp+1:])
dataAf.append(vecAf)
return dataAf

def choose(dataset):
allent = entropy(dataset)
numlab = len(dataset[0]) - 1
gainmax = 0.0
feature = -1
for i in range(numlab):
listAf = [example[i] for example in dataset]
listU = set(listAf)
entAf = 0.0
for value in listU:
datasetS = split(dataset, i, value)
pro = len(datasetS) / float(len(dataset))
entAf += pro * entropy(datasetS)
gain = allent - entAf
if (gain > gainmax):
gainmax = gain
feature = i
return feature

def most(listR):
listcount = {}
for vote in listR:
if vote not in listcount.keys():
listcount[vote] = 0
listcount[vote] += 1
listSort = sorted(listcount.items(), key = operator.itemgetter(1), reverse = True)
return listSort[0][0]

def create(dataset, label):
listR = [example[-1] for example in dataset]
if listR.count(listR[0]) == len(listR):
return listR[0]
if (len(dataset[0]) == 1):
return most(listR)
featureCH = choose(dataset)
featureLabel = label[featureCH]
tree = {featureLabel: {}}
del(label[featureCH])
featureValue = [example[featureCH] for example in dataset]
listU = set(featureValue)
for value in listU:
labelfin = label[:]
tree[featureLabel][value] = create(split(dataset, featureCH, value), labelfin)
return tree

2、决策树绘制

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import matplotlib.pyplot as plt

DN = dict(boxstyle="round4", fc="0.8")
LN = dict(boxstyle="round4", fc="0.8")
line = dict(arrowstyle="<-")

def number(tree):
num = 0
size1 = list(tree.keys())
str1 = size1[0]
dict2 = tree[str1]
for key in dict2.keys():
if type(dict2[key]).__name__=='dict':
num += number(dict2[key])
else:
num +=1
return num

def depth(tree):
depthmax = 0
size1 = list(tree.keys())
str1 = size1[0]
dict2 = tree[str1]
for key in dict2.keys():
if type(dict2[key]).__name__=='dict':
depthcnow = 1 + depth(dict2[key])
else:
depthcnow = 1
if depthcnow > depthmax:
depthmax = depthcnow
return depthmax

def plotnode(name, center, parents, label):
PLOT.ax1.annotate(name, xy=parents, xycoords='axes fraction',
xytext=center, textcoords='axes fraction',
va="center", ha="center", bbox=label, arrowprops=line )

def plotmid(center, parents, name):
midx = (parents[0]-center[0])/2 + center[0]
midy = (parents[1]-center[1])/2 + center[1]
PLOT.ax1.text(midx, midy, name, va="center", ha="center", rotation=30)

def plottree(tree, parents, name):
num = number(tree)
size1 = list(tree.keys())
str1 = size1[0]
center = (plottree.x + (1 + float(num))/2/plottree.num, plottree.y)
plotnode(str1, center, parents, DN)
plotmid(center, parents, name)
dict2 = tree[str1]
plottree.y = plottree.y - 1/plottree.depth
for key in dict2.keys():
if type(dict2[key]).__name__=='dict':
plottree(dict2[key],center,str(key))
else:
plottree.x = plottree.x + 1/plottree.num
plotnode(dict2[key], (plottree.x, plottree.y), center, LN)
plotmid((plottree.x, plottree.y), center, str(key))
plottree.y = plottree.y + 1/plottree.depth

def PLOT(tree):
axis = dict(xticks=[], yticks=[])
PLOT.ax1 = plt.subplot(frameon=False, **axis)
plottree.num = float(number(tree))
plottree.depth = float(depth(tree))
plottree.x = -0.5/plottree.num;
plottree.y = 1;
plottree(tree, (0.5,1), '')
plt.show()

四、结果图像

  1. 决策树

    决策树

  2. 绘制后决策树

    绘制后决策树