新增深度學習
This commit is contained in:
commit
c2516b73ac
4
.gitignore
vendored
Normal file
4
.gitignore
vendored
Normal file
@ -0,0 +1,4 @@
|
||||
.venv
|
||||
cifar-10/**/*
|
||||
cifar-10
|
||||
data
|
8
.idea/.gitignore
vendored
Normal file
8
.idea/.gitignore
vendored
Normal file
@ -0,0 +1,8 @@
|
||||
# Default ignored files
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
# Editor-based HTTP Client requests
|
||||
/httpRequests/
|
||||
# Datasource local storage ignored files
|
||||
/dataSources/
|
||||
/dataSources.local.xml
|
66
.idea/inspectionProfiles/Project_Default.xml
Normal file
66
.idea/inspectionProfiles/Project_Default.xml
Normal file
@ -0,0 +1,66 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<profile version="1.0">
|
||||
<option name="myName" value="Project Default" />
|
||||
<inspection_tool class="HttpUrlsUsage" enabled="true" level="WEAK WARNING" enabled_by_default="true">
|
||||
<option name="ignoredUrls">
|
||||
<list>
|
||||
<option value="http://0.0.0.0" />
|
||||
<option value="http://127.0.0.1" />
|
||||
<option value="http://activemq.apache.org/schema/" />
|
||||
<option value="http://cxf.apache.org/schemas/" />
|
||||
<option value="http://java.sun.com/" />
|
||||
<option value="http://javafx.com/fxml" />
|
||||
<option value="http://javafx.com/javafx/" />
|
||||
<option value="http://json-schema.org/draft" />
|
||||
<option value="http://localhost" />
|
||||
<option value="http://maven.apache.org/POM/" />
|
||||
<option value="http://maven.apache.org/xsd/" />
|
||||
<option value="http://primefaces.org/ui" />
|
||||
<option value="http://schema.cloudfoundry.org/spring/" />
|
||||
<option value="http://schemas.xmlsoap.org/" />
|
||||
<option value="http://tiles.apache.org/" />
|
||||
<option value="http://www.dda5.com" />
|
||||
<option value="http://www.ibm.com/webservices/xsd" />
|
||||
<option value="http://www.jboss.com/xml/ns/" />
|
||||
<option value="http://www.jboss.org/j2ee/schema/" />
|
||||
<option value="http://www.springframework.org/schema/" />
|
||||
<option value="http://www.springframework.org/security/tags" />
|
||||
<option value="http://www.springframework.org/tags" />
|
||||
<option value="http://www.thymeleaf.org" />
|
||||
<option value="http://www.w3.org/" />
|
||||
<option value="http://xmlns.jcp.org/" />
|
||||
</list>
|
||||
</option>
|
||||
</inspection_tool>
|
||||
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
||||
<option name="ignoredPackages">
|
||||
<value>
|
||||
<list size="22">
|
||||
<item index="0" class="java.lang.String" itemvalue="clyent" />
|
||||
<item index="1" class="java.lang.String" itemvalue="protobuf" />
|
||||
<item index="2" class="java.lang.String" itemvalue="python-lsp-jsonrpc" />
|
||||
<item index="3" class="java.lang.String" itemvalue="atomicwrites" />
|
||||
<item index="4" class="java.lang.String" itemvalue="jsonpointer" />
|
||||
<item index="5" class="java.lang.String" itemvalue="et-xmlfile" />
|
||||
<item index="6" class="java.lang.String" itemvalue="PyQtWebEngine" />
|
||||
<item index="7" class="java.lang.String" itemvalue="pyasn1-modules" />
|
||||
<item index="8" class="java.lang.String" itemvalue="fonttools" />
|
||||
<item index="9" class="java.lang.String" itemvalue="patsy" />
|
||||
<item index="10" class="java.lang.String" itemvalue="pyls-spyder" />
|
||||
<item index="11" class="java.lang.String" itemvalue="appdirs" />
|
||||
<item index="12" class="java.lang.String" itemvalue="conda-repo-cli" />
|
||||
<item index="13" class="java.lang.String" itemvalue="munkres" />
|
||||
<item index="14" class="java.lang.String" itemvalue="backports.weakref" />
|
||||
<item index="15" class="java.lang.String" itemvalue="conda-verify" />
|
||||
<item index="16" class="java.lang.String" itemvalue="PyQt5" />
|
||||
<item index="17" class="java.lang.String" itemvalue="PyDispatcher" />
|
||||
<item index="18" class="java.lang.String" itemvalue="ply" />
|
||||
<item index="19" class="java.lang.String" itemvalue="webencodings" />
|
||||
<item index="20" class="java.lang.String" itemvalue="inflection" />
|
||||
<item index="21" class="java.lang.String" itemvalue="openpyxl" />
|
||||
</list>
|
||||
</value>
|
||||
</option>
|
||||
</inspection_tool>
|
||||
</profile>
|
||||
</component>
|
6
.idea/inspectionProfiles/profiles_settings.xml
Normal file
6
.idea/inspectionProfiles/profiles_settings.xml
Normal file
@ -0,0 +1,6 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<settings>
|
||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||
<version value="1.0" />
|
||||
</settings>
|
||||
</component>
|
10
.idea/learn-pytorch.iml
Normal file
10
.idea/learn-pytorch.iml
Normal file
@ -0,0 +1,10 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$">
|
||||
<excludeFolder url="file://$MODULE_DIR$/.venv" />
|
||||
</content>
|
||||
<orderEntry type="inheritedJdk" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
</module>
|
7
.idea/misc.xml
Normal file
7
.idea/misc.xml
Normal file
@ -0,0 +1,7 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="Black">
|
||||
<option name="sdkName" value="Python 3.9 (learn-pytorch)" />
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.9 (learn-pytorch)" project-jdk-type="Python SDK" />
|
||||
</project>
|
8
.idea/modules.xml
Normal file
8
.idea/modules.xml
Normal file
@ -0,0 +1,8 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/learn-pytorch.iml" filepath="$PROJECT_DIR$/.idea/learn-pytorch.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
6
.idea/other.xml
Normal file
6
.idea/other.xml
Normal file
@ -0,0 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="PySciProjectComponent">
|
||||
<option name="PY_INTERACTIVE_PLOTS_SUGGESTED" value="true" />
|
||||
</component>
|
||||
</project>
|
6
.idea/vcs.xml
Normal file
6
.idea/vcs.xml
Normal file
@ -0,0 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
||||
</component>
|
||||
</project>
|
130
classification.py
Normal file
130
classification.py
Normal file
@ -0,0 +1,130 @@
|
||||
import torch
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
from torch import optim
|
||||
from torch.utils.data import Dataset
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
data_path = "data/"
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu")
|
||||
|
||||
transforms = transforms.Compose([
|
||||
transforms.ToTensor(), # 將圖像轉換為 Tensor
|
||||
transforms.Normalize((0.4915, 0.4823, 0.4468), (0.2470, 0.2435, 0.2616))
|
||||
# 歸一化,第一個 tuple 代表 CIFAR-10 這個資料集 RGB 三個通道的平均值,第二個 tuple 代表標準差
|
||||
])
|
||||
|
||||
train_dataset = torchvision.datasets.CIFAR10(
|
||||
root=data_path,
|
||||
train=True,
|
||||
download=True,
|
||||
transform=transforms
|
||||
)
|
||||
|
||||
test_dataset = torchvision.datasets.CIFAR10(
|
||||
root=data_path,
|
||||
train=False,
|
||||
download=True,
|
||||
transform=transforms
|
||||
)
|
||||
|
||||
label_map = {
|
||||
0: 0, # 飛機
|
||||
2: 1, # 小鳥
|
||||
}
|
||||
|
||||
class_names = ["airplane", "bird"]
|
||||
|
||||
train_dataset = [(img, label_map[label]) for img, label in train_dataset if label in [0, 2]]
|
||||
test_dataset = [(img, label_map[label]) for img, label in test_dataset if label in [0, 2]]
|
||||
|
||||
|
||||
class ModelDataset(Dataset):
|
||||
def __init__(self, dataset):
|
||||
self.dataset = dataset
|
||||
|
||||
def __getitem__(self, index):
|
||||
img, label = self.dataset[index]
|
||||
return img, label
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
|
||||
train_dataset = ModelDataset(train_dataset)
|
||||
test_dataset = ModelDataset(test_dataset)
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
|
||||
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=True)
|
||||
|
||||
# 宣告模型
|
||||
class Net(nn.Module):
|
||||
def __init__(self, n_chansl=16):
|
||||
super().__init__()
|
||||
|
||||
self.n_chansl = n_chansl
|
||||
|
||||
self.conv1 = nn.Conv2d(3, n_chansl, kernel_size=3, padding=1)
|
||||
self.conv1_batchnorm = nn.BatchNorm2d(n_chansl) # 歸一化器,輸入為通道數,輸出為相同的通道數
|
||||
self.conv2 = nn.Conv2d(n_chansl, n_chansl // 2, kernel_size=3, padding=1)
|
||||
self.conv2_batchnorm = nn.BatchNorm2d(n_chansl // 2)
|
||||
|
||||
self.fc1 = nn.Linear(8 * 8 * n_chansl // 2, 32)
|
||||
self.fc2 = nn.Linear(32, 2)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.max_pool2d(self.conv1_batchnorm(torch.relu(self.conv1(x))), kernel_size=2)
|
||||
out = F.max_pool2d(self.conv2_batchnorm(torch.relu(self.conv2(out))), kernel_size=2)
|
||||
out = out.view(-1, 8 * 8 * self.n_chansl // 2)
|
||||
out = torch.relu(self.fc1(out))
|
||||
out = self.fc2(out)
|
||||
return out
|
||||
|
||||
model = Net().to(device)
|
||||
|
||||
optimizer = optim.SGD(model.parameters(), lr=1e-2)
|
||||
loss_fn = nn.CrossEntropyLoss()
|
||||
|
||||
def train(epoch):
|
||||
global loss
|
||||
for epoch in range(epoch):
|
||||
for (image, label) in train_loader:
|
||||
image = image.to(device)
|
||||
label = label.to(device)
|
||||
outputs = model(image)
|
||||
loss = loss_fn(outputs, label)
|
||||
|
||||
l2_lambda = 0.001
|
||||
l2_norm = sum(p.pow(2.0).sum() for p in model.parameters())
|
||||
loss = loss + l2_lambda * l2_norm
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
print(f"Epoch {epoch}, Loss {loss}")
|
||||
|
||||
|
||||
train(100)
|
||||
|
||||
|
||||
def test():
|
||||
correct = 0
|
||||
total = 0
|
||||
with torch.no_grad():
|
||||
for (image, label) in test_loader:
|
||||
image = image.to(device)
|
||||
label = label.to(device)
|
||||
outputs = model(image)
|
||||
_, predicted = torch.max(outputs, dim=1)
|
||||
total += label.size(0)
|
||||
correct += (predicted == label).sum().item()
|
||||
|
||||
print(f"Accuracy: {correct / total}")
|
||||
|
||||
|
||||
test()
|
||||
|
||||
# torch.save(model.state_dict(), "model/model.pt")
|
97
linear.py
Normal file
97
linear.py
Normal file
@ -0,0 +1,97 @@
|
||||
# 1. 收集數據
|
||||
import torch
|
||||
|
||||
t_c = [0.5, 14.0, 15.0, 28.0, 11.0, 8.0, 3.0, -4.0, 6.0, 13.0, 21.0]
|
||||
t_u = [35.7, 55.9, 58.2, 81.9, 56.3, 48.9, 33.9, 21.8, 48.4, 60.4, 68.4]
|
||||
|
||||
t_c = torch.tensor(t_c).unsqueeze(1) # 升維,作用是將資料轉化成單個樣本
|
||||
t_u = torch.tensor(t_u).unsqueeze(1)
|
||||
|
||||
n_samples = t_u.shape[0] # 樣本量
|
||||
n_test = int(n_samples * 0.2) # 測試集數量
|
||||
|
||||
shuffled_samples = torch.randperm(n_samples) # 隨機化樣本
|
||||
|
||||
train_indics = shuffled_samples[:-n_test] # 訓練集索引
|
||||
test_indics = shuffled_samples[-n_test:] # 測試集索引
|
||||
|
||||
# 訓練集
|
||||
t_u_train = t_u[train_indics]
|
||||
t_c_train = t_c[train_indics]
|
||||
# 測試集
|
||||
t_u_test = t_u[test_indics]
|
||||
t_c_test = t_c[test_indics]
|
||||
|
||||
print(t_u_train)
|
||||
print(t_c_train)
|
||||
print(t_u_test)
|
||||
print(t_c_test)
|
||||
|
||||
# 歸一化
|
||||
t_u_mean = t_u_train.mean()
|
||||
t_u_std = t_u_train.std()
|
||||
|
||||
t_u_train_norm = (t_u_train - t_u_mean) / t_u_std
|
||||
t_u_test_norm = (t_u_test - t_u_mean) / t_u_std
|
||||
|
||||
t_c_mean = t_c_train.mean()
|
||||
t_c_std = t_c_train.std()
|
||||
|
||||
t_c_train_norm = (t_c_train - t_c_mean) / t_c_std
|
||||
t_c_test_norm = (t_c_test - t_c_mean) / t_c_std
|
||||
|
||||
# 2. 搭建模型
|
||||
# import torch.nn as nn
|
||||
#
|
||||
# linear_model = nn.Linear(in_features = 1, out_features = 1) # in_features 表示輸入神經元的個數,out_features 表示輸出神經元的個數
|
||||
|
||||
import torch.nn as nn
|
||||
from collections import OrderedDict
|
||||
|
||||
# 構建一個多層神經網路,隱藏層有 13 個神經元,輸出層有 1 個神經元
|
||||
neural_network = nn.Sequential(OrderedDict([
|
||||
('hidden', nn.Linear(1, 13)), # 隱藏層
|
||||
('hidden_activation', nn.Tanh()), # 隱藏層激勵函數
|
||||
('output', nn.Linear(13, 1)) # 輸出層
|
||||
]))
|
||||
|
||||
# 3. 宣告優化器和損失函數
|
||||
optimizer = torch.optim.SGD(
|
||||
neural_network.parameters(),
|
||||
lr=1e-2
|
||||
)
|
||||
|
||||
def loss_fn(t_p, t_c):
|
||||
return ((t_p - t_c) ** 2).mean()
|
||||
|
||||
# 4. 宣告 train loop
|
||||
def train_loop(n_epochs, optimizer, model, loss_fun, t_u_train, t_u_test, t_c_train, t_c_test):
|
||||
for epoch in range(1, n_epochs + 1):
|
||||
t_p_train = model(t_u_train)
|
||||
loss_train = loss_fun(t_p_train, t_c_train)
|
||||
|
||||
t_p_test = model(t_u_test)
|
||||
loss_test = loss_fun(t_p_test, t_c_test)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss_train.backward()
|
||||
optimizer.step()
|
||||
|
||||
if epoch == 1 or epoch % 10 == 0:
|
||||
print(f'Epoch {epoch}: Training Loss: {loss_train:.4f}')
|
||||
print(f'Test Loss: {loss_test:.4f}')
|
||||
|
||||
# 5. 開始訓練
|
||||
train_loop(
|
||||
n_epochs = 500,
|
||||
optimizer = optimizer,
|
||||
model = neural_network,
|
||||
loss_fun = loss_fn,
|
||||
t_u_train = t_u_train_norm,
|
||||
t_u_test = t_u_test_norm,
|
||||
t_c_train = t_c_train_norm,
|
||||
t_c_test = t_c_test_norm
|
||||
)
|
||||
|
||||
print('output', neural_network(t_u_test_norm) * t_c_std + t_c_mean)
|
||||
print('val', t_c_test)
|
Loading…
Reference in New Issue
Block a user