Bò sát: Một thuật toán siêu học có thể mở rộng

Chúng tôi đã phát triển một thuật toán siêu học đơn giản có tên là Reptile, thuật toán này hoạt động bằng cách lấy mẫu lặp lại một tác vụ, thực hiện giảm dần độ dốc ngẫu nhiên trên tác vụ đó và cập nhật các tham số ban đầu theo các tham số cuối cùng đã học được trên tác vụ đó. Reptile là ứng dụng của thuật toán Giảm dần ngắn nhất vào bối cảnh siêu học và về mặt toán học tương tự như MAML bậc nhất (là một phiên bản của thuật toán MAML nổi tiếng) chỉ cần truy cập hộp đen vào trình tối ưu hóa như SGD hoặc Adam, với hiệu suất và hiệu suất tính toán tương tự.

Học siêu dữ liệu là quá trình học cách học. Một thuật toán học siêu dữ liệu tiếp nhận một phân phối các nhiệm vụ, trong đó mỗi nhiệm vụ là một bài toán học tập, và nó tạo ra một người học nhanh—một người học có thể khái quát hóa từ một số lượng nhỏ các ví dụ. Một bài toán học siêu dữ liệu được nghiên cứu kỹ lưỡng là phân loại few-shot, trong đó mỗi nhiệm vụ là một bài toán phân loại mà người học chỉ nhìn thấy 1–5 ví dụ đầu vào-đầu ra từ mỗi lớp, và sau đó nó phải phân loại các đầu vào mới. Dưới đây, bạn có thể dùng thử bản demo tương tác của chúng tôi về phân loại 1-shot, sử dụng Reptile.

Bò sát hoạt động như thế nào

Giống như MAML, Reptile tìm kiếm một khởi tạo cho các tham số của mạng nơ-ron, sao cho mạng có thể được tinh chỉnh bằng một lượng nhỏ dữ liệu từ một tác vụ mới. Nhưng trong khi MAML mở rộng và phân biệt thông qua đồ thị tính toán của thuật toán giảm dần độ dốc, Reptile chỉ thực hiện  giảm dần độ dốc ngẫu nhiên (SGD)(mở trong cửa sổ mới) trên mỗi tác vụ theo cách chuẩn—nó không mở đồ thị tính toán hoặc tính toán bất kỳ đạo hàm bậc hai nào. Điều này khiến Reptile tốn ít tính toán và bộ nhớ hơn MAML. Mã giả như sau:

Khởi tạoFF, vector tham số ban đầu cho phép lặp1,2,3,…1 ,2 ,3 ,… LÀM Lấy mẫu ngẫu nhiên một nhiệm vụ $T$ Trình diễntôi>1tôi>1các bước của SGD trên nhiệm vụ $T$, bắt đầu bằng các tham sốFF, dẫn đến các tham sốTRONGTRONG Cập nhật:F←F+ϵ(TRONG−F)F←F+ϵ ( W−F )kết thúc cho
ReturnFF 

Như một giải pháp thay thế cho bước cuối cùng, chúng ta có thể xử lý F−TRONGF−TRONGnhư một gradient và đưa nó vào một trình tối ưu hóa tinh vi hơn như  Adam.

Lúc đầu thật ngạc nhiên khi phương pháp này lại có hiệu quả. Nếu tôi=1tôi=1, thuật toán này sẽ tương ứng với “huấn luyện chung”—thực hiện SGD trên hỗn hợp của tất cả các tác vụ. Trong khi huấn luyện chung có thể học được một khởi tạo hữu ích trong một số trường hợp, nó học được rất ít khi không thể học zero-shot (ví dụ khi các nhãn đầu ra được hoán đổi ngẫu nhiên). Reptile yêu cầu tôi>1tôi>1, trong đó bản cập nhật phụ thuộc vào các đạo hàm bậc cao của hàm mất mát; như chúng tôi trình bày trong bài báo, điều này hoạt động rất khác so với tôi=1tôi=1 (đào tạo chung).

