1. 데이터 확인
데이터 불러오기
- 판다스를 이용해 csv파일 읽어오기
traindata = pd.read_csv(datasetFolderPath+trainsetName) traindata.dtypes
- 널데이터 확인 및 제거데이터 분포 확인하기
traindata.isnull().sum() traindata.dropna(inplace=True)
- target과 다른 정보들간에 관계가 있는지 확인하기
sns.histplot(x="species",y="island",data=traindata)
sns.histplot(x="species",y="culmen_length_mm",data=traindata)
sns.histplot(x="species",y="culmen_depth_mm",data=traindata)
sns.histplot(x="species",y="flipper_length_mm",data=traindata)
sns.histplot(x="species",y="body_mass_g",data=traindata)
sns.histplot(x="species",y="sex",data=traindata)
- 연관 없는 데이터 드랍
traindata.drop("sex",inplace=True,axis=1)
2. 데이터 전처리
def makedata(path,Train=True) :
x = pd.read_csv(path)
x.dropna(inplace=True)
x.drop('sex',inplace=True,axis=1)
for var in ('species','island') :
_, x[var] = np.unique(x[var],return_inverse=True)
label = np.ravel(x['species'])
x.drop(['species'],inplace=True,axis=1)
x = (x-np.mean(x,axis=0))/np.std(x,axis=0)
x = torch.from_numpy(x.to_numpy(np.float32))
label = torch.from_numpy(label)
y = F.one_hot(label).float()
print(y.shape)
return x, y
3. 학습
train_x,train_y = makedata(datasetFolderPath+traindsetName)
test_x,test_y = makedata(datasetFolderPath+testsetName,Train=False)
model = nn.Linear(5,3)
optimizer = torch.optim.SGD(model.parameters(),lr=0.5)
lossfn = nn.CrossEntropyLoss()
epochs =10
for epoch in range(epochs) :
pred = model(train_x)
cost = lossfn(pred,train_y)
optimizer.zero_grad()
cost.backward()
optimizer.step()
rint(f'epoch : {epoch} | cost : {cost.item()}')
print("Done")
4. 테스트
model.eval()
test_loss, correct = 0, 0
with torch.no_grad() :
predtion = model(test_x)
test_loss=lossfn(predtion,test_y).item()
correct += (predtion.argmax(1) == test_y.argmax(1)).type(torch.float).sum().item()
correct /= len(test_y)
print(f'testset Loss : {test_loss:>.3f},Accuracy : {correct*100:>.3f}%')
https://www.wenyanet.com/opensource/ko/611377af0524566054290c16.html
'AI > MachineLearning' 카테고리의 다른 글
[기계학습] 1. 기계학습 개요 및 용어 (0) | 2022.03.22 |
---|