Xem thêm: mua tài khoản ChatGPT Plus chính hãng giá rẻ

Để phân tích lý do tại sao Reptile hoạt động, chúng tôi ước tính bản cập nhật bằng cách sử dụng  chuỗi Taylor(mở trong cửa sổ mới). Chúng tôi chỉ ra rằng bản cập nhật Reptile tối đa hóa tích bên trong giữa các gradient của các minibatch khác nhau từ cùng một tác vụ, tương ứng với sự khái quát hóa được cải thiện. Phát hiện này có thể có ý nghĩa bên ngoài bối cảnh học siêu dữ liệu để giải thích các đặc tính khái quát hóa của SGD. Phân tích của chúng tôi cho thấy Reptile và MAML thực hiện một bản cập nhật rất giống nhau, bao gồm hai thuật ngữ giống nhau với các trọng số khác nhau.

Trong các thí nghiệm của chúng tôi, chúng tôi chứng minh rằng Reptile và MAML mang lại hiệu suất tương tự trên Omniglot(mở trong cửa sổ mới) và  Mini-ImageNet(mở trong cửa sổ mới) chuẩn mực cho phân loại ít ảnh. Bò sát cũng hội tụ về giải pháp nhanh hơn vì bản cập nhật có độ biến thiên thấp hơn.

Phân tích của chúng tôi về Reptile cho thấy rất nhiều thuật toán khác nhau mà chúng ta có thể có được bằng cách sử dụng các kết hợp khác nhau của các gradient SGD. Trong hình bên dưới, giả sử rằng chúng ta thực hiện k bước SGD trên mỗi tác vụ bằng cách sử dụng các minibatch khác nhau, tạo ra các gradient g1,g2,…,gtôi g1​,g2​,…,gtôi​Hình bên dưới hiển thị đường cong học tập trên Omniglot thu được bằng cách sử dụng từng tổng làm siêu gradient. g2 g2​​ tương ứng với MAML bậc nhất, một thuật toán được đề xuất trong bài báo MAML gốc. Bao gồm nhiều gradient hơn sẽ giúp học nhanh hơn, do giảm phương sai. Lưu ý rằng chỉ cần sử dụng g1 g1​(tương ứng với tôi=1tôi=1) không mang lại tiến triển như dự đoán cho nhiệm vụ này vì hiệu suất không cần bắn không thể được cải thiện.

Triển khai

Việc triển khai Reptile của chúng tôi có  sẵn trên GitHub. Nó sử dụng TensorFlow cho các phép tính liên quan và bao gồm mã để sao chép các thí nghiệm trên Omniglot và Mini-ImageNet. Chúng tôi cũng đang phát hành  một triển khai JavaScript nhỏ hơn tinh chỉnh mô hình được đào tạo trước bằng TensorFlow—chúng tôi đã sử dụng điều này để tạo bản demo ở trên.

Cuối cùng, đây là một ví dụ tối thiểu về hồi quy ít lần, dự đoán một sóng sin ngẫu nhiên từ 10 (x,Và)( ,Và )cặp. Cái này sử dụng PyTorch và phù hợp với ý chính:



importnumpyasnp

importtorch

, fromtorchimportnnautogradasag

. importmatplotlibpyplotasplt

fromcopyimportdeepcopy



seed=0

plot=True

innerstepsize=0.02# stepsize in inner SGD

innerepochs=1# number of epochs of each inner SGD

outerstepsize0=0.1# stepsize of outer optimization, i.e., meta-optimization

niterations=30000# number of outer updates; each iteration we sample one task and update on it



..()rng=nprandomRandomStateseed

.()torchmanual_seedseed



# Define task distribution

.(, , )[:,] x_all=nplinspace-5550None# All of the x points

ntrain=10# Size of training minibatches

():defgen_task

"Generate classification problem"

.(, .)phase=rnguniformlow=0high=2*nppi

.(, )ampl=rnguniform0.15

: .( ) f_randomsine=lambdaxnpsinx+phase*ampl

returnf_randomsine



# Define model. Reptile paper uses ReLU, but Tanh gives slightly better results

.(model=nnSequential

.(, ),nnLinear164

.(),nnTanh

.(, ),nnLinear6464

.(),nnTanh

.(, ),nnLinear641

)



():deftotorchx

.(.())returnagVariabletorchTensorx



(, ):deftrain_on_batchxy

()x=totorchx

()y=totorchy

.()modelzero_grad

()ypred=modelx

( ).().()loss=ypred-ypow2mean

.()lossbackward

.():forparaminmodelparameters

. ..paramdata-=innerstepsize*paramgraddata



():defpredictx

()x=totorchx

()..()returnmodelxdatanumpy



# Choose a fixed task and minibatch for visualization

()f_plot=gen_task

[.((), )]xtrain_plot=x_allrngchoicelenx_allsize=ntrain



# Reptile training loop

():foriterationinrangeniterations

(.())weights_before=deepcopymodelstate_dict

# Generate task

()f=gen_task

()y_all=fx_all

# Do SGD on this task

.(())inds=rngpermutationlenx_all

():for_inrangeinnerepochs

(, (), ):forstartinrange0lenx_allntrain

[:]mbinds=indsstartstart+ntrain

([], [])train_on_batchx_allmbindsy_allmbinds

# Interpolate between current weights and trained weights from this task

# I.e. (weights_before - weights_after) is the meta-gradient

.()weights_after=modelstate_dict

( ) outerstepsize=outerstepsize0*1-iteration/niterations# linear schedule

.({ :modelload_state_dictname

[] ([] []) weights_beforename+weights_aftername-weights_beforename*outerstepsize

})fornameinweights_before



# Periodically plot the results on a particular task and minibatch

() :ifplotanditeration==0oriteration+1%1000==0

.()pltcla

f=f_plot

(.()) weights_before=deepcopymodelstate_dict# save snapshot before evaluation

.(, (), , (,,))pltplotx_allpredictx_alllabel="pred after 0"color=001

():forinneriterinrange32

(, ())train_on_batchxtrain_plotfxtrain_plot

() :ifinneriter+1%8==0

() frac=inneriter+1/32

.(, (), (), (, , ))pltplotx_allpredictx_alllabel="pred after %i"%inneriter+1color=frac01-frac

.(, (), , (,,))pltplotx_allfx_alllabel="true"color=010

.(() ()).()lossval=npsquarepredictx_all-fx_allmean

.(, (), , , )pltplotxtrain_plotfxtrain_plot"x"label="train"color="k"

.(,)pltylim-44

.()pltlegendloc="lower right"

.()pltpause0.01

.() modelload_state_dictweights_before# restore from snapshot

()printf"-----------------------------"

()printf"iteration "{iteration+1}

() printf"loss on plotted curve ":.3f{lossval}# would be better to average loss over a set of examples, but this is optimized for brevity

 

Một số người đã chỉ ra với chúng tôi rằng MAML bậc nhất và Reptile có liên quan chặt chẽ hơn MAML và Reptile. Các thuật toán này có những góc nhìn khác nhau về vấn đề này, nhưng cuối cùng lại tính toán các bản cập nhật tương tự—và cụ thể là, đóng góp của Reptile dựa trên lịch sử của cả Shortest Descent và tránh đạo hàm bậc hai  trong siêu dữ liệu(mở trong cửa sổ mới)- học hỏi. Chúng tôi đã cập nhật đoạn văn đầu tiên để phản ánh điều này.

Hot Deal

Họ tên (*)

Số điện thoại (*)

Email (*)

Dịch vụ

Đăng ký để nhận bản tin mới nhất !